php-ml/src/CrossValidation/RandomSplit.php

27 lines
732 B
PHP
Raw Normal View History

2016-04-06 20:38:27 +00:00
<?php
2016-11-20 21:53:17 +00:00
declare(strict_types=1);
2016-04-06 20:38:27 +00:00
namespace Phpml\CrossValidation;
use Phpml\Dataset\Dataset;
class RandomSplit extends Split
2016-04-06 20:38:27 +00:00
{
protected function splitDataset(Dataset $dataset, float $testSize): void
2016-04-06 20:38:27 +00:00
{
2016-04-07 20:12:36 +00:00
$samples = $dataset->getSamples();
2016-06-16 21:56:15 +00:00
$labels = $dataset->getTargets();
2016-04-07 20:12:36 +00:00
$datasetSize = count($samples);
$testCount = count($this->testSamples);
2016-04-06 20:38:27 +00:00
2016-04-07 20:35:49 +00:00
for ($i = $datasetSize; $i > 0; --$i) {
$key = mt_rand(0, $datasetSize - 1);
$setName = (count($this->testSamples) - $testCount) / $datasetSize >= $testSize ? 'train' : 'test';
2016-04-07 20:12:36 +00:00
$this->{$setName.'Samples'}[] = $samples[$key];
$this->{$setName.'Labels'}[] = $labels[$key];
}
}
2016-04-06 20:38:27 +00:00
}