mirror of
https://github.com/Llewellynvdm/php-ml.git
synced 2025-01-23 23:28:24 +00:00
implement ConfusionMatrix metric
This commit is contained in:
parent
cce68997a1
commit
6c7416a9c4
71
src/Phpml/Metric/ConfusionMatrix.php
Normal file
71
src/Phpml/Metric/ConfusionMatrix.php
Normal file
@ -0,0 +1,71 @@
|
||||
<?php
|
||||
|
||||
declare (strict_types = 1);
|
||||
|
||||
namespace Phpml\Metric;
|
||||
|
||||
class ConfusionMatrix
|
||||
{
|
||||
/**
|
||||
* @param array $actualLabels
|
||||
* @param array $predictedLabels
|
||||
* @param array $labels
|
||||
*
|
||||
* @return array
|
||||
*/
|
||||
public static function compute(array $actualLabels, array $predictedLabels, array $labels = null): array
|
||||
{
|
||||
$labels = $labels ? array_flip($labels) : self::getUniqueLabels($actualLabels);
|
||||
$matrix = self::generateMatrixWithZeros($labels);
|
||||
|
||||
foreach ($actualLabels as $index => $actual) {
|
||||
$predicted = $predictedLabels[$index];
|
||||
|
||||
if (!isset($labels[$actual]) || !isset($labels[$predicted])) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if ($predicted === $actual) {
|
||||
$row = $column = $labels[$actual];
|
||||
} else {
|
||||
$row = $labels[$actual];
|
||||
$column = $labels[$predicted];
|
||||
}
|
||||
|
||||
$matrix[$row][$column] += 1;
|
||||
}
|
||||
|
||||
return $matrix;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param array $labels
|
||||
*
|
||||
* @return array
|
||||
*/
|
||||
private static function generateMatrixWithZeros(array $labels): array
|
||||
{
|
||||
$count = count($labels);
|
||||
$matrix = [];
|
||||
|
||||
for ($i = 0; $i < $count; ++$i) {
|
||||
$matrix[$i] = array_fill(0, $count, 0);
|
||||
}
|
||||
|
||||
return $matrix;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param array $labels
|
||||
*
|
||||
* @return array
|
||||
*/
|
||||
private static function getUniqueLabels(array $labels): array
|
||||
{
|
||||
$labels = array_values(array_unique($labels));
|
||||
sort($labels);
|
||||
$labels = array_flip($labels);
|
||||
|
||||
return $labels;
|
||||
}
|
||||
}
|
61
tests/Phpml/Metric/ConfusionMatrixTest.php
Normal file
61
tests/Phpml/Metric/ConfusionMatrixTest.php
Normal file
@ -0,0 +1,61 @@
|
||||
<?php
|
||||
|
||||
declare (strict_types = 1);
|
||||
|
||||
namespace tests\Phpml\Metric;
|
||||
|
||||
use Phpml\Metric\ConfusionMatrix;
|
||||
|
||||
class ConfusionMatrixTest extends \PHPUnit_Framework_TestCase
|
||||
{
|
||||
public function testComputeConfusionMatrixOnNumericLabels()
|
||||
{
|
||||
$actualLabels = [2, 0, 2, 2, 0, 1];
|
||||
$predictedLabels = [0, 0, 2, 2, 0, 2];
|
||||
|
||||
$confusionMatrix = [
|
||||
[2, 0, 0],
|
||||
[0, 0, 1],
|
||||
[1, 0, 2],
|
||||
];
|
||||
|
||||
$this->assertEquals($confusionMatrix, ConfusionMatrix::compute($actualLabels, $predictedLabels));
|
||||
}
|
||||
|
||||
public function testComputeConfusionMatrixOnStringLabels()
|
||||
{
|
||||
$actualLabels = ['cat', 'ant', 'cat', 'cat', 'ant', 'bird'];
|
||||
$predictedLabels = ['ant', 'ant', 'cat', 'cat', 'ant', 'cat'];
|
||||
|
||||
$confusionMatrix = [
|
||||
[2, 0, 0],
|
||||
[0, 0, 1],
|
||||
[1, 0, 2],
|
||||
];
|
||||
|
||||
$this->assertEquals($confusionMatrix, ConfusionMatrix::compute($actualLabels, $predictedLabels));
|
||||
}
|
||||
|
||||
public function testComputeConfusionMatrixOnLabelsWithSubset()
|
||||
{
|
||||
$actualLabels = ['cat', 'ant', 'cat', 'cat', 'ant', 'bird'];
|
||||
$predictedLabels = ['ant', 'ant', 'cat', 'cat', 'ant', 'cat'];
|
||||
$labels = ['ant', 'bird'];
|
||||
|
||||
$confusionMatrix = [
|
||||
[2, 0],
|
||||
[0, 0],
|
||||
];
|
||||
|
||||
$this->assertEquals($confusionMatrix, ConfusionMatrix::compute($actualLabels, $predictedLabels, $labels));
|
||||
|
||||
$labels = ['bird', 'ant'];
|
||||
|
||||
$confusionMatrix = [
|
||||
[0, 0],
|
||||
[0, 2],
|
||||
];
|
||||
|
||||
$this->assertEquals($confusionMatrix, ConfusionMatrix::compute($actualLabels, $predictedLabels, $labels));
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user