2016-07-10 12:13:35 +00:00
|
|
|
<?php
|
|
|
|
|
2016-11-20 21:53:17 +00:00
|
|
|
declare(strict_types=1);
|
2016-07-10 12:13:35 +00:00
|
|
|
|
|
|
|
namespace Phpml\CrossValidation;
|
|
|
|
|
|
|
|
use Phpml\Dataset\ArrayDataset;
|
|
|
|
use Phpml\Dataset\Dataset;
|
|
|
|
|
|
|
|
class StratifiedRandomSplit extends RandomSplit
|
|
|
|
{
|
2017-11-14 20:21:23 +00:00
|
|
|
protected function splitDataset(Dataset $dataset, float $testSize): void
|
2016-07-10 12:13:35 +00:00
|
|
|
{
|
|
|
|
$datasets = $this->splitByTarget($dataset);
|
|
|
|
|
|
|
|
foreach ($datasets as $targetSet) {
|
|
|
|
parent::splitDataset($targetSet, $testSize);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
* @return Dataset[]|array
|
|
|
|
*/
|
2017-11-22 21:16:10 +00:00
|
|
|
private function splitByTarget(Dataset $dataset): array
|
2016-07-10 12:13:35 +00:00
|
|
|
{
|
|
|
|
$targets = $dataset->getTargets();
|
|
|
|
$samples = $dataset->getSamples();
|
|
|
|
|
|
|
|
$uniqueTargets = array_unique($targets);
|
2018-10-28 06:44:52 +00:00
|
|
|
/** @var array $split */
|
2016-07-10 12:13:35 +00:00
|
|
|
$split = array_combine($uniqueTargets, array_fill(0, count($uniqueTargets), []));
|
|
|
|
|
|
|
|
foreach ($samples as $key => $sample) {
|
|
|
|
$split[$targets[$key]][] = $sample;
|
|
|
|
}
|
|
|
|
|
2018-06-15 05:57:45 +00:00
|
|
|
return $this->createDatasets($uniqueTargets, $split);
|
2016-07-10 12:13:35 +00:00
|
|
|
}
|
|
|
|
|
2017-11-22 21:16:10 +00:00
|
|
|
private function createDatasets(array $uniqueTargets, array $split): array
|
2016-07-10 12:13:35 +00:00
|
|
|
{
|
|
|
|
$datasets = [];
|
|
|
|
foreach ($uniqueTargets as $target) {
|
|
|
|
$datasets[$target] = new ArrayDataset($split[$target], array_fill(0, count($split[$target]), $target));
|
|
|
|
}
|
|
|
|
|
|
|
|
return $datasets;
|
|
|
|
}
|
|
|
|
}
|