diff --git a/src/Phpml/Metric/ClassificationReport.php b/src/Phpml/Metric/ClassificationReport.php index b1b135f..f56fb2a 100644 --- a/src/Phpml/Metric/ClassificationReport.php +++ b/src/Phpml/Metric/ClassificationReport.php @@ -113,6 +113,10 @@ class ClassificationReport { foreach (['precision', 'recall', 'f1score'] as $metric) { $values = array_filter($this->$metric); + if(0==count($values)) { + $this->average[$metric] = 0.0; + continue; + } $this->average[$metric] = array_sum($values) / count($values); } } diff --git a/tests/Phpml/Metric/ClassificationReportTest.php b/tests/Phpml/Metric/ClassificationReportTest.php index 6ccc21d..515f97c 100644 --- a/tests/Phpml/Metric/ClassificationReportTest.php +++ b/tests/Phpml/Metric/ClassificationReportTest.php @@ -67,4 +67,19 @@ class ClassificationReportTest extends \PHPUnit_Framework_TestCase $this->assertEquals([1 => 0.0, 2 => 1, 3 => 0], $report->getPrecision(), '', 0.01); } + + public function testPreventDividedByZeroWhenPredictedLabelsAllNotMatch() + { + $labels = [1,2,3,4,5]; + $predicted = [2,3,4,5,6]; + + $report = new ClassificationReport($labels, $predicted); + + $this->assertEquals([ + 'precision' => 0, + 'recall' => 0, + 'f1score' => 0 + ], $report->getAverage(), '', 0.01); + } + }