php-ml/src/Phpml/CrossValidation/RandomSplit.php

106 lines
2.1 KiB
PHP
Raw Normal View History

2016-04-06 20:38:27 +00:00
<?php
declare (strict_types = 1);
namespace Phpml\CrossValidation;
use Phpml\Dataset\Dataset;
use Phpml\Exception\InvalidArgumentException;
class RandomSplit
{
/**
* @var array
*/
private $trainSamples = [];
/**
* @var array
*/
private $testSamples = [];
/**
* @var array
*/
private $trainLabels = [];
/**
* @var array
*/
private $testLabels = [];
2016-04-07 20:12:36 +00:00
/**
* @param Dataset $dataset
2016-04-07 20:35:49 +00:00
* @param float $testSize
* @param int $seed
2016-04-07 20:12:36 +00:00
*
* @throws InvalidArgumentException
*/
public function __construct(Dataset $dataset, float $testSize = 0.3, int $seed = null)
2016-04-06 20:38:27 +00:00
{
2016-04-07 20:12:36 +00:00
if (0 >= $testSize || 1 <= $testSize) {
2016-04-06 20:38:27 +00:00
throw InvalidArgumentException::percentNotInRange('testSize');
}
2016-04-07 20:12:36 +00:00
$this->seedGenerator($seed);
$samples = $dataset->getSamples();
2016-06-16 21:56:15 +00:00
$labels = $dataset->getTargets();
2016-04-07 20:12:36 +00:00
$datasetSize = count($samples);
2016-04-06 20:38:27 +00:00
2016-04-07 20:35:49 +00:00
for ($i = $datasetSize; $i > 0; --$i) {
$key = mt_rand(0, $datasetSize - 1);
2016-04-07 20:12:36 +00:00
$setName = count($this->testSamples) / $datasetSize >= $testSize ? 'train' : 'test';
$this->{$setName.'Samples'}[] = $samples[$key];
$this->{$setName.'Labels'}[] = $labels[$key];
$samples = array_values($samples);
$labels = array_values($labels);
}
2016-04-06 20:38:27 +00:00
}
/**
* @return array
*/
public function getTrainSamples()
{
return $this->trainSamples;
}
/**
* @return array
*/
public function getTestSamples()
{
return $this->testSamples;
}
/**
* @return array
*/
public function getTrainLabels()
{
return $this->trainLabels;
}
/**
* @return array
*/
public function getTestLabels()
{
return $this->testLabels;
}
2016-04-07 20:12:36 +00:00
/**
* @param int|null $seed
*/
private function seedGenerator(int $seed = null)
{
if (null === $seed) {
mt_srand();
} else {
mt_srand($seed);
}
}
2016-04-06 20:38:27 +00:00
}