From 9938cf29113d8797a58a9d2f526f286d8c65f53d Mon Sep 17 00:00:00 2001 From: Yuji Uchiyama Date: Tue, 9 Jan 2018 18:53:02 +0900 Subject: [PATCH] Rewrite DBSCAN (#185) * Add testcases to DBSCAN * Fix DBSCAN implementation * Refactoring DBSCAN implementation * Fix coding style --- src/Phpml/Clustering/DBSCAN.php | 112 ++++++++++++++++---------- tests/Phpml/Clustering/DBSCANTest.php | 34 ++++++++ 2 files changed, 103 insertions(+), 43 deletions(-) diff --git a/src/Phpml/Clustering/DBSCAN.php b/src/Phpml/Clustering/DBSCAN.php index 3546ebf..e96c5ff 100644 --- a/src/Phpml/Clustering/DBSCAN.php +++ b/src/Phpml/Clustering/DBSCAN.php @@ -4,12 +4,13 @@ declare(strict_types=1); namespace Phpml\Clustering; -use array_merge; use Phpml\Math\Distance; use Phpml\Math\Distance\Euclidean; class DBSCAN implements Clusterer { + private const NOISE = -1; + /** * @var float */ @@ -38,57 +39,82 @@ class DBSCAN implements Clusterer public function cluster(array $samples): array { - $clusters = []; - $visited = []; + $labels = []; + $n = 0; foreach ($samples as $index => $sample) { - if (isset($visited[$index])) { + if (isset($labels[$index])) { continue; } - $visited[$index] = true; + $neighborIndices = $this->getIndicesInRegion($sample, $samples); - $regionSamples = $this->getSamplesInRegion($sample, $samples); - if (count($regionSamples) >= $this->minSamples) { - $clusters[] = $this->expandCluster($regionSamples, $visited); + if (count($neighborIndices) < $this->minSamples) { + $labels[$index] = self::NOISE; + + continue; } + + $labels[$index] = $n; + + $this->expandCluster($samples, $neighborIndices, $labels, $n); + + ++$n; + } + + return $this->groupByCluster($samples, $labels, $n); + } + + private function expandCluster(array $samples, array $seeds, array &$labels, int $n): void + { + while (($index = array_pop($seeds)) !== null) { + if (isset($labels[$index])) { + if ($labels[$index] === self::NOISE) { + $labels[$index] = $n; + } + + continue; + } + + $labels[$index] = $n; + + $sample = $samples[$index]; + $neighborIndices = $this->getIndicesInRegion($sample, $samples); + + if (count($neighborIndices) >= $this->minSamples) { + $seeds = array_unique(array_merge($seeds, $neighborIndices)); + } + } + } + + private function getIndicesInRegion(array $center, array $samples): array + { + $indices = []; + + foreach ($samples as $index => $sample) { + if ($this->distanceMetric->distance($center, $sample) < $this->epsilon) { + $indices[] = $index; + } + } + + return $indices; + } + + private function groupByCluster(array $samples, array $labels, int $n): array + { + $clusters = array_fill(0, $n, []); + + foreach ($samples as $index => $sample) { + if ($labels[$index] !== self::NOISE) { + $clusters[$labels[$index]][$index] = $sample; + } + } + + // Reindex (i.e. to 0, 1, 2, ...) integer indices for backword compatibility + foreach ($clusters as $index => $cluster) { + $clusters[$index] = array_merge($cluster, []); } return $clusters; } - - private function getSamplesInRegion(array $localSample, array $samples): array - { - $region = []; - - foreach ($samples as $index => $sample) { - if ($this->distanceMetric->distance($localSample, $sample) < $this->epsilon) { - $region[$index] = $sample; - } - } - - return $region; - } - - private function expandCluster(array $samples, array &$visited): array - { - $cluster = []; - - $clusterMerge = [[]]; - foreach ($samples as $index => $sample) { - if (!isset($visited[$index])) { - $visited[$index] = true; - $regionSamples = $this->getSamplesInRegion($sample, $samples); - if (count($regionSamples) > $this->minSamples) { - $clusterMerge[] = $regionSamples; - } - } - - $cluster[$index] = $sample; - } - - $cluster = array_merge($cluster, ...$clusterMerge); - - return $cluster; - } } diff --git a/tests/Phpml/Clustering/DBSCANTest.php b/tests/Phpml/Clustering/DBSCANTest.php index c0d0401..3c6d08d 100644 --- a/tests/Phpml/Clustering/DBSCANTest.php +++ b/tests/Phpml/Clustering/DBSCANTest.php @@ -59,4 +59,38 @@ class DBSCANTest extends TestCase $this->assertEquals($clustered, $dbscan->cluster($samples)); } + + public function testClusterEpsilonSmall(): void + { + $samples = [[0], [1], [2]]; + $clustered = [ + ]; + + $dbscan = new DBSCAN($epsilon = 0.5, $minSamples = 2); + + $this->assertEquals($clustered, $dbscan->cluster($samples)); + } + + public function testClusterEpsilonBoundary(): void + { + $samples = [[0], [1], [2]]; + $clustered = [ + ]; + + $dbscan = new DBSCAN($epsilon = 1.0, $minSamples = 2); + + $this->assertEquals($clustered, $dbscan->cluster($samples)); + } + + public function testClusterEpsilonLarge(): void + { + $samples = [[0], [1], [2]]; + $clustered = [ + [[0], [1], [2]], + ]; + + $dbscan = new DBSCAN($epsilon = 1.5, $minSamples = 2); + + $this->assertEquals($clustered, $dbscan->cluster($samples)); + } }