mirror of
https://github.com/Llewellynvdm/php-ml.git
synced 2025-02-10 16:08:46 +00:00
accuracy score with test
This commit is contained in:
parent
0b6dc42807
commit
f1c81638d6
39
src/Phpml/Metric/Accuracy.php
Normal file
39
src/Phpml/Metric/Accuracy.php
Normal file
@ -0,0 +1,39 @@
|
||||
<?php
|
||||
declare(strict_types = 1);
|
||||
|
||||
namespace Phpml\Metric;
|
||||
|
||||
use Phpml\Exception\InvalidArgumentException;
|
||||
|
||||
class Accuracy
|
||||
{
|
||||
|
||||
/**
|
||||
* @param array $actualLabels
|
||||
* @param array $predictedLabels
|
||||
* @param bool $normalize
|
||||
*
|
||||
* @return float|int
|
||||
*
|
||||
* @throws InvalidArgumentException
|
||||
*/
|
||||
public static function score(array $actualLabels, array $predictedLabels, bool $normalize = true)
|
||||
{
|
||||
if (count($actualLabels) != count($predictedLabels)) {
|
||||
throw InvalidArgumentException::sizeNotMatch();
|
||||
}
|
||||
|
||||
$score = 0;
|
||||
foreach ($actualLabels as $index => $label) {
|
||||
if($label===$predictedLabels[$index]) {
|
||||
$score++;
|
||||
}
|
||||
}
|
||||
|
||||
if($normalize) {
|
||||
$score = $score / count($actualLabels);
|
||||
}
|
||||
|
||||
return $score;
|
||||
}
|
||||
}
|
38
tests/Phpml/Metric/AccuracyTest.php
Normal file
38
tests/Phpml/Metric/AccuracyTest.php
Normal file
@ -0,0 +1,38 @@
|
||||
<?php
|
||||
declare(strict_types = 1);
|
||||
|
||||
namespace tests\Phpml\Metric;
|
||||
|
||||
use Phpml\Metric\Accuracy;
|
||||
|
||||
class AccuracyTest extends \PHPUnit_Framework_TestCase
|
||||
{
|
||||
|
||||
/**
|
||||
* @expectedException \Phpml\Exception\InvalidArgumentException
|
||||
*/
|
||||
public function testThrowExceptionOnInvalidArguments()
|
||||
{
|
||||
$actualLabels = ['a', 'b', 'a', 'b'];
|
||||
$predictedLabels = ['a', 'a'];
|
||||
|
||||
Accuracy::score($actualLabels, $predictedLabels);
|
||||
}
|
||||
|
||||
public function testCalculateNormalizedScore()
|
||||
{
|
||||
$actualLabels = ['a', 'b', 'a', 'b'];
|
||||
$predictedLabels = ['a', 'a', 'b', 'b'];
|
||||
|
||||
$this->assertEquals(0.5, Accuracy::score($actualLabels, $predictedLabels));
|
||||
}
|
||||
|
||||
public function testCalculateNotNormalizedScore()
|
||||
{
|
||||
$actualLabels = ['a', 'b', 'a', 'b'];
|
||||
$predictedLabels = ['a', 'b', 'b', 'b'];
|
||||
|
||||
$this->assertEquals(3, Accuracy::score($actualLabels, $predictedLabels, false));
|
||||
}
|
||||
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user