From 7572304d5039585538bc63804e7c0a8dda2a6e75 Mon Sep 17 00:00:00 2001 From: Arkadiusz Kondas Date: Sun, 1 May 2016 23:36:33 +0200 Subject: [PATCH] refactor kmeans subclasses --- src/Phpml/Clustering/KMeans.php | 27 +- src/Phpml/Clustering/KMeans/Cluster.php | 188 +++++---- src/Phpml/Clustering/KMeans/Point.php | 175 ++++---- src/Phpml/Clustering/KMeans/Space.php | 381 ++++++++++-------- .../Exception/InvalidArgumentException.php | 1 - tests/Phpml/Clustering/DBSCANTest.php | 2 - tests/Phpml/Clustering/KMeansTest.php | 31 +- 7 files changed, 465 insertions(+), 340 deletions(-) diff --git a/src/Phpml/Clustering/KMeans.php b/src/Phpml/Clustering/KMeans.php index 7bebe7e..cdae3b5 100644 --- a/src/Phpml/Clustering/KMeans.php +++ b/src/Phpml/Clustering/KMeans.php @@ -1,5 +1,6 @@ clustersNumber = $clustersNumber; + $this->initialization = $initialization; } /** @@ -38,14 +49,12 @@ class KMeans implements Clusterer foreach ($samples as $sample) { $space->addPoint($sample); } - + $clusters = []; - foreach ($space->solve($this->clustersNumber) as $cluster) - { + foreach ($space->solve($this->clustersNumber, $this->initialization) as $cluster) { $clusters[] = $cluster->getPoints(); } - + return $clusters; } - } diff --git a/src/Phpml/Clustering/KMeans/Cluster.php b/src/Phpml/Clustering/KMeans/Cluster.php index fec6d07..5cd974d 100644 --- a/src/Phpml/Clustering/KMeans/Cluster.php +++ b/src/Phpml/Clustering/KMeans/Cluster.php @@ -1,101 +1,137 @@ points = new SplObjectStorage; - } + /** + * @param Space $space + * @param array $coordinates + */ + public function __construct(Space $space, array $coordinates) + { + parent::__construct($coordinates); + $this->space = $space; + $this->points = new SplObjectStorage(); + } - /** - * @return array - */ - public function getPoints() - { - $points = []; - foreach ($this->points as $point) { - $points[] = $point->toArray(); - } + /** + * @return array + */ + public function getPoints() + { + $points = []; + foreach ($this->points as $point) { + $points[] = $point->toArray(); + } - return $points; - } - - public function toArray() - { - $points = array(); - foreach ($this->points as $point) - $points[] = $point->toArray(); + return $points; + } - return array( - 'centroid' => parent::toArray(), - 'points' => $points, - ); - } + /** + * @return array + */ + public function toArray() + { + return array( + 'centroid' => parent::toArray(), + 'points' => $this->getPoints(), + ); + } - public function attach(Point $point) - { - if ($point instanceof self) - throw new LogicException("cannot attach a cluster to another"); + /** + * @param Point $point + * + * @return Point + */ + public function attach(Point $point) + { + if ($point instanceof self) { + throw new LogicException('cannot attach a cluster to another'); + } - $this->points->attach($point); - return $point; - } + $this->points->attach($point); - public function detach(Point $point) - { - $this->points->detach($point); - return $point; - } + return $point; + } - public function attachAll(SplObjectStorage $points) - { - $this->points->addAll($points); - } + /** + * @param Point $point + * + * @return Point + */ + public function detach(Point $point) + { + $this->points->detach($point); - public function detachAll(SplObjectStorage $points) - { - $this->points->removeAll($points); - } + return $point; + } - public function updateCentroid() - { - if (!$count = count($this->points)) - return; + /** + * @param SplObjectStorage $points + */ + public function attachAll(SplObjectStorage $points) + { + $this->points->addAll($points); + } - $centroid = $this->space->newPoint(array_fill(0, $this->dimention, 0)); + /** + * @param SplObjectStorage $points + */ + public function detachAll(SplObjectStorage $points) + { + $this->points->removeAll($points); + } - foreach ($this->points as $point) - for ($n=0; $n<$this->dimention; $n++) - $centroid->coordinates[$n] += $point->coordinates[$n]; + public function updateCentroid() + { + if (!$count = count($this->points)) { + return; + } - for ($n=0; $n<$this->dimention; $n++) - $this->coordinates[$n] = $centroid->coordinates[$n] / $count; - } + $centroid = $this->space->newPoint(array_fill(0, $this->dimension, 0)); - public function getIterator() - { - return $this->points; - } + foreach ($this->points as $point) { + for ($n = 0; $n < $this->dimension; ++$n) { + $centroid->coordinates[$n] += $point->coordinates[$n]; + } + } - public function count() - { - return count($this->points); - } + for ($n = 0; $n < $this->dimension; ++$n) { + $this->coordinates[$n] = $centroid->coordinates[$n] / $count; + } + } + + /** + * @return Point[]|SplObjectStorage + */ + public function getIterator() + { + return $this->points; + } + + /** + * @return mixed + */ + public function count() + { + return count($this->points); + } } diff --git a/src/Phpml/Clustering/KMeans/Point.php b/src/Phpml/Clustering/KMeans/Point.php index 4d888c3..9ff4b45 100644 --- a/src/Phpml/Clustering/KMeans/Point.php +++ b/src/Phpml/Clustering/KMeans/Point.php @@ -1,95 +1,124 @@ space = $space; - $this->dimention = $space->getDimention(); - $this->coordinates = $coordinates; - } + /** + * @var array + */ + protected $coordinates; - public function toArray() - { - return $this->coordinates; - } + /** + * @param array $coordinates + */ + public function __construct(array $coordinates) + { + $this->dimension = count($coordinates); + $this->coordinates = $coordinates; + } - public function getDistanceWith(self $point, $precise = true) - { - if ($point->space !== $this->space) - throw new LogicException("can only calculate distances from points in the same space"); + /** + * @return array + */ + public function toArray() + { + return $this->coordinates; + } - $distance = 0; - for ($n=0; $n<$this->dimention; $n++) { - $difference = $this->coordinates[$n] - $point->coordinates[$n]; - $distance += $difference * $difference; - } + /** + * @param Point $point + * @param bool $precise + * + * @return int|mixed + */ + public function getDistanceWith(self $point, $precise = true) + { + $distance = 0; + for ($n = 0; $n < $this->dimension; ++$n) { + $difference = $this->coordinates[$n] - $point->coordinates[$n]; + $distance += $difference * $difference; + } - return $precise ? sqrt($distance) : $distance; - } + return $precise ? sqrt($distance) : $distance; + } - public function getClosest($points) - { - foreach($points as $point) { - $distance = $this->getDistanceWith($point, false); + /** + * @param $points + * + * @return mixed + */ + public function getClosest($points) + { + foreach ($points as $point) { + $distance = $this->getDistanceWith($point, false); - if (!isset($minDistance)) { - $minDistance = $distance; - $minPoint = $point; - continue; - } + if (!isset($minDistance)) { + $minDistance = $distance; + $minPoint = $point; + continue; + } - if ($distance < $minDistance) { - $minDistance = $distance; - $minPoint = $point; - } - } + if ($distance < $minDistance) { + $minDistance = $distance; + $minPoint = $point; + } + } - return $minPoint; - } + return $minPoint; + } - public function belongsTo(Space $space) - { - return $this->space === $space; - } + /** + * @return array + */ + public function getCoordinates() + { + return $this->coordinates; + } - public function getSpace() - { - return $this->space; - } + /** + * @param mixed $offset + * + * @return bool + */ + public function offsetExists($offset) + { + return isset($this->coordinates[$offset]); + } - public function getCoordinates() - { - return $this->coordinates; - } + /** + * @param mixed $offset + * + * @return mixed + */ + public function offsetGet($offset) + { + return $this->coordinates[$offset]; + } - public function offsetExists($offset) - { - return isset($this->coordinates[$offset]); - } + /** + * @param mixed $offset + * @param mixed $value + */ + public function offsetSet($offset, $value) + { + $this->coordinates[$offset] = $value; + } - public function offsetGet($offset) - { - return $this->coordinates[$offset]; - } - - public function offsetSet($offset, $value) - { - $this->coordinates[$offset] = $value; - } - - public function offsetUnset($offset) - { - unset($this->coordinates[$offset]); - } + /** + * @param mixed $offset + */ + public function offsetUnset($offset) + { + unset($this->coordinates[$offset]); + } } diff --git a/src/Phpml/Clustering/KMeans/Space.php b/src/Phpml/Clustering/KMeans/Space.php index 090a48b..f4465cf 100644 --- a/src/Phpml/Clustering/KMeans/Space.php +++ b/src/Phpml/Clustering/KMeans/Space.php @@ -1,216 +1,271 @@ dimension = $dimension; + } - public function __construct($dimention) - { - if ($dimention < 1) - throw new LogicException("a space dimention cannot be null or negative"); + /** + * @return array + */ + public function toArray() + { + $points = []; + foreach ($this as $point) { + $points[] = $point->toArray(); + } - $this->dimention = $dimention; - } + return ['points' => $points]; + } - public function toArray() - { - $points = array(); - foreach ($this as $point) - $points[] = $point->toArray(); + /** + * @param array $coordinates + * + * @return Point + */ + public function newPoint(array $coordinates) + { + if (count($coordinates) != $this->dimension) { + throw new LogicException('('.implode(',', $coordinates).') is not a point of this space'); + } - return array('points' => $points); - } + return new Point($coordinates); + } - public function newPoint(array $coordinates) - { - if (count($coordinates) != $this->dimention) - throw new LogicException("(" . implode(',', $coordinates) . ") is not a point of this space"); + /** + * @param array $coordinates + * @param null $data + */ + public function addPoint(array $coordinates, $data = null) + { + return $this->attach($this->newPoint($coordinates), $data); + } - return new Point($this, $coordinates); - } + /** + * @param object $point + * @param null $data + */ + public function attach($point, $data = null) + { + if (!$point instanceof Point) { + throw new InvalidArgumentException('can only attach points to spaces'); + } - public function addPoint(array $coordinates, $data = null) - { - return $this->attach($this->newPoint($coordinates), $data); - } + return parent::attach($point, $data); + } - public function attach($point, $data = null) - { - if (!$point instanceof Point) - throw new InvalidArgumentException("can only attach points to spaces"); + /** + * @return int + */ + public function getDimension() + { + return $this->dimension; + } - return parent::attach($point, $data); - } + /** + * @return array|bool + */ + public function getBoundaries() + { + if (!count($this)) { + return false; + } - public function getDimention() - { - return $this->dimention; - } + $min = $this->newPoint(array_fill(0, $this->dimension, null)); + $max = $this->newPoint(array_fill(0, $this->dimension, null)); - public function getBoundaries() - { - if (!count($this)) - return false; + foreach ($this as $point) { + for ($n = 0; $n < $this->dimension; ++$n) { + ($min[$n] > $point[$n] || $min[$n] === null) && $min[$n] = $point[$n]; + ($max[$n] < $point[$n] || $max[$n] === null) && $max[$n] = $point[$n]; + } + } - $min = $this->newPoint(array_fill(0, $this->dimention, null)); - $max = $this->newPoint(array_fill(0, $this->dimention, null)); + return array($min, $max); + } - foreach ($this as $point) { - for ($n=0; $n < $this->dimention; $n++) { - ($min[$n] > $point[$n] || $min[$n] === null) && $min[$n] = $point[$n]; - ($max[$n] < $point[$n] || $max[$n] === null) && $max[$n] = $point[$n]; - } - } + /** + * @param Point $min + * @param Point $max + * + * @return Point + */ + public function getRandomPoint(Point $min, Point $max) + { + $point = $this->newPoint(array_fill(0, $this->dimension, null)); - return array($min, $max); - } + for ($n = 0; $n < $this->dimension; ++$n) { + $point[$n] = rand($min[$n], $max[$n]); + } - public function getRandomPoint(Point $min, Point $max) - { - $point = $this->newPoint(array_fill(0, $this->dimention, null)); + return $point; + } - for ($n=0; $n < $this->dimention; $n++) - $point[$n] = rand($min[$n], $max[$n]); + /** + * @param $nbClusters + * @param int $seed + * @param null $iterationCallback + * + * @return array|Cluster[] + */ + public function solve($nbClusters, $seed = KMeans::INIT_RANDOM, $iterationCallback = null) + { + if ($iterationCallback && !is_callable($iterationCallback)) { + throw new InvalidArgumentException('invalid iteration callback'); + } - return $point; - } + // initialize K clusters + $clusters = $this->initializeClusters($nbClusters, $seed); - /** - * @param $nbClusters - * @param int $seed - * @param null $iterationCallback - * @return array|Cluster[] - */ - public function solve($nbClusters, $seed = self::SEED_DEFAULT, $iterationCallback = null) - { - if ($iterationCallback && !is_callable($iterationCallback)) - throw new InvalidArgumentException("invalid iteration callback"); + // there's only one cluster, clusterization has no meaning + if (count($clusters) == 1) { + return $clusters[0]; + } - // initialize K clusters - $clusters = $this->initializeClusters($nbClusters, $seed); + // until convergence is reached + do { + $iterationCallback && $iterationCallback($this, $clusters); + } while ($this->iterate($clusters)); - // there's only one cluster, clusterization has no meaning - if (count($clusters) == 1) - return $clusters[0]; + // clustering is done. + return $clusters; + } - // until convergence is reached - do { - $iterationCallback && $iterationCallback($this, $clusters); - } while ($this->iterate($clusters)); + /** + * @param $nbClusters + * @param $seed + * + * @return array + */ + protected function initializeClusters($nbClusters, $seed) + { + if ($nbClusters <= 0) { + throw new InvalidArgumentException('invalid clusters number'); + } - // clustering is done. - return $clusters; - } + switch ($seed) { + // the default seeding method chooses completely random centroid + case KMeans::INIT_RANDOM: + // get the space boundaries to avoid placing clusters centroid too far from points + list($min, $max) = $this->getBoundaries(); - protected function initializeClusters($nbClusters, $seed) - { - if ($nbClusters <= 0) - throw new InvalidArgumentException("invalid clusters number"); + // initialize N clusters with a random point within space boundaries + for ($n = 0; $n < $nbClusters; ++$n) { + $clusters[] = new Cluster($this, $this->getRandomPoint($min, $max)->getCoordinates()); + } - switch ($seed) { - // the default seeding method chooses completely random centroid - case self::SEED_DEFAULT: - // get the space boundaries to avoid placing clusters centroid too far from points - list($min, $max) = $this->getBoundaries(); + break; - // initialize N clusters with a random point within space boundaries - for ($n=0; $n<$nbClusters; $n++) - $clusters[] = new Cluster($this, $this->getRandomPoint($min, $max)->getCoordinates()); + // the DASV seeding method consists of finding good initial centroids for the clusters + case KMeans::INIT_KMEANS_PLUS_PLUS: + // find a random point + $position = rand(1, count($this)); + for ($i = 1, $this->rewind(); $i < $position && $this->valid(); $i++, $this->next()); + $clusters[] = new Cluster($this, $this->current()->getCoordinates()); - break; + // retains the distances between points and their closest clusters + $distances = new SplObjectStorage(); - // the DASV seeding method consists of finding good initial centroids for the clusters - case self::SEED_DASV: - // find a random point - $position = rand(1, count($this)); - for ($i=1, $this->rewind(); $i<$position && $this->valid(); $i++, $this->next()); - $clusters[] = new Cluster($this, $this->current()->getCoordinates()); + // create k clusters + for ($i = 1; $i < $nbClusters; ++$i) { + $sum = 0; - // retains the distances between points and their closest clusters - $distances = new SplObjectStorage; + // for each points, get the distance with the closest centroid already choosen + foreach ($this as $point) { + $distance = $point->getDistanceWith($point->getClosest($clusters)); + $sum += $distances[$point] = $distance; + } - // create k clusters - for ($i=1; $i<$nbClusters; $i++) { - $sum = 0; + // choose a new random point using a weighted probability distribution + $sum = rand(0, (int) $sum); + foreach ($this as $point) { + if (($sum -= $distances[$point]) > 0) { + continue; + } - // for each points, get the distance with the closest centroid already choosen - foreach ($this as $point) { - $distance = $point->getDistanceWith($point->getClosest($clusters)); - $sum += $distances[$point] = $distance; - } + $clusters[] = new Cluster($this, $point->getCoordinates()); + break; + } + } - // choose a new random point using a weighted probability distribution - $sum = rand(0, $sum); - foreach ($this as $point) { - if (($sum -= $distances[$point]) > 0) - continue; + break; + } - $clusters[] = new Cluster($this, $point->getCoordinates()); - break; - } - } + // assing all points to the first cluster + $clusters[0]->attachAll($this); - break; - } + return $clusters; + } - // assing all points to the first cluster - $clusters[0]->attachAll($this); + /** + * @param $clusters + * + * @return bool + */ + protected function iterate($clusters) + { + $continue = false; - return $clusters; - } + // migration storages + $attach = new SplObjectStorage(); + $detach = new SplObjectStorage(); - protected function iterate($clusters) - { - $continue = false; + // calculate proximity amongst points and clusters + foreach ($clusters as $cluster) { + foreach ($cluster as $point) { + // find the closest cluster + $closest = $point->getClosest($clusters); - // migration storages - $attach = new SplObjectStorage; - $detach = new SplObjectStorage; + // move the point from its old cluster to its closest + if ($closest !== $cluster) { + isset($attach[$closest]) || $attach[$closest] = new SplObjectStorage(); + isset($detach[$cluster]) || $detach[$cluster] = new SplObjectStorage(); - // calculate proximity amongst points and clusters - foreach ($clusters as $cluster) { - foreach ($cluster as $point) { - // find the closest cluster - $closest = $point->getClosest($clusters); + $attach[$closest]->attach($point); + $detach[$cluster]->attach($point); - // move the point from its old cluster to its closest - if ($closest !== $cluster) { - isset($attach[$closest]) || $attach[$closest] = new SplObjectStorage; - isset($detach[$cluster]) || $detach[$cluster] = new SplObjectStorage; + $continue = true; + } + } + } - $attach[$closest]->attach($point); - $detach[$cluster]->attach($point); + // perform points migrations + foreach ($attach as $cluster) { + $cluster->attachAll($attach[$cluster]); + } - $continue = true; - } - } - } + foreach ($detach as $cluster) { + $cluster->detachAll($detach[$cluster]); + } - // perform points migrations - foreach ($attach as $cluster) - $cluster->attachAll($attach[$cluster]); + // update all cluster's centroids + foreach ($clusters as $cluster) { + $cluster->updateCentroid(); + } - foreach ($detach as $cluster) - $cluster->detachAll($detach[$cluster]); - - // update all cluster's centroids - foreach ($clusters as $cluster) - $cluster->updateCentroid(); - - return $continue; - } + return $continue; + } } diff --git a/src/Phpml/Exception/InvalidArgumentException.php b/src/Phpml/Exception/InvalidArgumentException.php index 3185205..45d532e 100644 --- a/src/Phpml/Exception/InvalidArgumentException.php +++ b/src/Phpml/Exception/InvalidArgumentException.php @@ -65,5 +65,4 @@ class InvalidArgumentException extends \Exception { return new self('Invalid clusters number'); } - } diff --git a/tests/Phpml/Clustering/DBSCANTest.php b/tests/Phpml/Clustering/DBSCANTest.php index 7be5331..be37fff 100644 --- a/tests/Phpml/Clustering/DBSCANTest.php +++ b/tests/Phpml/Clustering/DBSCANTest.php @@ -20,7 +20,6 @@ class DBSCANTest extends \PHPUnit_Framework_TestCase $this->assertEquals($clustered, $dbscan->cluster($samples)); - $samples = [[1, 1], [6, 6], [1, -1], [5, 6], [-1, -1], [7, 8], [-1, 1], [7, 7]]; $clustered = [ [[1, 1], [1, -1], [-1, -1], [-1, 1]], @@ -31,5 +30,4 @@ class DBSCANTest extends \PHPUnit_Framework_TestCase $this->assertEquals($clustered, $dbscan->cluster($samples)); } - } diff --git a/tests/Phpml/Clustering/KMeansTest.php b/tests/Phpml/Clustering/KMeansTest.php index 5c21c89..dae62fd 100644 --- a/tests/Phpml/Clustering/KMeansTest.php +++ b/tests/Phpml/Clustering/KMeansTest.php @@ -1,5 +1,6 @@ assertEquals(2, count($clusters)); foreach ($samples as $index => $sample) { - if(in_array($sample, $clusters[0]) || in_array($sample, $clusters[1])) { + if (in_array($sample, $clusters[0]) || in_array($sample, $clusters[1])) { unset($samples[$index]); } } @@ -28,16 +28,16 @@ class KMeansTest extends \PHPUnit_Framework_TestCase public function testKMeansMoreSamplesClustering() { $samples = [ - [80,55],[86,59],[19,85],[41,47],[57,58], - [76,22],[94,60],[13,93],[90,48],[52,54], - [62,46],[88,44],[85,24],[63,14],[51,40], - [75,31],[86,62],[81,95],[47,22],[43,95], - [71,19],[17,65],[69,21],[59,60],[59,12], - [15,22],[49,93],[56,35],[18,20],[39,59], - [50,15],[81,36],[67,62],[32,15],[75,65], - [10,47],[75,18],[13,45],[30,62],[95,79], - [64,11],[92,14],[94,49],[39,13],[60,68], - [62,10],[74,44],[37,42],[97,60],[47,73], + [80, 55], [86, 59], [19, 85], [41, 47], [57, 58], + [76, 22], [94, 60], [13, 93], [90, 48], [52, 54], + [62, 46], [88, 44], [85, 24], [63, 14], [51, 40], + [75, 31], [86, 62], [81, 95], [47, 22], [43, 95], + [71, 19], [17, 65], [69, 21], [59, 60], [59, 12], + [15, 22], [49, 93], [56, 35], [18, 20], [39, 59], + [50, 15], [81, 36], [67, 62], [32, 15], [75, 65], + [10, 47], [75, 18], [13, 45], [30, 62], [95, 79], + [64, 11], [92, 14], [94, 49], [39, 13], [60, 68], + [62, 10], [74, 44], [37, 42], [97, 60], [47, 73], ]; $kmeans = new KMeans(4); @@ -46,13 +46,12 @@ class KMeansTest extends \PHPUnit_Framework_TestCase $this->assertEquals(4, count($clusters)); foreach ($samples as $index => $sample) { - for($i=0; $i<4; $i++) { - if(in_array($sample, $clusters[$i])) { + for ($i = 0; $i < 4; ++$i) { + if (in_array($sample, $clusters[$i])) { unset($samples[$index]); } } } $this->assertEquals(0, count($samples)); } - }