diff --git a/src/Phpml/CrossValidation/RandomSplit.php b/src/Phpml/CrossValidation/RandomSplit.php index 4e2ecc9..c5a24bd 100644 --- a/src/Phpml/CrossValidation/RandomSplit.php +++ b/src/Phpml/CrossValidation/RandomSplit.php @@ -31,8 +31,8 @@ class RandomSplit /** * @param Dataset $dataset - * @param float $testSize - * @param int $seed + * @param float $testSize + * @param int $seed * * @throws InvalidArgumentException */ @@ -47,8 +47,8 @@ class RandomSplit $labels = $dataset->getLabels(); $datasetSize = count($samples); - for($i=$datasetSize; $i>0; $i--) { - $key = mt_rand(0, $datasetSize-1); + for ($i = $datasetSize; $i > 0; --$i) { + $key = mt_rand(0, $datasetSize - 1); $setName = count($this->testSamples) / $datasetSize >= $testSize ? 'train' : 'test'; $this->{$setName.'Samples'}[] = $samples[$key]; diff --git a/src/Phpml/Dataset/ArrayDataset.php b/src/Phpml/Dataset/ArrayDataset.php index 580df0a..d117122 100644 --- a/src/Phpml/Dataset/ArrayDataset.php +++ b/src/Phpml/Dataset/ArrayDataset.php @@ -1,5 +1,6 @@ labels = $labels; } - /** * @return array */ @@ -50,5 +49,4 @@ class ArrayDataset implements Dataset { return $this->labels; } - } diff --git a/src/Phpml/Dataset/CsvDataset.php b/src/Phpml/Dataset/CsvDataset.php index de540c9..e6dafd2 100644 --- a/src/Phpml/Dataset/CsvDataset.php +++ b/src/Phpml/Dataset/CsvDataset.php @@ -14,11 +14,13 @@ class CsvDataset extends ArrayDataset protected $filepath; /** - * @param string|null $filepath + * @param string $filepath + * @param int $features + * @param bool $headingRow * * @throws DatasetException */ - public function __construct(string $filepath = null) + public function __construct(string $filepath, int $features, bool $headingRow = true) { if (!file_exists($filepath)) { throw DatasetException::missingFile(basename($filepath)); @@ -28,11 +30,11 @@ class CsvDataset extends ArrayDataset if (($handle = fopen($filepath, 'r')) !== false) { while (($data = fgetcsv($handle, 1000, ',')) !== false) { ++$row; - if ($row == 1) { + if ($headingRow && $row == 1) { continue; } - $this->samples[] = array_slice($data, 0, 4); - $this->labels[] = $data[4]; + $this->samples[] = array_slice($data, 0, $features); + $this->labels[] = $data[$features]; } fclose($handle); } else { diff --git a/src/Phpml/Dataset/Demo/Iris.php b/src/Phpml/Dataset/Demo/Iris.php index 5a4789e..923f0ba 100644 --- a/src/Phpml/Dataset/Demo/Iris.php +++ b/src/Phpml/Dataset/Demo/Iris.php @@ -14,12 +14,9 @@ use Phpml\Dataset\CsvDataset; */ class Iris extends CsvDataset { - /** - * @param string|null $filepath - */ - public function __construct(string $filepath = null) + public function __construct() { $filepath = dirname(__FILE__).'/../../../../data/iris.csv'; - parent::__construct($filepath); + parent::__construct($filepath, 4, true); } } diff --git a/tests/Phpml/CrossValidation/RandomSplitTest.php b/tests/Phpml/CrossValidation/RandomSplitTest.php index cbbd497..e6ae30e 100644 --- a/tests/Phpml/CrossValidation/RandomSplitTest.php +++ b/tests/Phpml/CrossValidation/RandomSplitTest.php @@ -1,5 +1,6 @@ assertEquals($randomSplit->getTrainSamples()[0][0], $randomSplit->getTrainLabels()[0]); $this->assertEquals($randomSplit->getTrainSamples()[1][0], $randomSplit->getTrainLabels()[1]); } - -} \ No newline at end of file +}