Fix division by zero in ClassificationReport #21

This commit is contained in:
Arkadiusz Kondas 2016-09-27 20:07:21 +02:00
parent 1ce6bb544b
commit 84af842f04
2 changed files with 57 additions and 6 deletions

View File

@ -37,7 +37,7 @@ class ClassificationReport
*/
public function __construct(array $actualLabels, array $predictedLabels)
{
$truePositive = $falsePositive = $falseNegative = $this->support = self::getLabelIndexedArray($actualLabels);
$truePositive = $falsePositive = $falseNegative = $this->support = self::getLabelIndexedArray($actualLabels, $predictedLabels);
foreach ($actualLabels as $index => $actual) {
$predicted = $predictedLabels[$index];
@ -103,8 +103,8 @@ class ClassificationReport
private function computeMetrics(array $truePositive, array $falsePositive, array $falseNegative)
{
foreach ($truePositive as $label => $tp) {
$this->precision[$label] = $tp / ($tp + $falsePositive[$label]);
$this->recall[$label] = $tp / ($tp + $falseNegative[$label]);
$this->precision[$label] = $this->computePrecision($tp, $falsePositive[$label]);
$this->recall[$label] = $this->computeRecall($tp, $falseNegative[$label]);
$this->f1score[$label] = $this->computeF1Score((float) $this->precision[$label], (float) $this->recall[$label]);
}
}
@ -117,6 +117,36 @@ class ClassificationReport
}
}
/**
* @param int $truePositive
* @param int $falsePositive
*
* @return float|string
*/
private function computePrecision(int $truePositive, int $falsePositive)
{
if (0 == ($divider = $truePositive + $falsePositive)) {
return 0.0;
}
return $truePositive / $divider;
}
/**
* @param int $truePositive
* @param int $falseNegative
*
* @return float|string
*/
private function computeRecall(int $truePositive, int $falseNegative)
{
if (0 == ($divider = $truePositive + $falseNegative)) {
return 0.0;
}
return $truePositive / $divider;
}
/**
* @param float $precision
* @param float $recall
@ -133,13 +163,14 @@ class ClassificationReport
}
/**
* @param array $labels
* @param array $actualLabels
* @param array $predictedLabels
*
* @return array
*/
private static function getLabelIndexedArray(array $labels): array
private static function getLabelIndexedArray(array $actualLabels, array $predictedLabels): array
{
$labels = array_values(array_unique($labels));
$labels = array_values(array_unique(array_merge($actualLabels, $predictedLabels)));
sort($labels);
$labels = array_combine($labels, array_fill(0, count($labels), 0));

View File

@ -47,4 +47,24 @@ class ClassificationReportTest extends \PHPUnit_Framework_TestCase
$this->assertEquals($support, $report->getSupport(), '', 0.01);
$this->assertEquals($average, $report->getAverage(), '', 0.01);
}
public function testPreventDivideByZeroWhenTruePositiveAndFalsePositiveSumEqualsZero()
{
$labels = [1, 2];
$predicted = [2, 2];
$report = new ClassificationReport($labels, $predicted);
$this->assertEquals([1 => 0.0, 2 => 0.5], $report->getPrecision(), '', 0.01);
}
public function testPreventDivideByZeroWhenTruePositiveAndFalseNegativeSumEqualsZero()
{
$labels = [2, 2, 1];
$predicted = [2, 2, 3];
$report = new ClassificationReport($labels, $predicted);
$this->assertEquals([1 => 0.0, 2 => 1, 3 => 0], $report->getPrecision(), '', 0.01);
}
}