refactor kmeans subclasses

This commit is contained in:
Arkadiusz Kondas 2016-05-01 23:36:33 +02:00
parent c0513e9b82
commit 7572304d50
7 changed files with 465 additions and 340 deletions

View File

@ -1,5 +1,6 @@
<?php <?php
declare(strict_types = 1);
declare (strict_types = 1);
namespace Phpml\Clustering; namespace Phpml\Clustering;
@ -8,23 +9,33 @@ use Phpml\Exception\InvalidArgumentException;
class KMeans implements Clusterer class KMeans implements Clusterer
{ {
const INIT_RANDOM = 1;
const INIT_KMEANS_PLUS_PLUS = 2;
/** /**
* @var int * @var int
*/ */
private $clustersNumber; private $clustersNumber;
/**
* @var int
*/
private $initialization;
/** /**
* @param int $clustersNumber * @param int $clustersNumber
* @param int $initialization
* *
* @throws InvalidArgumentException * @throws InvalidArgumentException
*/ */
public function __construct(int $clustersNumber) public function __construct(int $clustersNumber, int $initialization = self::INIT_KMEANS_PLUS_PLUS)
{ {
if($clustersNumber <= 0) { if ($clustersNumber <= 0) {
throw InvalidArgumentException::invalidClustersNumber(); throw InvalidArgumentException::invalidClustersNumber();
} }
$this->clustersNumber = $clustersNumber; $this->clustersNumber = $clustersNumber;
$this->initialization = $initialization;
} }
/** /**
@ -38,14 +49,12 @@ class KMeans implements Clusterer
foreach ($samples as $sample) { foreach ($samples as $sample) {
$space->addPoint($sample); $space->addPoint($sample);
} }
$clusters = []; $clusters = [];
foreach ($space->solve($this->clustersNumber) as $cluster) foreach ($space->solve($this->clustersNumber, $this->initialization) as $cluster) {
{
$clusters[] = $cluster->getPoints(); $clusters[] = $cluster->getPoints();
} }
return $clusters; return $clusters;
} }
} }

View File

@ -1,101 +1,137 @@
<?php <?php
declare(strict_types = 1); declare (strict_types = 1);
namespace Phpml\Clustering\KMeans; namespace Phpml\Clustering\KMeans;
use \IteratorAggregate; use IteratorAggregate;
use \Countable; use Countable;
use \SplObjectStorage; use SplObjectStorage;
use \LogicException; use LogicException;
class Cluster extends Point implements IteratorAggregate, Countable class Cluster extends Point implements IteratorAggregate, Countable
{ {
protected $space; /**
* @var Space
*/
protected $space;
/** /**
* @var SplObjectStorage|Point[] * @var SplObjectStorage|Point[]
*/ */
protected $points; protected $points;
public function __construct(Space $space, array $coordinates) /**
{ * @param Space $space
parent::__construct($space, $coordinates); * @param array $coordinates
$this->points = new SplObjectStorage; */
} public function __construct(Space $space, array $coordinates)
{
parent::__construct($coordinates);
$this->space = $space;
$this->points = new SplObjectStorage();
}
/** /**
* @return array * @return array
*/ */
public function getPoints() public function getPoints()
{ {
$points = []; $points = [];
foreach ($this->points as $point) { foreach ($this->points as $point) {
$points[] = $point->toArray(); $points[] = $point->toArray();
} }
return $points; return $points;
} }
public function toArray()
{
$points = array();
foreach ($this->points as $point)
$points[] = $point->toArray();
return array( /**
'centroid' => parent::toArray(), * @return array
'points' => $points, */
); public function toArray()
} {
return array(
'centroid' => parent::toArray(),
'points' => $this->getPoints(),
);
}
public function attach(Point $point) /**
{ * @param Point $point
if ($point instanceof self) *
throw new LogicException("cannot attach a cluster to another"); * @return Point
*/
public function attach(Point $point)
{
if ($point instanceof self) {
throw new LogicException('cannot attach a cluster to another');
}
$this->points->attach($point); $this->points->attach($point);
return $point;
}
public function detach(Point $point) return $point;
{ }
$this->points->detach($point);
return $point;
}
public function attachAll(SplObjectStorage $points) /**
{ * @param Point $point
$this->points->addAll($points); *
} * @return Point
*/
public function detach(Point $point)
{
$this->points->detach($point);
public function detachAll(SplObjectStorage $points) return $point;
{ }
$this->points->removeAll($points);
}
public function updateCentroid() /**
{ * @param SplObjectStorage $points
if (!$count = count($this->points)) */
return; public function attachAll(SplObjectStorage $points)
{
$this->points->addAll($points);
}
$centroid = $this->space->newPoint(array_fill(0, $this->dimention, 0)); /**
* @param SplObjectStorage $points
*/
public function detachAll(SplObjectStorage $points)
{
$this->points->removeAll($points);
}
foreach ($this->points as $point) public function updateCentroid()
for ($n=0; $n<$this->dimention; $n++) {
$centroid->coordinates[$n] += $point->coordinates[$n]; if (!$count = count($this->points)) {
return;
}
for ($n=0; $n<$this->dimention; $n++) $centroid = $this->space->newPoint(array_fill(0, $this->dimension, 0));
$this->coordinates[$n] = $centroid->coordinates[$n] / $count;
}
public function getIterator() foreach ($this->points as $point) {
{ for ($n = 0; $n < $this->dimension; ++$n) {
return $this->points; $centroid->coordinates[$n] += $point->coordinates[$n];
} }
}
public function count() for ($n = 0; $n < $this->dimension; ++$n) {
{ $this->coordinates[$n] = $centroid->coordinates[$n] / $count;
return count($this->points); }
} }
/**
* @return Point[]|SplObjectStorage
*/
public function getIterator()
{
return $this->points;
}
/**
* @return mixed
*/
public function count()
{
return count($this->points);
}
} }

View File

@ -1,95 +1,124 @@
<?php <?php
declare(strict_types = 1);
declare (strict_types = 1);
namespace Phpml\Clustering\KMeans; namespace Phpml\Clustering\KMeans;
use \ArrayAccess; use ArrayAccess;
use \LogicException;
class Point implements ArrayAccess class Point implements ArrayAccess
{ {
protected $space; /**
protected $dimention; * @var int
protected $coordinates; */
protected $dimension;
public function __construct(Space $space, array $coordinates) /**
{ * @var array
$this->space = $space; */
$this->dimention = $space->getDimention(); protected $coordinates;
$this->coordinates = $coordinates;
}
public function toArray() /**
{ * @param array $coordinates
return $this->coordinates; */
} public function __construct(array $coordinates)
{
$this->dimension = count($coordinates);
$this->coordinates = $coordinates;
}
public function getDistanceWith(self $point, $precise = true) /**
{ * @return array
if ($point->space !== $this->space) */
throw new LogicException("can only calculate distances from points in the same space"); public function toArray()
{
return $this->coordinates;
}
$distance = 0; /**
for ($n=0; $n<$this->dimention; $n++) { * @param Point $point
$difference = $this->coordinates[$n] - $point->coordinates[$n]; * @param bool $precise
$distance += $difference * $difference; *
} * @return int|mixed
*/
public function getDistanceWith(self $point, $precise = true)
{
$distance = 0;
for ($n = 0; $n < $this->dimension; ++$n) {
$difference = $this->coordinates[$n] - $point->coordinates[$n];
$distance += $difference * $difference;
}
return $precise ? sqrt($distance) : $distance; return $precise ? sqrt($distance) : $distance;
} }
public function getClosest($points) /**
{ * @param $points
foreach($points as $point) { *
$distance = $this->getDistanceWith($point, false); * @return mixed
*/
public function getClosest($points)
{
foreach ($points as $point) {
$distance = $this->getDistanceWith($point, false);
if (!isset($minDistance)) { if (!isset($minDistance)) {
$minDistance = $distance; $minDistance = $distance;
$minPoint = $point; $minPoint = $point;
continue; continue;
} }
if ($distance < $minDistance) { if ($distance < $minDistance) {
$minDistance = $distance; $minDistance = $distance;
$minPoint = $point; $minPoint = $point;
} }
} }
return $minPoint; return $minPoint;
} }
public function belongsTo(Space $space) /**
{ * @return array
return $this->space === $space; */
} public function getCoordinates()
{
return $this->coordinates;
}
public function getSpace() /**
{ * @param mixed $offset
return $this->space; *
} * @return bool
*/
public function offsetExists($offset)
{
return isset($this->coordinates[$offset]);
}
public function getCoordinates() /**
{ * @param mixed $offset
return $this->coordinates; *
} * @return mixed
*/
public function offsetGet($offset)
{
return $this->coordinates[$offset];
}
public function offsetExists($offset) /**
{ * @param mixed $offset
return isset($this->coordinates[$offset]); * @param mixed $value
} */
public function offsetSet($offset, $value)
{
$this->coordinates[$offset] = $value;
}
public function offsetGet($offset) /**
{ * @param mixed $offset
return $this->coordinates[$offset]; */
} public function offsetUnset($offset)
{
public function offsetSet($offset, $value) unset($this->coordinates[$offset]);
{ }
$this->coordinates[$offset] = $value;
}
public function offsetUnset($offset)
{
unset($this->coordinates[$offset]);
}
} }

View File

@ -1,216 +1,271 @@
<?php <?php
declare(strict_types = 1);
declare (strict_types = 1);
namespace Phpml\Clustering\KMeans; namespace Phpml\Clustering\KMeans;
use \SplObjectStorage; use Phpml\Clustering\KMeans;
use \LogicException; use SplObjectStorage;
use \InvalidArgumentException; use LogicException;
use InvalidArgumentException;
class Space extends SplObjectStorage class Space extends SplObjectStorage
{ {
// Default seeding method, initial cluster centroid are randomly choosen /**
const SEED_DEFAULT = 1; * @var int
*/
protected $dimension;
// Alternative seeding method by David Arthur and Sergei Vassilvitskii /**
// (see http://en.wikipedia.org/wiki/K-means++) * @param $dimension
const SEED_DASV = 2; */
public function __construct($dimension)
{
if ($dimension < 1) {
throw new LogicException('a space dimension cannot be null or negative');
}
protected $dimention; $this->dimension = $dimension;
}
public function __construct($dimention) /**
{ * @return array
if ($dimention < 1) */
throw new LogicException("a space dimention cannot be null or negative"); public function toArray()
{
$points = [];
foreach ($this as $point) {
$points[] = $point->toArray();
}
$this->dimention = $dimention; return ['points' => $points];
} }
public function toArray() /**
{ * @param array $coordinates
$points = array(); *
foreach ($this as $point) * @return Point
$points[] = $point->toArray(); */
public function newPoint(array $coordinates)
{
if (count($coordinates) != $this->dimension) {
throw new LogicException('('.implode(',', $coordinates).') is not a point of this space');
}
return array('points' => $points); return new Point($coordinates);
} }
public function newPoint(array $coordinates) /**
{ * @param array $coordinates
if (count($coordinates) != $this->dimention) * @param null $data
throw new LogicException("(" . implode(',', $coordinates) . ") is not a point of this space"); */
public function addPoint(array $coordinates, $data = null)
{
return $this->attach($this->newPoint($coordinates), $data);
}
return new Point($this, $coordinates); /**
} * @param object $point
* @param null $data
*/
public function attach($point, $data = null)
{
if (!$point instanceof Point) {
throw new InvalidArgumentException('can only attach points to spaces');
}
public function addPoint(array $coordinates, $data = null) return parent::attach($point, $data);
{ }
return $this->attach($this->newPoint($coordinates), $data);
}
public function attach($point, $data = null) /**
{ * @return int
if (!$point instanceof Point) */
throw new InvalidArgumentException("can only attach points to spaces"); public function getDimension()
{
return $this->dimension;
}
return parent::attach($point, $data); /**
} * @return array|bool
*/
public function getBoundaries()
{
if (!count($this)) {
return false;
}
public function getDimention() $min = $this->newPoint(array_fill(0, $this->dimension, null));
{ $max = $this->newPoint(array_fill(0, $this->dimension, null));
return $this->dimention;
}
public function getBoundaries() foreach ($this as $point) {
{ for ($n = 0; $n < $this->dimension; ++$n) {
if (!count($this)) ($min[$n] > $point[$n] || $min[$n] === null) && $min[$n] = $point[$n];
return false; ($max[$n] < $point[$n] || $max[$n] === null) && $max[$n] = $point[$n];
}
}
$min = $this->newPoint(array_fill(0, $this->dimention, null)); return array($min, $max);
$max = $this->newPoint(array_fill(0, $this->dimention, null)); }
foreach ($this as $point) { /**
for ($n=0; $n < $this->dimention; $n++) { * @param Point $min
($min[$n] > $point[$n] || $min[$n] === null) && $min[$n] = $point[$n]; * @param Point $max
($max[$n] < $point[$n] || $max[$n] === null) && $max[$n] = $point[$n]; *
} * @return Point
} */
public function getRandomPoint(Point $min, Point $max)
{
$point = $this->newPoint(array_fill(0, $this->dimension, null));
return array($min, $max); for ($n = 0; $n < $this->dimension; ++$n) {
} $point[$n] = rand($min[$n], $max[$n]);
}
public function getRandomPoint(Point $min, Point $max) return $point;
{ }
$point = $this->newPoint(array_fill(0, $this->dimention, null));
for ($n=0; $n < $this->dimention; $n++) /**
$point[$n] = rand($min[$n], $max[$n]); * @param $nbClusters
* @param int $seed
* @param null $iterationCallback
*
* @return array|Cluster[]
*/
public function solve($nbClusters, $seed = KMeans::INIT_RANDOM, $iterationCallback = null)
{
if ($iterationCallback && !is_callable($iterationCallback)) {
throw new InvalidArgumentException('invalid iteration callback');
}
return $point; // initialize K clusters
} $clusters = $this->initializeClusters($nbClusters, $seed);
/** // there's only one cluster, clusterization has no meaning
* @param $nbClusters if (count($clusters) == 1) {
* @param int $seed return $clusters[0];
* @param null $iterationCallback }
* @return array|Cluster[]
*/
public function solve($nbClusters, $seed = self::SEED_DEFAULT, $iterationCallback = null)
{
if ($iterationCallback && !is_callable($iterationCallback))
throw new InvalidArgumentException("invalid iteration callback");
// initialize K clusters // until convergence is reached
$clusters = $this->initializeClusters($nbClusters, $seed); do {
$iterationCallback && $iterationCallback($this, $clusters);
} while ($this->iterate($clusters));
// there's only one cluster, clusterization has no meaning // clustering is done.
if (count($clusters) == 1) return $clusters;
return $clusters[0]; }
// until convergence is reached /**
do { * @param $nbClusters
$iterationCallback && $iterationCallback($this, $clusters); * @param $seed
} while ($this->iterate($clusters)); *
* @return array
*/
protected function initializeClusters($nbClusters, $seed)
{
if ($nbClusters <= 0) {
throw new InvalidArgumentException('invalid clusters number');
}
// clustering is done. switch ($seed) {
return $clusters; // the default seeding method chooses completely random centroid
} case KMeans::INIT_RANDOM:
// get the space boundaries to avoid placing clusters centroid too far from points
list($min, $max) = $this->getBoundaries();
protected function initializeClusters($nbClusters, $seed) // initialize N clusters with a random point within space boundaries
{ for ($n = 0; $n < $nbClusters; ++$n) {
if ($nbClusters <= 0) $clusters[] = new Cluster($this, $this->getRandomPoint($min, $max)->getCoordinates());
throw new InvalidArgumentException("invalid clusters number"); }
switch ($seed) { break;
// the default seeding method chooses completely random centroid
case self::SEED_DEFAULT:
// get the space boundaries to avoid placing clusters centroid too far from points
list($min, $max) = $this->getBoundaries();
// initialize N clusters with a random point within space boundaries // the DASV seeding method consists of finding good initial centroids for the clusters
for ($n=0; $n<$nbClusters; $n++) case KMeans::INIT_KMEANS_PLUS_PLUS:
$clusters[] = new Cluster($this, $this->getRandomPoint($min, $max)->getCoordinates()); // find a random point
$position = rand(1, count($this));
for ($i = 1, $this->rewind(); $i < $position && $this->valid(); $i++, $this->next());
$clusters[] = new Cluster($this, $this->current()->getCoordinates());
break; // retains the distances between points and their closest clusters
$distances = new SplObjectStorage();
// the DASV seeding method consists of finding good initial centroids for the clusters // create k clusters
case self::SEED_DASV: for ($i = 1; $i < $nbClusters; ++$i) {
// find a random point $sum = 0;
$position = rand(1, count($this));
for ($i=1, $this->rewind(); $i<$position && $this->valid(); $i++, $this->next());
$clusters[] = new Cluster($this, $this->current()->getCoordinates());
// retains the distances between points and their closest clusters // for each points, get the distance with the closest centroid already choosen
$distances = new SplObjectStorage; foreach ($this as $point) {
$distance = $point->getDistanceWith($point->getClosest($clusters));
$sum += $distances[$point] = $distance;
}
// create k clusters // choose a new random point using a weighted probability distribution
for ($i=1; $i<$nbClusters; $i++) { $sum = rand(0, (int) $sum);
$sum = 0; foreach ($this as $point) {
if (($sum -= $distances[$point]) > 0) {
continue;
}
// for each points, get the distance with the closest centroid already choosen $clusters[] = new Cluster($this, $point->getCoordinates());
foreach ($this as $point) { break;
$distance = $point->getDistanceWith($point->getClosest($clusters)); }
$sum += $distances[$point] = $distance; }
}
// choose a new random point using a weighted probability distribution break;
$sum = rand(0, $sum); }
foreach ($this as $point) {
if (($sum -= $distances[$point]) > 0)
continue;
$clusters[] = new Cluster($this, $point->getCoordinates()); // assing all points to the first cluster
break; $clusters[0]->attachAll($this);
}
}
break; return $clusters;
} }
// assing all points to the first cluster /**
$clusters[0]->attachAll($this); * @param $clusters
*
* @return bool
*/
protected function iterate($clusters)
{
$continue = false;
return $clusters; // migration storages
} $attach = new SplObjectStorage();
$detach = new SplObjectStorage();
protected function iterate($clusters) // calculate proximity amongst points and clusters
{ foreach ($clusters as $cluster) {
$continue = false; foreach ($cluster as $point) {
// find the closest cluster
$closest = $point->getClosest($clusters);
// migration storages // move the point from its old cluster to its closest
$attach = new SplObjectStorage; if ($closest !== $cluster) {
$detach = new SplObjectStorage; isset($attach[$closest]) || $attach[$closest] = new SplObjectStorage();
isset($detach[$cluster]) || $detach[$cluster] = new SplObjectStorage();
// calculate proximity amongst points and clusters $attach[$closest]->attach($point);
foreach ($clusters as $cluster) { $detach[$cluster]->attach($point);
foreach ($cluster as $point) {
// find the closest cluster
$closest = $point->getClosest($clusters);
// move the point from its old cluster to its closest $continue = true;
if ($closest !== $cluster) { }
isset($attach[$closest]) || $attach[$closest] = new SplObjectStorage; }
isset($detach[$cluster]) || $detach[$cluster] = new SplObjectStorage; }
$attach[$closest]->attach($point); // perform points migrations
$detach[$cluster]->attach($point); foreach ($attach as $cluster) {
$cluster->attachAll($attach[$cluster]);
}
$continue = true; foreach ($detach as $cluster) {
} $cluster->detachAll($detach[$cluster]);
} }
}
// perform points migrations // update all cluster's centroids
foreach ($attach as $cluster) foreach ($clusters as $cluster) {
$cluster->attachAll($attach[$cluster]); $cluster->updateCentroid();
}
foreach ($detach as $cluster) return $continue;
$cluster->detachAll($detach[$cluster]); }
// update all cluster's centroids
foreach ($clusters as $cluster)
$cluster->updateCentroid();
return $continue;
}
} }

View File

@ -65,5 +65,4 @@ class InvalidArgumentException extends \Exception
{ {
return new self('Invalid clusters number'); return new self('Invalid clusters number');
} }
} }

View File

@ -20,7 +20,6 @@ class DBSCANTest extends \PHPUnit_Framework_TestCase
$this->assertEquals($clustered, $dbscan->cluster($samples)); $this->assertEquals($clustered, $dbscan->cluster($samples));
$samples = [[1, 1], [6, 6], [1, -1], [5, 6], [-1, -1], [7, 8], [-1, 1], [7, 7]]; $samples = [[1, 1], [6, 6], [1, -1], [5, 6], [-1, -1], [7, 8], [-1, 1], [7, 7]];
$clustered = [ $clustered = [
[[1, 1], [1, -1], [-1, -1], [-1, 1]], [[1, 1], [1, -1], [-1, -1], [-1, 1]],
@ -31,5 +30,4 @@ class DBSCANTest extends \PHPUnit_Framework_TestCase
$this->assertEquals($clustered, $dbscan->cluster($samples)); $this->assertEquals($clustered, $dbscan->cluster($samples));
} }
} }

View File

@ -1,5 +1,6 @@
<?php <?php
declare(strict_types = 1);
declare (strict_types = 1);
namespace tests\Clustering; namespace tests\Clustering;
@ -7,7 +8,6 @@ use Phpml\Clustering\KMeans;
class KMeansTest extends \PHPUnit_Framework_TestCase class KMeansTest extends \PHPUnit_Framework_TestCase
{ {
public function testKMeansSamplesClustering() public function testKMeansSamplesClustering()
{ {
$samples = [[1, 1], [8, 7], [1, 2], [7, 8], [2, 1], [8, 9]]; $samples = [[1, 1], [8, 7], [1, 2], [7, 8], [2, 1], [8, 9]];
@ -18,7 +18,7 @@ class KMeansTest extends \PHPUnit_Framework_TestCase
$this->assertEquals(2, count($clusters)); $this->assertEquals(2, count($clusters));
foreach ($samples as $index => $sample) { foreach ($samples as $index => $sample) {
if(in_array($sample, $clusters[0]) || in_array($sample, $clusters[1])) { if (in_array($sample, $clusters[0]) || in_array($sample, $clusters[1])) {
unset($samples[$index]); unset($samples[$index]);
} }
} }
@ -28,16 +28,16 @@ class KMeansTest extends \PHPUnit_Framework_TestCase
public function testKMeansMoreSamplesClustering() public function testKMeansMoreSamplesClustering()
{ {
$samples = [ $samples = [
[80,55],[86,59],[19,85],[41,47],[57,58], [80, 55], [86, 59], [19, 85], [41, 47], [57, 58],
[76,22],[94,60],[13,93],[90,48],[52,54], [76, 22], [94, 60], [13, 93], [90, 48], [52, 54],
[62,46],[88,44],[85,24],[63,14],[51,40], [62, 46], [88, 44], [85, 24], [63, 14], [51, 40],
[75,31],[86,62],[81,95],[47,22],[43,95], [75, 31], [86, 62], [81, 95], [47, 22], [43, 95],
[71,19],[17,65],[69,21],[59,60],[59,12], [71, 19], [17, 65], [69, 21], [59, 60], [59, 12],
[15,22],[49,93],[56,35],[18,20],[39,59], [15, 22], [49, 93], [56, 35], [18, 20], [39, 59],
[50,15],[81,36],[67,62],[32,15],[75,65], [50, 15], [81, 36], [67, 62], [32, 15], [75, 65],
[10,47],[75,18],[13,45],[30,62],[95,79], [10, 47], [75, 18], [13, 45], [30, 62], [95, 79],
[64,11],[92,14],[94,49],[39,13],[60,68], [64, 11], [92, 14], [94, 49], [39, 13], [60, 68],
[62,10],[74,44],[37,42],[97,60],[47,73], [62, 10], [74, 44], [37, 42], [97, 60], [47, 73],
]; ];
$kmeans = new KMeans(4); $kmeans = new KMeans(4);
@ -46,13 +46,12 @@ class KMeansTest extends \PHPUnit_Framework_TestCase
$this->assertEquals(4, count($clusters)); $this->assertEquals(4, count($clusters));
foreach ($samples as $index => $sample) { foreach ($samples as $index => $sample) {
for($i=0; $i<4; $i++) { for ($i = 0; $i < 4; ++$i) {
if(in_array($sample, $clusters[$i])) { if (in_array($sample, $clusters[$i])) {
unset($samples[$index]); unset($samples[$index]);
} }
} }
} }
$this->assertEquals(0, count($samples)); $this->assertEquals(0, count($samples));
} }
} }