mirror of
https://github.com/Llewellynvdm/php-ml.git
synced 2025-02-13 17:38:36 +00:00
random split implementation and tests
This commit is contained in:
parent
c3f98e4093
commit
bbcc8a3e68
@ -29,13 +29,34 @@ class RandomSplit
|
|||||||
*/
|
*/
|
||||||
private $testLabels = [];
|
private $testLabels = [];
|
||||||
|
|
||||||
public function __construct(Dataset $dataset, float $testSize = 0.3)
|
/**
|
||||||
|
* @param Dataset $dataset
|
||||||
|
* @param float $testSize
|
||||||
|
* @param int $seed
|
||||||
|
*
|
||||||
|
* @throws InvalidArgumentException
|
||||||
|
*/
|
||||||
|
public function __construct(Dataset $dataset, float $testSize = 0.3, int $seed = null)
|
||||||
{
|
{
|
||||||
if (0 > $testSize || 1 < $testSize) {
|
if (0 >= $testSize || 1 <= $testSize) {
|
||||||
throw InvalidArgumentException::percentNotInRange('testSize');
|
throw InvalidArgumentException::percentNotInRange('testSize');
|
||||||
}
|
}
|
||||||
|
$this->seedGenerator($seed);
|
||||||
|
|
||||||
// TODO: implement this !
|
$samples = $dataset->getSamples();
|
||||||
|
$labels = $dataset->getLabels();
|
||||||
|
$datasetSize = count($samples);
|
||||||
|
|
||||||
|
for($i=$datasetSize; $i>0; $i--) {
|
||||||
|
$key = mt_rand(0, $datasetSize-1);
|
||||||
|
$setName = count($this->testSamples) / $datasetSize >= $testSize ? 'train' : 'test';
|
||||||
|
|
||||||
|
$this->{$setName.'Samples'}[] = $samples[$key];
|
||||||
|
$this->{$setName.'Labels'}[] = $labels[$key];
|
||||||
|
|
||||||
|
$samples = array_values($samples);
|
||||||
|
$labels = array_values($labels);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -69,4 +90,16 @@ class RandomSplit
|
|||||||
{
|
{
|
||||||
return $this->testLabels;
|
return $this->testLabels;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param int|null $seed
|
||||||
|
*/
|
||||||
|
private function seedGenerator(int $seed = null)
|
||||||
|
{
|
||||||
|
if (null === $seed) {
|
||||||
|
mt_srand();
|
||||||
|
} else {
|
||||||
|
mt_srand($seed);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
46
src/Phpml/Dataset/ArrayDataset.php
Normal file
46
src/Phpml/Dataset/ArrayDataset.php
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
<?php
|
||||||
|
declare(strict_types = 1);
|
||||||
|
|
||||||
|
namespace Phpml\Dataset;
|
||||||
|
|
||||||
|
class ArrayDataset implements Dataset
|
||||||
|
{
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @var array
|
||||||
|
*/
|
||||||
|
private $samples = [];
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @var array
|
||||||
|
*/
|
||||||
|
private $labels = [];
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param array $samples
|
||||||
|
* @param array $labels
|
||||||
|
*/
|
||||||
|
public function __construct(array $samples, array $labels)
|
||||||
|
{
|
||||||
|
$this->samples = $samples;
|
||||||
|
$this->labels = $labels;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return array
|
||||||
|
*/
|
||||||
|
public function getSamples(): array
|
||||||
|
{
|
||||||
|
return $this->samples;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return array
|
||||||
|
*/
|
||||||
|
public function getLabels(): array
|
||||||
|
{
|
||||||
|
return $this->labels;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -21,7 +21,7 @@ abstract class CsvDataset implements Dataset
|
|||||||
/**
|
/**
|
||||||
* @var array
|
* @var array
|
||||||
*/
|
*/
|
||||||
private $lables = [];
|
private $labels = [];
|
||||||
|
|
||||||
public function __construct()
|
public function __construct()
|
||||||
{
|
{
|
||||||
@ -39,7 +39,7 @@ abstract class CsvDataset implements Dataset
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
$this->samples[] = array_slice($data, 0, 4);
|
$this->samples[] = array_slice($data, 0, 4);
|
||||||
$this->lables[] = $data[4];
|
$this->labels[] = $data[4];
|
||||||
}
|
}
|
||||||
fclose($handle);
|
fclose($handle);
|
||||||
} else {
|
} else {
|
||||||
@ -60,6 +60,6 @@ abstract class CsvDataset implements Dataset
|
|||||||
*/
|
*/
|
||||||
public function getLabels(): array
|
public function getLabels(): array
|
||||||
{
|
{
|
||||||
return $this->lables;
|
return $this->labels;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2,7 +2,9 @@
|
|||||||
|
|
||||||
declare (strict_types = 1);
|
declare (strict_types = 1);
|
||||||
|
|
||||||
namespace Phpml\Dataset;
|
namespace Phpml\Dataset\Demo;
|
||||||
|
|
||||||
|
use Phpml\Dataset\CsvDataset;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Classes: 3
|
* Classes: 3
|
95
tests/Phpml/CrossValidation/RandomSplitTest.php
Normal file
95
tests/Phpml/CrossValidation/RandomSplitTest.php
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
<?php
|
||||||
|
declare(strict_types = 1);
|
||||||
|
|
||||||
|
namespace tests\Phpml\CrossValidation;
|
||||||
|
|
||||||
|
use Phpml\CrossValidation\RandomSplit;
|
||||||
|
use Phpml\Dataset\ArrayDataset;
|
||||||
|
|
||||||
|
class RandomSplitTest extends \PHPUnit_Framework_TestCase
|
||||||
|
{
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @expectedException \Phpml\Exception\InvalidArgumentException
|
||||||
|
*/
|
||||||
|
public function testThrowExceptionOnToSmallTestSize()
|
||||||
|
{
|
||||||
|
new RandomSplit(new ArrayDataset([], []), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @expectedException \Phpml\Exception\InvalidArgumentException
|
||||||
|
*/
|
||||||
|
public function testThrowExceptionOnToBigTestSize()
|
||||||
|
{
|
||||||
|
new RandomSplit(new ArrayDataset([], []), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
public function testDatasetRandomSplitWithoutSeed()
|
||||||
|
{
|
||||||
|
$dataset = new ArrayDataset(
|
||||||
|
$samples = [[1], [2], [3], [4]],
|
||||||
|
$labels = ['a', 'a', 'b', 'b']
|
||||||
|
);
|
||||||
|
|
||||||
|
$randomSplit1 = new RandomSplit($dataset, 0.5);
|
||||||
|
|
||||||
|
$this->assertEquals(2, count($randomSplit1->getTestSamples()));
|
||||||
|
$this->assertEquals(2, count($randomSplit1->getTrainSamples()));
|
||||||
|
|
||||||
|
$randomSplit2 = new RandomSplit($dataset, 0.25);
|
||||||
|
|
||||||
|
$this->assertEquals(1, count($randomSplit2->getTestSamples()));
|
||||||
|
$this->assertEquals(3, count($randomSplit2->getTrainSamples()));
|
||||||
|
}
|
||||||
|
|
||||||
|
public function testDatasetRandomSplitWithSameSeed()
|
||||||
|
{
|
||||||
|
$dataset = new ArrayDataset(
|
||||||
|
$samples = [[1], [2], [3], [4], [5], [6], [7], [8]],
|
||||||
|
$labels = ['a', 'a', 'a', 'a', 'b', 'b', 'b', 'b']
|
||||||
|
);
|
||||||
|
|
||||||
|
$seed = 123;
|
||||||
|
|
||||||
|
$randomSplit1 = new RandomSplit($dataset, 0.5, $seed);
|
||||||
|
$randomSplit2 = new RandomSplit($dataset, 0.5, $seed);
|
||||||
|
|
||||||
|
$this->assertEquals($randomSplit1->getTestLabels(), $randomSplit2->getTestLabels());
|
||||||
|
$this->assertEquals($randomSplit1->getTestSamples(), $randomSplit2->getTestSamples());
|
||||||
|
$this->assertEquals($randomSplit1->getTrainLabels(), $randomSplit2->getTrainLabels());
|
||||||
|
$this->assertEquals($randomSplit1->getTrainSamples(), $randomSplit2->getTrainSamples());
|
||||||
|
}
|
||||||
|
|
||||||
|
public function testDatasetRandomSplitWithDifferentSeed()
|
||||||
|
{
|
||||||
|
$dataset = new ArrayDataset(
|
||||||
|
$samples = [[1], [2], [3], [4], [5], [6], [7], [8]],
|
||||||
|
$labels = ['a', 'a', 'a', 'a', 'b', 'b', 'b', 'b']
|
||||||
|
);
|
||||||
|
|
||||||
|
$randomSplit1 = new RandomSplit($dataset, 0.5, 4321);
|
||||||
|
$randomSplit2 = new RandomSplit($dataset, 0.5, 1234);
|
||||||
|
|
||||||
|
$this->assertNotEquals($randomSplit1->getTestLabels(), $randomSplit2->getTestLabels());
|
||||||
|
$this->assertNotEquals($randomSplit1->getTestSamples(), $randomSplit2->getTestSamples());
|
||||||
|
$this->assertNotEquals($randomSplit1->getTrainLabels(), $randomSplit2->getTrainLabels());
|
||||||
|
$this->assertNotEquals($randomSplit1->getTrainSamples(), $randomSplit2->getTrainSamples());
|
||||||
|
}
|
||||||
|
|
||||||
|
public function testRandomSplitCorrectSampleAndLabelPosition()
|
||||||
|
{
|
||||||
|
$dataset = new ArrayDataset(
|
||||||
|
$samples = [[1], [2], [3], [4]],
|
||||||
|
$labels = [1, 2, 3, 4]
|
||||||
|
);
|
||||||
|
|
||||||
|
$randomSplit = new RandomSplit($dataset, 0.5);
|
||||||
|
|
||||||
|
$this->assertEquals($randomSplit->getTestSamples()[0][0], $randomSplit->getTestLabels()[0]);
|
||||||
|
$this->assertEquals($randomSplit->getTestSamples()[1][0], $randomSplit->getTestLabels()[1]);
|
||||||
|
$this->assertEquals($randomSplit->getTrainSamples()[0][0], $randomSplit->getTrainLabels()[0]);
|
||||||
|
$this->assertEquals($randomSplit->getTrainSamples()[1][0], $randomSplit->getTrainLabels()[1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -2,9 +2,9 @@
|
|||||||
|
|
||||||
declare (strict_types = 1);
|
declare (strict_types = 1);
|
||||||
|
|
||||||
namespace tests\Phpml\Dataset;
|
namespace tests\Phpml\Dataset\Demo;
|
||||||
|
|
||||||
use Phpml\Dataset\Iris;
|
use Phpml\Dataset\Demo\Iris;
|
||||||
|
|
||||||
class IrisTest extends \PHPUnit_Framework_TestCase
|
class IrisTest extends \PHPUnit_Framework_TestCase
|
||||||
{
|
{
|
Loading…
x
Reference in New Issue
Block a user