mirror of
https://github.com/Llewellynvdm/php-ml.git
synced 2024-09-22 12:19:02 +00:00
653c7c772d
* upgrade to PHP 7.1 * bump travis and composer to PHP 7.1 * fix tests
76 lines
1.7 KiB
PHP
76 lines
1.7 KiB
PHP
<?php
|
|
|
|
declare(strict_types=1);
|
|
|
|
namespace Phpml\Classification;
|
|
|
|
use Phpml\Helper\Predictable;
|
|
use Phpml\Helper\Trainable;
|
|
use Phpml\Math\Distance;
|
|
use Phpml\Math\Distance\Euclidean;
|
|
|
|
class KNearestNeighbors implements Classifier
|
|
{
|
|
use Trainable, Predictable;
|
|
|
|
/**
|
|
* @var int
|
|
*/
|
|
private $k;
|
|
|
|
/**
|
|
* @var Distance
|
|
*/
|
|
private $distanceMetric;
|
|
|
|
/**
|
|
* @param Distance|null $distanceMetric (if null then Euclidean distance as default)
|
|
*/
|
|
public function __construct(int $k = 3, ?Distance $distanceMetric = null)
|
|
{
|
|
if (null === $distanceMetric) {
|
|
$distanceMetric = new Euclidean();
|
|
}
|
|
|
|
$this->k = $k;
|
|
$this->samples = [];
|
|
$this->targets = [];
|
|
$this->distanceMetric = $distanceMetric;
|
|
}
|
|
|
|
/**
|
|
* @return mixed
|
|
*/
|
|
protected function predictSample(array $sample)
|
|
{
|
|
$distances = $this->kNeighborsDistances($sample);
|
|
|
|
$predictions = array_combine(array_values($this->targets), array_fill(0, count($this->targets), 0));
|
|
|
|
foreach ($distances as $index => $distance) {
|
|
++$predictions[$this->targets[$index]];
|
|
}
|
|
|
|
arsort($predictions);
|
|
reset($predictions);
|
|
|
|
return key($predictions);
|
|
}
|
|
|
|
/**
|
|
* @throws \Phpml\Exception\InvalidArgumentException
|
|
*/
|
|
private function kNeighborsDistances(array $sample) : array
|
|
{
|
|
$distances = [];
|
|
|
|
foreach ($this->samples as $index => $neighbor) {
|
|
$distances[$index] = $this->distanceMetric->distance($sample, $neighbor);
|
|
}
|
|
|
|
asort($distances);
|
|
|
|
return array_slice($distances, 0, $this->k, true);
|
|
}
|
|
}
|