php-ml/tests/CrossValidation/RandomSplitTest.php

93 lines
3.4 KiB
PHP
Raw Normal View History

2016-04-07 22:12:36 +02:00
<?php
2016-04-07 22:35:49 +02:00
2016-11-20 22:53:17 +01:00
declare(strict_types=1);
2016-04-07 22:12:36 +02:00
namespace Phpml\Tests\CrossValidation;
2016-04-07 22:12:36 +02:00
use Phpml\CrossValidation\RandomSplit;
use Phpml\Dataset\ArrayDataset;
use Phpml\Exception\InvalidArgumentException;
2017-02-03 12:58:25 +01:00
use PHPUnit\Framework\TestCase;
2016-04-07 22:12:36 +02:00
2017-02-03 12:58:25 +01:00
class RandomSplitTest extends TestCase
2016-04-07 22:12:36 +02:00
{
public function testThrowExceptionOnToSmallTestSize(): void
2016-04-07 22:12:36 +02:00
{
$this->expectException(InvalidArgumentException::class);
2016-04-07 22:12:36 +02:00
new RandomSplit(new ArrayDataset([], []), 0);
}
public function testThrowExceptionOnToBigTestSize(): void
2016-04-07 22:12:36 +02:00
{
$this->expectException(InvalidArgumentException::class);
2016-04-07 22:12:36 +02:00
new RandomSplit(new ArrayDataset([], []), 1);
}
public function testDatasetRandomSplitWithoutSeed(): void
2016-04-07 22:12:36 +02:00
{
$dataset = new ArrayDataset(
$samples = [[1], [2], [3], [4]],
$labels = ['a', 'a', 'b', 'b']
);
2016-04-09 00:36:48 +02:00
$randomSplit = new RandomSplit($dataset, 0.5);
2016-04-07 22:12:36 +02:00
2016-12-12 19:31:30 +01:00
$this->assertCount(2, $randomSplit->getTestSamples());
$this->assertCount(2, $randomSplit->getTrainSamples());
2016-04-07 22:12:36 +02:00
$randomSplit2 = new RandomSplit($dataset, 0.25);
2016-12-12 19:31:30 +01:00
$this->assertCount(1, $randomSplit2->getTestSamples());
$this->assertCount(3, $randomSplit2->getTrainSamples());
2016-04-07 22:12:36 +02:00
}
public function testDatasetRandomSplitWithSameSeed(): void
2016-04-07 22:12:36 +02:00
{
$dataset = new ArrayDataset(
$samples = [[1], [2], [3], [4], [5], [6], [7], [8]],
$labels = ['a', 'a', 'a', 'a', 'b', 'b', 'b', 'b']
);
$seed = 123;
$randomSplit1 = new RandomSplit($dataset, 0.5, $seed);
$randomSplit2 = new RandomSplit($dataset, 0.5, $seed);
$this->assertEquals($randomSplit1->getTestLabels(), $randomSplit2->getTestLabels());
$this->assertEquals($randomSplit1->getTestSamples(), $randomSplit2->getTestSamples());
$this->assertEquals($randomSplit1->getTrainLabels(), $randomSplit2->getTrainLabels());
$this->assertEquals($randomSplit1->getTrainSamples(), $randomSplit2->getTrainSamples());
}
public function testDatasetRandomSplitWithDifferentSeed(): void
2016-04-07 22:12:36 +02:00
{
$dataset = new ArrayDataset(
$samples = [[1], [2], [3], [4], [5], [6], [7], [8]],
$labels = ['a', 'a', 'a', 'a', 'b', 'b', 'b', 'b']
);
$randomSplit1 = new RandomSplit($dataset, 0.5, 4321);
$randomSplit2 = new RandomSplit($dataset, 0.5, 1234);
$this->assertNotEquals($randomSplit1->getTestLabels(), $randomSplit2->getTestLabels());
$this->assertNotEquals($randomSplit1->getTestSamples(), $randomSplit2->getTestSamples());
$this->assertNotEquals($randomSplit1->getTrainLabels(), $randomSplit2->getTrainLabels());
$this->assertNotEquals($randomSplit1->getTrainSamples(), $randomSplit2->getTrainSamples());
}
public function testRandomSplitCorrectSampleAndLabelPosition(): void
2016-04-07 22:12:36 +02:00
{
$dataset = new ArrayDataset(
$samples = [[1], [2], [3], [4]],
$labels = [1, 2, 3, 4]
);
$randomSplit = new RandomSplit($dataset, 0.5);
$this->assertEquals($randomSplit->getTestSamples()[0][0], $randomSplit->getTestLabels()[0]);
$this->assertEquals($randomSplit->getTestSamples()[1][0], $randomSplit->getTestLabels()[1]);
$this->assertEquals($randomSplit->getTrainSamples()[0][0], $randomSplit->getTrainLabels()[0]);
$this->assertEquals($randomSplit->getTrainSamples()[1][0], $randomSplit->getTrainLabels()[1]);
}
2016-04-07 22:35:49 +02:00
}