create StratifiedRandomSplit for cross validation

This commit is contained in:
Arkadiusz Kondas 2016-07-10 14:13:35 +02:00
parent 0213208a96
commit f04cc04da5
4 changed files with 225 additions and 79 deletions

View File

@ -5,101 +5,26 @@ declare (strict_types = 1);
namespace Phpml\CrossValidation;
use Phpml\Dataset\Dataset;
use Phpml\Exception\InvalidArgumentException;
class RandomSplit
class RandomSplit extends Split
{
/**
* @var array
*/
private $trainSamples = [];
/**
* @var array
*/
private $testSamples = [];
/**
* @var array
*/
private $trainLabels = [];
/**
* @var array
*/
private $testLabels = [];
/**
* @param Dataset $dataset
* @param float $testSize
* @param int $seed
*
* @throws InvalidArgumentException
*/
public function __construct(Dataset $dataset, float $testSize = 0.3, int $seed = null)
protected function splitDataset(Dataset $dataset, float $testSize)
{
if (0 >= $testSize || 1 <= $testSize) {
throw InvalidArgumentException::percentNotInRange('testSize');
}
$this->seedGenerator($seed);
$samples = $dataset->getSamples();
$labels = $dataset->getTargets();
$datasetSize = count($samples);
$testCount = count($this->testSamples);
for ($i = $datasetSize; $i > 0; --$i) {
$key = mt_rand(0, $datasetSize - 1);
$setName = count($this->testSamples) / $datasetSize >= $testSize ? 'train' : 'test';
$setName = (count($this->testSamples) - $testCount) / $datasetSize >= $testSize ? 'train' : 'test';
$this->{$setName.'Samples'}[] = $samples[$key];
$this->{$setName.'Labels'}[] = $labels[$key];
$samples = array_values($samples);
$labels = array_values($labels);
}
}
/**
* @return array
*/
public function getTrainSamples()
{
return $this->trainSamples;
}
/**
* @return array
*/
public function getTestSamples()
{
return $this->testSamples;
}
/**
* @return array
*/
public function getTrainLabels()
{
return $this->trainLabels;
}
/**
* @return array
*/
public function getTestLabels()
{
return $this->testLabels;
}
/**
* @param int|null $seed
*/
private function seedGenerator(int $seed = null)
{
if (null === $seed) {
mt_srand();
} else {
mt_srand($seed);
}
}
}

View File

@ -0,0 +1,94 @@
<?php
declare (strict_types = 1);
namespace Phpml\CrossValidation;
use Phpml\Dataset\Dataset;
use Phpml\Exception\InvalidArgumentException;
abstract class Split
{
/**
* @var array
*/
protected $trainSamples = [];
/**
* @var array
*/
protected $testSamples = [];
/**
* @var array
*/
protected $trainLabels = [];
/**
* @var array
*/
protected $testLabels = [];
/**
* @param Dataset $dataset
* @param float $testSize
* @param int $seed
*
* @throws InvalidArgumentException
*/
public function __construct(Dataset $dataset, float $testSize = 0.3, int $seed = null)
{
if (0 >= $testSize || 1 <= $testSize) {
throw InvalidArgumentException::percentNotInRange('testSize');
}
$this->seedGenerator($seed);
$this->splitDataset($dataset, $testSize);
}
abstract protected function splitDataset(Dataset $dataset, float $testSize);
/**
* @return array
*/
public function getTrainSamples()
{
return $this->trainSamples;
}
/**
* @return array
*/
public function getTestSamples()
{
return $this->testSamples;
}
/**
* @return array
*/
public function getTrainLabels()
{
return $this->trainLabels;
}
/**
* @return array
*/
public function getTestLabels()
{
return $this->testLabels;
}
/**
* @param int|null $seed
*/
protected function seedGenerator(int $seed = null)
{
if (null === $seed) {
mt_srand();
} else {
mt_srand($seed);
}
}
}

View File

@ -0,0 +1,62 @@
<?php
declare (strict_types = 1);
namespace Phpml\CrossValidation;
use Phpml\Dataset\ArrayDataset;
use Phpml\Dataset\Dataset;
class StratifiedRandomSplit extends RandomSplit
{
/**
* @param Dataset $dataset
* @param float $testSize
*/
protected function splitDataset(Dataset $dataset, float $testSize)
{
$datasets = $this->splitByTarget($dataset);
foreach ($datasets as $targetSet) {
parent::splitDataset($targetSet, $testSize);
}
}
/**
* @param Dataset $dataset
*
* @return Dataset[]|array
*/
private function splitByTarget(Dataset $dataset): array
{
$targets = $dataset->getTargets();
$samples = $dataset->getSamples();
$uniqueTargets = array_unique($targets);
$split = array_combine($uniqueTargets, array_fill(0, count($uniqueTargets), []));
foreach ($samples as $key => $sample) {
$split[$targets[$key]][] = $sample;
}
$datasets = $this->createDatasets($uniqueTargets, $split);
return $datasets;
}
/**
* @param array $uniqueTargets
* @param array $split
*
* @return array
*/
private function createDatasets(array $uniqueTargets, array $split): array
{
$datasets = [];
foreach ($uniqueTargets as $target) {
$datasets[$target] = new ArrayDataset($split[$target], array_fill(0, count($split[$target]), $target));
}
return $datasets;
}
}

View File

@ -0,0 +1,65 @@
<?php
declare (strict_types = 1);
namespace tests\Phpml\CrossValidation;
use Phpml\CrossValidation\StratifiedRandomSplit;
use Phpml\Dataset\ArrayDataset;
class StratifiedRandomSplitTest extends \PHPUnit_Framework_TestCase
{
public function testDatasetStratifiedRandomSplitWithEvenDistribution()
{
$dataset = new ArrayDataset(
$samples = [[1], [2], [3], [4], [5], [6], [7], [8]],
$labels = ['a', 'a', 'a', 'a', 'b', 'b', 'b', 'b']
);
$split = new StratifiedRandomSplit($dataset, 0.5);
$this->assertEquals(2, $this->countSamplesByTarget($split->getTestLabels(), 'a'));
$this->assertEquals(2, $this->countSamplesByTarget($split->getTestLabels(), 'b'));
$split = new StratifiedRandomSplit($dataset, 0.25);
$this->assertEquals(1, $this->countSamplesByTarget($split->getTestLabels(), 'a'));
$this->assertEquals(1, $this->countSamplesByTarget($split->getTestLabels(), 'b'));
}
public function testDatasetStratifiedRandomSplitWithEvenDistributionAndNumericTargets()
{
$dataset = new ArrayDataset(
$samples = [[1], [2], [3], [4], [5], [6], [7], [8]],
$labels = [1, 2, 1, 2, 1, 2, 1, 2]
);
$split = new StratifiedRandomSplit($dataset, 0.5);
$this->assertEquals(2, $this->countSamplesByTarget($split->getTestLabels(), 1));
$this->assertEquals(2, $this->countSamplesByTarget($split->getTestLabels(), 2));
$split = new StratifiedRandomSplit($dataset, 0.25);
$this->assertEquals(1, $this->countSamplesByTarget($split->getTestLabels(), 1));
$this->assertEquals(1, $this->countSamplesByTarget($split->getTestLabels(), 2));
}
/**
* @param $splitTargets
* @param $countTarget
*
* @return int
*/
private function countSamplesByTarget($splitTargets, $countTarget): int
{
$count = 0;
foreach ($splitTargets as $target) {
if ($target === $countTarget) {
++$count;
}
}
return $count;
}
}