implement k nearest neighbors classifier

This commit is contained in:
Arkadiusz Kondas 2016-04-05 21:06:53 +02:00
parent 4235f143bf
commit 469848ff49
3 changed files with 48 additions and 15 deletions

View File

@ -7,15 +7,15 @@ namespace Phpml\Classifier;
interface Classifier
{
/**
* @param array $features
* @param array $samples
* @param array $labels
*/
public function train(array $features, array $labels);
public function train(array $samples, array $labels);
/**
* @param mixed $feature
* @param array $sample
*
* @return mixed
*/
public function predict($feature);
public function predict(array $sample);
}

View File

@ -4,6 +4,8 @@ declare (strict_types = 1);
namespace Phpml\Classifier;
use Phpml\Metric\Distance;
class KNearestNeighbors implements Classifier
{
/**
@ -14,7 +16,7 @@ class KNearestNeighbors implements Classifier
/**
* @var array
*/
private $features;
private $samples;
/**
* @var array
@ -27,26 +29,57 @@ class KNearestNeighbors implements Classifier
public function __construct(int $k = 3)
{
$this->k = $k;
$this->features = [];
$this->samples = [];
$this->labels = [];
}
/**
* @param array $features
* @param array $samples
* @param array $labels
*/
public function train(array $features, array $labels)
public function train(array $samples, array $labels)
{
$this->features = $features;
$this->samples = $samples;
$this->labels = $labels;
}
/**
* @param mixed $feature
* @param array $sample
*
* @return mixed
*/
public function predict($feature)
public function predict(array $sample)
{
$distances = $this->kNeighborsDistances($sample);
$predictions = [];
foreach ($distances as $index => $distance) {
$predictions[$this->labels[$index]]++;
}
arsort($predictions);
return array_shift(array_keys($predictions));
}
/**
* @param array $sample
*
* @return array
*
* @throws \Phpml\Exception\InvalidArgumentException
*/
private function kNeighborsDistances(array $sample): array
{
$distances = [];
foreach($this->samples as $index => $neighbor) {
$distances[$index] = Distance::euclidean($sample, $neighbor);
if(count($distances)==$this->k) {
break;
}
}
asort($distances);
return $distances;
}
}

View File

@ -7,19 +7,19 @@ namespace Phpml\Classifier;
class NaiveBayes implements Classifier
{
/**
* @param array $features
* @param array $samples
* @param array $labels
*/
public function train(array $features, array $labels)
public function train(array $samples, array $labels)
{
}
/**
* @param mixed $feature
* @param array $sample
*
* @return mixed
*/
public function predict($feature)
public function predict(array $sample)
{
}
}