2016-07-19 19:58:59 +00:00
|
|
|
<?php
|
2016-07-19 19:59:23 +00:00
|
|
|
|
2016-11-20 21:53:17 +00:00
|
|
|
declare(strict_types=1);
|
2016-07-19 19:58:59 +00:00
|
|
|
|
|
|
|
namespace Phpml\Metric;
|
|
|
|
|
2018-01-29 17:06:21 +00:00
|
|
|
use Phpml\Exception\InvalidArgumentException;
|
|
|
|
|
2016-07-19 19:58:59 +00:00
|
|
|
class ClassificationReport
|
|
|
|
{
|
2018-01-29 17:06:21 +00:00
|
|
|
public const MICRO_AVERAGE = 1;
|
|
|
|
|
|
|
|
public const MACRO_AVERAGE = 2;
|
|
|
|
|
|
|
|
public const WEIGHTED_AVERAGE = 3;
|
|
|
|
|
2016-07-19 19:58:59 +00:00
|
|
|
/**
|
|
|
|
* @var array
|
|
|
|
*/
|
2018-01-29 17:06:21 +00:00
|
|
|
private $truePositive = [];
|
2016-07-19 19:58:59 +00:00
|
|
|
|
|
|
|
/**
|
|
|
|
* @var array
|
|
|
|
*/
|
2018-01-29 17:06:21 +00:00
|
|
|
private $falsePositive = [];
|
2016-07-19 19:58:59 +00:00
|
|
|
|
|
|
|
/**
|
|
|
|
* @var array
|
|
|
|
*/
|
2018-01-29 17:06:21 +00:00
|
|
|
private $falseNegative = [];
|
2016-07-19 19:58:59 +00:00
|
|
|
|
|
|
|
/**
|
|
|
|
* @var array
|
|
|
|
*/
|
|
|
|
private $support = [];
|
|
|
|
|
|
|
|
/**
|
|
|
|
* @var array
|
|
|
|
*/
|
2018-01-29 17:06:21 +00:00
|
|
|
private $precision = [];
|
2016-07-19 19:58:59 +00:00
|
|
|
|
2018-01-29 17:06:21 +00:00
|
|
|
/**
|
|
|
|
* @var array
|
|
|
|
*/
|
|
|
|
private $recall = [];
|
2016-07-19 19:58:59 +00:00
|
|
|
|
2018-01-29 17:06:21 +00:00
|
|
|
/**
|
|
|
|
* @var array
|
|
|
|
*/
|
|
|
|
private $f1score = [];
|
2016-07-19 19:58:59 +00:00
|
|
|
|
2018-01-29 17:06:21 +00:00
|
|
|
/**
|
|
|
|
* @var array
|
|
|
|
*/
|
|
|
|
private $average = [];
|
|
|
|
|
|
|
|
public function __construct(array $actualLabels, array $predictedLabels, int $average = self::MACRO_AVERAGE)
|
|
|
|
{
|
|
|
|
$averagingMethods = range(self::MICRO_AVERAGE, self::WEIGHTED_AVERAGE);
|
2018-02-16 06:25:24 +00:00
|
|
|
if (!in_array($average, $averagingMethods, true)) {
|
2018-01-29 17:06:21 +00:00
|
|
|
throw new InvalidArgumentException('Averaging method must be MICRO_AVERAGE, MACRO_AVERAGE or WEIGHTED_AVERAGE');
|
2016-07-19 19:58:59 +00:00
|
|
|
}
|
|
|
|
|
2018-01-29 17:06:21 +00:00
|
|
|
$this->aggregateClassificationResults($actualLabels, $predictedLabels);
|
|
|
|
$this->computeMetrics();
|
|
|
|
$this->computeAverage($average);
|
2016-07-19 19:58:59 +00:00
|
|
|
}
|
|
|
|
|
2017-11-22 21:16:10 +00:00
|
|
|
public function getPrecision(): array
|
2016-07-19 19:58:59 +00:00
|
|
|
{
|
|
|
|
return $this->precision;
|
|
|
|
}
|
|
|
|
|
2017-11-22 21:16:10 +00:00
|
|
|
public function getRecall(): array
|
2016-07-19 19:58:59 +00:00
|
|
|
{
|
|
|
|
return $this->recall;
|
|
|
|
}
|
|
|
|
|
2017-11-22 21:16:10 +00:00
|
|
|
public function getF1score(): array
|
2016-07-19 19:58:59 +00:00
|
|
|
{
|
|
|
|
return $this->f1score;
|
|
|
|
}
|
|
|
|
|
2017-11-22 21:16:10 +00:00
|
|
|
public function getSupport(): array
|
2016-07-19 19:58:59 +00:00
|
|
|
{
|
|
|
|
return $this->support;
|
|
|
|
}
|
|
|
|
|
2017-11-22 21:16:10 +00:00
|
|
|
public function getAverage(): array
|
2016-07-19 19:58:59 +00:00
|
|
|
{
|
|
|
|
return $this->average;
|
|
|
|
}
|
|
|
|
|
2018-01-29 17:06:21 +00:00
|
|
|
private function aggregateClassificationResults(array $actualLabels, array $predictedLabels): void
|
2016-07-19 19:58:59 +00:00
|
|
|
{
|
2018-01-29 17:06:21 +00:00
|
|
|
$truePositive = $falsePositive = $falseNegative = $support = self::getLabelIndexedArray($actualLabels, $predictedLabels);
|
|
|
|
|
|
|
|
foreach ($actualLabels as $index => $actual) {
|
|
|
|
$predicted = $predictedLabels[$index];
|
|
|
|
++$support[$actual];
|
|
|
|
|
|
|
|
if ($actual === $predicted) {
|
|
|
|
++$truePositive[$actual];
|
|
|
|
} else {
|
|
|
|
++$falsePositive[$predicted];
|
|
|
|
++$falseNegative[$actual];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
$this->truePositive = $truePositive;
|
|
|
|
$this->falsePositive = $falsePositive;
|
|
|
|
$this->falseNegative = $falseNegative;
|
|
|
|
$this->support = $support;
|
|
|
|
}
|
|
|
|
|
|
|
|
private function computeMetrics(): void
|
|
|
|
{
|
|
|
|
foreach ($this->truePositive as $label => $tp) {
|
|
|
|
$this->precision[$label] = $this->computePrecision($tp, $this->falsePositive[$label]);
|
|
|
|
$this->recall[$label] = $this->computeRecall($tp, $this->falseNegative[$label]);
|
2016-07-19 19:59:23 +00:00
|
|
|
$this->f1score[$label] = $this->computeF1Score((float) $this->precision[$label], (float) $this->recall[$label]);
|
2016-07-19 19:58:59 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2018-01-29 17:06:21 +00:00
|
|
|
private function computeAverage(int $average): void
|
|
|
|
{
|
|
|
|
switch ($average) {
|
|
|
|
case self::MICRO_AVERAGE:
|
|
|
|
$this->computeMicroAverage();
|
|
|
|
|
|
|
|
return;
|
|
|
|
case self::MACRO_AVERAGE:
|
|
|
|
$this->computeMacroAverage();
|
|
|
|
|
|
|
|
return;
|
|
|
|
case self::WEIGHTED_AVERAGE:
|
|
|
|
$this->computeWeightedAverage();
|
|
|
|
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
private function computeMicroAverage(): void
|
|
|
|
{
|
2018-10-28 06:44:52 +00:00
|
|
|
$truePositive = (int) array_sum($this->truePositive);
|
|
|
|
$falsePositive = (int) array_sum($this->falsePositive);
|
|
|
|
$falseNegative = (int) array_sum($this->falseNegative);
|
2018-01-29 17:06:21 +00:00
|
|
|
|
|
|
|
$precision = $this->computePrecision($truePositive, $falsePositive);
|
|
|
|
$recall = $this->computeRecall($truePositive, $falseNegative);
|
|
|
|
$f1score = $this->computeF1Score((float) $precision, (float) $recall);
|
|
|
|
|
|
|
|
$this->average = compact('precision', 'recall', 'f1score');
|
|
|
|
}
|
|
|
|
|
|
|
|
private function computeMacroAverage(): void
|
2016-07-19 19:58:59 +00:00
|
|
|
{
|
|
|
|
foreach (['precision', 'recall', 'f1score'] as $metric) {
|
2018-01-29 17:06:21 +00:00
|
|
|
$values = $this->{$metric};
|
|
|
|
if (count($values) == 0) {
|
2016-11-20 21:49:26 +00:00
|
|
|
$this->average[$metric] = 0.0;
|
2018-01-06 12:09:33 +00:00
|
|
|
|
2016-11-20 21:49:26 +00:00
|
|
|
continue;
|
|
|
|
}
|
2017-11-22 21:16:10 +00:00
|
|
|
|
2016-07-19 19:58:59 +00:00
|
|
|
$this->average[$metric] = array_sum($values) / count($values);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2018-01-29 17:06:21 +00:00
|
|
|
private function computeWeightedAverage(): void
|
|
|
|
{
|
|
|
|
foreach (['precision', 'recall', 'f1score'] as $metric) {
|
|
|
|
$values = $this->{$metric};
|
|
|
|
if (count($values) == 0) {
|
|
|
|
$this->average[$metric] = 0.0;
|
|
|
|
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
|
|
|
|
$sum = 0;
|
|
|
|
foreach ($values as $i => $value) {
|
|
|
|
$sum += $value * $this->support[$i];
|
|
|
|
}
|
|
|
|
|
|
|
|
$this->average[$metric] = $sum / array_sum($this->support);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2016-09-27 18:07:21 +00:00
|
|
|
/**
|
|
|
|
* @return float|string
|
|
|
|
*/
|
|
|
|
private function computePrecision(int $truePositive, int $falsePositive)
|
|
|
|
{
|
2017-11-22 21:16:10 +00:00
|
|
|
$divider = $truePositive + $falsePositive;
|
|
|
|
if ($divider == 0) {
|
2016-09-27 18:07:21 +00:00
|
|
|
return 0.0;
|
|
|
|
}
|
|
|
|
|
|
|
|
return $truePositive / $divider;
|
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
* @return float|string
|
|
|
|
*/
|
|
|
|
private function computeRecall(int $truePositive, int $falseNegative)
|
|
|
|
{
|
2017-11-22 21:16:10 +00:00
|
|
|
$divider = $truePositive + $falseNegative;
|
|
|
|
if ($divider == 0) {
|
2016-09-27 18:07:21 +00:00
|
|
|
return 0.0;
|
|
|
|
}
|
|
|
|
|
|
|
|
return $truePositive / $divider;
|
|
|
|
}
|
|
|
|
|
2017-11-22 21:16:10 +00:00
|
|
|
private function computeF1Score(float $precision, float $recall): float
|
2016-07-19 19:58:59 +00:00
|
|
|
{
|
2017-11-22 21:16:10 +00:00
|
|
|
$divider = $precision + $recall;
|
|
|
|
if ($divider == 0) {
|
2016-07-19 19:58:59 +00:00
|
|
|
return 0.0;
|
|
|
|
}
|
|
|
|
|
2016-12-07 23:45:42 +00:00
|
|
|
return 2.0 * (($precision * $recall) / $divider);
|
2016-07-19 19:58:59 +00:00
|
|
|
}
|
|
|
|
|
2017-11-22 21:16:10 +00:00
|
|
|
private static function getLabelIndexedArray(array $actualLabels, array $predictedLabels): array
|
2016-07-19 19:58:59 +00:00
|
|
|
{
|
2016-09-27 18:07:21 +00:00
|
|
|
$labels = array_values(array_unique(array_merge($actualLabels, $predictedLabels)));
|
2016-07-19 19:58:59 +00:00
|
|
|
sort($labels);
|
|
|
|
|
2018-10-28 06:44:52 +00:00
|
|
|
return (array) array_combine($labels, array_fill(0, count($labels), 0));
|
2016-07-19 19:58:59 +00:00
|
|
|
}
|
|
|
|
}
|