Choose averaging method in classification report (#205)

* Fix testcases of ClassificationReport

* Fix averaging method in ClassificationReport

* Fix divided by zero if labels are empty

* Fix calculation of f1score

* Add averaging methods (not completed)

* Implement weighted average method

* Extract counts to properties

* Fix default to macro average

* Implement micro average method

* Fix style

* Update docs

* Fix styles
This commit is contained in:
Yuji Uchiyama 2018-01-30 02:06:21 +09:00 committed by Arkadiusz Kondas
parent ba7114a3f7
commit 554c86af68
3 changed files with 201 additions and 33 deletions

View File

@ -18,6 +18,13 @@ $predictedLabels = ['cat', 'cat', 'bird', 'bird', 'ant'];
$report = new ClassificationReport($actualLabels, $predictedLabels); $report = new ClassificationReport($actualLabels, $predictedLabels);
``` ```
Optionally you can provide the following parameter:
* $average - (int) averaging method for multi-class classification
* `ClassificationReport::MICRO_AVERAGE` = 1
* `ClassificationReport::MACRO_AVERAGE` = 2 (default)
* `ClassificationReport::WEIGHTED_AVERAGE` = 3
### Metrics ### Metrics
After creating the report you can draw its individual metrics: After creating the report you can draw its individual metrics:

View File

@ -4,8 +4,36 @@ declare(strict_types=1);
namespace Phpml\Metric; namespace Phpml\Metric;
use Phpml\Exception\InvalidArgumentException;
class ClassificationReport class ClassificationReport
{ {
public const MICRO_AVERAGE = 1;
public const MACRO_AVERAGE = 2;
public const WEIGHTED_AVERAGE = 3;
/**
* @var array
*/
private $truePositive = [];
/**
* @var array
*/
private $falsePositive = [];
/**
* @var array
*/
private $falseNegative = [];
/**
* @var array
*/
private $support = [];
/** /**
* @var array * @var array
*/ */
@ -21,34 +49,21 @@ class ClassificationReport
*/ */
private $f1score = []; private $f1score = [];
/**
* @var array
*/
private $support = [];
/** /**
* @var array * @var array
*/ */
private $average = []; private $average = [];
public function __construct(array $actualLabels, array $predictedLabels) public function __construct(array $actualLabels, array $predictedLabels, int $average = self::MACRO_AVERAGE)
{ {
$truePositive = $falsePositive = $falseNegative = $this->support = self::getLabelIndexedArray($actualLabels, $predictedLabels); $averagingMethods = range(self::MICRO_AVERAGE, self::WEIGHTED_AVERAGE);
if (!in_array($average, $averagingMethods)) {
foreach ($actualLabels as $index => $actual) { throw new InvalidArgumentException('Averaging method must be MICRO_AVERAGE, MACRO_AVERAGE or WEIGHTED_AVERAGE');
$predicted = $predictedLabels[$index];
++$this->support[$actual];
if ($actual === $predicted) {
++$truePositive[$actual];
} else {
++$falsePositive[$predicted];
++$falseNegative[$actual];
}
} }
$this->computeMetrics($truePositive, $falsePositive, $falseNegative); $this->aggregateClassificationResults($actualLabels, $predictedLabels);
$this->computeAverage(); $this->computeMetrics();
$this->computeAverage($average);
} }
public function getPrecision(): array public function getPrecision(): array
@ -76,20 +91,73 @@ class ClassificationReport
return $this->average; return $this->average;
} }
private function computeMetrics(array $truePositive, array $falsePositive, array $falseNegative): void private function aggregateClassificationResults(array $actualLabels, array $predictedLabels): void
{ {
foreach ($truePositive as $label => $tp) { $truePositive = $falsePositive = $falseNegative = $support = self::getLabelIndexedArray($actualLabels, $predictedLabels);
$this->precision[$label] = $this->computePrecision($tp, $falsePositive[$label]);
$this->recall[$label] = $this->computeRecall($tp, $falseNegative[$label]); foreach ($actualLabels as $index => $actual) {
$predicted = $predictedLabels[$index];
++$support[$actual];
if ($actual === $predicted) {
++$truePositive[$actual];
} else {
++$falsePositive[$predicted];
++$falseNegative[$actual];
}
}
$this->truePositive = $truePositive;
$this->falsePositive = $falsePositive;
$this->falseNegative = $falseNegative;
$this->support = $support;
}
private function computeMetrics(): void
{
foreach ($this->truePositive as $label => $tp) {
$this->precision[$label] = $this->computePrecision($tp, $this->falsePositive[$label]);
$this->recall[$label] = $this->computeRecall($tp, $this->falseNegative[$label]);
$this->f1score[$label] = $this->computeF1Score((float) $this->precision[$label], (float) $this->recall[$label]); $this->f1score[$label] = $this->computeF1Score((float) $this->precision[$label], (float) $this->recall[$label]);
} }
} }
private function computeAverage(): void private function computeAverage(int $average): void
{
switch ($average) {
case self::MICRO_AVERAGE:
$this->computeMicroAverage();
return;
case self::MACRO_AVERAGE:
$this->computeMacroAverage();
return;
case self::WEIGHTED_AVERAGE:
$this->computeWeightedAverage();
return;
}
}
private function computeMicroAverage(): void
{
$truePositive = array_sum($this->truePositive);
$falsePositive = array_sum($this->falsePositive);
$falseNegative = array_sum($this->falseNegative);
$precision = $this->computePrecision($truePositive, $falsePositive);
$recall = $this->computeRecall($truePositive, $falseNegative);
$f1score = $this->computeF1Score((float) $precision, (float) $recall);
$this->average = compact('precision', 'recall', 'f1score');
}
private function computeMacroAverage(): void
{ {
foreach (['precision', 'recall', 'f1score'] as $metric) { foreach (['precision', 'recall', 'f1score'] as $metric) {
$values = array_filter($this->{$metric}); $values = $this->{$metric};
if (empty($values)) { if (count($values) == 0) {
$this->average[$metric] = 0.0; $this->average[$metric] = 0.0;
continue; continue;
@ -99,6 +167,25 @@ class ClassificationReport
} }
} }
private function computeWeightedAverage(): void
{
foreach (['precision', 'recall', 'f1score'] as $metric) {
$values = $this->{$metric};
if (count($values) == 0) {
$this->average[$metric] = 0.0;
continue;
}
$sum = 0;
foreach ($values as $i => $value) {
$sum += $value * $this->support[$i];
}
$this->average[$metric] = $sum / array_sum($this->support);
}
}
/** /**
* @return float|string * @return float|string
*/ */

View File

@ -4,6 +4,7 @@ declare(strict_types=1);
namespace Phpml\Tests\Metric; namespace Phpml\Tests\Metric;
use Phpml\Exception\InvalidArgumentException;
use Phpml\Metric\ClassificationReport; use Phpml\Metric\ClassificationReport;
use PHPUnit\Framework\TestCase; use PHPUnit\Framework\TestCase;
@ -36,10 +37,12 @@ class ClassificationReportTest extends TestCase
'ant' => 1, 'ant' => 1,
'bird' => 3, 'bird' => 3,
]; ];
// ClassificationReport uses macro-averaging as default
$average = [ $average = [
'precision' => 0.75, 'precision' => 0.5, // (1/2 + 0 + 1) / 3 = 1/2
'recall' => 0.83, 'recall' => 0.56, // (1 + 0 + 2/3) / 3 = 5/9
'f1score' => 0.73, 'f1score' => 0.49, // (2/3 + 0 + 4/5) / 3 = 22/45
]; ];
$this->assertEquals($precision, $report->getPrecision(), '', 0.01); $this->assertEquals($precision, $report->getPrecision(), '', 0.01);
@ -77,9 +80,9 @@ class ClassificationReportTest extends TestCase
2 => 3, 2 => 3,
]; ];
$average = [ $average = [
'precision' => 0.75, 'precision' => 0.5,
'recall' => 0.83, 'recall' => 0.56,
'f1score' => 0.73, 'f1score' => 0.49,
]; ];
$this->assertEquals($precision, $report->getPrecision(), '', 0.01); $this->assertEquals($precision, $report->getPrecision(), '', 0.01);
@ -89,6 +92,63 @@ class ClassificationReportTest extends TestCase
$this->assertEquals($average, $report->getAverage(), '', 0.01); $this->assertEquals($average, $report->getAverage(), '', 0.01);
} }
public function testClassificationReportAverageOutOfRange(): void
{
$labels = ['cat', 'ant', 'bird', 'bird', 'bird'];
$predicted = ['cat', 'cat', 'bird', 'bird', 'ant'];
$this->expectException(InvalidArgumentException::class);
$report = new ClassificationReport($labels, $predicted, 0);
}
public function testClassificationReportMicroAverage(): void
{
$labels = ['cat', 'ant', 'bird', 'bird', 'bird'];
$predicted = ['cat', 'cat', 'bird', 'bird', 'ant'];
$report = new ClassificationReport($labels, $predicted, ClassificationReport::MICRO_AVERAGE);
$average = [
'precision' => 0.6, // TP / (TP + FP) = (1 + 0 + 2) / (2 + 1 + 2) = 3/5
'recall' => 0.6, // TP / (TP + FN) = (1 + 0 + 2) / (1 + 1 + 3) = 3/5
'f1score' => 0.6, // Harmonic mean of precision and recall
];
$this->assertEquals($average, $report->getAverage(), '', 0.01);
}
public function testClassificationReportMacroAverage(): void
{
$labels = ['cat', 'ant', 'bird', 'bird', 'bird'];
$predicted = ['cat', 'cat', 'bird', 'bird', 'ant'];
$report = new ClassificationReport($labels, $predicted, ClassificationReport::MACRO_AVERAGE);
$average = [
'precision' => 0.5, // (1/2 + 0 + 1) / 3 = 1/2
'recall' => 0.56, // (1 + 0 + 2/3) / 3 = 5/9
'f1score' => 0.49, // (2/3 + 0 + 4/5) / 3 = 22/45
];
$this->assertEquals($average, $report->getAverage(), '', 0.01);
}
public function testClassificationReportWeightedAverage(): void
{
$labels = ['cat', 'ant', 'bird', 'bird', 'bird'];
$predicted = ['cat', 'cat', 'bird', 'bird', 'ant'];
$report = new ClassificationReport($labels, $predicted, ClassificationReport::WEIGHTED_AVERAGE);
$average = [
'precision' => 0.7, // (1/2 * 1 + 0 * 1 + 1 * 3) / 5 = 7/10
'recall' => 0.6, // (1 * 1 + 0 * 1 + 2/3 * 3) / 5 = 3/5
'f1score' => 0.61, // (2/3 * 1 + 0 * 1 + 4/5 * 3) / 5 = 46/75
];
$this->assertEquals($average, $report->getAverage(), '', 0.01);
}
public function testPreventDivideByZeroWhenTruePositiveAndFalsePositiveSumEqualsZero(): void public function testPreventDivideByZeroWhenTruePositiveAndFalsePositiveSumEqualsZero(): void
{ {
$labels = [1, 2]; $labels = [1, 2];
@ -129,4 +189,18 @@ class ClassificationReportTest extends TestCase
'f1score' => 0, 'f1score' => 0,
], $report->getAverage(), '', 0.01); ], $report->getAverage(), '', 0.01);
} }
public function testPreventDividedByZeroWhenLabelsAreEmpty(): void
{
$labels = [];
$predicted = [];
$report = new ClassificationReport($labels, $predicted);
$this->assertEquals([
'precision' => 0,
'recall' => 0,
'f1score' => 0,
], $report->getAverage(), '', 0.01);
}
} }