diff --git a/src/Phpml/Exception/NormalizerException.php b/src/Phpml/Exception/NormalizerException.php index 9f88f0c..5c7ced1 100644 --- a/src/Phpml/Exception/NormalizerException.php +++ b/src/Phpml/Exception/NormalizerException.php @@ -13,4 +13,13 @@ class NormalizerException extends \Exception { return new self('Unknown norm supplied.'); } + + /** + * @return NormalizerException + */ + public static function fitNotAllowed() + { + return new self('Fit is not allowed for this preprocessor.'); + } + } diff --git a/src/Phpml/Exception/PreprocessorException.php b/src/Phpml/Exception/PreprocessorException.php deleted file mode 100644 index 15e3975..0000000 --- a/src/Phpml/Exception/PreprocessorException.php +++ /dev/null @@ -1,17 +0,0 @@ -fitTransformers($samples); $this->transformSamples($samples); $this->estimator->train($samples, $targets); } @@ -83,6 +84,16 @@ class Pipeline implements Estimator return $this->estimator->predict($samples); } + /** + * @param array $samples + */ + private function fitTransformers(array &$samples) + { + foreach ($this->transformers as $transformer) { + $transformer->fit($samples); + } + } + /** * @param array $samples */ diff --git a/src/Phpml/Preprocessing/Imputer.php b/src/Phpml/Preprocessing/Imputer.php index fdbfaf6..424efa4 100644 --- a/src/Phpml/Preprocessing/Imputer.php +++ b/src/Phpml/Preprocessing/Imputer.php @@ -26,16 +26,23 @@ class Imputer implements Preprocessor */ private $axis; + /** + * @var $samples + */ + private $samples; + /** * @param mixed $missingValue * @param Strategy $strategy * @param int $axis + * @param array|null $samples */ - public function __construct($missingValue = null, Strategy $strategy, int $axis = self::AXIS_COLUMN) + public function __construct($missingValue = null, Strategy $strategy, int $axis = self::AXIS_COLUMN, array $samples = []) { $this->missingValue = $missingValue; $this->strategy = $strategy; $this->axis = $axis; + $this->samples = $samples; } /** @@ -43,7 +50,7 @@ class Imputer implements Preprocessor */ public function fit(array $samples) { - // TODO: Implement fit() method. + $this->samples = $samples; } /** @@ -52,19 +59,18 @@ class Imputer implements Preprocessor public function transform(array &$samples) { foreach ($samples as &$sample) { - $this->preprocessSample($sample, $samples); + $this->preprocessSample($sample); } } /** * @param array $sample - * @param array $samples */ - private function preprocessSample(array &$sample, array $samples) + private function preprocessSample(array &$sample) { foreach ($sample as $column => &$value) { if ($value === $this->missingValue) { - $value = $this->strategy->replaceValue($this->getAxis($column, $sample, $samples)); + $value = $this->strategy->replaceValue($this->getAxis($column, $sample)); } } } @@ -72,18 +78,17 @@ class Imputer implements Preprocessor /** * @param int $column * @param array $currentSample - * @param array $samples * * @return array */ - private function getAxis(int $column, array $currentSample, array $samples): array + private function getAxis(int $column, array $currentSample): array { if (self::AXIS_ROW === $this->axis) { return array_diff($currentSample, [$this->missingValue]); } $axis = []; - foreach ($samples as $sample) { + foreach ($this->samples as $sample) { if ($sample[$column] !== $this->missingValue) { $axis[] = $sample[$column]; } diff --git a/src/Phpml/Preprocessing/Normalizer.php b/src/Phpml/Preprocessing/Normalizer.php index 11a0218..7647997 100644 --- a/src/Phpml/Preprocessing/Normalizer.php +++ b/src/Phpml/Preprocessing/Normalizer.php @@ -5,7 +5,6 @@ declare (strict_types = 1); namespace Phpml\Preprocessing; use Phpml\Exception\NormalizerException; -use Phpml\Exception\PreprocessorException; class Normalizer implements Preprocessor { @@ -33,12 +32,10 @@ class Normalizer implements Preprocessor /** * @param array $samples - * - * @throws PreprocessorException */ public function fit(array $samples) { - throw PreprocessorException::fitNotAllowed(); + // intentionally not implemented } /** diff --git a/tests/Phpml/Preprocessing/ImputerTest.php b/tests/Phpml/Preprocessing/ImputerTest.php index a7e36f7..9aa3ea3 100644 --- a/tests/Phpml/Preprocessing/ImputerTest.php +++ b/tests/Phpml/Preprocessing/ImputerTest.php @@ -27,7 +27,7 @@ class ImputerTest extends \PHPUnit_Framework_TestCase [8, 7, 4, 5], ]; - $imputer = new Imputer(null, new MeanStrategy(), Imputer::AXIS_COLUMN); + $imputer = new Imputer(null, new MeanStrategy(), Imputer::AXIS_COLUMN, $data); $imputer->transform($data); $this->assertEquals($imputeData, $data, '', $delta = 0.01); @@ -49,7 +49,7 @@ class ImputerTest extends \PHPUnit_Framework_TestCase [8, 7, 6.66, 5], ]; - $imputer = new Imputer(null, new MeanStrategy(), Imputer::AXIS_ROW); + $imputer = new Imputer(null, new MeanStrategy(), Imputer::AXIS_ROW, $data); $imputer->transform($data); $this->assertEquals($imputeData, $data, '', $delta = 0.01); @@ -71,7 +71,7 @@ class ImputerTest extends \PHPUnit_Framework_TestCase [8, 7, 3, 5], ]; - $imputer = new Imputer(null, new MedianStrategy(), Imputer::AXIS_COLUMN); + $imputer = new Imputer(null, new MedianStrategy(), Imputer::AXIS_COLUMN, $data); $imputer->transform($data); $this->assertEquals($imputeData, $data, '', $delta = 0.01); @@ -93,7 +93,7 @@ class ImputerTest extends \PHPUnit_Framework_TestCase [8, 7, 7, 5], ]; - $imputer = new Imputer(null, new MedianStrategy(), Imputer::AXIS_ROW); + $imputer = new Imputer(null, new MedianStrategy(), Imputer::AXIS_ROW, $data); $imputer->transform($data); $this->assertEquals($imputeData, $data, '', $delta = 0.01); @@ -117,7 +117,7 @@ class ImputerTest extends \PHPUnit_Framework_TestCase [8, 3, 2, 5], ]; - $imputer = new Imputer(null, new MostFrequentStrategy(), Imputer::AXIS_COLUMN); + $imputer = new Imputer(null, new MostFrequentStrategy(), Imputer::AXIS_COLUMN, $data); $imputer->transform($data); $this->assertEquals($imputeData, $data); @@ -141,7 +141,7 @@ class ImputerTest extends \PHPUnit_Framework_TestCase [8, 3, 2, 5, 4], ]; - $imputer = new Imputer(null, new MostFrequentStrategy(), Imputer::AXIS_ROW); + $imputer = new Imputer(null, new MostFrequentStrategy(), Imputer::AXIS_ROW, $data); $imputer->transform($data); $this->assertEquals($imputeData, $data);