php-ml/src/Clustering/KMeans/Space.php

260 lines
6.3 KiB
PHP
Raw Normal View History

2016-05-01 21:17:09 +00:00
<?php
2016-05-01 21:36:33 +00:00
2016-11-20 21:53:17 +00:00
declare(strict_types=1);
2016-05-01 21:17:09 +00:00
namespace Phpml\Clustering\KMeans;
use InvalidArgumentException;
use LogicException;
2016-05-01 21:36:33 +00:00
use Phpml\Clustering\KMeans;
use SplObjectStorage;
2016-05-01 21:17:09 +00:00
class Space extends SplObjectStorage
{
2016-05-01 21:36:33 +00:00
/**
* @var int
*/
protected $dimension;
public function __construct(int $dimension)
2016-05-01 21:36:33 +00:00
{
if ($dimension < 1) {
throw new LogicException('a space dimension cannot be null or negative');
}
$this->dimension = $dimension;
}
public function toArray(): array
2016-05-01 21:36:33 +00:00
{
$points = [];
2018-10-28 06:44:52 +00:00
/** @var Point $point */
2016-05-01 21:36:33 +00:00
foreach ($this as $point) {
$points[] = $point->toArray();
}
return ['points' => $points];
}
2018-10-28 06:44:52 +00:00
/**
* @param mixed $label
*/
public function newPoint(array $coordinates, $label = null): Point
2016-05-01 21:36:33 +00:00
{
2018-10-28 06:44:52 +00:00
if (count($coordinates) !== $this->dimension) {
2016-05-01 21:36:33 +00:00
throw new LogicException('('.implode(',', $coordinates).') is not a point of this space');
}
return new Point($coordinates, $label);
2016-05-01 21:36:33 +00:00
}
/**
2018-10-28 06:44:52 +00:00
* @param mixed $label
* @param mixed $data
2016-05-01 21:36:33 +00:00
*/
public function addPoint(array $coordinates, $label = null, $data = null): void
2016-05-01 21:36:33 +00:00
{
$this->attach($this->newPoint($coordinates, $label), $data);
2016-05-01 21:36:33 +00:00
}
/**
2018-10-28 06:44:52 +00:00
* @param object $point
* @param mixed $data
2016-05-01 21:36:33 +00:00
*/
public function attach($point, $data = null): void
2016-05-01 21:36:33 +00:00
{
if (!$point instanceof Point) {
throw new InvalidArgumentException('can only attach points to spaces');
}
parent::attach($point, $data);
2016-05-01 21:36:33 +00:00
}
public function getDimension(): int
2016-05-01 21:36:33 +00:00
{
return $this->dimension;
}
/**
* @return array|bool
*/
public function getBoundaries()
{
if (count($this) === 0) {
2016-05-01 21:36:33 +00:00
return false;
}
$min = $this->newPoint(array_fill(0, $this->dimension, null));
$max = $this->newPoint(array_fill(0, $this->dimension, null));
2018-10-28 06:44:52 +00:00
/** @var self $point */
2016-05-01 21:36:33 +00:00
foreach ($this as $point) {
for ($n = 0; $n < $this->dimension; ++$n) {
2018-10-28 06:44:52 +00:00
if ($min[$n] === null || $min[$n] > $point[$n]) {
$min[$n] = $point[$n];
}
if ($max[$n] === null || $max[$n] < $point[$n]) {
$max[$n] = $point[$n];
}
2016-05-01 21:36:33 +00:00
}
}
2017-01-31 19:33:08 +00:00
return [$min, $max];
2016-05-01 21:36:33 +00:00
}
public function getRandomPoint(Point $min, Point $max): Point
2016-05-01 21:36:33 +00:00
{
$point = $this->newPoint(array_fill(0, $this->dimension, null));
for ($n = 0; $n < $this->dimension; ++$n) {
$point[$n] = random_int($min[$n], $max[$n]);
2016-05-01 21:36:33 +00:00
}
return $point;
}
/**
* @return Cluster[]
2016-05-01 21:36:33 +00:00
*/
public function cluster(int $clustersNumber, int $initMethod = KMeans::INIT_RANDOM): array
2016-05-01 21:36:33 +00:00
{
2016-05-02 12:02:00 +00:00
$clusters = $this->initializeClusters($clustersNumber, $initMethod);
2016-05-01 21:36:33 +00:00
do {
2016-05-02 12:02:00 +00:00
} while (!$this->iterate($clusters));
2016-05-01 21:36:33 +00:00
return $clusters;
}
/**
* @return Cluster[]
2016-05-01 21:36:33 +00:00
*/
protected function initializeClusters(int $clustersNumber, int $initMethod): array
2016-05-01 21:36:33 +00:00
{
2016-05-02 12:02:00 +00:00
switch ($initMethod) {
2016-05-01 21:36:33 +00:00
case KMeans::INIT_RANDOM:
2016-05-02 21:36:58 +00:00
$clusters = $this->initializeRandomClusters($clustersNumber);
2016-05-01 21:36:33 +00:00
break;
case KMeans::INIT_KMEANS_PLUS_PLUS:
2016-05-02 21:36:58 +00:00
$clusters = $this->initializeKMPPClusters($clustersNumber);
2016-05-01 21:36:33 +00:00
break;
default:
return [];
2016-05-01 21:36:33 +00:00
}
2016-05-01 21:36:33 +00:00
$clusters[0]->attachAll($this);
return $clusters;
}
2018-10-28 06:44:52 +00:00
/**
* @param Cluster[] $clusters
*/
protected function iterate(array $clusters): bool
2016-05-01 21:36:33 +00:00
{
2016-05-02 12:02:00 +00:00
$convergence = true;
2016-05-01 21:36:33 +00:00
$attach = new SplObjectStorage();
$detach = new SplObjectStorage();
foreach ($clusters as $cluster) {
foreach ($cluster as $point) {
$closest = $point->getClosest($clusters);
if ($closest !== $cluster) {
$attach[$closest] ?? $attach[$closest] = new SplObjectStorage();
$detach[$cluster] ?? $detach[$cluster] = new SplObjectStorage();
2016-05-01 21:36:33 +00:00
$attach[$closest]->attach($point);
$detach[$cluster]->attach($point);
2016-05-02 12:02:00 +00:00
$convergence = false;
2016-05-01 21:36:33 +00:00
}
}
}
2018-10-28 06:44:52 +00:00
/** @var Cluster $cluster */
2016-05-01 21:36:33 +00:00
foreach ($attach as $cluster) {
$cluster->attachAll($attach[$cluster]);
}
2018-10-28 06:44:52 +00:00
/** @var Cluster $cluster */
2016-05-01 21:36:33 +00:00
foreach ($detach as $cluster) {
$cluster->detachAll($detach[$cluster]);
}
foreach ($clusters as $cluster) {
$cluster->updateCentroid();
}
2016-05-02 12:02:00 +00:00
return $convergence;
2016-05-01 21:36:33 +00:00
}
2016-05-02 21:36:58 +00:00
2018-10-28 06:44:52 +00:00
/**
* @return Cluster[]
*/
protected function initializeKMPPClusters(int $clustersNumber): array
2016-05-02 21:36:58 +00:00
{
$clusters = [];
$this->rewind();
2018-10-28 06:44:52 +00:00
/** @var Point $current */
$current = $this->current();
$clusters[] = new Cluster($this, $current->getCoordinates());
2016-05-02 21:36:58 +00:00
$distances = new SplObjectStorage();
for ($i = 1; $i < $clustersNumber; ++$i) {
$sum = 0;
2018-10-28 06:44:52 +00:00
/** @var Point $point */
2016-05-02 21:36:58 +00:00
foreach ($this as $point) {
2018-10-28 06:44:52 +00:00
$closest = $point->getClosest($clusters);
if ($closest === null) {
continue;
}
$distance = $point->getDistanceWith($closest);
2016-05-02 21:36:58 +00:00
$sum += $distances[$point] = $distance;
}
$sum = random_int(0, (int) $sum);
2018-10-28 06:44:52 +00:00
/** @var Point $point */
2016-05-02 21:36:58 +00:00
foreach ($this as $point) {
$sum -= $distances[$point];
if ($sum > 0) {
2016-05-02 21:36:58 +00:00
continue;
}
$clusters[] = new Cluster($this, $point->getCoordinates());
2016-05-02 21:36:58 +00:00
break;
}
}
return $clusters;
}
2018-10-28 06:44:52 +00:00
/**
* @return Cluster[]
*/
private function initializeRandomClusters(int $clustersNumber): array
{
$clusters = [];
[$min, $max] = $this->getBoundaries();
for ($n = 0; $n < $clustersNumber; ++$n) {
$clusters[] = new Cluster($this, $this->getRandomPoint($min, $max)->getCoordinates());
}
return $clusters;
}
2016-05-01 21:17:09 +00:00
}