diff --git a/src/Phpml/Classifier/KNearestNeighbors.php b/src/Phpml/Classifier/KNearestNeighbors.php index 53c66b0..f913488 100644 --- a/src/Phpml/Classifier/KNearestNeighbors.php +++ b/src/Phpml/Classifier/KNearestNeighbors.php @@ -50,7 +50,7 @@ class KNearestNeighbors implements Classifier */ public function predict(array $samples) { - if(!is_array($samples[0])) { + if (!is_array($samples[0])) { $predicted = $this->predictSample($samples); } else { $predicted = []; diff --git a/src/Phpml/Metric/Accuracy.php b/src/Phpml/Metric/Accuracy.php index 878cadd..d871e85 100644 --- a/src/Phpml/Metric/Accuracy.php +++ b/src/Phpml/Metric/Accuracy.php @@ -1,5 +1,6 @@ $label) { - if($label===$predictedLabels[$index]) { - $score++; + if ($label === $predictedLabels[$index]) { + ++$score; } } - if($normalize) { + if ($normalize) { $score = $score / count($actualLabels); } diff --git a/tests/Phpml/Classifier/KNearestNeighborsTest.php b/tests/Phpml/Classifier/KNearestNeighborsTest.php index 06ae42f..1050607 100644 --- a/tests/Phpml/Classifier/KNearestNeighborsTest.php +++ b/tests/Phpml/Classifier/KNearestNeighborsTest.php @@ -5,10 +5,13 @@ declare (strict_types = 1); namespace tests\Classifier; use Phpml\Classifier\KNearestNeighbors; +use Phpml\CrossValidation\RandomSplit; +use Phpml\Dataset\Demo\Iris; +use Phpml\Metric\Accuracy; class KNearestNeighborsTest extends \PHPUnit_Framework_TestCase { - public function testPredictSimpleSampleWithDefaultK() + public function testPredictSingleSampleWithDefaultK() { $samples = [[1, 3], [1, 4], [2, 4], [3, 1], [4, 1], [4, 2]]; $labels = ['a', 'a', 'a', 'b', 'b', 'b']; @@ -33,7 +36,7 @@ class KNearestNeighborsTest extends \PHPUnit_Framework_TestCase $trainLabels = ['a', 'a', 'a', 'b', 'b', 'b']; $testSamples = [[3, 2], [5, 1], [4, 3], [4, -5], [2, 3], [1, 2], [1, 5], [3, 10]]; - $testLabels = ['b', 'b', 'b', 'b', 'a', 'a', 'a', 'a',]; + $testLabels = ['b', 'b', 'b', 'b', 'a', 'a', 'a', 'a']; $classifier = new KNearestNeighbors(); $classifier->train($trainSamples, $trainLabels); @@ -41,4 +44,15 @@ class KNearestNeighborsTest extends \PHPUnit_Framework_TestCase $this->assertEquals($testLabels, $predicted); } + + public function testAccuracyOnIrisDataset() + { + $dataset = new RandomSplit(new Iris(), $testSize = 0.5, $seed = 123); + $classifier = new KNearestNeighbors($k = 4); + $classifier->train($dataset->getTrainSamples(), $dataset->getTrainLabels()); + $predicted = $classifier->predict($dataset->getTestSamples()); + $score = Accuracy::score($dataset->getTestLabels(), $predicted); + + $this->assertEquals(0.96, $score); + } } diff --git a/tests/Phpml/Metric/AccuracyTest.php b/tests/Phpml/Metric/AccuracyTest.php index 31bb0fc..aa68b22 100644 --- a/tests/Phpml/Metric/AccuracyTest.php +++ b/tests/Phpml/Metric/AccuracyTest.php @@ -1,5 +1,6 @@ assertEquals(3, Accuracy::score($actualLabels, $predictedLabels, false)); } - }