diff --git a/docs/machine-learning/metric/classification-report.md b/docs/machine-learning/metric/classification-report.md index 53490b2..a0a6acc 100644 --- a/docs/machine-learning/metric/classification-report.md +++ b/docs/machine-learning/metric/classification-report.md @@ -18,6 +18,13 @@ $predictedLabels = ['cat', 'cat', 'bird', 'bird', 'ant']; $report = new ClassificationReport($actualLabels, $predictedLabels); ``` +Optionally you can provide the following parameter: + +* $average - (int) averaging method for multi-class classification + * `ClassificationReport::MICRO_AVERAGE` = 1 + * `ClassificationReport::MACRO_AVERAGE` = 2 (default) + * `ClassificationReport::WEIGHTED_AVERAGE` = 3 + ### Metrics After creating the report you can draw its individual metrics: diff --git a/src/Phpml/Metric/ClassificationReport.php b/src/Phpml/Metric/ClassificationReport.php index 0c3198f..755d78b 100644 --- a/src/Phpml/Metric/ClassificationReport.php +++ b/src/Phpml/Metric/ClassificationReport.php @@ -4,8 +4,36 @@ declare(strict_types=1); namespace Phpml\Metric; +use Phpml\Exception\InvalidArgumentException; + class ClassificationReport { + public const MICRO_AVERAGE = 1; + + public const MACRO_AVERAGE = 2; + + public const WEIGHTED_AVERAGE = 3; + + /** + * @var array + */ + private $truePositive = []; + + /** + * @var array + */ + private $falsePositive = []; + + /** + * @var array + */ + private $falseNegative = []; + + /** + * @var array + */ + private $support = []; + /** * @var array */ @@ -21,34 +49,21 @@ class ClassificationReport */ private $f1score = []; - /** - * @var array - */ - private $support = []; - /** * @var array */ private $average = []; - public function __construct(array $actualLabels, array $predictedLabels) + public function __construct(array $actualLabels, array $predictedLabels, int $average = self::MACRO_AVERAGE) { - $truePositive = $falsePositive = $falseNegative = $this->support = self::getLabelIndexedArray($actualLabels, $predictedLabels); - - foreach ($actualLabels as $index => $actual) { - $predicted = $predictedLabels[$index]; - ++$this->support[$actual]; - - if ($actual === $predicted) { - ++$truePositive[$actual]; - } else { - ++$falsePositive[$predicted]; - ++$falseNegative[$actual]; - } + $averagingMethods = range(self::MICRO_AVERAGE, self::WEIGHTED_AVERAGE); + if (!in_array($average, $averagingMethods)) { + throw new InvalidArgumentException('Averaging method must be MICRO_AVERAGE, MACRO_AVERAGE or WEIGHTED_AVERAGE'); } - $this->computeMetrics($truePositive, $falsePositive, $falseNegative); - $this->computeAverage(); + $this->aggregateClassificationResults($actualLabels, $predictedLabels); + $this->computeMetrics(); + $this->computeAverage($average); } public function getPrecision(): array @@ -76,20 +91,73 @@ class ClassificationReport return $this->average; } - private function computeMetrics(array $truePositive, array $falsePositive, array $falseNegative): void + private function aggregateClassificationResults(array $actualLabels, array $predictedLabels): void { - foreach ($truePositive as $label => $tp) { - $this->precision[$label] = $this->computePrecision($tp, $falsePositive[$label]); - $this->recall[$label] = $this->computeRecall($tp, $falseNegative[$label]); + $truePositive = $falsePositive = $falseNegative = $support = self::getLabelIndexedArray($actualLabels, $predictedLabels); + + foreach ($actualLabels as $index => $actual) { + $predicted = $predictedLabels[$index]; + ++$support[$actual]; + + if ($actual === $predicted) { + ++$truePositive[$actual]; + } else { + ++$falsePositive[$predicted]; + ++$falseNegative[$actual]; + } + } + + $this->truePositive = $truePositive; + $this->falsePositive = $falsePositive; + $this->falseNegative = $falseNegative; + $this->support = $support; + } + + private function computeMetrics(): void + { + foreach ($this->truePositive as $label => $tp) { + $this->precision[$label] = $this->computePrecision($tp, $this->falsePositive[$label]); + $this->recall[$label] = $this->computeRecall($tp, $this->falseNegative[$label]); $this->f1score[$label] = $this->computeF1Score((float) $this->precision[$label], (float) $this->recall[$label]); } } - private function computeAverage(): void + private function computeAverage(int $average): void + { + switch ($average) { + case self::MICRO_AVERAGE: + $this->computeMicroAverage(); + + return; + case self::MACRO_AVERAGE: + $this->computeMacroAverage(); + + return; + case self::WEIGHTED_AVERAGE: + $this->computeWeightedAverage(); + + return; + } + } + + private function computeMicroAverage(): void + { + $truePositive = array_sum($this->truePositive); + $falsePositive = array_sum($this->falsePositive); + $falseNegative = array_sum($this->falseNegative); + + $precision = $this->computePrecision($truePositive, $falsePositive); + $recall = $this->computeRecall($truePositive, $falseNegative); + $f1score = $this->computeF1Score((float) $precision, (float) $recall); + + $this->average = compact('precision', 'recall', 'f1score'); + } + + private function computeMacroAverage(): void { foreach (['precision', 'recall', 'f1score'] as $metric) { - $values = array_filter($this->{$metric}); - if (empty($values)) { + $values = $this->{$metric}; + if (count($values) == 0) { $this->average[$metric] = 0.0; continue; @@ -99,6 +167,25 @@ class ClassificationReport } } + private function computeWeightedAverage(): void + { + foreach (['precision', 'recall', 'f1score'] as $metric) { + $values = $this->{$metric}; + if (count($values) == 0) { + $this->average[$metric] = 0.0; + + continue; + } + + $sum = 0; + foreach ($values as $i => $value) { + $sum += $value * $this->support[$i]; + } + + $this->average[$metric] = $sum / array_sum($this->support); + } + } + /** * @return float|string */ diff --git a/tests/Phpml/Metric/ClassificationReportTest.php b/tests/Phpml/Metric/ClassificationReportTest.php index 483f769..4c4f01f 100644 --- a/tests/Phpml/Metric/ClassificationReportTest.php +++ b/tests/Phpml/Metric/ClassificationReportTest.php @@ -4,6 +4,7 @@ declare(strict_types=1); namespace Phpml\Tests\Metric; +use Phpml\Exception\InvalidArgumentException; use Phpml\Metric\ClassificationReport; use PHPUnit\Framework\TestCase; @@ -36,10 +37,12 @@ class ClassificationReportTest extends TestCase 'ant' => 1, 'bird' => 3, ]; + + // ClassificationReport uses macro-averaging as default $average = [ - 'precision' => 0.75, - 'recall' => 0.83, - 'f1score' => 0.73, + 'precision' => 0.5, // (1/2 + 0 + 1) / 3 = 1/2 + 'recall' => 0.56, // (1 + 0 + 2/3) / 3 = 5/9 + 'f1score' => 0.49, // (2/3 + 0 + 4/5) / 3 = 22/45 ]; $this->assertEquals($precision, $report->getPrecision(), '', 0.01); @@ -77,9 +80,9 @@ class ClassificationReportTest extends TestCase 2 => 3, ]; $average = [ - 'precision' => 0.75, - 'recall' => 0.83, - 'f1score' => 0.73, + 'precision' => 0.5, + 'recall' => 0.56, + 'f1score' => 0.49, ]; $this->assertEquals($precision, $report->getPrecision(), '', 0.01); @@ -89,6 +92,63 @@ class ClassificationReportTest extends TestCase $this->assertEquals($average, $report->getAverage(), '', 0.01); } + public function testClassificationReportAverageOutOfRange(): void + { + $labels = ['cat', 'ant', 'bird', 'bird', 'bird']; + $predicted = ['cat', 'cat', 'bird', 'bird', 'ant']; + + $this->expectException(InvalidArgumentException::class); + $report = new ClassificationReport($labels, $predicted, 0); + } + + public function testClassificationReportMicroAverage(): void + { + $labels = ['cat', 'ant', 'bird', 'bird', 'bird']; + $predicted = ['cat', 'cat', 'bird', 'bird', 'ant']; + + $report = new ClassificationReport($labels, $predicted, ClassificationReport::MICRO_AVERAGE); + + $average = [ + 'precision' => 0.6, // TP / (TP + FP) = (1 + 0 + 2) / (2 + 1 + 2) = 3/5 + 'recall' => 0.6, // TP / (TP + FN) = (1 + 0 + 2) / (1 + 1 + 3) = 3/5 + 'f1score' => 0.6, // Harmonic mean of precision and recall + ]; + + $this->assertEquals($average, $report->getAverage(), '', 0.01); + } + + public function testClassificationReportMacroAverage(): void + { + $labels = ['cat', 'ant', 'bird', 'bird', 'bird']; + $predicted = ['cat', 'cat', 'bird', 'bird', 'ant']; + + $report = new ClassificationReport($labels, $predicted, ClassificationReport::MACRO_AVERAGE); + + $average = [ + 'precision' => 0.5, // (1/2 + 0 + 1) / 3 = 1/2 + 'recall' => 0.56, // (1 + 0 + 2/3) / 3 = 5/9 + 'f1score' => 0.49, // (2/3 + 0 + 4/5) / 3 = 22/45 + ]; + + $this->assertEquals($average, $report->getAverage(), '', 0.01); + } + + public function testClassificationReportWeightedAverage(): void + { + $labels = ['cat', 'ant', 'bird', 'bird', 'bird']; + $predicted = ['cat', 'cat', 'bird', 'bird', 'ant']; + + $report = new ClassificationReport($labels, $predicted, ClassificationReport::WEIGHTED_AVERAGE); + + $average = [ + 'precision' => 0.7, // (1/2 * 1 + 0 * 1 + 1 * 3) / 5 = 7/10 + 'recall' => 0.6, // (1 * 1 + 0 * 1 + 2/3 * 3) / 5 = 3/5 + 'f1score' => 0.61, // (2/3 * 1 + 0 * 1 + 4/5 * 3) / 5 = 46/75 + ]; + + $this->assertEquals($average, $report->getAverage(), '', 0.01); + } + public function testPreventDivideByZeroWhenTruePositiveAndFalsePositiveSumEqualsZero(): void { $labels = [1, 2]; @@ -129,4 +189,18 @@ class ClassificationReportTest extends TestCase 'f1score' => 0, ], $report->getAverage(), '', 0.01); } + + public function testPreventDividedByZeroWhenLabelsAreEmpty(): void + { + $labels = []; + $predicted = []; + + $report = new ClassificationReport($labels, $predicted); + + $this->assertEquals([ + 'precision' => 0, + 'recall' => 0, + 'f1score' => 0, + ], $report->getAverage(), '', 0.01); + } }