mirror of
https://github.com/Llewellynvdm/php-ml.git
synced 2024-11-22 04:55:10 +00:00
Choose averaging method in classification report (#205)
* Fix testcases of ClassificationReport * Fix averaging method in ClassificationReport * Fix divided by zero if labels are empty * Fix calculation of f1score * Add averaging methods (not completed) * Implement weighted average method * Extract counts to properties * Fix default to macro average * Implement micro average method * Fix style * Update docs * Fix styles
This commit is contained in:
parent
ba7114a3f7
commit
554c86af68
@ -18,6 +18,13 @@ $predictedLabels = ['cat', 'cat', 'bird', 'bird', 'ant'];
|
||||
$report = new ClassificationReport($actualLabels, $predictedLabels);
|
||||
```
|
||||
|
||||
Optionally you can provide the following parameter:
|
||||
|
||||
* $average - (int) averaging method for multi-class classification
|
||||
* `ClassificationReport::MICRO_AVERAGE` = 1
|
||||
* `ClassificationReport::MACRO_AVERAGE` = 2 (default)
|
||||
* `ClassificationReport::WEIGHTED_AVERAGE` = 3
|
||||
|
||||
### Metrics
|
||||
|
||||
After creating the report you can draw its individual metrics:
|
||||
|
@ -4,8 +4,36 @@ declare(strict_types=1);
|
||||
|
||||
namespace Phpml\Metric;
|
||||
|
||||
use Phpml\Exception\InvalidArgumentException;
|
||||
|
||||
class ClassificationReport
|
||||
{
|
||||
public const MICRO_AVERAGE = 1;
|
||||
|
||||
public const MACRO_AVERAGE = 2;
|
||||
|
||||
public const WEIGHTED_AVERAGE = 3;
|
||||
|
||||
/**
|
||||
* @var array
|
||||
*/
|
||||
private $truePositive = [];
|
||||
|
||||
/**
|
||||
* @var array
|
||||
*/
|
||||
private $falsePositive = [];
|
||||
|
||||
/**
|
||||
* @var array
|
||||
*/
|
||||
private $falseNegative = [];
|
||||
|
||||
/**
|
||||
* @var array
|
||||
*/
|
||||
private $support = [];
|
||||
|
||||
/**
|
||||
* @var array
|
||||
*/
|
||||
@ -21,34 +49,21 @@ class ClassificationReport
|
||||
*/
|
||||
private $f1score = [];
|
||||
|
||||
/**
|
||||
* @var array
|
||||
*/
|
||||
private $support = [];
|
||||
|
||||
/**
|
||||
* @var array
|
||||
*/
|
||||
private $average = [];
|
||||
|
||||
public function __construct(array $actualLabels, array $predictedLabels)
|
||||
public function __construct(array $actualLabels, array $predictedLabels, int $average = self::MACRO_AVERAGE)
|
||||
{
|
||||
$truePositive = $falsePositive = $falseNegative = $this->support = self::getLabelIndexedArray($actualLabels, $predictedLabels);
|
||||
|
||||
foreach ($actualLabels as $index => $actual) {
|
||||
$predicted = $predictedLabels[$index];
|
||||
++$this->support[$actual];
|
||||
|
||||
if ($actual === $predicted) {
|
||||
++$truePositive[$actual];
|
||||
} else {
|
||||
++$falsePositive[$predicted];
|
||||
++$falseNegative[$actual];
|
||||
}
|
||||
$averagingMethods = range(self::MICRO_AVERAGE, self::WEIGHTED_AVERAGE);
|
||||
if (!in_array($average, $averagingMethods)) {
|
||||
throw new InvalidArgumentException('Averaging method must be MICRO_AVERAGE, MACRO_AVERAGE or WEIGHTED_AVERAGE');
|
||||
}
|
||||
|
||||
$this->computeMetrics($truePositive, $falsePositive, $falseNegative);
|
||||
$this->computeAverage();
|
||||
$this->aggregateClassificationResults($actualLabels, $predictedLabels);
|
||||
$this->computeMetrics();
|
||||
$this->computeAverage($average);
|
||||
}
|
||||
|
||||
public function getPrecision(): array
|
||||
@ -76,20 +91,73 @@ class ClassificationReport
|
||||
return $this->average;
|
||||
}
|
||||
|
||||
private function computeMetrics(array $truePositive, array $falsePositive, array $falseNegative): void
|
||||
private function aggregateClassificationResults(array $actualLabels, array $predictedLabels): void
|
||||
{
|
||||
foreach ($truePositive as $label => $tp) {
|
||||
$this->precision[$label] = $this->computePrecision($tp, $falsePositive[$label]);
|
||||
$this->recall[$label] = $this->computeRecall($tp, $falseNegative[$label]);
|
||||
$truePositive = $falsePositive = $falseNegative = $support = self::getLabelIndexedArray($actualLabels, $predictedLabels);
|
||||
|
||||
foreach ($actualLabels as $index => $actual) {
|
||||
$predicted = $predictedLabels[$index];
|
||||
++$support[$actual];
|
||||
|
||||
if ($actual === $predicted) {
|
||||
++$truePositive[$actual];
|
||||
} else {
|
||||
++$falsePositive[$predicted];
|
||||
++$falseNegative[$actual];
|
||||
}
|
||||
}
|
||||
|
||||
$this->truePositive = $truePositive;
|
||||
$this->falsePositive = $falsePositive;
|
||||
$this->falseNegative = $falseNegative;
|
||||
$this->support = $support;
|
||||
}
|
||||
|
||||
private function computeMetrics(): void
|
||||
{
|
||||
foreach ($this->truePositive as $label => $tp) {
|
||||
$this->precision[$label] = $this->computePrecision($tp, $this->falsePositive[$label]);
|
||||
$this->recall[$label] = $this->computeRecall($tp, $this->falseNegative[$label]);
|
||||
$this->f1score[$label] = $this->computeF1Score((float) $this->precision[$label], (float) $this->recall[$label]);
|
||||
}
|
||||
}
|
||||
|
||||
private function computeAverage(): void
|
||||
private function computeAverage(int $average): void
|
||||
{
|
||||
switch ($average) {
|
||||
case self::MICRO_AVERAGE:
|
||||
$this->computeMicroAverage();
|
||||
|
||||
return;
|
||||
case self::MACRO_AVERAGE:
|
||||
$this->computeMacroAverage();
|
||||
|
||||
return;
|
||||
case self::WEIGHTED_AVERAGE:
|
||||
$this->computeWeightedAverage();
|
||||
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
private function computeMicroAverage(): void
|
||||
{
|
||||
$truePositive = array_sum($this->truePositive);
|
||||
$falsePositive = array_sum($this->falsePositive);
|
||||
$falseNegative = array_sum($this->falseNegative);
|
||||
|
||||
$precision = $this->computePrecision($truePositive, $falsePositive);
|
||||
$recall = $this->computeRecall($truePositive, $falseNegative);
|
||||
$f1score = $this->computeF1Score((float) $precision, (float) $recall);
|
||||
|
||||
$this->average = compact('precision', 'recall', 'f1score');
|
||||
}
|
||||
|
||||
private function computeMacroAverage(): void
|
||||
{
|
||||
foreach (['precision', 'recall', 'f1score'] as $metric) {
|
||||
$values = array_filter($this->{$metric});
|
||||
if (empty($values)) {
|
||||
$values = $this->{$metric};
|
||||
if (count($values) == 0) {
|
||||
$this->average[$metric] = 0.0;
|
||||
|
||||
continue;
|
||||
@ -99,6 +167,25 @@ class ClassificationReport
|
||||
}
|
||||
}
|
||||
|
||||
private function computeWeightedAverage(): void
|
||||
{
|
||||
foreach (['precision', 'recall', 'f1score'] as $metric) {
|
||||
$values = $this->{$metric};
|
||||
if (count($values) == 0) {
|
||||
$this->average[$metric] = 0.0;
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
$sum = 0;
|
||||
foreach ($values as $i => $value) {
|
||||
$sum += $value * $this->support[$i];
|
||||
}
|
||||
|
||||
$this->average[$metric] = $sum / array_sum($this->support);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @return float|string
|
||||
*/
|
||||
|
@ -4,6 +4,7 @@ declare(strict_types=1);
|
||||
|
||||
namespace Phpml\Tests\Metric;
|
||||
|
||||
use Phpml\Exception\InvalidArgumentException;
|
||||
use Phpml\Metric\ClassificationReport;
|
||||
use PHPUnit\Framework\TestCase;
|
||||
|
||||
@ -36,10 +37,12 @@ class ClassificationReportTest extends TestCase
|
||||
'ant' => 1,
|
||||
'bird' => 3,
|
||||
];
|
||||
|
||||
// ClassificationReport uses macro-averaging as default
|
||||
$average = [
|
||||
'precision' => 0.75,
|
||||
'recall' => 0.83,
|
||||
'f1score' => 0.73,
|
||||
'precision' => 0.5, // (1/2 + 0 + 1) / 3 = 1/2
|
||||
'recall' => 0.56, // (1 + 0 + 2/3) / 3 = 5/9
|
||||
'f1score' => 0.49, // (2/3 + 0 + 4/5) / 3 = 22/45
|
||||
];
|
||||
|
||||
$this->assertEquals($precision, $report->getPrecision(), '', 0.01);
|
||||
@ -77,9 +80,9 @@ class ClassificationReportTest extends TestCase
|
||||
2 => 3,
|
||||
];
|
||||
$average = [
|
||||
'precision' => 0.75,
|
||||
'recall' => 0.83,
|
||||
'f1score' => 0.73,
|
||||
'precision' => 0.5,
|
||||
'recall' => 0.56,
|
||||
'f1score' => 0.49,
|
||||
];
|
||||
|
||||
$this->assertEquals($precision, $report->getPrecision(), '', 0.01);
|
||||
@ -89,6 +92,63 @@ class ClassificationReportTest extends TestCase
|
||||
$this->assertEquals($average, $report->getAverage(), '', 0.01);
|
||||
}
|
||||
|
||||
public function testClassificationReportAverageOutOfRange(): void
|
||||
{
|
||||
$labels = ['cat', 'ant', 'bird', 'bird', 'bird'];
|
||||
$predicted = ['cat', 'cat', 'bird', 'bird', 'ant'];
|
||||
|
||||
$this->expectException(InvalidArgumentException::class);
|
||||
$report = new ClassificationReport($labels, $predicted, 0);
|
||||
}
|
||||
|
||||
public function testClassificationReportMicroAverage(): void
|
||||
{
|
||||
$labels = ['cat', 'ant', 'bird', 'bird', 'bird'];
|
||||
$predicted = ['cat', 'cat', 'bird', 'bird', 'ant'];
|
||||
|
||||
$report = new ClassificationReport($labels, $predicted, ClassificationReport::MICRO_AVERAGE);
|
||||
|
||||
$average = [
|
||||
'precision' => 0.6, // TP / (TP + FP) = (1 + 0 + 2) / (2 + 1 + 2) = 3/5
|
||||
'recall' => 0.6, // TP / (TP + FN) = (1 + 0 + 2) / (1 + 1 + 3) = 3/5
|
||||
'f1score' => 0.6, // Harmonic mean of precision and recall
|
||||
];
|
||||
|
||||
$this->assertEquals($average, $report->getAverage(), '', 0.01);
|
||||
}
|
||||
|
||||
public function testClassificationReportMacroAverage(): void
|
||||
{
|
||||
$labels = ['cat', 'ant', 'bird', 'bird', 'bird'];
|
||||
$predicted = ['cat', 'cat', 'bird', 'bird', 'ant'];
|
||||
|
||||
$report = new ClassificationReport($labels, $predicted, ClassificationReport::MACRO_AVERAGE);
|
||||
|
||||
$average = [
|
||||
'precision' => 0.5, // (1/2 + 0 + 1) / 3 = 1/2
|
||||
'recall' => 0.56, // (1 + 0 + 2/3) / 3 = 5/9
|
||||
'f1score' => 0.49, // (2/3 + 0 + 4/5) / 3 = 22/45
|
||||
];
|
||||
|
||||
$this->assertEquals($average, $report->getAverage(), '', 0.01);
|
||||
}
|
||||
|
||||
public function testClassificationReportWeightedAverage(): void
|
||||
{
|
||||
$labels = ['cat', 'ant', 'bird', 'bird', 'bird'];
|
||||
$predicted = ['cat', 'cat', 'bird', 'bird', 'ant'];
|
||||
|
||||
$report = new ClassificationReport($labels, $predicted, ClassificationReport::WEIGHTED_AVERAGE);
|
||||
|
||||
$average = [
|
||||
'precision' => 0.7, // (1/2 * 1 + 0 * 1 + 1 * 3) / 5 = 7/10
|
||||
'recall' => 0.6, // (1 * 1 + 0 * 1 + 2/3 * 3) / 5 = 3/5
|
||||
'f1score' => 0.61, // (2/3 * 1 + 0 * 1 + 4/5 * 3) / 5 = 46/75
|
||||
];
|
||||
|
||||
$this->assertEquals($average, $report->getAverage(), '', 0.01);
|
||||
}
|
||||
|
||||
public function testPreventDivideByZeroWhenTruePositiveAndFalsePositiveSumEqualsZero(): void
|
||||
{
|
||||
$labels = [1, 2];
|
||||
@ -129,4 +189,18 @@ class ClassificationReportTest extends TestCase
|
||||
'f1score' => 0,
|
||||
], $report->getAverage(), '', 0.01);
|
||||
}
|
||||
|
||||
public function testPreventDividedByZeroWhenLabelsAreEmpty(): void
|
||||
{
|
||||
$labels = [];
|
||||
$predicted = [];
|
||||
|
||||
$report = new ClassificationReport($labels, $predicted);
|
||||
|
||||
$this->assertEquals([
|
||||
'precision' => 0,
|
||||
'recall' => 0,
|
||||
'f1score' => 0,
|
||||
], $report->getAverage(), '', 0.01);
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user