mirror of
https://github.com/Llewellynvdm/php-ml.git
synced 2025-01-10 00:37:55 +00:00
Fix division by zero in ClassificationReport #21
This commit is contained in:
parent
1ce6bb544b
commit
84af842f04
@ -37,7 +37,7 @@ class ClassificationReport
|
||||
*/
|
||||
public function __construct(array $actualLabels, array $predictedLabels)
|
||||
{
|
||||
$truePositive = $falsePositive = $falseNegative = $this->support = self::getLabelIndexedArray($actualLabels);
|
||||
$truePositive = $falsePositive = $falseNegative = $this->support = self::getLabelIndexedArray($actualLabels, $predictedLabels);
|
||||
|
||||
foreach ($actualLabels as $index => $actual) {
|
||||
$predicted = $predictedLabels[$index];
|
||||
@ -103,8 +103,8 @@ class ClassificationReport
|
||||
private function computeMetrics(array $truePositive, array $falsePositive, array $falseNegative)
|
||||
{
|
||||
foreach ($truePositive as $label => $tp) {
|
||||
$this->precision[$label] = $tp / ($tp + $falsePositive[$label]);
|
||||
$this->recall[$label] = $tp / ($tp + $falseNegative[$label]);
|
||||
$this->precision[$label] = $this->computePrecision($tp, $falsePositive[$label]);
|
||||
$this->recall[$label] = $this->computeRecall($tp, $falseNegative[$label]);
|
||||
$this->f1score[$label] = $this->computeF1Score((float) $this->precision[$label], (float) $this->recall[$label]);
|
||||
}
|
||||
}
|
||||
@ -117,6 +117,36 @@ class ClassificationReport
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @param int $truePositive
|
||||
* @param int $falsePositive
|
||||
*
|
||||
* @return float|string
|
||||
*/
|
||||
private function computePrecision(int $truePositive, int $falsePositive)
|
||||
{
|
||||
if (0 == ($divider = $truePositive + $falsePositive)) {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
return $truePositive / $divider;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param int $truePositive
|
||||
* @param int $falseNegative
|
||||
*
|
||||
* @return float|string
|
||||
*/
|
||||
private function computeRecall(int $truePositive, int $falseNegative)
|
||||
{
|
||||
if (0 == ($divider = $truePositive + $falseNegative)) {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
return $truePositive / $divider;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param float $precision
|
||||
* @param float $recall
|
||||
@ -133,13 +163,14 @@ class ClassificationReport
|
||||
}
|
||||
|
||||
/**
|
||||
* @param array $labels
|
||||
* @param array $actualLabels
|
||||
* @param array $predictedLabels
|
||||
*
|
||||
* @return array
|
||||
*/
|
||||
private static function getLabelIndexedArray(array $labels): array
|
||||
private static function getLabelIndexedArray(array $actualLabels, array $predictedLabels): array
|
||||
{
|
||||
$labels = array_values(array_unique($labels));
|
||||
$labels = array_values(array_unique(array_merge($actualLabels, $predictedLabels)));
|
||||
sort($labels);
|
||||
$labels = array_combine($labels, array_fill(0, count($labels), 0));
|
||||
|
||||
|
@ -47,4 +47,24 @@ class ClassificationReportTest extends \PHPUnit_Framework_TestCase
|
||||
$this->assertEquals($support, $report->getSupport(), '', 0.01);
|
||||
$this->assertEquals($average, $report->getAverage(), '', 0.01);
|
||||
}
|
||||
|
||||
public function testPreventDivideByZeroWhenTruePositiveAndFalsePositiveSumEqualsZero()
|
||||
{
|
||||
$labels = [1, 2];
|
||||
$predicted = [2, 2];
|
||||
|
||||
$report = new ClassificationReport($labels, $predicted);
|
||||
|
||||
$this->assertEquals([1 => 0.0, 2 => 0.5], $report->getPrecision(), '', 0.01);
|
||||
}
|
||||
|
||||
public function testPreventDivideByZeroWhenTruePositiveAndFalseNegativeSumEqualsZero()
|
||||
{
|
||||
$labels = [2, 2, 1];
|
||||
$predicted = [2, 2, 3];
|
||||
|
||||
$report = new ClassificationReport($labels, $predicted);
|
||||
|
||||
$this->assertEquals([1 => 0.0, 2 => 1, 3 => 0], $report->getPrecision(), '', 0.01);
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user