diff --git a/docs/machine-learning/clustering/k-means.md b/docs/machine-learning/clustering/k-means.md index 296feb1..661f717 100644 --- a/docs/machine-learning/clustering/k-means.md +++ b/docs/machine-learning/clustering/k-means.md @@ -19,10 +19,12 @@ To divide the samples into clusters simply use `cluster` method. It's return the ``` $samples = [[1, 1], [8, 7], [1, 2], [7, 8], [2, 1], [8, 9]]; +Or if you need to keep your indentifiers along with yours samples you can use array keys as labels. +$samples = [ 'Label1' => [1, 1], 'Label2' => [8, 7], 'Label3' => [1, 2]]; $kmeans = new KMeans(2); $kmeans->cluster($samples); -// return [0=>[[1, 1], ...], 1=>[[8, 7], ...]] +// return [0=>[[1, 1], ...], 1=>[[8, 7], ...]] or [0=>['Label1' => [1, 1], 'Label3' => [1, 2], ...], 1=>['Label2' => [8, 7], ...]] ``` ### Initialization methods diff --git a/src/Clustering/KMeans.php b/src/Clustering/KMeans.php index 86ad754..1aff1c4 100644 --- a/src/Clustering/KMeans.php +++ b/src/Clustering/KMeans.php @@ -35,9 +35,9 @@ class KMeans implements Clusterer public function cluster(array $samples): array { - $space = new Space(count($samples[0])); - foreach ($samples as $sample) { - $space->addPoint($sample); + $space = new Space(count(reset($samples))); + foreach ($samples as $key => $sample) { + $space->addPoint($sample, $key); } $clusters = []; diff --git a/src/Clustering/KMeans/Cluster.php b/src/Clustering/KMeans/Cluster.php index 8936926..731d79c 100644 --- a/src/Clustering/KMeans/Cluster.php +++ b/src/Clustering/KMeans/Cluster.php @@ -32,7 +32,11 @@ class Cluster extends Point implements IteratorAggregate, Countable { $points = []; foreach ($this->points as $point) { - $points[] = $point->toArray(); + if (!empty($point->label)) { + $points[$point->label] = $point->toArray(); + } else { + $points[] = $point->toArray(); + } } return $points; diff --git a/src/Clustering/KMeans/Point.php b/src/Clustering/KMeans/Point.php index 8c918a7..7d41093 100644 --- a/src/Clustering/KMeans/Point.php +++ b/src/Clustering/KMeans/Point.php @@ -18,10 +18,16 @@ class Point implements ArrayAccess */ protected $coordinates = []; - public function __construct(array $coordinates) + /** + * @var mixed + */ + protected $label; + + public function __construct(array $coordinates, $label = null) { $this->dimension = count($coordinates); $this->coordinates = $coordinates; + $this->label = $label; } public function toArray(): array diff --git a/src/Clustering/KMeans/Space.php b/src/Clustering/KMeans/Space.php index b85b329..8d80dc0 100644 --- a/src/Clustering/KMeans/Space.php +++ b/src/Clustering/KMeans/Space.php @@ -35,21 +35,21 @@ class Space extends SplObjectStorage return ['points' => $points]; } - public function newPoint(array $coordinates): Point + public function newPoint(array $coordinates, $label = null): Point { if (count($coordinates) != $this->dimension) { throw new LogicException('('.implode(',', $coordinates).') is not a point of this space'); } - return new Point($coordinates); + return new Point($coordinates, $label); } /** * @param null $data */ - public function addPoint(array $coordinates, $data = null): void + public function addPoint(array $coordinates, $label = null, $data = null): void { - $this->attach($this->newPoint($coordinates), $data); + $this->attach($this->newPoint($coordinates, $label), $data); } /** diff --git a/tests/Clustering/KMeansTest.php b/tests/Clustering/KMeansTest.php index 032c804..ba36bc6 100644 --- a/tests/Clustering/KMeansTest.php +++ b/tests/Clustering/KMeansTest.php @@ -28,6 +28,32 @@ class KMeansTest extends TestCase $this->assertCount(0, $samples); } + public function testKMeansSamplesLabeledClustering(): void + { + $samples = [ + '555' => [1, 1], + '666' => [8, 7], + 'ABC' => [1, 2], + 'DEF' => [7, 8], + 668 => [2, 1], + [8, 9], + ]; + + $kmeans = new KMeans(2); + $clusters = $kmeans->cluster($samples); + + $this->assertCount(2, $clusters); + + foreach ($samples as $index => $sample) { + if (in_array($sample, $clusters[0], true) || in_array($sample, $clusters[1], true)) { + $this->assertArrayHasKey($index, $clusters[0] + $clusters[1]); + unset($samples[$index]); + } + } + + $this->assertCount(0, $samples); + } + public function testKMeansInitializationMethods(): void { $samples = [