From c0513e9b8234a217d5b89ed6f57fcee488003cdd Mon Sep 17 00:00:00 2001 From: Arkadiusz Kondas Date: Sun, 1 May 2016 23:17:09 +0200 Subject: [PATCH] kmeans clustering --- src/Phpml/Clustering/KMeans.php | 51 +++++ src/Phpml/Clustering/KMeans/Cluster.php | 101 ++++++++ src/Phpml/Clustering/KMeans/Point.php | 95 ++++++++ src/Phpml/Clustering/KMeans/Space.php | 216 ++++++++++++++++++ .../Exception/InvalidArgumentException.php | 9 + tests/Phpml/Clustering/DBSCANTest.php | 7 +- tests/Phpml/Clustering/KMeansTest.php | 58 +++++ 7 files changed, 532 insertions(+), 5 deletions(-) create mode 100644 src/Phpml/Clustering/KMeans.php create mode 100644 src/Phpml/Clustering/KMeans/Cluster.php create mode 100644 src/Phpml/Clustering/KMeans/Point.php create mode 100644 src/Phpml/Clustering/KMeans/Space.php create mode 100644 tests/Phpml/Clustering/KMeansTest.php diff --git a/src/Phpml/Clustering/KMeans.php b/src/Phpml/Clustering/KMeans.php new file mode 100644 index 0000000..7bebe7e --- /dev/null +++ b/src/Phpml/Clustering/KMeans.php @@ -0,0 +1,51 @@ +clustersNumber = $clustersNumber; + } + + /** + * @param array $samples + * + * @return array + */ + public function cluster(array $samples) + { + $space = new Space(count($samples[0])); + foreach ($samples as $sample) { + $space->addPoint($sample); + } + + $clusters = []; + foreach ($space->solve($this->clustersNumber) as $cluster) + { + $clusters[] = $cluster->getPoints(); + } + + return $clusters; + } + +} diff --git a/src/Phpml/Clustering/KMeans/Cluster.php b/src/Phpml/Clustering/KMeans/Cluster.php new file mode 100644 index 0000000..fec6d07 --- /dev/null +++ b/src/Phpml/Clustering/KMeans/Cluster.php @@ -0,0 +1,101 @@ +points = new SplObjectStorage; + } + + /** + * @return array + */ + public function getPoints() + { + $points = []; + foreach ($this->points as $point) { + $points[] = $point->toArray(); + } + + return $points; + } + + public function toArray() + { + $points = array(); + foreach ($this->points as $point) + $points[] = $point->toArray(); + + return array( + 'centroid' => parent::toArray(), + 'points' => $points, + ); + } + + public function attach(Point $point) + { + if ($point instanceof self) + throw new LogicException("cannot attach a cluster to another"); + + $this->points->attach($point); + return $point; + } + + public function detach(Point $point) + { + $this->points->detach($point); + return $point; + } + + public function attachAll(SplObjectStorage $points) + { + $this->points->addAll($points); + } + + public function detachAll(SplObjectStorage $points) + { + $this->points->removeAll($points); + } + + public function updateCentroid() + { + if (!$count = count($this->points)) + return; + + $centroid = $this->space->newPoint(array_fill(0, $this->dimention, 0)); + + foreach ($this->points as $point) + for ($n=0; $n<$this->dimention; $n++) + $centroid->coordinates[$n] += $point->coordinates[$n]; + + for ($n=0; $n<$this->dimention; $n++) + $this->coordinates[$n] = $centroid->coordinates[$n] / $count; + } + + public function getIterator() + { + return $this->points; + } + + public function count() + { + return count($this->points); + } +} diff --git a/src/Phpml/Clustering/KMeans/Point.php b/src/Phpml/Clustering/KMeans/Point.php new file mode 100644 index 0000000..4d888c3 --- /dev/null +++ b/src/Phpml/Clustering/KMeans/Point.php @@ -0,0 +1,95 @@ +space = $space; + $this->dimention = $space->getDimention(); + $this->coordinates = $coordinates; + } + + public function toArray() + { + return $this->coordinates; + } + + public function getDistanceWith(self $point, $precise = true) + { + if ($point->space !== $this->space) + throw new LogicException("can only calculate distances from points in the same space"); + + $distance = 0; + for ($n=0; $n<$this->dimention; $n++) { + $difference = $this->coordinates[$n] - $point->coordinates[$n]; + $distance += $difference * $difference; + } + + return $precise ? sqrt($distance) : $distance; + } + + public function getClosest($points) + { + foreach($points as $point) { + $distance = $this->getDistanceWith($point, false); + + if (!isset($minDistance)) { + $minDistance = $distance; + $minPoint = $point; + continue; + } + + if ($distance < $minDistance) { + $minDistance = $distance; + $minPoint = $point; + } + } + + return $minPoint; + } + + public function belongsTo(Space $space) + { + return $this->space === $space; + } + + public function getSpace() + { + return $this->space; + } + + public function getCoordinates() + { + return $this->coordinates; + } + + public function offsetExists($offset) + { + return isset($this->coordinates[$offset]); + } + + public function offsetGet($offset) + { + return $this->coordinates[$offset]; + } + + public function offsetSet($offset, $value) + { + $this->coordinates[$offset] = $value; + } + + public function offsetUnset($offset) + { + unset($this->coordinates[$offset]); + } +} diff --git a/src/Phpml/Clustering/KMeans/Space.php b/src/Phpml/Clustering/KMeans/Space.php new file mode 100644 index 0000000..090a48b --- /dev/null +++ b/src/Phpml/Clustering/KMeans/Space.php @@ -0,0 +1,216 @@ +dimention = $dimention; + } + + public function toArray() + { + $points = array(); + foreach ($this as $point) + $points[] = $point->toArray(); + + return array('points' => $points); + } + + public function newPoint(array $coordinates) + { + if (count($coordinates) != $this->dimention) + throw new LogicException("(" . implode(',', $coordinates) . ") is not a point of this space"); + + return new Point($this, $coordinates); + } + + public function addPoint(array $coordinates, $data = null) + { + return $this->attach($this->newPoint($coordinates), $data); + } + + public function attach($point, $data = null) + { + if (!$point instanceof Point) + throw new InvalidArgumentException("can only attach points to spaces"); + + return parent::attach($point, $data); + } + + public function getDimention() + { + return $this->dimention; + } + + public function getBoundaries() + { + if (!count($this)) + return false; + + $min = $this->newPoint(array_fill(0, $this->dimention, null)); + $max = $this->newPoint(array_fill(0, $this->dimention, null)); + + foreach ($this as $point) { + for ($n=0; $n < $this->dimention; $n++) { + ($min[$n] > $point[$n] || $min[$n] === null) && $min[$n] = $point[$n]; + ($max[$n] < $point[$n] || $max[$n] === null) && $max[$n] = $point[$n]; + } + } + + return array($min, $max); + } + + public function getRandomPoint(Point $min, Point $max) + { + $point = $this->newPoint(array_fill(0, $this->dimention, null)); + + for ($n=0; $n < $this->dimention; $n++) + $point[$n] = rand($min[$n], $max[$n]); + + return $point; + } + + /** + * @param $nbClusters + * @param int $seed + * @param null $iterationCallback + * @return array|Cluster[] + */ + public function solve($nbClusters, $seed = self::SEED_DEFAULT, $iterationCallback = null) + { + if ($iterationCallback && !is_callable($iterationCallback)) + throw new InvalidArgumentException("invalid iteration callback"); + + // initialize K clusters + $clusters = $this->initializeClusters($nbClusters, $seed); + + // there's only one cluster, clusterization has no meaning + if (count($clusters) == 1) + return $clusters[0]; + + // until convergence is reached + do { + $iterationCallback && $iterationCallback($this, $clusters); + } while ($this->iterate($clusters)); + + // clustering is done. + return $clusters; + } + + protected function initializeClusters($nbClusters, $seed) + { + if ($nbClusters <= 0) + throw new InvalidArgumentException("invalid clusters number"); + + switch ($seed) { + // the default seeding method chooses completely random centroid + case self::SEED_DEFAULT: + // get the space boundaries to avoid placing clusters centroid too far from points + list($min, $max) = $this->getBoundaries(); + + // initialize N clusters with a random point within space boundaries + for ($n=0; $n<$nbClusters; $n++) + $clusters[] = new Cluster($this, $this->getRandomPoint($min, $max)->getCoordinates()); + + break; + + // the DASV seeding method consists of finding good initial centroids for the clusters + case self::SEED_DASV: + // find a random point + $position = rand(1, count($this)); + for ($i=1, $this->rewind(); $i<$position && $this->valid(); $i++, $this->next()); + $clusters[] = new Cluster($this, $this->current()->getCoordinates()); + + // retains the distances between points and their closest clusters + $distances = new SplObjectStorage; + + // create k clusters + for ($i=1; $i<$nbClusters; $i++) { + $sum = 0; + + // for each points, get the distance with the closest centroid already choosen + foreach ($this as $point) { + $distance = $point->getDistanceWith($point->getClosest($clusters)); + $sum += $distances[$point] = $distance; + } + + // choose a new random point using a weighted probability distribution + $sum = rand(0, $sum); + foreach ($this as $point) { + if (($sum -= $distances[$point]) > 0) + continue; + + $clusters[] = new Cluster($this, $point->getCoordinates()); + break; + } + } + + break; + } + + // assing all points to the first cluster + $clusters[0]->attachAll($this); + + return $clusters; + } + + protected function iterate($clusters) + { + $continue = false; + + // migration storages + $attach = new SplObjectStorage; + $detach = new SplObjectStorage; + + // calculate proximity amongst points and clusters + foreach ($clusters as $cluster) { + foreach ($cluster as $point) { + // find the closest cluster + $closest = $point->getClosest($clusters); + + // move the point from its old cluster to its closest + if ($closest !== $cluster) { + isset($attach[$closest]) || $attach[$closest] = new SplObjectStorage; + isset($detach[$cluster]) || $detach[$cluster] = new SplObjectStorage; + + $attach[$closest]->attach($point); + $detach[$cluster]->attach($point); + + $continue = true; + } + } + } + + // perform points migrations + foreach ($attach as $cluster) + $cluster->attachAll($attach[$cluster]); + + foreach ($detach as $cluster) + $cluster->detachAll($detach[$cluster]); + + // update all cluster's centroids + foreach ($clusters as $cluster) + $cluster->updateCentroid(); + + return $continue; + } +} diff --git a/src/Phpml/Exception/InvalidArgumentException.php b/src/Phpml/Exception/InvalidArgumentException.php index 9e88250..3185205 100644 --- a/src/Phpml/Exception/InvalidArgumentException.php +++ b/src/Phpml/Exception/InvalidArgumentException.php @@ -57,4 +57,13 @@ class InvalidArgumentException extends \Exception { return new self('Inconsistent matrix aupplied'); } + + /** + * @return InvalidArgumentException + */ + public static function invalidClustersNumber() + { + return new self('Invalid clusters number'); + } + } diff --git a/tests/Phpml/Clustering/DBSCANTest.php b/tests/Phpml/Clustering/DBSCANTest.php index 5952636..7be5331 100644 --- a/tests/Phpml/Clustering/DBSCANTest.php +++ b/tests/Phpml/Clustering/DBSCANTest.php @@ -11,7 +11,6 @@ class DBSCANTest extends \PHPUnit_Framework_TestCase public function testDBSCANSamplesClustering() { $samples = [[1, 1], [8, 7], [1, 2], [7, 8], [2, 1], [8, 9]]; - $clustered = [ [[1, 1], [1, 2], [2, 1]], [[8, 7], [7, 8], [8, 9]], @@ -20,12 +19,9 @@ class DBSCANTest extends \PHPUnit_Framework_TestCase $dbscan = new DBSCAN($epsilon = 2, $minSamples = 3); $this->assertEquals($clustered, $dbscan->cluster($samples)); - } - public function testDBSCANSamplesInCircleClustering() - { + $samples = [[1, 1], [6, 6], [1, -1], [5, 6], [-1, -1], [7, 8], [-1, 1], [7, 7]]; - $clustered = [ [[1, 1], [1, -1], [-1, -1], [-1, 1]], [[6, 6], [5, 6], [7, 8], [7, 7]], @@ -35,4 +31,5 @@ class DBSCANTest extends \PHPUnit_Framework_TestCase $this->assertEquals($clustered, $dbscan->cluster($samples)); } + } diff --git a/tests/Phpml/Clustering/KMeansTest.php b/tests/Phpml/Clustering/KMeansTest.php new file mode 100644 index 0000000..5c21c89 --- /dev/null +++ b/tests/Phpml/Clustering/KMeansTest.php @@ -0,0 +1,58 @@ +cluster($samples); + + $this->assertEquals(2, count($clusters)); + + foreach ($samples as $index => $sample) { + if(in_array($sample, $clusters[0]) || in_array($sample, $clusters[1])) { + unset($samples[$index]); + } + } + $this->assertEquals(0, count($samples)); + } + + public function testKMeansMoreSamplesClustering() + { + $samples = [ + [80,55],[86,59],[19,85],[41,47],[57,58], + [76,22],[94,60],[13,93],[90,48],[52,54], + [62,46],[88,44],[85,24],[63,14],[51,40], + [75,31],[86,62],[81,95],[47,22],[43,95], + [71,19],[17,65],[69,21],[59,60],[59,12], + [15,22],[49,93],[56,35],[18,20],[39,59], + [50,15],[81,36],[67,62],[32,15],[75,65], + [10,47],[75,18],[13,45],[30,62],[95,79], + [64,11],[92,14],[94,49],[39,13],[60,68], + [62,10],[74,44],[37,42],[97,60],[47,73], + ]; + + $kmeans = new KMeans(4); + $clusters = $kmeans->cluster($samples); + + $this->assertEquals(4, count($clusters)); + + foreach ($samples as $index => $sample) { + for($i=0; $i<4; $i++) { + if(in_array($sample, $clusters[$i])) { + unset($samples[$index]); + } + } + } + $this->assertEquals(0, count($samples)); + } + +}