php-ml/src/Metric/ClassificationReport.php

233 lines
5.9 KiB
PHP
Raw Normal View History

2016-07-19 19:58:59 +00:00
<?php
2016-07-19 19:59:23 +00:00
2016-11-20 21:53:17 +00:00
declare(strict_types=1);
2016-07-19 19:58:59 +00:00
namespace Phpml\Metric;
use Phpml\Exception\InvalidArgumentException;
2016-07-19 19:58:59 +00:00
class ClassificationReport
{
public const MICRO_AVERAGE = 1;
public const MACRO_AVERAGE = 2;
public const WEIGHTED_AVERAGE = 3;
2016-07-19 19:58:59 +00:00
/**
* @var array
*/
private $truePositive = [];
2016-07-19 19:58:59 +00:00
/**
* @var array
*/
private $falsePositive = [];
2016-07-19 19:58:59 +00:00
/**
* @var array
*/
private $falseNegative = [];
2016-07-19 19:58:59 +00:00
/**
* @var array
*/
private $support = [];
/**
* @var array
*/
private $precision = [];
2016-07-19 19:58:59 +00:00
/**
* @var array
*/
private $recall = [];
2016-07-19 19:58:59 +00:00
/**
* @var array
*/
private $f1score = [];
2016-07-19 19:58:59 +00:00
/**
* @var array
*/
private $average = [];
public function __construct(array $actualLabels, array $predictedLabels, int $average = self::MACRO_AVERAGE)
{
$averagingMethods = range(self::MICRO_AVERAGE, self::WEIGHTED_AVERAGE);
if (!in_array($average, $averagingMethods, true)) {
throw new InvalidArgumentException('Averaging method must be MICRO_AVERAGE, MACRO_AVERAGE or WEIGHTED_AVERAGE');
2016-07-19 19:58:59 +00:00
}
$this->aggregateClassificationResults($actualLabels, $predictedLabels);
$this->computeMetrics();
$this->computeAverage($average);
2016-07-19 19:58:59 +00:00
}
public function getPrecision(): array
2016-07-19 19:58:59 +00:00
{
return $this->precision;
}
public function getRecall(): array
2016-07-19 19:58:59 +00:00
{
return $this->recall;
}
public function getF1score(): array
2016-07-19 19:58:59 +00:00
{
return $this->f1score;
}
public function getSupport(): array
2016-07-19 19:58:59 +00:00
{
return $this->support;
}
public function getAverage(): array
2016-07-19 19:58:59 +00:00
{
return $this->average;
}
private function aggregateClassificationResults(array $actualLabels, array $predictedLabels): void
2016-07-19 19:58:59 +00:00
{
$truePositive = $falsePositive = $falseNegative = $support = self::getLabelIndexedArray($actualLabels, $predictedLabels);
foreach ($actualLabels as $index => $actual) {
$predicted = $predictedLabels[$index];
++$support[$actual];
if ($actual === $predicted) {
++$truePositive[$actual];
} else {
++$falsePositive[$predicted];
++$falseNegative[$actual];
}
}
$this->truePositive = $truePositive;
$this->falsePositive = $falsePositive;
$this->falseNegative = $falseNegative;
$this->support = $support;
}
private function computeMetrics(): void
{
foreach ($this->truePositive as $label => $tp) {
$this->precision[$label] = $this->computePrecision($tp, $this->falsePositive[$label]);
$this->recall[$label] = $this->computeRecall($tp, $this->falseNegative[$label]);
2016-07-19 19:59:23 +00:00
$this->f1score[$label] = $this->computeF1Score((float) $this->precision[$label], (float) $this->recall[$label]);
2016-07-19 19:58:59 +00:00
}
}
private function computeAverage(int $average): void
{
switch ($average) {
case self::MICRO_AVERAGE:
$this->computeMicroAverage();
return;
case self::MACRO_AVERAGE:
$this->computeMacroAverage();
return;
case self::WEIGHTED_AVERAGE:
$this->computeWeightedAverage();
return;
}
}
private function computeMicroAverage(): void
{
2018-10-28 06:44:52 +00:00
$truePositive = (int) array_sum($this->truePositive);
$falsePositive = (int) array_sum($this->falsePositive);
$falseNegative = (int) array_sum($this->falseNegative);
$precision = $this->computePrecision($truePositive, $falsePositive);
$recall = $this->computeRecall($truePositive, $falseNegative);
$f1score = $this->computeF1Score((float) $precision, (float) $recall);
$this->average = compact('precision', 'recall', 'f1score');
}
private function computeMacroAverage(): void
2016-07-19 19:58:59 +00:00
{
foreach (['precision', 'recall', 'f1score'] as $metric) {
$values = $this->{$metric};
if (count($values) == 0) {
$this->average[$metric] = 0.0;
continue;
}
2016-07-19 19:58:59 +00:00
$this->average[$metric] = array_sum($values) / count($values);
}
}
private function computeWeightedAverage(): void
{
foreach (['precision', 'recall', 'f1score'] as $metric) {
$values = $this->{$metric};
if (count($values) == 0) {
$this->average[$metric] = 0.0;
continue;
}
$sum = 0;
foreach ($values as $i => $value) {
$sum += $value * $this->support[$i];
}
$this->average[$metric] = $sum / array_sum($this->support);
}
}
/**
* @return float|string
*/
private function computePrecision(int $truePositive, int $falsePositive)
{
$divider = $truePositive + $falsePositive;
if ($divider == 0) {
return 0.0;
}
return $truePositive / $divider;
}
/**
* @return float|string
*/
private function computeRecall(int $truePositive, int $falseNegative)
{
$divider = $truePositive + $falseNegative;
if ($divider == 0) {
return 0.0;
}
return $truePositive / $divider;
}
private function computeF1Score(float $precision, float $recall): float
2016-07-19 19:58:59 +00:00
{
$divider = $precision + $recall;
if ($divider == 0) {
2016-07-19 19:58:59 +00:00
return 0.0;
}
return 2.0 * (($precision * $recall) / $divider);
2016-07-19 19:58:59 +00:00
}
private static function getLabelIndexedArray(array $actualLabels, array $predictedLabels): array
2016-07-19 19:58:59 +00:00
{
$labels = array_values(array_unique(array_merge($actualLabels, $predictedLabels)));
2016-07-19 19:58:59 +00:00
sort($labels);
2018-10-28 06:44:52 +00:00
return (array) array_combine($labels, array_fill(0, count($labels), 0));
2016-07-19 19:58:59 +00:00
}
}