refactor kmeans subclasses

This commit is contained in:
Arkadiusz Kondas 2016-05-01 23:36:33 +02:00
parent c0513e9b82
commit 7572304d50
7 changed files with 465 additions and 340 deletions

View File

@ -1,4 +1,5 @@
<?php <?php
declare (strict_types = 1); declare (strict_types = 1);
namespace Phpml\Clustering; namespace Phpml\Clustering;
@ -8,23 +9,33 @@ use Phpml\Exception\InvalidArgumentException;
class KMeans implements Clusterer class KMeans implements Clusterer
{ {
const INIT_RANDOM = 1;
const INIT_KMEANS_PLUS_PLUS = 2;
/** /**
* @var int * @var int
*/ */
private $clustersNumber; private $clustersNumber;
/**
* @var int
*/
private $initialization;
/** /**
* @param int $clustersNumber * @param int $clustersNumber
* @param int $initialization
* *
* @throws InvalidArgumentException * @throws InvalidArgumentException
*/ */
public function __construct(int $clustersNumber) public function __construct(int $clustersNumber, int $initialization = self::INIT_KMEANS_PLUS_PLUS)
{ {
if ($clustersNumber <= 0) { if ($clustersNumber <= 0) {
throw InvalidArgumentException::invalidClustersNumber(); throw InvalidArgumentException::invalidClustersNumber();
} }
$this->clustersNumber = $clustersNumber; $this->clustersNumber = $clustersNumber;
$this->initialization = $initialization;
} }
/** /**
@ -40,12 +51,10 @@ class KMeans implements Clusterer
} }
$clusters = []; $clusters = [];
foreach ($space->solve($this->clustersNumber) as $cluster) foreach ($space->solve($this->clustersNumber, $this->initialization) as $cluster) {
{
$clusters[] = $cluster->getPoints(); $clusters[] = $cluster->getPoints();
} }
return $clusters; return $clusters;
} }
} }

View File

@ -4,13 +4,16 @@ declare(strict_types = 1);
namespace Phpml\Clustering\KMeans; namespace Phpml\Clustering\KMeans;
use \IteratorAggregate; use IteratorAggregate;
use \Countable; use Countable;
use \SplObjectStorage; use SplObjectStorage;
use \LogicException; use LogicException;
class Cluster extends Point implements IteratorAggregate, Countable class Cluster extends Point implements IteratorAggregate, Countable
{ {
/**
* @var Space
*/
protected $space; protected $space;
/** /**
@ -18,10 +21,15 @@ class Cluster extends Point implements IteratorAggregate, Countable
*/ */
protected $points; protected $points;
/**
* @param Space $space
* @param array $coordinates
*/
public function __construct(Space $space, array $coordinates) public function __construct(Space $space, array $coordinates)
{ {
parent::__construct($space, $coordinates); parent::__construct($coordinates);
$this->points = new SplObjectStorage; $this->space = $space;
$this->points = new SplObjectStorage();
} }
/** /**
@ -37,38 +45,56 @@ class Cluster extends Point implements IteratorAggregate, Countable
return $points; return $points;
} }
/**
* @return array
*/
public function toArray() public function toArray()
{ {
$points = array();
foreach ($this->points as $point)
$points[] = $point->toArray();
return array( return array(
'centroid' => parent::toArray(), 'centroid' => parent::toArray(),
'points' => $points, 'points' => $this->getPoints(),
); );
} }
/**
* @param Point $point
*
* @return Point
*/
public function attach(Point $point) public function attach(Point $point)
{ {
if ($point instanceof self) if ($point instanceof self) {
throw new LogicException("cannot attach a cluster to another"); throw new LogicException('cannot attach a cluster to another');
}
$this->points->attach($point); $this->points->attach($point);
return $point; return $point;
} }
/**
* @param Point $point
*
* @return Point
*/
public function detach(Point $point) public function detach(Point $point)
{ {
$this->points->detach($point); $this->points->detach($point);
return $point; return $point;
} }
/**
* @param SplObjectStorage $points
*/
public function attachAll(SplObjectStorage $points) public function attachAll(SplObjectStorage $points)
{ {
$this->points->addAll($points); $this->points->addAll($points);
} }
/**
* @param SplObjectStorage $points
*/
public function detachAll(SplObjectStorage $points) public function detachAll(SplObjectStorage $points)
{ {
$this->points->removeAll($points); $this->points->removeAll($points);
@ -76,24 +102,34 @@ class Cluster extends Point implements IteratorAggregate, Countable
public function updateCentroid() public function updateCentroid()
{ {
if (!$count = count($this->points)) if (!$count = count($this->points)) {
return; return;
$centroid = $this->space->newPoint(array_fill(0, $this->dimention, 0));
foreach ($this->points as $point)
for ($n=0; $n<$this->dimention; $n++)
$centroid->coordinates[$n] += $point->coordinates[$n];
for ($n=0; $n<$this->dimention; $n++)
$this->coordinates[$n] = $centroid->coordinates[$n] / $count;
} }
$centroid = $this->space->newPoint(array_fill(0, $this->dimension, 0));
foreach ($this->points as $point) {
for ($n = 0; $n < $this->dimension; ++$n) {
$centroid->coordinates[$n] += $point->coordinates[$n];
}
}
for ($n = 0; $n < $this->dimension; ++$n) {
$this->coordinates[$n] = $centroid->coordinates[$n] / $count;
}
}
/**
* @return Point[]|SplObjectStorage
*/
public function getIterator() public function getIterator()
{ {
return $this->points; return $this->points;
} }
/**
* @return mixed
*/
public function count() public function count()
{ {
return count($this->points); return count($this->points);

View File

@ -1,36 +1,50 @@
<?php <?php
declare (strict_types = 1); declare (strict_types = 1);
namespace Phpml\Clustering\KMeans; namespace Phpml\Clustering\KMeans;
use \ArrayAccess; use ArrayAccess;
use \LogicException;
class Point implements ArrayAccess class Point implements ArrayAccess
{ {
protected $space; /**
protected $dimention; * @var int
*/
protected $dimension;
/**
* @var array
*/
protected $coordinates; protected $coordinates;
public function __construct(Space $space, array $coordinates) /**
* @param array $coordinates
*/
public function __construct(array $coordinates)
{ {
$this->space = $space; $this->dimension = count($coordinates);
$this->dimention = $space->getDimention();
$this->coordinates = $coordinates; $this->coordinates = $coordinates;
} }
/**
* @return array
*/
public function toArray() public function toArray()
{ {
return $this->coordinates; return $this->coordinates;
} }
/**
* @param Point $point
* @param bool $precise
*
* @return int|mixed
*/
public function getDistanceWith(self $point, $precise = true) public function getDistanceWith(self $point, $precise = true)
{ {
if ($point->space !== $this->space)
throw new LogicException("can only calculate distances from points in the same space");
$distance = 0; $distance = 0;
for ($n=0; $n<$this->dimention; $n++) { for ($n = 0; $n < $this->dimension; ++$n) {
$difference = $this->coordinates[$n] - $point->coordinates[$n]; $difference = $this->coordinates[$n] - $point->coordinates[$n];
$distance += $difference * $difference; $distance += $difference * $difference;
} }
@ -38,6 +52,11 @@ class Point implements ArrayAccess
return $precise ? sqrt($distance) : $distance; return $precise ? sqrt($distance) : $distance;
} }
/**
* @param $points
*
* @return mixed
*/
public function getClosest($points) public function getClosest($points)
{ {
foreach ($points as $point) { foreach ($points as $point) {
@ -58,36 +77,46 @@ class Point implements ArrayAccess
return $minPoint; return $minPoint;
} }
public function belongsTo(Space $space) /**
{ * @return array
return $this->space === $space; */
}
public function getSpace()
{
return $this->space;
}
public function getCoordinates() public function getCoordinates()
{ {
return $this->coordinates; return $this->coordinates;
} }
/**
* @param mixed $offset
*
* @return bool
*/
public function offsetExists($offset) public function offsetExists($offset)
{ {
return isset($this->coordinates[$offset]); return isset($this->coordinates[$offset]);
} }
/**
* @param mixed $offset
*
* @return mixed
*/
public function offsetGet($offset) public function offsetGet($offset)
{ {
return $this->coordinates[$offset]; return $this->coordinates[$offset];
} }
/**
* @param mixed $offset
* @param mixed $value
*/
public function offsetSet($offset, $value) public function offsetSet($offset, $value)
{ {
$this->coordinates[$offset] = $value; $this->coordinates[$offset] = $value;
} }
/**
* @param mixed $offset
*/
public function offsetUnset($offset) public function offsetUnset($offset)
{ {
unset($this->coordinates[$offset]); unset($this->coordinates[$offset]);

View File

@ -1,76 +1,104 @@
<?php <?php
declare (strict_types = 1); declare (strict_types = 1);
namespace Phpml\Clustering\KMeans; namespace Phpml\Clustering\KMeans;
use \SplObjectStorage; use Phpml\Clustering\KMeans;
use \LogicException; use SplObjectStorage;
use \InvalidArgumentException; use LogicException;
use InvalidArgumentException;
class Space extends SplObjectStorage class Space extends SplObjectStorage
{ {
// Default seeding method, initial cluster centroid are randomly choosen /**
const SEED_DEFAULT = 1; * @var int
*/
protected $dimension;
// Alternative seeding method by David Arthur and Sergei Vassilvitskii /**
// (see http://en.wikipedia.org/wiki/K-means++) * @param $dimension
const SEED_DASV = 2; */
public function __construct($dimension)
protected $dimention;
public function __construct($dimention)
{ {
if ($dimention < 1) if ($dimension < 1) {
throw new LogicException("a space dimention cannot be null or negative"); throw new LogicException('a space dimension cannot be null or negative');
$this->dimention = $dimention;
} }
$this->dimension = $dimension;
}
/**
* @return array
*/
public function toArray() public function toArray()
{ {
$points = array(); $points = [];
foreach ($this as $point) foreach ($this as $point) {
$points[] = $point->toArray(); $points[] = $point->toArray();
return array('points' => $points);
} }
return ['points' => $points];
}
/**
* @param array $coordinates
*
* @return Point
*/
public function newPoint(array $coordinates) public function newPoint(array $coordinates)
{ {
if (count($coordinates) != $this->dimention) if (count($coordinates) != $this->dimension) {
throw new LogicException("(" . implode(',', $coordinates) . ") is not a point of this space"); throw new LogicException('('.implode(',', $coordinates).') is not a point of this space');
return new Point($this, $coordinates);
} }
return new Point($coordinates);
}
/**
* @param array $coordinates
* @param null $data
*/
public function addPoint(array $coordinates, $data = null) public function addPoint(array $coordinates, $data = null)
{ {
return $this->attach($this->newPoint($coordinates), $data); return $this->attach($this->newPoint($coordinates), $data);
} }
/**
* @param object $point
* @param null $data
*/
public function attach($point, $data = null) public function attach($point, $data = null)
{ {
if (!$point instanceof Point) if (!$point instanceof Point) {
throw new InvalidArgumentException("can only attach points to spaces"); throw new InvalidArgumentException('can only attach points to spaces');
}
return parent::attach($point, $data); return parent::attach($point, $data);
} }
public function getDimention() /**
* @return int
*/
public function getDimension()
{ {
return $this->dimention; return $this->dimension;
} }
/**
* @return array|bool
*/
public function getBoundaries() public function getBoundaries()
{ {
if (!count($this)) if (!count($this)) {
return false; return false;
}
$min = $this->newPoint(array_fill(0, $this->dimention, null)); $min = $this->newPoint(array_fill(0, $this->dimension, null));
$max = $this->newPoint(array_fill(0, $this->dimention, null)); $max = $this->newPoint(array_fill(0, $this->dimension, null));
foreach ($this as $point) { foreach ($this as $point) {
for ($n=0; $n < $this->dimention; $n++) { for ($n = 0; $n < $this->dimension; ++$n) {
($min[$n] > $point[$n] || $min[$n] === null) && $min[$n] = $point[$n]; ($min[$n] > $point[$n] || $min[$n] === null) && $min[$n] = $point[$n];
($max[$n] < $point[$n] || $max[$n] === null) && $max[$n] = $point[$n]; ($max[$n] < $point[$n] || $max[$n] === null) && $max[$n] = $point[$n];
} }
@ -79,12 +107,19 @@ class Space extends SplObjectStorage
return array($min, $max); return array($min, $max);
} }
/**
* @param Point $min
* @param Point $max
*
* @return Point
*/
public function getRandomPoint(Point $min, Point $max) public function getRandomPoint(Point $min, Point $max)
{ {
$point = $this->newPoint(array_fill(0, $this->dimention, null)); $point = $this->newPoint(array_fill(0, $this->dimension, null));
for ($n=0; $n < $this->dimention; $n++) for ($n = 0; $n < $this->dimension; ++$n) {
$point[$n] = rand($min[$n], $max[$n]); $point[$n] = rand($min[$n], $max[$n]);
}
return $point; return $point;
} }
@ -93,19 +128,22 @@ class Space extends SplObjectStorage
* @param $nbClusters * @param $nbClusters
* @param int $seed * @param int $seed
* @param null $iterationCallback * @param null $iterationCallback
*
* @return array|Cluster[] * @return array|Cluster[]
*/ */
public function solve($nbClusters, $seed = self::SEED_DEFAULT, $iterationCallback = null) public function solve($nbClusters, $seed = KMeans::INIT_RANDOM, $iterationCallback = null)
{ {
if ($iterationCallback && !is_callable($iterationCallback)) if ($iterationCallback && !is_callable($iterationCallback)) {
throw new InvalidArgumentException("invalid iteration callback"); throw new InvalidArgumentException('invalid iteration callback');
}
// initialize K clusters // initialize K clusters
$clusters = $this->initializeClusters($nbClusters, $seed); $clusters = $this->initializeClusters($nbClusters, $seed);
// there's only one cluster, clusterization has no meaning // there's only one cluster, clusterization has no meaning
if (count($clusters) == 1) if (count($clusters) == 1) {
return $clusters[0]; return $clusters[0];
}
// until convergence is reached // until convergence is reached
do { do {
@ -116,35 +154,43 @@ class Space extends SplObjectStorage
return $clusters; return $clusters;
} }
/**
* @param $nbClusters
* @param $seed
*
* @return array
*/
protected function initializeClusters($nbClusters, $seed) protected function initializeClusters($nbClusters, $seed)
{ {
if ($nbClusters <= 0) if ($nbClusters <= 0) {
throw new InvalidArgumentException("invalid clusters number"); throw new InvalidArgumentException('invalid clusters number');
}
switch ($seed) { switch ($seed) {
// the default seeding method chooses completely random centroid // the default seeding method chooses completely random centroid
case self::SEED_DEFAULT: case KMeans::INIT_RANDOM:
// get the space boundaries to avoid placing clusters centroid too far from points // get the space boundaries to avoid placing clusters centroid too far from points
list($min, $max) = $this->getBoundaries(); list($min, $max) = $this->getBoundaries();
// initialize N clusters with a random point within space boundaries // initialize N clusters with a random point within space boundaries
for ($n=0; $n<$nbClusters; $n++) for ($n = 0; $n < $nbClusters; ++$n) {
$clusters[] = new Cluster($this, $this->getRandomPoint($min, $max)->getCoordinates()); $clusters[] = new Cluster($this, $this->getRandomPoint($min, $max)->getCoordinates());
}
break; break;
// the DASV seeding method consists of finding good initial centroids for the clusters // the DASV seeding method consists of finding good initial centroids for the clusters
case self::SEED_DASV: case KMeans::INIT_KMEANS_PLUS_PLUS:
// find a random point // find a random point
$position = rand(1, count($this)); $position = rand(1, count($this));
for ($i = 1, $this->rewind(); $i < $position && $this->valid(); $i++, $this->next()); for ($i = 1, $this->rewind(); $i < $position && $this->valid(); $i++, $this->next());
$clusters[] = new Cluster($this, $this->current()->getCoordinates()); $clusters[] = new Cluster($this, $this->current()->getCoordinates());
// retains the distances between points and their closest clusters // retains the distances between points and their closest clusters
$distances = new SplObjectStorage; $distances = new SplObjectStorage();
// create k clusters // create k clusters
for ($i=1; $i<$nbClusters; $i++) { for ($i = 1; $i < $nbClusters; ++$i) {
$sum = 0; $sum = 0;
// for each points, get the distance with the closest centroid already choosen // for each points, get the distance with the closest centroid already choosen
@ -154,10 +200,11 @@ class Space extends SplObjectStorage
} }
// choose a new random point using a weighted probability distribution // choose a new random point using a weighted probability distribution
$sum = rand(0, $sum); $sum = rand(0, (int) $sum);
foreach ($this as $point) { foreach ($this as $point) {
if (($sum -= $distances[$point]) > 0) if (($sum -= $distances[$point]) > 0) {
continue; continue;
}
$clusters[] = new Cluster($this, $point->getCoordinates()); $clusters[] = new Cluster($this, $point->getCoordinates());
break; break;
@ -173,13 +220,18 @@ class Space extends SplObjectStorage
return $clusters; return $clusters;
} }
/**
* @param $clusters
*
* @return bool
*/
protected function iterate($clusters) protected function iterate($clusters)
{ {
$continue = false; $continue = false;
// migration storages // migration storages
$attach = new SplObjectStorage; $attach = new SplObjectStorage();
$detach = new SplObjectStorage; $detach = new SplObjectStorage();
// calculate proximity amongst points and clusters // calculate proximity amongst points and clusters
foreach ($clusters as $cluster) { foreach ($clusters as $cluster) {
@ -189,8 +241,8 @@ class Space extends SplObjectStorage
// move the point from its old cluster to its closest // move the point from its old cluster to its closest
if ($closest !== $cluster) { if ($closest !== $cluster) {
isset($attach[$closest]) || $attach[$closest] = new SplObjectStorage; isset($attach[$closest]) || $attach[$closest] = new SplObjectStorage();
isset($detach[$cluster]) || $detach[$cluster] = new SplObjectStorage; isset($detach[$cluster]) || $detach[$cluster] = new SplObjectStorage();
$attach[$closest]->attach($point); $attach[$closest]->attach($point);
$detach[$cluster]->attach($point); $detach[$cluster]->attach($point);
@ -201,15 +253,18 @@ class Space extends SplObjectStorage
} }
// perform points migrations // perform points migrations
foreach ($attach as $cluster) foreach ($attach as $cluster) {
$cluster->attachAll($attach[$cluster]); $cluster->attachAll($attach[$cluster]);
}
foreach ($detach as $cluster) foreach ($detach as $cluster) {
$cluster->detachAll($detach[$cluster]); $cluster->detachAll($detach[$cluster]);
}
// update all cluster's centroids // update all cluster's centroids
foreach ($clusters as $cluster) foreach ($clusters as $cluster) {
$cluster->updateCentroid(); $cluster->updateCentroid();
}
return $continue; return $continue;
} }

View File

@ -65,5 +65,4 @@ class InvalidArgumentException extends \Exception
{ {
return new self('Invalid clusters number'); return new self('Invalid clusters number');
} }
} }

View File

@ -20,7 +20,6 @@ class DBSCANTest extends \PHPUnit_Framework_TestCase
$this->assertEquals($clustered, $dbscan->cluster($samples)); $this->assertEquals($clustered, $dbscan->cluster($samples));
$samples = [[1, 1], [6, 6], [1, -1], [5, 6], [-1, -1], [7, 8], [-1, 1], [7, 7]]; $samples = [[1, 1], [6, 6], [1, -1], [5, 6], [-1, -1], [7, 8], [-1, 1], [7, 7]];
$clustered = [ $clustered = [
[[1, 1], [1, -1], [-1, -1], [-1, 1]], [[1, 1], [1, -1], [-1, -1], [-1, 1]],
@ -31,5 +30,4 @@ class DBSCANTest extends \PHPUnit_Framework_TestCase
$this->assertEquals($clustered, $dbscan->cluster($samples)); $this->assertEquals($clustered, $dbscan->cluster($samples));
} }
} }

View File

@ -1,4 +1,5 @@
<?php <?php
declare (strict_types = 1); declare (strict_types = 1);
namespace tests\Clustering; namespace tests\Clustering;
@ -7,7 +8,6 @@ use Phpml\Clustering\KMeans;
class KMeansTest extends \PHPUnit_Framework_TestCase class KMeansTest extends \PHPUnit_Framework_TestCase
{ {
public function testKMeansSamplesClustering() public function testKMeansSamplesClustering()
{ {
$samples = [[1, 1], [8, 7], [1, 2], [7, 8], [2, 1], [8, 9]]; $samples = [[1, 1], [8, 7], [1, 2], [7, 8], [2, 1], [8, 9]];
@ -46,7 +46,7 @@ class KMeansTest extends \PHPUnit_Framework_TestCase
$this->assertEquals(4, count($clusters)); $this->assertEquals(4, count($clusters));
foreach ($samples as $index => $sample) { foreach ($samples as $index => $sample) {
for($i=0; $i<4; $i++) { for ($i = 0; $i < 4; ++$i) {
if (in_array($sample, $clusters[$i])) { if (in_array($sample, $clusters[$i])) {
unset($samples[$index]); unset($samples[$index]);
} }
@ -54,5 +54,4 @@ class KMeansTest extends \PHPUnit_Framework_TestCase
} }
$this->assertEquals(0, count($samples)); $this->assertEquals(0, count($samples));
} }
} }