php-ml/src/Phpml/Classification/KNearestNeighbors.php

83 lines
1.8 KiB
PHP
Raw Normal View History

2016-04-04 20:25:27 +00:00
<?php
2016-04-04 20:49:54 +00:00
declare (strict_types = 1);
2016-04-04 20:25:27 +00:00
namespace Phpml\Classification;
2016-04-04 20:25:27 +00:00
use Phpml\Classification\Traits\Predictable;
use Phpml\Classification\Traits\Trainable;
2016-04-20 21:56:33 +00:00
use Phpml\Math\Distance;
use Phpml\Math\Distance\Euclidean;
2016-04-04 20:25:27 +00:00
class KNearestNeighbors implements Classifier
{
2016-04-16 19:24:40 +00:00
use Trainable, Predictable;
2016-04-04 20:25:27 +00:00
/**
* @var int
*/
private $k;
/**
2016-04-16 19:24:40 +00:00
* @var Distance
2016-04-04 20:25:27 +00:00
*/
2016-04-16 19:24:40 +00:00
private $distanceMetric;
2016-04-04 20:25:27 +00:00
/**
* @param int $k
* @param Distance|null $distanceMetric (if null then Euclidean distance as default)
2016-04-04 20:25:27 +00:00
*/
public function __construct(int $k = 3, Distance $distanceMetric = null)
2016-04-04 20:25:27 +00:00
{
if (null === $distanceMetric) {
$distanceMetric = new Euclidean();
}
2016-04-04 20:25:27 +00:00
$this->k = $k;
$this->samples = [];
2016-04-04 20:25:27 +00:00
$this->labels = [];
$this->distanceMetric = $distanceMetric;
2016-04-04 20:25:27 +00:00
}
/**
* @param array $sample
2016-04-04 20:49:54 +00:00
*
2016-04-04 20:25:27 +00:00
* @return mixed
*/
protected function predictSample(array $sample)
{
$distances = $this->kNeighborsDistances($sample);
2016-04-05 19:35:06 +00:00
$predictions = array_combine(array_values($this->labels), array_fill(0, count($this->labels), 0));
foreach ($distances as $index => $distance) {
2016-04-05 19:35:06 +00:00
++$predictions[$this->labels[$index]];
}
arsort($predictions);
2016-04-05 19:35:06 +00:00
reset($predictions);
2016-04-05 19:35:06 +00:00
return key($predictions);
}
/**
* @param array $sample
*
* @return array
*
* @throws \Phpml\Exception\InvalidArgumentException
*/
private function kNeighborsDistances(array $sample): array
2016-04-04 20:25:27 +00:00
{
$distances = [];
2016-04-05 19:35:06 +00:00
foreach ($this->samples as $index => $neighbor) {
$distances[$index] = $this->distanceMetric->distance($sample, $neighbor);
}
2016-04-05 19:35:06 +00:00
asort($distances);
2016-04-05 19:35:06 +00:00
return array_slice($distances, 0, $this->k, true);
2016-04-04 20:25:27 +00:00
}
}