mirror of
https://github.com/Llewellynvdm/php-ml.git
synced 2024-09-28 15:09:01 +00:00
create StratifiedRandomSplit for cross validation
This commit is contained in:
parent
0213208a96
commit
f04cc04da5
@ -5,101 +5,26 @@ declare (strict_types = 1);
|
||||
namespace Phpml\CrossValidation;
|
||||
|
||||
use Phpml\Dataset\Dataset;
|
||||
use Phpml\Exception\InvalidArgumentException;
|
||||
|
||||
class RandomSplit
|
||||
class RandomSplit extends Split
|
||||
{
|
||||
/**
|
||||
* @var array
|
||||
*/
|
||||
private $trainSamples = [];
|
||||
|
||||
/**
|
||||
* @var array
|
||||
*/
|
||||
private $testSamples = [];
|
||||
|
||||
/**
|
||||
* @var array
|
||||
*/
|
||||
private $trainLabels = [];
|
||||
|
||||
/**
|
||||
* @var array
|
||||
*/
|
||||
private $testLabels = [];
|
||||
|
||||
/**
|
||||
* @param Dataset $dataset
|
||||
* @param float $testSize
|
||||
* @param int $seed
|
||||
*
|
||||
* @throws InvalidArgumentException
|
||||
*/
|
||||
public function __construct(Dataset $dataset, float $testSize = 0.3, int $seed = null)
|
||||
protected function splitDataset(Dataset $dataset, float $testSize)
|
||||
{
|
||||
if (0 >= $testSize || 1 <= $testSize) {
|
||||
throw InvalidArgumentException::percentNotInRange('testSize');
|
||||
}
|
||||
$this->seedGenerator($seed);
|
||||
|
||||
$samples = $dataset->getSamples();
|
||||
$labels = $dataset->getTargets();
|
||||
$datasetSize = count($samples);
|
||||
$testCount = count($this->testSamples);
|
||||
|
||||
for ($i = $datasetSize; $i > 0; --$i) {
|
||||
$key = mt_rand(0, $datasetSize - 1);
|
||||
$setName = count($this->testSamples) / $datasetSize >= $testSize ? 'train' : 'test';
|
||||
$setName = (count($this->testSamples) - $testCount) / $datasetSize >= $testSize ? 'train' : 'test';
|
||||
|
||||
$this->{$setName.'Samples'}[] = $samples[$key];
|
||||
$this->{$setName.'Labels'}[] = $labels[$key];
|
||||
|
||||
$samples = array_values($samples);
|
||||
$labels = array_values($labels);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @return array
|
||||
*/
|
||||
public function getTrainSamples()
|
||||
{
|
||||
return $this->trainSamples;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return array
|
||||
*/
|
||||
public function getTestSamples()
|
||||
{
|
||||
return $this->testSamples;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return array
|
||||
*/
|
||||
public function getTrainLabels()
|
||||
{
|
||||
return $this->trainLabels;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return array
|
||||
*/
|
||||
public function getTestLabels()
|
||||
{
|
||||
return $this->testLabels;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param int|null $seed
|
||||
*/
|
||||
private function seedGenerator(int $seed = null)
|
||||
{
|
||||
if (null === $seed) {
|
||||
mt_srand();
|
||||
} else {
|
||||
mt_srand($seed);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
94
src/Phpml/CrossValidation/Split.php
Normal file
94
src/Phpml/CrossValidation/Split.php
Normal file
@ -0,0 +1,94 @@
|
||||
<?php
|
||||
|
||||
declare (strict_types = 1);
|
||||
|
||||
namespace Phpml\CrossValidation;
|
||||
|
||||
use Phpml\Dataset\Dataset;
|
||||
use Phpml\Exception\InvalidArgumentException;
|
||||
|
||||
abstract class Split
|
||||
{
|
||||
/**
|
||||
* @var array
|
||||
*/
|
||||
protected $trainSamples = [];
|
||||
|
||||
/**
|
||||
* @var array
|
||||
*/
|
||||
protected $testSamples = [];
|
||||
|
||||
/**
|
||||
* @var array
|
||||
*/
|
||||
protected $trainLabels = [];
|
||||
|
||||
/**
|
||||
* @var array
|
||||
*/
|
||||
protected $testLabels = [];
|
||||
|
||||
/**
|
||||
* @param Dataset $dataset
|
||||
* @param float $testSize
|
||||
* @param int $seed
|
||||
*
|
||||
* @throws InvalidArgumentException
|
||||
*/
|
||||
public function __construct(Dataset $dataset, float $testSize = 0.3, int $seed = null)
|
||||
{
|
||||
if (0 >= $testSize || 1 <= $testSize) {
|
||||
throw InvalidArgumentException::percentNotInRange('testSize');
|
||||
}
|
||||
$this->seedGenerator($seed);
|
||||
|
||||
$this->splitDataset($dataset, $testSize);
|
||||
}
|
||||
|
||||
abstract protected function splitDataset(Dataset $dataset, float $testSize);
|
||||
|
||||
/**
|
||||
* @return array
|
||||
*/
|
||||
public function getTrainSamples()
|
||||
{
|
||||
return $this->trainSamples;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return array
|
||||
*/
|
||||
public function getTestSamples()
|
||||
{
|
||||
return $this->testSamples;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return array
|
||||
*/
|
||||
public function getTrainLabels()
|
||||
{
|
||||
return $this->trainLabels;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return array
|
||||
*/
|
||||
public function getTestLabels()
|
||||
{
|
||||
return $this->testLabels;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param int|null $seed
|
||||
*/
|
||||
protected function seedGenerator(int $seed = null)
|
||||
{
|
||||
if (null === $seed) {
|
||||
mt_srand();
|
||||
} else {
|
||||
mt_srand($seed);
|
||||
}
|
||||
}
|
||||
}
|
62
src/Phpml/CrossValidation/StratifiedRandomSplit.php
Normal file
62
src/Phpml/CrossValidation/StratifiedRandomSplit.php
Normal file
@ -0,0 +1,62 @@
|
||||
<?php
|
||||
|
||||
declare (strict_types = 1);
|
||||
|
||||
namespace Phpml\CrossValidation;
|
||||
|
||||
use Phpml\Dataset\ArrayDataset;
|
||||
use Phpml\Dataset\Dataset;
|
||||
|
||||
class StratifiedRandomSplit extends RandomSplit
|
||||
{
|
||||
/**
|
||||
* @param Dataset $dataset
|
||||
* @param float $testSize
|
||||
*/
|
||||
protected function splitDataset(Dataset $dataset, float $testSize)
|
||||
{
|
||||
$datasets = $this->splitByTarget($dataset);
|
||||
|
||||
foreach ($datasets as $targetSet) {
|
||||
parent::splitDataset($targetSet, $testSize);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @param Dataset $dataset
|
||||
*
|
||||
* @return Dataset[]|array
|
||||
*/
|
||||
private function splitByTarget(Dataset $dataset): array
|
||||
{
|
||||
$targets = $dataset->getTargets();
|
||||
$samples = $dataset->getSamples();
|
||||
|
||||
$uniqueTargets = array_unique($targets);
|
||||
$split = array_combine($uniqueTargets, array_fill(0, count($uniqueTargets), []));
|
||||
|
||||
foreach ($samples as $key => $sample) {
|
||||
$split[$targets[$key]][] = $sample;
|
||||
}
|
||||
|
||||
$datasets = $this->createDatasets($uniqueTargets, $split);
|
||||
|
||||
return $datasets;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param array $uniqueTargets
|
||||
* @param array $split
|
||||
*
|
||||
* @return array
|
||||
*/
|
||||
private function createDatasets(array $uniqueTargets, array $split): array
|
||||
{
|
||||
$datasets = [];
|
||||
foreach ($uniqueTargets as $target) {
|
||||
$datasets[$target] = new ArrayDataset($split[$target], array_fill(0, count($split[$target]), $target));
|
||||
}
|
||||
|
||||
return $datasets;
|
||||
}
|
||||
}
|
65
tests/Phpml/CrossValidation/StratifiedRandomSplitTest.php
Normal file
65
tests/Phpml/CrossValidation/StratifiedRandomSplitTest.php
Normal file
@ -0,0 +1,65 @@
|
||||
<?php
|
||||
|
||||
declare (strict_types = 1);
|
||||
|
||||
namespace tests\Phpml\CrossValidation;
|
||||
|
||||
use Phpml\CrossValidation\StratifiedRandomSplit;
|
||||
use Phpml\Dataset\ArrayDataset;
|
||||
|
||||
class StratifiedRandomSplitTest extends \PHPUnit_Framework_TestCase
|
||||
{
|
||||
public function testDatasetStratifiedRandomSplitWithEvenDistribution()
|
||||
{
|
||||
$dataset = new ArrayDataset(
|
||||
$samples = [[1], [2], [3], [4], [5], [6], [7], [8]],
|
||||
$labels = ['a', 'a', 'a', 'a', 'b', 'b', 'b', 'b']
|
||||
);
|
||||
|
||||
$split = new StratifiedRandomSplit($dataset, 0.5);
|
||||
|
||||
$this->assertEquals(2, $this->countSamplesByTarget($split->getTestLabels(), 'a'));
|
||||
$this->assertEquals(2, $this->countSamplesByTarget($split->getTestLabels(), 'b'));
|
||||
|
||||
$split = new StratifiedRandomSplit($dataset, 0.25);
|
||||
|
||||
$this->assertEquals(1, $this->countSamplesByTarget($split->getTestLabels(), 'a'));
|
||||
$this->assertEquals(1, $this->countSamplesByTarget($split->getTestLabels(), 'b'));
|
||||
}
|
||||
|
||||
public function testDatasetStratifiedRandomSplitWithEvenDistributionAndNumericTargets()
|
||||
{
|
||||
$dataset = new ArrayDataset(
|
||||
$samples = [[1], [2], [3], [4], [5], [6], [7], [8]],
|
||||
$labels = [1, 2, 1, 2, 1, 2, 1, 2]
|
||||
);
|
||||
|
||||
$split = new StratifiedRandomSplit($dataset, 0.5);
|
||||
|
||||
$this->assertEquals(2, $this->countSamplesByTarget($split->getTestLabels(), 1));
|
||||
$this->assertEquals(2, $this->countSamplesByTarget($split->getTestLabels(), 2));
|
||||
|
||||
$split = new StratifiedRandomSplit($dataset, 0.25);
|
||||
|
||||
$this->assertEquals(1, $this->countSamplesByTarget($split->getTestLabels(), 1));
|
||||
$this->assertEquals(1, $this->countSamplesByTarget($split->getTestLabels(), 2));
|
||||
}
|
||||
|
||||
/**
|
||||
* @param $splitTargets
|
||||
* @param $countTarget
|
||||
*
|
||||
* @return int
|
||||
*/
|
||||
private function countSamplesByTarget($splitTargets, $countTarget): int
|
||||
{
|
||||
$count = 0;
|
||||
foreach ($splitTargets as $target) {
|
||||
if ($target === $countTarget) {
|
||||
++$count;
|
||||
}
|
||||
}
|
||||
|
||||
return $count;
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user