From 96654571592eb1e1ae3fe4cc230c47312c777c97 Mon Sep 17 00:00:00 2001 From: Arkadiusz Kondas Date: Tue, 19 Jul 2016 21:58:59 +0200 Subject: [PATCH] implement ClassificationReport class --- src/Phpml/Metric/ClassificationReport.php | 148 ++++++++++++++++++ .../Phpml/Metric/ClassificationReportTest.php | 32 ++++ 2 files changed, 180 insertions(+) create mode 100644 src/Phpml/Metric/ClassificationReport.php create mode 100644 tests/Phpml/Metric/ClassificationReportTest.php diff --git a/src/Phpml/Metric/ClassificationReport.php b/src/Phpml/Metric/ClassificationReport.php new file mode 100644 index 0000000..69209fd --- /dev/null +++ b/src/Phpml/Metric/ClassificationReport.php @@ -0,0 +1,148 @@ +support = self::getLabelIndexedArray($actualLabels); + + foreach ($actualLabels as $index => $actual) { + $predicted = $predictedLabels[$index]; + $this->support[$actual]++; + + if($actual === $predicted) { + $truePositive[$actual]++; + } else { + $falsePositive[$predicted]++; + $falseNegative[$actual]++; + } + } + + $this->computeMetrics($truePositive, $falsePositive, $falseNegative); + $this->computeAverage(); + } + + /** + * @return array + */ + public function getPrecision() + { + return $this->precision; + } + + /** + * @return array + */ + public function getRecall() + { + return $this->recall; + } + + /** + * @return array + */ + public function getF1score() + { + return $this->f1score; + } + + /** + * @return array + */ + public function getSupport() + { + return $this->support; + } + + /** + * @return array + */ + public function getAverage() + { + return $this->average; + } + + /** + * @param array $truePositive + * @param array $falsePositive + * @param array $falseNegative + */ + private function computeMetrics(array $truePositive, array $falsePositive, array $falseNegative) + { + foreach ($truePositive as $label => $tp) { + $this->precision[$label] = $tp / ($tp + $falsePositive[$label]); + $this->recall[$label] = $tp / ($tp + $falseNegative[$label]); + $this->f1score[$label] = $this->computeF1Score((float)$this->precision[$label], (float)$this->recall[$label]); + } + } + + private function computeAverage() + { + foreach (['precision', 'recall', 'f1score'] as $metric) { + $values = array_filter($this->$metric); + $this->average[$metric] = array_sum($values) / count($values); + } + } + + /** + * @param float $precision + * @param float $recall + * + * @return float + */ + private function computeF1Score(float $precision, float $recall): float + { + if(0 == ($divider = $precision+$recall)) { + return 0.0; + } + + return 2.0 * (($precision * $recall) / ($divider)); + } + + /** + * @param array $labels + * + * @return array + */ + private static function getLabelIndexedArray(array $labels): array + { + $labels = array_values(array_unique($labels)); + sort($labels); + $labels = array_combine($labels, array_fill(0, count($labels), 0)); + + return $labels; + } + +} diff --git a/tests/Phpml/Metric/ClassificationReportTest.php b/tests/Phpml/Metric/ClassificationReportTest.php new file mode 100644 index 0000000..58520be --- /dev/null +++ b/tests/Phpml/Metric/ClassificationReportTest.php @@ -0,0 +1,32 @@ + 0.5, 'ant' => 0.0, 'bird' => 1.0]; + $recall = ['cat' => 1.0, 'ant' => 0.0, 'bird' => 0.67]; + $f1score = ['cat' => 0.67, 'ant' => 0.0, 'bird' => 0.80]; + $support = ['cat' => 1, 'ant' => 1, 'bird' => 3]; + $average = ['precision' => 0.75, 'recall' => 0.83, 'f1score' => 0.73]; + + + $this->assertEquals($precision, $report->getPrecision(), '', 0.01); + $this->assertEquals($recall, $report->getRecall(), '', 0.01); + $this->assertEquals($f1score, $report->getF1score(), '', 0.01); + $this->assertEquals($support, $report->getSupport(), '', 0.01); + $this->assertEquals($average, $report->getAverage(), '', 0.01); + } + +}