From bbcc8a3e685290bea034186d0458fb9c77846fd6 Mon Sep 17 00:00:00 2001 From: Arkadiusz Kondas Date: Thu, 7 Apr 2016 22:12:36 +0200 Subject: [PATCH] random split implementation and tests --- src/Phpml/CrossValidation/RandomSplit.php | 39 +++++++- src/Phpml/Dataset/ArrayDataset.php | 46 +++++++++ src/Phpml/Dataset/CsvDataset.php | 6 +- src/Phpml/Dataset/{ => Demo}/Iris.php | 4 +- .../Phpml/CrossValidation/RandomSplitTest.php | 95 +++++++++++++++++++ tests/Phpml/Dataset/{ => Demo}/IrisTest.php | 4 +- 6 files changed, 185 insertions(+), 9 deletions(-) create mode 100644 src/Phpml/Dataset/ArrayDataset.php rename src/Phpml/Dataset/{ => Demo}/Iris.php (79%) create mode 100644 tests/Phpml/CrossValidation/RandomSplitTest.php rename tests/Phpml/Dataset/{ => Demo}/IrisTest.php (86%) diff --git a/src/Phpml/CrossValidation/RandomSplit.php b/src/Phpml/CrossValidation/RandomSplit.php index 3cc6c05..4e2ecc9 100644 --- a/src/Phpml/CrossValidation/RandomSplit.php +++ b/src/Phpml/CrossValidation/RandomSplit.php @@ -29,13 +29,34 @@ class RandomSplit */ private $testLabels = []; - public function __construct(Dataset $dataset, float $testSize = 0.3) + /** + * @param Dataset $dataset + * @param float $testSize + * @param int $seed + * + * @throws InvalidArgumentException + */ + public function __construct(Dataset $dataset, float $testSize = 0.3, int $seed = null) { - if (0 > $testSize || 1 < $testSize) { + if (0 >= $testSize || 1 <= $testSize) { throw InvalidArgumentException::percentNotInRange('testSize'); } + $this->seedGenerator($seed); - // TODO: implement this ! + $samples = $dataset->getSamples(); + $labels = $dataset->getLabels(); + $datasetSize = count($samples); + + for($i=$datasetSize; $i>0; $i--) { + $key = mt_rand(0, $datasetSize-1); + $setName = count($this->testSamples) / $datasetSize >= $testSize ? 'train' : 'test'; + + $this->{$setName.'Samples'}[] = $samples[$key]; + $this->{$setName.'Labels'}[] = $labels[$key]; + + $samples = array_values($samples); + $labels = array_values($labels); + } } /** @@ -69,4 +90,16 @@ class RandomSplit { return $this->testLabels; } + + /** + * @param int|null $seed + */ + private function seedGenerator(int $seed = null) + { + if (null === $seed) { + mt_srand(); + } else { + mt_srand($seed); + } + } } diff --git a/src/Phpml/Dataset/ArrayDataset.php b/src/Phpml/Dataset/ArrayDataset.php new file mode 100644 index 0000000..85f9914 --- /dev/null +++ b/src/Phpml/Dataset/ArrayDataset.php @@ -0,0 +1,46 @@ +samples = $samples; + $this->labels = $labels; + } + + + /** + * @return array + */ + public function getSamples(): array + { + return $this->samples; + } + + /** + * @return array + */ + public function getLabels(): array + { + return $this->labels; + } + +} diff --git a/src/Phpml/Dataset/CsvDataset.php b/src/Phpml/Dataset/CsvDataset.php index 6fa6b42..c21ac97 100644 --- a/src/Phpml/Dataset/CsvDataset.php +++ b/src/Phpml/Dataset/CsvDataset.php @@ -21,7 +21,7 @@ abstract class CsvDataset implements Dataset /** * @var array */ - private $lables = []; + private $labels = []; public function __construct() { @@ -39,7 +39,7 @@ abstract class CsvDataset implements Dataset continue; } $this->samples[] = array_slice($data, 0, 4); - $this->lables[] = $data[4]; + $this->labels[] = $data[4]; } fclose($handle); } else { @@ -60,6 +60,6 @@ abstract class CsvDataset implements Dataset */ public function getLabels(): array { - return $this->lables; + return $this->labels; } } diff --git a/src/Phpml/Dataset/Iris.php b/src/Phpml/Dataset/Demo/Iris.php similarity index 79% rename from src/Phpml/Dataset/Iris.php rename to src/Phpml/Dataset/Demo/Iris.php index 1353989..d544a55 100644 --- a/src/Phpml/Dataset/Iris.php +++ b/src/Phpml/Dataset/Demo/Iris.php @@ -2,7 +2,9 @@ declare (strict_types = 1); -namespace Phpml\Dataset; +namespace Phpml\Dataset\Demo; + +use Phpml\Dataset\CsvDataset; /** * Classes: 3 diff --git a/tests/Phpml/CrossValidation/RandomSplitTest.php b/tests/Phpml/CrossValidation/RandomSplitTest.php new file mode 100644 index 0000000..cbbd497 --- /dev/null +++ b/tests/Phpml/CrossValidation/RandomSplitTest.php @@ -0,0 +1,95 @@ +assertEquals(2, count($randomSplit1->getTestSamples())); + $this->assertEquals(2, count($randomSplit1->getTrainSamples())); + + $randomSplit2 = new RandomSplit($dataset, 0.25); + + $this->assertEquals(1, count($randomSplit2->getTestSamples())); + $this->assertEquals(3, count($randomSplit2->getTrainSamples())); + } + + public function testDatasetRandomSplitWithSameSeed() + { + $dataset = new ArrayDataset( + $samples = [[1], [2], [3], [4], [5], [6], [7], [8]], + $labels = ['a', 'a', 'a', 'a', 'b', 'b', 'b', 'b'] + ); + + $seed = 123; + + $randomSplit1 = new RandomSplit($dataset, 0.5, $seed); + $randomSplit2 = new RandomSplit($dataset, 0.5, $seed); + + $this->assertEquals($randomSplit1->getTestLabels(), $randomSplit2->getTestLabels()); + $this->assertEquals($randomSplit1->getTestSamples(), $randomSplit2->getTestSamples()); + $this->assertEquals($randomSplit1->getTrainLabels(), $randomSplit2->getTrainLabels()); + $this->assertEquals($randomSplit1->getTrainSamples(), $randomSplit2->getTrainSamples()); + } + + public function testDatasetRandomSplitWithDifferentSeed() + { + $dataset = new ArrayDataset( + $samples = [[1], [2], [3], [4], [5], [6], [7], [8]], + $labels = ['a', 'a', 'a', 'a', 'b', 'b', 'b', 'b'] + ); + + $randomSplit1 = new RandomSplit($dataset, 0.5, 4321); + $randomSplit2 = new RandomSplit($dataset, 0.5, 1234); + + $this->assertNotEquals($randomSplit1->getTestLabels(), $randomSplit2->getTestLabels()); + $this->assertNotEquals($randomSplit1->getTestSamples(), $randomSplit2->getTestSamples()); + $this->assertNotEquals($randomSplit1->getTrainLabels(), $randomSplit2->getTrainLabels()); + $this->assertNotEquals($randomSplit1->getTrainSamples(), $randomSplit2->getTrainSamples()); + } + + public function testRandomSplitCorrectSampleAndLabelPosition() + { + $dataset = new ArrayDataset( + $samples = [[1], [2], [3], [4]], + $labels = [1, 2, 3, 4] + ); + + $randomSplit = new RandomSplit($dataset, 0.5); + + $this->assertEquals($randomSplit->getTestSamples()[0][0], $randomSplit->getTestLabels()[0]); + $this->assertEquals($randomSplit->getTestSamples()[1][0], $randomSplit->getTestLabels()[1]); + $this->assertEquals($randomSplit->getTrainSamples()[0][0], $randomSplit->getTrainLabels()[0]); + $this->assertEquals($randomSplit->getTrainSamples()[1][0], $randomSplit->getTrainLabels()[1]); + } + +} \ No newline at end of file diff --git a/tests/Phpml/Dataset/IrisTest.php b/tests/Phpml/Dataset/Demo/IrisTest.php similarity index 86% rename from tests/Phpml/Dataset/IrisTest.php rename to tests/Phpml/Dataset/Demo/IrisTest.php index 99e19ad..1f0da90 100644 --- a/tests/Phpml/Dataset/IrisTest.php +++ b/tests/Phpml/Dataset/Demo/IrisTest.php @@ -2,9 +2,9 @@ declare (strict_types = 1); -namespace tests\Phpml\Dataset; +namespace tests\Phpml\Dataset\Demo; -use Phpml\Dataset\Iris; +use Phpml\Dataset\Demo\Iris; class IrisTest extends \PHPUnit_Framework_TestCase {