2016-04-06 20:38:27 +00:00
|
|
|
<?php
|
|
|
|
|
|
|
|
declare (strict_types = 1);
|
|
|
|
|
|
|
|
namespace Phpml\CrossValidation;
|
|
|
|
|
|
|
|
use Phpml\Dataset\Dataset;
|
|
|
|
use Phpml\Exception\InvalidArgumentException;
|
|
|
|
|
|
|
|
class RandomSplit
|
|
|
|
{
|
|
|
|
/**
|
|
|
|
* @var array
|
|
|
|
*/
|
|
|
|
private $trainSamples = [];
|
|
|
|
|
|
|
|
/**
|
|
|
|
* @var array
|
|
|
|
*/
|
|
|
|
private $testSamples = [];
|
|
|
|
|
|
|
|
/**
|
|
|
|
* @var array
|
|
|
|
*/
|
|
|
|
private $trainLabels = [];
|
|
|
|
|
|
|
|
/**
|
|
|
|
* @var array
|
|
|
|
*/
|
|
|
|
private $testLabels = [];
|
|
|
|
|
2016-04-07 20:12:36 +00:00
|
|
|
/**
|
|
|
|
* @param Dataset $dataset
|
2016-04-07 20:35:49 +00:00
|
|
|
* @param float $testSize
|
|
|
|
* @param int $seed
|
2016-04-07 20:12:36 +00:00
|
|
|
*
|
|
|
|
* @throws InvalidArgumentException
|
|
|
|
*/
|
|
|
|
public function __construct(Dataset $dataset, float $testSize = 0.3, int $seed = null)
|
2016-04-06 20:38:27 +00:00
|
|
|
{
|
2016-04-07 20:12:36 +00:00
|
|
|
if (0 >= $testSize || 1 <= $testSize) {
|
2016-04-06 20:38:27 +00:00
|
|
|
throw InvalidArgumentException::percentNotInRange('testSize');
|
|
|
|
}
|
2016-04-07 20:12:36 +00:00
|
|
|
$this->seedGenerator($seed);
|
|
|
|
|
|
|
|
$samples = $dataset->getSamples();
|
2016-06-16 21:56:15 +00:00
|
|
|
$labels = $dataset->getTargets();
|
2016-04-07 20:12:36 +00:00
|
|
|
$datasetSize = count($samples);
|
2016-04-06 20:38:27 +00:00
|
|
|
|
2016-04-07 20:35:49 +00:00
|
|
|
for ($i = $datasetSize; $i > 0; --$i) {
|
|
|
|
$key = mt_rand(0, $datasetSize - 1);
|
2016-04-07 20:12:36 +00:00
|
|
|
$setName = count($this->testSamples) / $datasetSize >= $testSize ? 'train' : 'test';
|
|
|
|
|
|
|
|
$this->{$setName.'Samples'}[] = $samples[$key];
|
|
|
|
$this->{$setName.'Labels'}[] = $labels[$key];
|
|
|
|
|
|
|
|
$samples = array_values($samples);
|
|
|
|
$labels = array_values($labels);
|
|
|
|
}
|
2016-04-06 20:38:27 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
* @return array
|
|
|
|
*/
|
|
|
|
public function getTrainSamples()
|
|
|
|
{
|
|
|
|
return $this->trainSamples;
|
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
* @return array
|
|
|
|
*/
|
|
|
|
public function getTestSamples()
|
|
|
|
{
|
|
|
|
return $this->testSamples;
|
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
* @return array
|
|
|
|
*/
|
|
|
|
public function getTrainLabels()
|
|
|
|
{
|
|
|
|
return $this->trainLabels;
|
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
* @return array
|
|
|
|
*/
|
|
|
|
public function getTestLabels()
|
|
|
|
{
|
|
|
|
return $this->testLabels;
|
|
|
|
}
|
2016-04-07 20:12:36 +00:00
|
|
|
|
|
|
|
/**
|
|
|
|
* @param int|null $seed
|
|
|
|
*/
|
|
|
|
private function seedGenerator(int $seed = null)
|
|
|
|
{
|
|
|
|
if (null === $seed) {
|
|
|
|
mt_srand();
|
|
|
|
} else {
|
|
|
|
mt_srand($seed);
|
|
|
|
}
|
|
|
|
}
|
2016-04-06 20:38:27 +00:00
|
|
|
}
|