implement k nearest neighbors classifier

This commit is contained in:
Arkadiusz Kondas 2016-04-05 21:06:53 +02:00
parent 4235f143bf
commit 469848ff49
3 changed files with 48 additions and 15 deletions

View File

@ -7,15 +7,15 @@ namespace Phpml\Classifier;
interface Classifier interface Classifier
{ {
/** /**
* @param array $features * @param array $samples
* @param array $labels * @param array $labels
*/ */
public function train(array $features, array $labels); public function train(array $samples, array $labels);
/** /**
* @param mixed $feature * @param array $sample
* *
* @return mixed * @return mixed
*/ */
public function predict($feature); public function predict(array $sample);
} }

View File

@ -4,6 +4,8 @@ declare (strict_types = 1);
namespace Phpml\Classifier; namespace Phpml\Classifier;
use Phpml\Metric\Distance;
class KNearestNeighbors implements Classifier class KNearestNeighbors implements Classifier
{ {
/** /**
@ -14,7 +16,7 @@ class KNearestNeighbors implements Classifier
/** /**
* @var array * @var array
*/ */
private $features; private $samples;
/** /**
* @var array * @var array
@ -27,26 +29,57 @@ class KNearestNeighbors implements Classifier
public function __construct(int $k = 3) public function __construct(int $k = 3)
{ {
$this->k = $k; $this->k = $k;
$this->features = []; $this->samples = [];
$this->labels = []; $this->labels = [];
} }
/** /**
* @param array $features * @param array $samples
* @param array $labels * @param array $labels
*/ */
public function train(array $features, array $labels) public function train(array $samples, array $labels)
{ {
$this->features = $features; $this->samples = $samples;
$this->labels = $labels; $this->labels = $labels;
} }
/** /**
* @param mixed $feature * @param array $sample
* *
* @return mixed * @return mixed
*/ */
public function predict($feature) public function predict(array $sample)
{ {
$distances = $this->kNeighborsDistances($sample);
$predictions = [];
foreach ($distances as $index => $distance) {
$predictions[$this->labels[$index]]++;
}
arsort($predictions);
return array_shift(array_keys($predictions));
}
/**
* @param array $sample
*
* @return array
*
* @throws \Phpml\Exception\InvalidArgumentException
*/
private function kNeighborsDistances(array $sample): array
{
$distances = [];
foreach($this->samples as $index => $neighbor) {
$distances[$index] = Distance::euclidean($sample, $neighbor);
if(count($distances)==$this->k) {
break;
}
}
asort($distances);
return $distances;
} }
} }

View File

@ -7,19 +7,19 @@ namespace Phpml\Classifier;
class NaiveBayes implements Classifier class NaiveBayes implements Classifier
{ {
/** /**
* @param array $features * @param array $samples
* @param array $labels * @param array $labels
*/ */
public function train(array $features, array $labels) public function train(array $samples, array $labels)
{ {
} }
/** /**
* @param mixed $feature * @param array $sample
* *
* @return mixed * @return mixed
*/ */
public function predict($feature) public function predict(array $sample)
{ {
} }
} }