mirror of
https://github.com/Llewellynvdm/php-ml.git
synced 2024-11-22 13:05:10 +00:00
implement k nearest neighbors classifier
This commit is contained in:
parent
4235f143bf
commit
469848ff49
@ -7,15 +7,15 @@ namespace Phpml\Classifier;
|
||||
interface Classifier
|
||||
{
|
||||
/**
|
||||
* @param array $features
|
||||
* @param array $samples
|
||||
* @param array $labels
|
||||
*/
|
||||
public function train(array $features, array $labels);
|
||||
public function train(array $samples, array $labels);
|
||||
|
||||
/**
|
||||
* @param mixed $feature
|
||||
* @param array $sample
|
||||
*
|
||||
* @return mixed
|
||||
*/
|
||||
public function predict($feature);
|
||||
public function predict(array $sample);
|
||||
}
|
||||
|
@ -4,6 +4,8 @@ declare (strict_types = 1);
|
||||
|
||||
namespace Phpml\Classifier;
|
||||
|
||||
use Phpml\Metric\Distance;
|
||||
|
||||
class KNearestNeighbors implements Classifier
|
||||
{
|
||||
/**
|
||||
@ -14,7 +16,7 @@ class KNearestNeighbors implements Classifier
|
||||
/**
|
||||
* @var array
|
||||
*/
|
||||
private $features;
|
||||
private $samples;
|
||||
|
||||
/**
|
||||
* @var array
|
||||
@ -27,26 +29,57 @@ class KNearestNeighbors implements Classifier
|
||||
public function __construct(int $k = 3)
|
||||
{
|
||||
$this->k = $k;
|
||||
$this->features = [];
|
||||
$this->samples = [];
|
||||
$this->labels = [];
|
||||
}
|
||||
|
||||
/**
|
||||
* @param array $features
|
||||
* @param array $samples
|
||||
* @param array $labels
|
||||
*/
|
||||
public function train(array $features, array $labels)
|
||||
public function train(array $samples, array $labels)
|
||||
{
|
||||
$this->features = $features;
|
||||
$this->samples = $samples;
|
||||
$this->labels = $labels;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param mixed $feature
|
||||
* @param array $sample
|
||||
*
|
||||
* @return mixed
|
||||
*/
|
||||
public function predict($feature)
|
||||
public function predict(array $sample)
|
||||
{
|
||||
$distances = $this->kNeighborsDistances($sample);
|
||||
|
||||
$predictions = [];
|
||||
foreach ($distances as $index => $distance) {
|
||||
$predictions[$this->labels[$index]]++;
|
||||
}
|
||||
|
||||
arsort($predictions);
|
||||
|
||||
return array_shift(array_keys($predictions));
|
||||
}
|
||||
|
||||
/**
|
||||
* @param array $sample
|
||||
*
|
||||
* @return array
|
||||
*
|
||||
* @throws \Phpml\Exception\InvalidArgumentException
|
||||
*/
|
||||
private function kNeighborsDistances(array $sample): array
|
||||
{
|
||||
$distances = [];
|
||||
foreach($this->samples as $index => $neighbor) {
|
||||
$distances[$index] = Distance::euclidean($sample, $neighbor);
|
||||
if(count($distances)==$this->k) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
asort($distances);
|
||||
|
||||
return $distances;
|
||||
}
|
||||
}
|
||||
|
@ -7,19 +7,19 @@ namespace Phpml\Classifier;
|
||||
class NaiveBayes implements Classifier
|
||||
{
|
||||
/**
|
||||
* @param array $features
|
||||
* @param array $samples
|
||||
* @param array $labels
|
||||
*/
|
||||
public function train(array $features, array $labels)
|
||||
public function train(array $samples, array $labels)
|
||||
{
|
||||
}
|
||||
|
||||
/**
|
||||
* @param mixed $feature
|
||||
* @param array $sample
|
||||
*
|
||||
* @return mixed
|
||||
*/
|
||||
public function predict($feature)
|
||||
public function predict(array $sample)
|
||||
{
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user