php-ml/src/Phpml/Classifier/KNearestNeighbors.php

86 lines
1.6 KiB
PHP
Raw Normal View History

2016-04-04 20:25:27 +00:00
<?php
2016-04-04 20:49:54 +00:00
declare (strict_types = 1);
2016-04-04 20:25:27 +00:00
namespace Phpml\Classifier;
use Phpml\Metric\Distance;
2016-04-04 20:25:27 +00:00
class KNearestNeighbors implements Classifier
{
/**
* @var int
*/
private $k;
/**
* @var array
*/
private $samples;
2016-04-04 20:25:27 +00:00
/**
* @var array
*/
private $labels;
/**
* @param int $k
*/
public function __construct(int $k = 3)
{
$this->k = $k;
$this->samples = [];
2016-04-04 20:25:27 +00:00
$this->labels = [];
}
/**
* @param array $samples
2016-04-04 20:25:27 +00:00
* @param array $labels
*/
public function train(array $samples, array $labels)
2016-04-04 20:25:27 +00:00
{
$this->samples = $samples;
2016-04-04 20:25:27 +00:00
$this->labels = $labels;
}
/**
* @param array $sample
2016-04-04 20:49:54 +00:00
*
2016-04-04 20:25:27 +00:00
* @return mixed
*/
public function predict(array $sample)
{
$distances = $this->kNeighborsDistances($sample);
$predictions = [];
foreach ($distances as $index => $distance) {
$predictions[$this->labels[$index]]++;
}
arsort($predictions);
return array_shift(array_keys($predictions));
}
/**
* @param array $sample
*
* @return array
*
* @throws \Phpml\Exception\InvalidArgumentException
*/
private function kNeighborsDistances(array $sample): array
2016-04-04 20:25:27 +00:00
{
$distances = [];
foreach($this->samples as $index => $neighbor) {
$distances[$index] = Distance::euclidean($sample, $neighbor);
if(count($distances)==$this->k) {
break;
}
}
asort($distances);
return $distances;
2016-04-04 20:25:27 +00:00
}
}