diff --git a/src/Phpml/CrossValidation/RandomSplit.php b/src/Phpml/CrossValidation/RandomSplit.php index 92de976..c1e709e 100644 --- a/src/Phpml/CrossValidation/RandomSplit.php +++ b/src/Phpml/CrossValidation/RandomSplit.php @@ -5,101 +5,26 @@ declare (strict_types = 1); namespace Phpml\CrossValidation; use Phpml\Dataset\Dataset; -use Phpml\Exception\InvalidArgumentException; -class RandomSplit +class RandomSplit extends Split { - /** - * @var array - */ - private $trainSamples = []; - - /** - * @var array - */ - private $testSamples = []; - - /** - * @var array - */ - private $trainLabels = []; - - /** - * @var array - */ - private $testLabels = []; - /** * @param Dataset $dataset * @param float $testSize - * @param int $seed - * - * @throws InvalidArgumentException */ - public function __construct(Dataset $dataset, float $testSize = 0.3, int $seed = null) + protected function splitDataset(Dataset $dataset, float $testSize) { - if (0 >= $testSize || 1 <= $testSize) { - throw InvalidArgumentException::percentNotInRange('testSize'); - } - $this->seedGenerator($seed); - $samples = $dataset->getSamples(); $labels = $dataset->getTargets(); $datasetSize = count($samples); + $testCount = count($this->testSamples); for ($i = $datasetSize; $i > 0; --$i) { $key = mt_rand(0, $datasetSize - 1); - $setName = count($this->testSamples) / $datasetSize >= $testSize ? 'train' : 'test'; + $setName = (count($this->testSamples) - $testCount) / $datasetSize >= $testSize ? 'train' : 'test'; $this->{$setName.'Samples'}[] = $samples[$key]; $this->{$setName.'Labels'}[] = $labels[$key]; - - $samples = array_values($samples); - $labels = array_values($labels); - } - } - - /** - * @return array - */ - public function getTrainSamples() - { - return $this->trainSamples; - } - - /** - * @return array - */ - public function getTestSamples() - { - return $this->testSamples; - } - - /** - * @return array - */ - public function getTrainLabels() - { - return $this->trainLabels; - } - - /** - * @return array - */ - public function getTestLabels() - { - return $this->testLabels; - } - - /** - * @param int|null $seed - */ - private function seedGenerator(int $seed = null) - { - if (null === $seed) { - mt_srand(); - } else { - mt_srand($seed); } } } diff --git a/src/Phpml/CrossValidation/Split.php b/src/Phpml/CrossValidation/Split.php new file mode 100644 index 0000000..3478f54 --- /dev/null +++ b/src/Phpml/CrossValidation/Split.php @@ -0,0 +1,94 @@ += $testSize || 1 <= $testSize) { + throw InvalidArgumentException::percentNotInRange('testSize'); + } + $this->seedGenerator($seed); + + $this->splitDataset($dataset, $testSize); + } + + abstract protected function splitDataset(Dataset $dataset, float $testSize); + + /** + * @return array + */ + public function getTrainSamples() + { + return $this->trainSamples; + } + + /** + * @return array + */ + public function getTestSamples() + { + return $this->testSamples; + } + + /** + * @return array + */ + public function getTrainLabels() + { + return $this->trainLabels; + } + + /** + * @return array + */ + public function getTestLabels() + { + return $this->testLabels; + } + + /** + * @param int|null $seed + */ + protected function seedGenerator(int $seed = null) + { + if (null === $seed) { + mt_srand(); + } else { + mt_srand($seed); + } + } +} diff --git a/src/Phpml/CrossValidation/StratifiedRandomSplit.php b/src/Phpml/CrossValidation/StratifiedRandomSplit.php new file mode 100644 index 0000000..10af303 --- /dev/null +++ b/src/Phpml/CrossValidation/StratifiedRandomSplit.php @@ -0,0 +1,62 @@ +splitByTarget($dataset); + + foreach ($datasets as $targetSet) { + parent::splitDataset($targetSet, $testSize); + } + } + + /** + * @param Dataset $dataset + * + * @return Dataset[]|array + */ + private function splitByTarget(Dataset $dataset): array + { + $targets = $dataset->getTargets(); + $samples = $dataset->getSamples(); + + $uniqueTargets = array_unique($targets); + $split = array_combine($uniqueTargets, array_fill(0, count($uniqueTargets), [])); + + foreach ($samples as $key => $sample) { + $split[$targets[$key]][] = $sample; + } + + $datasets = $this->createDatasets($uniqueTargets, $split); + + return $datasets; + } + + /** + * @param array $uniqueTargets + * @param array $split + * + * @return array + */ + private function createDatasets(array $uniqueTargets, array $split): array + { + $datasets = []; + foreach ($uniqueTargets as $target) { + $datasets[$target] = new ArrayDataset($split[$target], array_fill(0, count($split[$target]), $target)); + } + + return $datasets; + } +} diff --git a/tests/Phpml/CrossValidation/StratifiedRandomSplitTest.php b/tests/Phpml/CrossValidation/StratifiedRandomSplitTest.php new file mode 100644 index 0000000..14802de --- /dev/null +++ b/tests/Phpml/CrossValidation/StratifiedRandomSplitTest.php @@ -0,0 +1,65 @@ +assertEquals(2, $this->countSamplesByTarget($split->getTestLabels(), 'a')); + $this->assertEquals(2, $this->countSamplesByTarget($split->getTestLabels(), 'b')); + + $split = new StratifiedRandomSplit($dataset, 0.25); + + $this->assertEquals(1, $this->countSamplesByTarget($split->getTestLabels(), 'a')); + $this->assertEquals(1, $this->countSamplesByTarget($split->getTestLabels(), 'b')); + } + + public function testDatasetStratifiedRandomSplitWithEvenDistributionAndNumericTargets() + { + $dataset = new ArrayDataset( + $samples = [[1], [2], [3], [4], [5], [6], [7], [8]], + $labels = [1, 2, 1, 2, 1, 2, 1, 2] + ); + + $split = new StratifiedRandomSplit($dataset, 0.5); + + $this->assertEquals(2, $this->countSamplesByTarget($split->getTestLabels(), 1)); + $this->assertEquals(2, $this->countSamplesByTarget($split->getTestLabels(), 2)); + + $split = new StratifiedRandomSplit($dataset, 0.25); + + $this->assertEquals(1, $this->countSamplesByTarget($split->getTestLabels(), 1)); + $this->assertEquals(1, $this->countSamplesByTarget($split->getTestLabels(), 2)); + } + + /** + * @param $splitTargets + * @param $countTarget + * + * @return int + */ + private function countSamplesByTarget($splitTargets, $countTarget): int + { + $count = 0; + foreach ($splitTargets as $target) { + if ($target === $countTarget) { + ++$count; + } + } + + return $count; + } +}