kmeans clustering

This commit is contained in:
Arkadiusz Kondas 2016-05-01 23:17:09 +02:00
parent 01a2499754
commit c0513e9b82
7 changed files with 532 additions and 5 deletions

View File

@ -0,0 +1,51 @@
<?php
declare(strict_types = 1);
namespace Phpml\Clustering;
use Phpml\Clustering\KMeans\Space;
use Phpml\Exception\InvalidArgumentException;
class KMeans implements Clusterer
{
/**
* @var int
*/
private $clustersNumber;
/**
* @param int $clustersNumber
*
* @throws InvalidArgumentException
*/
public function __construct(int $clustersNumber)
{
if($clustersNumber <= 0) {
throw InvalidArgumentException::invalidClustersNumber();
}
$this->clustersNumber = $clustersNumber;
}
/**
* @param array $samples
*
* @return array
*/
public function cluster(array $samples)
{
$space = new Space(count($samples[0]));
foreach ($samples as $sample) {
$space->addPoint($sample);
}
$clusters = [];
foreach ($space->solve($this->clustersNumber) as $cluster)
{
$clusters[] = $cluster->getPoints();
}
return $clusters;
}
}

View File

@ -0,0 +1,101 @@
<?php
declare(strict_types = 1);
namespace Phpml\Clustering\KMeans;
use \IteratorAggregate;
use \Countable;
use \SplObjectStorage;
use \LogicException;
class Cluster extends Point implements IteratorAggregate, Countable
{
protected $space;
/**
* @var SplObjectStorage|Point[]
*/
protected $points;
public function __construct(Space $space, array $coordinates)
{
parent::__construct($space, $coordinates);
$this->points = new SplObjectStorage;
}
/**
* @return array
*/
public function getPoints()
{
$points = [];
foreach ($this->points as $point) {
$points[] = $point->toArray();
}
return $points;
}
public function toArray()
{
$points = array();
foreach ($this->points as $point)
$points[] = $point->toArray();
return array(
'centroid' => parent::toArray(),
'points' => $points,
);
}
public function attach(Point $point)
{
if ($point instanceof self)
throw new LogicException("cannot attach a cluster to another");
$this->points->attach($point);
return $point;
}
public function detach(Point $point)
{
$this->points->detach($point);
return $point;
}
public function attachAll(SplObjectStorage $points)
{
$this->points->addAll($points);
}
public function detachAll(SplObjectStorage $points)
{
$this->points->removeAll($points);
}
public function updateCentroid()
{
if (!$count = count($this->points))
return;
$centroid = $this->space->newPoint(array_fill(0, $this->dimention, 0));
foreach ($this->points as $point)
for ($n=0; $n<$this->dimention; $n++)
$centroid->coordinates[$n] += $point->coordinates[$n];
for ($n=0; $n<$this->dimention; $n++)
$this->coordinates[$n] = $centroid->coordinates[$n] / $count;
}
public function getIterator()
{
return $this->points;
}
public function count()
{
return count($this->points);
}
}

View File

@ -0,0 +1,95 @@
<?php
declare(strict_types = 1);
namespace Phpml\Clustering\KMeans;
use \ArrayAccess;
use \LogicException;
class Point implements ArrayAccess
{
protected $space;
protected $dimention;
protected $coordinates;
public function __construct(Space $space, array $coordinates)
{
$this->space = $space;
$this->dimention = $space->getDimention();
$this->coordinates = $coordinates;
}
public function toArray()
{
return $this->coordinates;
}
public function getDistanceWith(self $point, $precise = true)
{
if ($point->space !== $this->space)
throw new LogicException("can only calculate distances from points in the same space");
$distance = 0;
for ($n=0; $n<$this->dimention; $n++) {
$difference = $this->coordinates[$n] - $point->coordinates[$n];
$distance += $difference * $difference;
}
return $precise ? sqrt($distance) : $distance;
}
public function getClosest($points)
{
foreach($points as $point) {
$distance = $this->getDistanceWith($point, false);
if (!isset($minDistance)) {
$minDistance = $distance;
$minPoint = $point;
continue;
}
if ($distance < $minDistance) {
$minDistance = $distance;
$minPoint = $point;
}
}
return $minPoint;
}
public function belongsTo(Space $space)
{
return $this->space === $space;
}
public function getSpace()
{
return $this->space;
}
public function getCoordinates()
{
return $this->coordinates;
}
public function offsetExists($offset)
{
return isset($this->coordinates[$offset]);
}
public function offsetGet($offset)
{
return $this->coordinates[$offset];
}
public function offsetSet($offset, $value)
{
$this->coordinates[$offset] = $value;
}
public function offsetUnset($offset)
{
unset($this->coordinates[$offset]);
}
}

View File

@ -0,0 +1,216 @@
<?php
declare(strict_types = 1);
namespace Phpml\Clustering\KMeans;
use \SplObjectStorage;
use \LogicException;
use \InvalidArgumentException;
class Space extends SplObjectStorage
{
// Default seeding method, initial cluster centroid are randomly choosen
const SEED_DEFAULT = 1;
// Alternative seeding method by David Arthur and Sergei Vassilvitskii
// (see http://en.wikipedia.org/wiki/K-means++)
const SEED_DASV = 2;
protected $dimention;
public function __construct($dimention)
{
if ($dimention < 1)
throw new LogicException("a space dimention cannot be null or negative");
$this->dimention = $dimention;
}
public function toArray()
{
$points = array();
foreach ($this as $point)
$points[] = $point->toArray();
return array('points' => $points);
}
public function newPoint(array $coordinates)
{
if (count($coordinates) != $this->dimention)
throw new LogicException("(" . implode(',', $coordinates) . ") is not a point of this space");
return new Point($this, $coordinates);
}
public function addPoint(array $coordinates, $data = null)
{
return $this->attach($this->newPoint($coordinates), $data);
}
public function attach($point, $data = null)
{
if (!$point instanceof Point)
throw new InvalidArgumentException("can only attach points to spaces");
return parent::attach($point, $data);
}
public function getDimention()
{
return $this->dimention;
}
public function getBoundaries()
{
if (!count($this))
return false;
$min = $this->newPoint(array_fill(0, $this->dimention, null));
$max = $this->newPoint(array_fill(0, $this->dimention, null));
foreach ($this as $point) {
for ($n=0; $n < $this->dimention; $n++) {
($min[$n] > $point[$n] || $min[$n] === null) && $min[$n] = $point[$n];
($max[$n] < $point[$n] || $max[$n] === null) && $max[$n] = $point[$n];
}
}
return array($min, $max);
}
public function getRandomPoint(Point $min, Point $max)
{
$point = $this->newPoint(array_fill(0, $this->dimention, null));
for ($n=0; $n < $this->dimention; $n++)
$point[$n] = rand($min[$n], $max[$n]);
return $point;
}
/**
* @param $nbClusters
* @param int $seed
* @param null $iterationCallback
* @return array|Cluster[]
*/
public function solve($nbClusters, $seed = self::SEED_DEFAULT, $iterationCallback = null)
{
if ($iterationCallback && !is_callable($iterationCallback))
throw new InvalidArgumentException("invalid iteration callback");
// initialize K clusters
$clusters = $this->initializeClusters($nbClusters, $seed);
// there's only one cluster, clusterization has no meaning
if (count($clusters) == 1)
return $clusters[0];
// until convergence is reached
do {
$iterationCallback && $iterationCallback($this, $clusters);
} while ($this->iterate($clusters));
// clustering is done.
return $clusters;
}
protected function initializeClusters($nbClusters, $seed)
{
if ($nbClusters <= 0)
throw new InvalidArgumentException("invalid clusters number");
switch ($seed) {
// the default seeding method chooses completely random centroid
case self::SEED_DEFAULT:
// get the space boundaries to avoid placing clusters centroid too far from points
list($min, $max) = $this->getBoundaries();
// initialize N clusters with a random point within space boundaries
for ($n=0; $n<$nbClusters; $n++)
$clusters[] = new Cluster($this, $this->getRandomPoint($min, $max)->getCoordinates());
break;
// the DASV seeding method consists of finding good initial centroids for the clusters
case self::SEED_DASV:
// find a random point
$position = rand(1, count($this));
for ($i=1, $this->rewind(); $i<$position && $this->valid(); $i++, $this->next());
$clusters[] = new Cluster($this, $this->current()->getCoordinates());
// retains the distances between points and their closest clusters
$distances = new SplObjectStorage;
// create k clusters
for ($i=1; $i<$nbClusters; $i++) {
$sum = 0;
// for each points, get the distance with the closest centroid already choosen
foreach ($this as $point) {
$distance = $point->getDistanceWith($point->getClosest($clusters));
$sum += $distances[$point] = $distance;
}
// choose a new random point using a weighted probability distribution
$sum = rand(0, $sum);
foreach ($this as $point) {
if (($sum -= $distances[$point]) > 0)
continue;
$clusters[] = new Cluster($this, $point->getCoordinates());
break;
}
}
break;
}
// assing all points to the first cluster
$clusters[0]->attachAll($this);
return $clusters;
}
protected function iterate($clusters)
{
$continue = false;
// migration storages
$attach = new SplObjectStorage;
$detach = new SplObjectStorage;
// calculate proximity amongst points and clusters
foreach ($clusters as $cluster) {
foreach ($cluster as $point) {
// find the closest cluster
$closest = $point->getClosest($clusters);
// move the point from its old cluster to its closest
if ($closest !== $cluster) {
isset($attach[$closest]) || $attach[$closest] = new SplObjectStorage;
isset($detach[$cluster]) || $detach[$cluster] = new SplObjectStorage;
$attach[$closest]->attach($point);
$detach[$cluster]->attach($point);
$continue = true;
}
}
}
// perform points migrations
foreach ($attach as $cluster)
$cluster->attachAll($attach[$cluster]);
foreach ($detach as $cluster)
$cluster->detachAll($detach[$cluster]);
// update all cluster's centroids
foreach ($clusters as $cluster)
$cluster->updateCentroid();
return $continue;
}
}

View File

@ -57,4 +57,13 @@ class InvalidArgumentException extends \Exception
{
return new self('Inconsistent matrix aupplied');
}
/**
* @return InvalidArgumentException
*/
public static function invalidClustersNumber()
{
return new self('Invalid clusters number');
}
}

View File

@ -11,7 +11,6 @@ class DBSCANTest extends \PHPUnit_Framework_TestCase
public function testDBSCANSamplesClustering()
{
$samples = [[1, 1], [8, 7], [1, 2], [7, 8], [2, 1], [8, 9]];
$clustered = [
[[1, 1], [1, 2], [2, 1]],
[[8, 7], [7, 8], [8, 9]],
@ -20,12 +19,9 @@ class DBSCANTest extends \PHPUnit_Framework_TestCase
$dbscan = new DBSCAN($epsilon = 2, $minSamples = 3);
$this->assertEquals($clustered, $dbscan->cluster($samples));
}
public function testDBSCANSamplesInCircleClustering()
{
$samples = [[1, 1], [6, 6], [1, -1], [5, 6], [-1, -1], [7, 8], [-1, 1], [7, 7]];
$clustered = [
[[1, 1], [1, -1], [-1, -1], [-1, 1]],
[[6, 6], [5, 6], [7, 8], [7, 7]],
@ -35,4 +31,5 @@ class DBSCANTest extends \PHPUnit_Framework_TestCase
$this->assertEquals($clustered, $dbscan->cluster($samples));
}
}

View File

@ -0,0 +1,58 @@
<?php
declare(strict_types = 1);
namespace tests\Clustering;
use Phpml\Clustering\KMeans;
class KMeansTest extends \PHPUnit_Framework_TestCase
{
public function testKMeansSamplesClustering()
{
$samples = [[1, 1], [8, 7], [1, 2], [7, 8], [2, 1], [8, 9]];
$kmeans = new KMeans(2);
$clusters = $kmeans->cluster($samples);
$this->assertEquals(2, count($clusters));
foreach ($samples as $index => $sample) {
if(in_array($sample, $clusters[0]) || in_array($sample, $clusters[1])) {
unset($samples[$index]);
}
}
$this->assertEquals(0, count($samples));
}
public function testKMeansMoreSamplesClustering()
{
$samples = [
[80,55],[86,59],[19,85],[41,47],[57,58],
[76,22],[94,60],[13,93],[90,48],[52,54],
[62,46],[88,44],[85,24],[63,14],[51,40],
[75,31],[86,62],[81,95],[47,22],[43,95],
[71,19],[17,65],[69,21],[59,60],[59,12],
[15,22],[49,93],[56,35],[18,20],[39,59],
[50,15],[81,36],[67,62],[32,15],[75,65],
[10,47],[75,18],[13,45],[30,62],[95,79],
[64,11],[92,14],[94,49],[39,13],[60,68],
[62,10],[74,44],[37,42],[97,60],[47,73],
];
$kmeans = new KMeans(4);
$clusters = $kmeans->cluster($samples);
$this->assertEquals(4, count($clusters));
foreach ($samples as $index => $sample) {
for($i=0; $i<4; $i++) {
if(in_array($sample, $clusters[$i])) {
unset($samples[$index]);
}
}
}
$this->assertEquals(0, count($samples));
}
}