random split implementation and tests

This commit is contained in:
Arkadiusz Kondas 2016-04-07 22:12:36 +02:00
parent c3f98e4093
commit bbcc8a3e68
6 changed files with 185 additions and 9 deletions

View File

@ -29,13 +29,34 @@ class RandomSplit
*/ */
private $testLabels = []; private $testLabels = [];
public function __construct(Dataset $dataset, float $testSize = 0.3) /**
* @param Dataset $dataset
* @param float $testSize
* @param int $seed
*
* @throws InvalidArgumentException
*/
public function __construct(Dataset $dataset, float $testSize = 0.3, int $seed = null)
{ {
if (0 > $testSize || 1 < $testSize) { if (0 >= $testSize || 1 <= $testSize) {
throw InvalidArgumentException::percentNotInRange('testSize'); throw InvalidArgumentException::percentNotInRange('testSize');
} }
$this->seedGenerator($seed);
// TODO: implement this ! $samples = $dataset->getSamples();
$labels = $dataset->getLabels();
$datasetSize = count($samples);
for($i=$datasetSize; $i>0; $i--) {
$key = mt_rand(0, $datasetSize-1);
$setName = count($this->testSamples) / $datasetSize >= $testSize ? 'train' : 'test';
$this->{$setName.'Samples'}[] = $samples[$key];
$this->{$setName.'Labels'}[] = $labels[$key];
$samples = array_values($samples);
$labels = array_values($labels);
}
} }
/** /**
@ -69,4 +90,16 @@ class RandomSplit
{ {
return $this->testLabels; return $this->testLabels;
} }
/**
* @param int|null $seed
*/
private function seedGenerator(int $seed = null)
{
if (null === $seed) {
mt_srand();
} else {
mt_srand($seed);
}
}
} }

View File

@ -0,0 +1,46 @@
<?php
declare(strict_types = 1);
namespace Phpml\Dataset;
class ArrayDataset implements Dataset
{
/**
* @var array
*/
private $samples = [];
/**
* @var array
*/
private $labels = [];
/**
* @param array $samples
* @param array $labels
*/
public function __construct(array $samples, array $labels)
{
$this->samples = $samples;
$this->labels = $labels;
}
/**
* @return array
*/
public function getSamples(): array
{
return $this->samples;
}
/**
* @return array
*/
public function getLabels(): array
{
return $this->labels;
}
}

View File

@ -21,7 +21,7 @@ abstract class CsvDataset implements Dataset
/** /**
* @var array * @var array
*/ */
private $lables = []; private $labels = [];
public function __construct() public function __construct()
{ {
@ -39,7 +39,7 @@ abstract class CsvDataset implements Dataset
continue; continue;
} }
$this->samples[] = array_slice($data, 0, 4); $this->samples[] = array_slice($data, 0, 4);
$this->lables[] = $data[4]; $this->labels[] = $data[4];
} }
fclose($handle); fclose($handle);
} else { } else {
@ -60,6 +60,6 @@ abstract class CsvDataset implements Dataset
*/ */
public function getLabels(): array public function getLabels(): array
{ {
return $this->lables; return $this->labels;
} }
} }

View File

@ -2,7 +2,9 @@
declare (strict_types = 1); declare (strict_types = 1);
namespace Phpml\Dataset; namespace Phpml\Dataset\Demo;
use Phpml\Dataset\CsvDataset;
/** /**
* Classes: 3 * Classes: 3

View File

@ -0,0 +1,95 @@
<?php
declare(strict_types = 1);
namespace tests\Phpml\CrossValidation;
use Phpml\CrossValidation\RandomSplit;
use Phpml\Dataset\ArrayDataset;
class RandomSplitTest extends \PHPUnit_Framework_TestCase
{
/**
* @expectedException \Phpml\Exception\InvalidArgumentException
*/
public function testThrowExceptionOnToSmallTestSize()
{
new RandomSplit(new ArrayDataset([], []), 0);
}
/**
* @expectedException \Phpml\Exception\InvalidArgumentException
*/
public function testThrowExceptionOnToBigTestSize()
{
new RandomSplit(new ArrayDataset([], []), 1);
}
public function testDatasetRandomSplitWithoutSeed()
{
$dataset = new ArrayDataset(
$samples = [[1], [2], [3], [4]],
$labels = ['a', 'a', 'b', 'b']
);
$randomSplit1 = new RandomSplit($dataset, 0.5);
$this->assertEquals(2, count($randomSplit1->getTestSamples()));
$this->assertEquals(2, count($randomSplit1->getTrainSamples()));
$randomSplit2 = new RandomSplit($dataset, 0.25);
$this->assertEquals(1, count($randomSplit2->getTestSamples()));
$this->assertEquals(3, count($randomSplit2->getTrainSamples()));
}
public function testDatasetRandomSplitWithSameSeed()
{
$dataset = new ArrayDataset(
$samples = [[1], [2], [3], [4], [5], [6], [7], [8]],
$labels = ['a', 'a', 'a', 'a', 'b', 'b', 'b', 'b']
);
$seed = 123;
$randomSplit1 = new RandomSplit($dataset, 0.5, $seed);
$randomSplit2 = new RandomSplit($dataset, 0.5, $seed);
$this->assertEquals($randomSplit1->getTestLabels(), $randomSplit2->getTestLabels());
$this->assertEquals($randomSplit1->getTestSamples(), $randomSplit2->getTestSamples());
$this->assertEquals($randomSplit1->getTrainLabels(), $randomSplit2->getTrainLabels());
$this->assertEquals($randomSplit1->getTrainSamples(), $randomSplit2->getTrainSamples());
}
public function testDatasetRandomSplitWithDifferentSeed()
{
$dataset = new ArrayDataset(
$samples = [[1], [2], [3], [4], [5], [6], [7], [8]],
$labels = ['a', 'a', 'a', 'a', 'b', 'b', 'b', 'b']
);
$randomSplit1 = new RandomSplit($dataset, 0.5, 4321);
$randomSplit2 = new RandomSplit($dataset, 0.5, 1234);
$this->assertNotEquals($randomSplit1->getTestLabels(), $randomSplit2->getTestLabels());
$this->assertNotEquals($randomSplit1->getTestSamples(), $randomSplit2->getTestSamples());
$this->assertNotEquals($randomSplit1->getTrainLabels(), $randomSplit2->getTrainLabels());
$this->assertNotEquals($randomSplit1->getTrainSamples(), $randomSplit2->getTrainSamples());
}
public function testRandomSplitCorrectSampleAndLabelPosition()
{
$dataset = new ArrayDataset(
$samples = [[1], [2], [3], [4]],
$labels = [1, 2, 3, 4]
);
$randomSplit = new RandomSplit($dataset, 0.5);
$this->assertEquals($randomSplit->getTestSamples()[0][0], $randomSplit->getTestLabels()[0]);
$this->assertEquals($randomSplit->getTestSamples()[1][0], $randomSplit->getTestLabels()[1]);
$this->assertEquals($randomSplit->getTrainSamples()[0][0], $randomSplit->getTrainLabels()[0]);
$this->assertEquals($randomSplit->getTrainSamples()[1][0], $randomSplit->getTrainLabels()[1]);
}
}

View File

@ -2,9 +2,9 @@
declare (strict_types = 1); declare (strict_types = 1);
namespace tests\Phpml\Dataset; namespace tests\Phpml\Dataset\Demo;
use Phpml\Dataset\Iris; use Phpml\Dataset\Demo\Iris;
class IrisTest extends \PHPUnit_Framework_TestCase class IrisTest extends \PHPUnit_Framework_TestCase
{ {