diff --git a/src/Phpml/Classifier/Classifier.php b/src/Phpml/Classifier/Classifier.php index face7b2..6fad67a 100644 --- a/src/Phpml/Classifier/Classifier.php +++ b/src/Phpml/Classifier/Classifier.php @@ -7,15 +7,15 @@ namespace Phpml\Classifier; interface Classifier { /** - * @param array $features + * @param array $samples * @param array $labels */ - public function train(array $features, array $labels); + public function train(array $samples, array $labels); /** - * @param mixed $feature + * @param array $sample * * @return mixed */ - public function predict($feature); + public function predict(array $sample); } diff --git a/src/Phpml/Classifier/KNearestNeighbors.php b/src/Phpml/Classifier/KNearestNeighbors.php index e028d61..f454f9d 100644 --- a/src/Phpml/Classifier/KNearestNeighbors.php +++ b/src/Phpml/Classifier/KNearestNeighbors.php @@ -4,6 +4,8 @@ declare (strict_types = 1); namespace Phpml\Classifier; +use Phpml\Metric\Distance; + class KNearestNeighbors implements Classifier { /** @@ -14,7 +16,7 @@ class KNearestNeighbors implements Classifier /** * @var array */ - private $features; + private $samples; /** * @var array @@ -27,26 +29,57 @@ class KNearestNeighbors implements Classifier public function __construct(int $k = 3) { $this->k = $k; - $this->features = []; + $this->samples = []; $this->labels = []; } /** - * @param array $features + * @param array $samples * @param array $labels */ - public function train(array $features, array $labels) + public function train(array $samples, array $labels) { - $this->features = $features; + $this->samples = $samples; $this->labels = $labels; } /** - * @param mixed $feature + * @param array $sample * * @return mixed */ - public function predict($feature) + public function predict(array $sample) { + $distances = $this->kNeighborsDistances($sample); + + $predictions = []; + foreach ($distances as $index => $distance) { + $predictions[$this->labels[$index]]++; + } + + arsort($predictions); + + return array_shift(array_keys($predictions)); + } + + /** + * @param array $sample + * + * @return array + * + * @throws \Phpml\Exception\InvalidArgumentException + */ + private function kNeighborsDistances(array $sample): array + { + $distances = []; + foreach($this->samples as $index => $neighbor) { + $distances[$index] = Distance::euclidean($sample, $neighbor); + if(count($distances)==$this->k) { + break; + } + } + asort($distances); + + return $distances; } } diff --git a/src/Phpml/Classifier/NaiveBayes.php b/src/Phpml/Classifier/NaiveBayes.php index ed6bb8c..c1cc902 100644 --- a/src/Phpml/Classifier/NaiveBayes.php +++ b/src/Phpml/Classifier/NaiveBayes.php @@ -7,19 +7,19 @@ namespace Phpml\Classifier; class NaiveBayes implements Classifier { /** - * @param array $features + * @param array $samples * @param array $labels */ - public function train(array $features, array $labels) + public function train(array $samples, array $labels) { } /** - * @param mixed $feature + * @param array $sample * * @return mixed */ - public function predict($feature) + public function predict(array $sample) { } }