diff --git a/src/Association/Apriori.php b/src/Association/Apriori.php index c3f9c91..abbdf47 100644 --- a/src/Association/Apriori.php +++ b/src/Association/Apriori.php @@ -212,9 +212,9 @@ class Apriori implements Associator */ private function frequent(array $samples): array { - return array_filter($samples, function ($entry) { + return array_values(array_filter($samples, function ($entry) { return $this->support($entry) >= $this->support; - }); + })); } /** @@ -234,7 +234,7 @@ class Apriori implements Associator continue; } - $candidate = array_unique(array_merge($p, $q)); + $candidate = array_values(array_unique(array_merge($p, $q))); if ($this->contains($candidates, $candidate)) { continue; diff --git a/tests/Association/AprioriTest.php b/tests/Association/AprioriTest.php index 5ed4f8c..81a6ce6 100644 --- a/tests/Association/AprioriTest.php +++ b/tests/Association/AprioriTest.php @@ -101,6 +101,18 @@ class AprioriTest extends TestCase $this->assertEquals([['a']], $L[1]); } + public function testAprioriL3(): void + { + $sample = [['a', 'b', 'c']]; + + $apriori = new Apriori(0, 0); + $apriori->train($sample, []); + + $L = $apriori->apriori(); + + $this->assertEquals([['a', 'b', 'c']], $L[3]); + } + public function testGetRules(): void { $apriori = new Apriori(0.4, 0.8);