2016-04-07 20:12:36 +00:00
|
|
|
<?php
|
2016-04-07 20:35:49 +00:00
|
|
|
|
2016-11-20 21:53:17 +00:00
|
|
|
declare(strict_types=1);
|
2016-04-07 20:12:36 +00:00
|
|
|
|
|
|
|
namespace tests\Phpml\CrossValidation;
|
|
|
|
|
|
|
|
use Phpml\CrossValidation\RandomSplit;
|
|
|
|
use Phpml\Dataset\ArrayDataset;
|
2017-02-03 11:58:25 +00:00
|
|
|
use PHPUnit\Framework\TestCase;
|
2016-04-07 20:12:36 +00:00
|
|
|
|
2017-02-03 11:58:25 +00:00
|
|
|
class RandomSplitTest extends TestCase
|
2016-04-07 20:12:36 +00:00
|
|
|
{
|
|
|
|
/**
|
|
|
|
* @expectedException \Phpml\Exception\InvalidArgumentException
|
|
|
|
*/
|
2017-11-14 20:21:23 +00:00
|
|
|
public function testThrowExceptionOnToSmallTestSize(): void
|
2016-04-07 20:12:36 +00:00
|
|
|
{
|
|
|
|
new RandomSplit(new ArrayDataset([], []), 0);
|
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
* @expectedException \Phpml\Exception\InvalidArgumentException
|
|
|
|
*/
|
2017-11-14 20:21:23 +00:00
|
|
|
public function testThrowExceptionOnToBigTestSize(): void
|
2016-04-07 20:12:36 +00:00
|
|
|
{
|
|
|
|
new RandomSplit(new ArrayDataset([], []), 1);
|
|
|
|
}
|
|
|
|
|
2017-11-14 20:21:23 +00:00
|
|
|
public function testDatasetRandomSplitWithoutSeed(): void
|
2016-04-07 20:12:36 +00:00
|
|
|
{
|
|
|
|
$dataset = new ArrayDataset(
|
|
|
|
$samples = [[1], [2], [3], [4]],
|
|
|
|
$labels = ['a', 'a', 'b', 'b']
|
|
|
|
);
|
|
|
|
|
2016-04-08 22:36:48 +00:00
|
|
|
$randomSplit = new RandomSplit($dataset, 0.5);
|
2016-04-07 20:12:36 +00:00
|
|
|
|
2016-12-12 18:31:30 +00:00
|
|
|
$this->assertCount(2, $randomSplit->getTestSamples());
|
|
|
|
$this->assertCount(2, $randomSplit->getTrainSamples());
|
2016-04-07 20:12:36 +00:00
|
|
|
|
|
|
|
$randomSplit2 = new RandomSplit($dataset, 0.25);
|
|
|
|
|
2016-12-12 18:31:30 +00:00
|
|
|
$this->assertCount(1, $randomSplit2->getTestSamples());
|
|
|
|
$this->assertCount(3, $randomSplit2->getTrainSamples());
|
2016-04-07 20:12:36 +00:00
|
|
|
}
|
|
|
|
|
2017-11-14 20:21:23 +00:00
|
|
|
public function testDatasetRandomSplitWithSameSeed(): void
|
2016-04-07 20:12:36 +00:00
|
|
|
{
|
|
|
|
$dataset = new ArrayDataset(
|
|
|
|
$samples = [[1], [2], [3], [4], [5], [6], [7], [8]],
|
|
|
|
$labels = ['a', 'a', 'a', 'a', 'b', 'b', 'b', 'b']
|
|
|
|
);
|
|
|
|
|
|
|
|
$seed = 123;
|
|
|
|
|
|
|
|
$randomSplit1 = new RandomSplit($dataset, 0.5, $seed);
|
|
|
|
$randomSplit2 = new RandomSplit($dataset, 0.5, $seed);
|
|
|
|
|
|
|
|
$this->assertEquals($randomSplit1->getTestLabels(), $randomSplit2->getTestLabels());
|
|
|
|
$this->assertEquals($randomSplit1->getTestSamples(), $randomSplit2->getTestSamples());
|
|
|
|
$this->assertEquals($randomSplit1->getTrainLabels(), $randomSplit2->getTrainLabels());
|
|
|
|
$this->assertEquals($randomSplit1->getTrainSamples(), $randomSplit2->getTrainSamples());
|
|
|
|
}
|
|
|
|
|
2017-11-14 20:21:23 +00:00
|
|
|
public function testDatasetRandomSplitWithDifferentSeed(): void
|
2016-04-07 20:12:36 +00:00
|
|
|
{
|
|
|
|
$dataset = new ArrayDataset(
|
|
|
|
$samples = [[1], [2], [3], [4], [5], [6], [7], [8]],
|
|
|
|
$labels = ['a', 'a', 'a', 'a', 'b', 'b', 'b', 'b']
|
|
|
|
);
|
|
|
|
|
|
|
|
$randomSplit1 = new RandomSplit($dataset, 0.5, 4321);
|
|
|
|
$randomSplit2 = new RandomSplit($dataset, 0.5, 1234);
|
|
|
|
|
|
|
|
$this->assertNotEquals($randomSplit1->getTestLabels(), $randomSplit2->getTestLabels());
|
|
|
|
$this->assertNotEquals($randomSplit1->getTestSamples(), $randomSplit2->getTestSamples());
|
|
|
|
$this->assertNotEquals($randomSplit1->getTrainLabels(), $randomSplit2->getTrainLabels());
|
|
|
|
$this->assertNotEquals($randomSplit1->getTrainSamples(), $randomSplit2->getTrainSamples());
|
|
|
|
}
|
|
|
|
|
2017-11-14 20:21:23 +00:00
|
|
|
public function testRandomSplitCorrectSampleAndLabelPosition(): void
|
2016-04-07 20:12:36 +00:00
|
|
|
{
|
|
|
|
$dataset = new ArrayDataset(
|
|
|
|
$samples = [[1], [2], [3], [4]],
|
|
|
|
$labels = [1, 2, 3, 4]
|
|
|
|
);
|
|
|
|
|
|
|
|
$randomSplit = new RandomSplit($dataset, 0.5);
|
|
|
|
|
|
|
|
$this->assertEquals($randomSplit->getTestSamples()[0][0], $randomSplit->getTestLabels()[0]);
|
|
|
|
$this->assertEquals($randomSplit->getTestSamples()[1][0], $randomSplit->getTestLabels()[1]);
|
|
|
|
$this->assertEquals($randomSplit->getTrainSamples()[0][0], $randomSplit->getTrainLabels()[0]);
|
|
|
|
$this->assertEquals($randomSplit->getTrainSamples()[1][0], $randomSplit->getTrainLabels()[1]);
|
|
|
|
}
|
2016-04-07 20:35:49 +00:00
|
|
|
}
|