add tests for datasets

This commit is contained in:
Arkadiusz Kondas 2016-04-07 22:35:49 +02:00
parent a20f474324
commit 9c18a5a22d
5 changed files with 18 additions and 22 deletions

View File

@ -31,8 +31,8 @@ class RandomSplit
/** /**
* @param Dataset $dataset * @param Dataset $dataset
* @param float $testSize * @param float $testSize
* @param int $seed * @param int $seed
* *
* @throws InvalidArgumentException * @throws InvalidArgumentException
*/ */
@ -47,8 +47,8 @@ class RandomSplit
$labels = $dataset->getLabels(); $labels = $dataset->getLabels();
$datasetSize = count($samples); $datasetSize = count($samples);
for($i=$datasetSize; $i>0; $i--) { for ($i = $datasetSize; $i > 0; --$i) {
$key = mt_rand(0, $datasetSize-1); $key = mt_rand(0, $datasetSize - 1);
$setName = count($this->testSamples) / $datasetSize >= $testSize ? 'train' : 'test'; $setName = count($this->testSamples) / $datasetSize >= $testSize ? 'train' : 'test';
$this->{$setName.'Samples'}[] = $samples[$key]; $this->{$setName.'Samples'}[] = $samples[$key];

View File

@ -1,5 +1,6 @@
<?php <?php
declare(strict_types = 1);
declare (strict_types = 1);
namespace Phpml\Dataset; namespace Phpml\Dataset;
@ -7,7 +8,6 @@ use Phpml\Exception\InvalidArgumentException;
class ArrayDataset implements Dataset class ArrayDataset implements Dataset
{ {
/** /**
* @var array * @var array
*/ */
@ -34,7 +34,6 @@ class ArrayDataset implements Dataset
$this->labels = $labels; $this->labels = $labels;
} }
/** /**
* @return array * @return array
*/ */
@ -50,5 +49,4 @@ class ArrayDataset implements Dataset
{ {
return $this->labels; return $this->labels;
} }
} }

View File

@ -14,11 +14,13 @@ class CsvDataset extends ArrayDataset
protected $filepath; protected $filepath;
/** /**
* @param string|null $filepath * @param string $filepath
* @param int $features
* @param bool $headingRow
* *
* @throws DatasetException * @throws DatasetException
*/ */
public function __construct(string $filepath = null) public function __construct(string $filepath, int $features, bool $headingRow = true)
{ {
if (!file_exists($filepath)) { if (!file_exists($filepath)) {
throw DatasetException::missingFile(basename($filepath)); throw DatasetException::missingFile(basename($filepath));
@ -28,11 +30,11 @@ class CsvDataset extends ArrayDataset
if (($handle = fopen($filepath, 'r')) !== false) { if (($handle = fopen($filepath, 'r')) !== false) {
while (($data = fgetcsv($handle, 1000, ',')) !== false) { while (($data = fgetcsv($handle, 1000, ',')) !== false) {
++$row; ++$row;
if ($row == 1) { if ($headingRow && $row == 1) {
continue; continue;
} }
$this->samples[] = array_slice($data, 0, 4); $this->samples[] = array_slice($data, 0, $features);
$this->labels[] = $data[4]; $this->labels[] = $data[$features];
} }
fclose($handle); fclose($handle);
} else { } else {

View File

@ -14,12 +14,9 @@ use Phpml\Dataset\CsvDataset;
*/ */
class Iris extends CsvDataset class Iris extends CsvDataset
{ {
/** public function __construct()
* @param string|null $filepath
*/
public function __construct(string $filepath = null)
{ {
$filepath = dirname(__FILE__).'/../../../../data/iris.csv'; $filepath = dirname(__FILE__).'/../../../../data/iris.csv';
parent::__construct($filepath); parent::__construct($filepath, 4, true);
} }
} }

View File

@ -1,5 +1,6 @@
<?php <?php
declare(strict_types = 1);
declare (strict_types = 1);
namespace tests\Phpml\CrossValidation; namespace tests\Phpml\CrossValidation;
@ -8,7 +9,6 @@ use Phpml\Dataset\ArrayDataset;
class RandomSplitTest extends \PHPUnit_Framework_TestCase class RandomSplitTest extends \PHPUnit_Framework_TestCase
{ {
/** /**
* @expectedException \Phpml\Exception\InvalidArgumentException * @expectedException \Phpml\Exception\InvalidArgumentException
*/ */
@ -91,5 +91,4 @@ class RandomSplitTest extends \PHPUnit_Framework_TestCase
$this->assertEquals($randomSplit->getTrainSamples()[0][0], $randomSplit->getTrainLabels()[0]); $this->assertEquals($randomSplit->getTrainSamples()[0][0], $randomSplit->getTrainLabels()[0]);
$this->assertEquals($randomSplit->getTrainSamples()[1][0], $randomSplit->getTrainLabels()[1]); $this->assertEquals($randomSplit->getTrainSamples()[1][0], $randomSplit->getTrainLabels()[1]);
} }
}
}