Choose averaging method in classification report (#205)

* Fix testcases of ClassificationReport

* Fix averaging method in ClassificationReport

* Fix divided by zero if labels are empty

* Fix calculation of f1score

* Add averaging methods (not completed)

* Implement weighted average method

* Extract counts to properties

* Fix default to macro average

* Implement micro average method

* Fix style

* Update docs

* Fix styles
This commit is contained in:
Yuji Uchiyama 2018-01-30 02:06:21 +09:00 committed by Arkadiusz Kondas
parent ba7114a3f7
commit 554c86af68
3 changed files with 201 additions and 33 deletions

View File

@ -18,6 +18,13 @@ $predictedLabels = ['cat', 'cat', 'bird', 'bird', 'ant'];
$report = new ClassificationReport($actualLabels, $predictedLabels);
```
Optionally you can provide the following parameter:
* $average - (int) averaging method for multi-class classification
* `ClassificationReport::MICRO_AVERAGE` = 1
* `ClassificationReport::MACRO_AVERAGE` = 2 (default)
* `ClassificationReport::WEIGHTED_AVERAGE` = 3
### Metrics
After creating the report you can draw its individual metrics:

View File

@ -4,8 +4,36 @@ declare(strict_types=1);
namespace Phpml\Metric;
use Phpml\Exception\InvalidArgumentException;
class ClassificationReport
{
public const MICRO_AVERAGE = 1;
public const MACRO_AVERAGE = 2;
public const WEIGHTED_AVERAGE = 3;
/**
* @var array
*/
private $truePositive = [];
/**
* @var array
*/
private $falsePositive = [];
/**
* @var array
*/
private $falseNegative = [];
/**
* @var array
*/
private $support = [];
/**
* @var array
*/
@ -21,34 +49,21 @@ class ClassificationReport
*/
private $f1score = [];
/**
* @var array
*/
private $support = [];
/**
* @var array
*/
private $average = [];
public function __construct(array $actualLabels, array $predictedLabels)
public function __construct(array $actualLabels, array $predictedLabels, int $average = self::MACRO_AVERAGE)
{
$truePositive = $falsePositive = $falseNegative = $this->support = self::getLabelIndexedArray($actualLabels, $predictedLabels);
foreach ($actualLabels as $index => $actual) {
$predicted = $predictedLabels[$index];
++$this->support[$actual];
if ($actual === $predicted) {
++$truePositive[$actual];
} else {
++$falsePositive[$predicted];
++$falseNegative[$actual];
}
$averagingMethods = range(self::MICRO_AVERAGE, self::WEIGHTED_AVERAGE);
if (!in_array($average, $averagingMethods)) {
throw new InvalidArgumentException('Averaging method must be MICRO_AVERAGE, MACRO_AVERAGE or WEIGHTED_AVERAGE');
}
$this->computeMetrics($truePositive, $falsePositive, $falseNegative);
$this->computeAverage();
$this->aggregateClassificationResults($actualLabels, $predictedLabels);
$this->computeMetrics();
$this->computeAverage($average);
}
public function getPrecision(): array
@ -76,20 +91,73 @@ class ClassificationReport
return $this->average;
}
private function computeMetrics(array $truePositive, array $falsePositive, array $falseNegative): void
private function aggregateClassificationResults(array $actualLabels, array $predictedLabels): void
{
foreach ($truePositive as $label => $tp) {
$this->precision[$label] = $this->computePrecision($tp, $falsePositive[$label]);
$this->recall[$label] = $this->computeRecall($tp, $falseNegative[$label]);
$truePositive = $falsePositive = $falseNegative = $support = self::getLabelIndexedArray($actualLabels, $predictedLabels);
foreach ($actualLabels as $index => $actual) {
$predicted = $predictedLabels[$index];
++$support[$actual];
if ($actual === $predicted) {
++$truePositive[$actual];
} else {
++$falsePositive[$predicted];
++$falseNegative[$actual];
}
}
$this->truePositive = $truePositive;
$this->falsePositive = $falsePositive;
$this->falseNegative = $falseNegative;
$this->support = $support;
}
private function computeMetrics(): void
{
foreach ($this->truePositive as $label => $tp) {
$this->precision[$label] = $this->computePrecision($tp, $this->falsePositive[$label]);
$this->recall[$label] = $this->computeRecall($tp, $this->falseNegative[$label]);
$this->f1score[$label] = $this->computeF1Score((float) $this->precision[$label], (float) $this->recall[$label]);
}
}
private function computeAverage(): void
private function computeAverage(int $average): void
{
switch ($average) {
case self::MICRO_AVERAGE:
$this->computeMicroAverage();
return;
case self::MACRO_AVERAGE:
$this->computeMacroAverage();
return;
case self::WEIGHTED_AVERAGE:
$this->computeWeightedAverage();
return;
}
}
private function computeMicroAverage(): void
{
$truePositive = array_sum($this->truePositive);
$falsePositive = array_sum($this->falsePositive);
$falseNegative = array_sum($this->falseNegative);
$precision = $this->computePrecision($truePositive, $falsePositive);
$recall = $this->computeRecall($truePositive, $falseNegative);
$f1score = $this->computeF1Score((float) $precision, (float) $recall);
$this->average = compact('precision', 'recall', 'f1score');
}
private function computeMacroAverage(): void
{
foreach (['precision', 'recall', 'f1score'] as $metric) {
$values = array_filter($this->{$metric});
if (empty($values)) {
$values = $this->{$metric};
if (count($values) == 0) {
$this->average[$metric] = 0.0;
continue;
@ -99,6 +167,25 @@ class ClassificationReport
}
}
private function computeWeightedAverage(): void
{
foreach (['precision', 'recall', 'f1score'] as $metric) {
$values = $this->{$metric};
if (count($values) == 0) {
$this->average[$metric] = 0.0;
continue;
}
$sum = 0;
foreach ($values as $i => $value) {
$sum += $value * $this->support[$i];
}
$this->average[$metric] = $sum / array_sum($this->support);
}
}
/**
* @return float|string
*/

View File

@ -4,6 +4,7 @@ declare(strict_types=1);
namespace Phpml\Tests\Metric;
use Phpml\Exception\InvalidArgumentException;
use Phpml\Metric\ClassificationReport;
use PHPUnit\Framework\TestCase;
@ -36,10 +37,12 @@ class ClassificationReportTest extends TestCase
'ant' => 1,
'bird' => 3,
];
// ClassificationReport uses macro-averaging as default
$average = [
'precision' => 0.75,
'recall' => 0.83,
'f1score' => 0.73,
'precision' => 0.5, // (1/2 + 0 + 1) / 3 = 1/2
'recall' => 0.56, // (1 + 0 + 2/3) / 3 = 5/9
'f1score' => 0.49, // (2/3 + 0 + 4/5) / 3 = 22/45
];
$this->assertEquals($precision, $report->getPrecision(), '', 0.01);
@ -77,9 +80,9 @@ class ClassificationReportTest extends TestCase
2 => 3,
];
$average = [
'precision' => 0.75,
'recall' => 0.83,
'f1score' => 0.73,
'precision' => 0.5,
'recall' => 0.56,
'f1score' => 0.49,
];
$this->assertEquals($precision, $report->getPrecision(), '', 0.01);
@ -89,6 +92,63 @@ class ClassificationReportTest extends TestCase
$this->assertEquals($average, $report->getAverage(), '', 0.01);
}
public function testClassificationReportAverageOutOfRange(): void
{
$labels = ['cat', 'ant', 'bird', 'bird', 'bird'];
$predicted = ['cat', 'cat', 'bird', 'bird', 'ant'];
$this->expectException(InvalidArgumentException::class);
$report = new ClassificationReport($labels, $predicted, 0);
}
public function testClassificationReportMicroAverage(): void
{
$labels = ['cat', 'ant', 'bird', 'bird', 'bird'];
$predicted = ['cat', 'cat', 'bird', 'bird', 'ant'];
$report = new ClassificationReport($labels, $predicted, ClassificationReport::MICRO_AVERAGE);
$average = [
'precision' => 0.6, // TP / (TP + FP) = (1 + 0 + 2) / (2 + 1 + 2) = 3/5
'recall' => 0.6, // TP / (TP + FN) = (1 + 0 + 2) / (1 + 1 + 3) = 3/5
'f1score' => 0.6, // Harmonic mean of precision and recall
];
$this->assertEquals($average, $report->getAverage(), '', 0.01);
}
public function testClassificationReportMacroAverage(): void
{
$labels = ['cat', 'ant', 'bird', 'bird', 'bird'];
$predicted = ['cat', 'cat', 'bird', 'bird', 'ant'];
$report = new ClassificationReport($labels, $predicted, ClassificationReport::MACRO_AVERAGE);
$average = [
'precision' => 0.5, // (1/2 + 0 + 1) / 3 = 1/2
'recall' => 0.56, // (1 + 0 + 2/3) / 3 = 5/9
'f1score' => 0.49, // (2/3 + 0 + 4/5) / 3 = 22/45
];
$this->assertEquals($average, $report->getAverage(), '', 0.01);
}
public function testClassificationReportWeightedAverage(): void
{
$labels = ['cat', 'ant', 'bird', 'bird', 'bird'];
$predicted = ['cat', 'cat', 'bird', 'bird', 'ant'];
$report = new ClassificationReport($labels, $predicted, ClassificationReport::WEIGHTED_AVERAGE);
$average = [
'precision' => 0.7, // (1/2 * 1 + 0 * 1 + 1 * 3) / 5 = 7/10
'recall' => 0.6, // (1 * 1 + 0 * 1 + 2/3 * 3) / 5 = 3/5
'f1score' => 0.61, // (2/3 * 1 + 0 * 1 + 4/5 * 3) / 5 = 46/75
];
$this->assertEquals($average, $report->getAverage(), '', 0.01);
}
public function testPreventDivideByZeroWhenTruePositiveAndFalsePositiveSumEqualsZero(): void
{
$labels = [1, 2];
@ -129,4 +189,18 @@ class ClassificationReportTest extends TestCase
'f1score' => 0,
], $report->getAverage(), '', 0.01);
}
public function testPreventDividedByZeroWhenLabelsAreEmpty(): void
{
$labels = [];
$predicted = [];
$report = new ClassificationReport($labels, $predicted);
$this->assertEquals([
'precision' => 0,
'recall' => 0,
'f1score' => 0,
], $report->getAverage(), '', 0.01);
}
}