diff --git a/src/Phpml/Metric/ClassificationReport.php b/src/Phpml/Metric/ClassificationReport.php index 31ee0b7..b1b135f 100644 --- a/src/Phpml/Metric/ClassificationReport.php +++ b/src/Phpml/Metric/ClassificationReport.php @@ -37,7 +37,7 @@ class ClassificationReport */ public function __construct(array $actualLabels, array $predictedLabels) { - $truePositive = $falsePositive = $falseNegative = $this->support = self::getLabelIndexedArray($actualLabels); + $truePositive = $falsePositive = $falseNegative = $this->support = self::getLabelIndexedArray($actualLabels, $predictedLabels); foreach ($actualLabels as $index => $actual) { $predicted = $predictedLabels[$index]; @@ -103,8 +103,8 @@ class ClassificationReport private function computeMetrics(array $truePositive, array $falsePositive, array $falseNegative) { foreach ($truePositive as $label => $tp) { - $this->precision[$label] = $tp / ($tp + $falsePositive[$label]); - $this->recall[$label] = $tp / ($tp + $falseNegative[$label]); + $this->precision[$label] = $this->computePrecision($tp, $falsePositive[$label]); + $this->recall[$label] = $this->computeRecall($tp, $falseNegative[$label]); $this->f1score[$label] = $this->computeF1Score((float) $this->precision[$label], (float) $this->recall[$label]); } } @@ -117,6 +117,36 @@ class ClassificationReport } } + /** + * @param int $truePositive + * @param int $falsePositive + * + * @return float|string + */ + private function computePrecision(int $truePositive, int $falsePositive) + { + if (0 == ($divider = $truePositive + $falsePositive)) { + return 0.0; + } + + return $truePositive / $divider; + } + + /** + * @param int $truePositive + * @param int $falseNegative + * + * @return float|string + */ + private function computeRecall(int $truePositive, int $falseNegative) + { + if (0 == ($divider = $truePositive + $falseNegative)) { + return 0.0; + } + + return $truePositive / $divider; + } + /** * @param float $precision * @param float $recall @@ -133,13 +163,14 @@ class ClassificationReport } /** - * @param array $labels + * @param array $actualLabels + * @param array $predictedLabels * * @return array */ - private static function getLabelIndexedArray(array $labels): array + private static function getLabelIndexedArray(array $actualLabels, array $predictedLabels): array { - $labels = array_values(array_unique($labels)); + $labels = array_values(array_unique(array_merge($actualLabels, $predictedLabels))); sort($labels); $labels = array_combine($labels, array_fill(0, count($labels), 0)); diff --git a/tests/Phpml/Metric/ClassificationReportTest.php b/tests/Phpml/Metric/ClassificationReportTest.php index f0f1cd3..6ccc21d 100644 --- a/tests/Phpml/Metric/ClassificationReportTest.php +++ b/tests/Phpml/Metric/ClassificationReportTest.php @@ -47,4 +47,24 @@ class ClassificationReportTest extends \PHPUnit_Framework_TestCase $this->assertEquals($support, $report->getSupport(), '', 0.01); $this->assertEquals($average, $report->getAverage(), '', 0.01); } + + public function testPreventDivideByZeroWhenTruePositiveAndFalsePositiveSumEqualsZero() + { + $labels = [1, 2]; + $predicted = [2, 2]; + + $report = new ClassificationReport($labels, $predicted); + + $this->assertEquals([1 => 0.0, 2 => 0.5], $report->getPrecision(), '', 0.01); + } + + public function testPreventDivideByZeroWhenTruePositiveAndFalseNegativeSumEqualsZero() + { + $labels = [2, 2, 1]; + $predicted = [2, 2, 3]; + + $report = new ClassificationReport($labels, $predicted); + + $this->assertEquals([1 => 0.0, 2 => 1, 3 => 0], $report->getPrecision(), '', 0.01); + } }