implement fit on Imputer

This commit is contained in:
Arkadiusz Kondas 2016-06-17 00:16:49 +02:00
parent 557f344018
commit 3e9e70810d
6 changed files with 41 additions and 36 deletions

View File

@ -13,4 +13,13 @@ class NormalizerException extends \Exception
{
return new self('Unknown norm supplied.');
}
/**
* @return NormalizerException
*/
public static function fitNotAllowed()
{
return new self('Fit is not allowed for this preprocessor.');
}
}

View File

@ -1,17 +0,0 @@
<?php
declare(strict_types = 1);
namespace Phpml\Exception;
class PreprocessorException extends \Exception
{
/**
* @return PreprocessorException
*/
public static function fitNotAllowed()
{
return new self('Fit is not allowed for this preprocessor.');
}
}

View File

@ -67,6 +67,7 @@ class Pipeline implements Estimator
*/
public function train(array $samples, array $targets)
{
$this->fitTransformers($samples);
$this->transformSamples($samples);
$this->estimator->train($samples, $targets);
}
@ -83,6 +84,16 @@ class Pipeline implements Estimator
return $this->estimator->predict($samples);
}
/**
* @param array $samples
*/
private function fitTransformers(array &$samples)
{
foreach ($this->transformers as $transformer) {
$transformer->fit($samples);
}
}
/**
* @param array $samples
*/

View File

@ -26,16 +26,23 @@ class Imputer implements Preprocessor
*/
private $axis;
/**
* @var $samples
*/
private $samples;
/**
* @param mixed $missingValue
* @param Strategy $strategy
* @param int $axis
* @param array|null $samples
*/
public function __construct($missingValue = null, Strategy $strategy, int $axis = self::AXIS_COLUMN)
public function __construct($missingValue = null, Strategy $strategy, int $axis = self::AXIS_COLUMN, array $samples = [])
{
$this->missingValue = $missingValue;
$this->strategy = $strategy;
$this->axis = $axis;
$this->samples = $samples;
}
/**
@ -43,7 +50,7 @@ class Imputer implements Preprocessor
*/
public function fit(array $samples)
{
// TODO: Implement fit() method.
$this->samples = $samples;
}
/**
@ -52,19 +59,18 @@ class Imputer implements Preprocessor
public function transform(array &$samples)
{
foreach ($samples as &$sample) {
$this->preprocessSample($sample, $samples);
$this->preprocessSample($sample);
}
}
/**
* @param array $sample
* @param array $samples
*/
private function preprocessSample(array &$sample, array $samples)
private function preprocessSample(array &$sample)
{
foreach ($sample as $column => &$value) {
if ($value === $this->missingValue) {
$value = $this->strategy->replaceValue($this->getAxis($column, $sample, $samples));
$value = $this->strategy->replaceValue($this->getAxis($column, $sample));
}
}
}
@ -72,18 +78,17 @@ class Imputer implements Preprocessor
/**
* @param int $column
* @param array $currentSample
* @param array $samples
*
* @return array
*/
private function getAxis(int $column, array $currentSample, array $samples): array
private function getAxis(int $column, array $currentSample): array
{
if (self::AXIS_ROW === $this->axis) {
return array_diff($currentSample, [$this->missingValue]);
}
$axis = [];
foreach ($samples as $sample) {
foreach ($this->samples as $sample) {
if ($sample[$column] !== $this->missingValue) {
$axis[] = $sample[$column];
}

View File

@ -5,7 +5,6 @@ declare (strict_types = 1);
namespace Phpml\Preprocessing;
use Phpml\Exception\NormalizerException;
use Phpml\Exception\PreprocessorException;
class Normalizer implements Preprocessor
{
@ -33,12 +32,10 @@ class Normalizer implements Preprocessor
/**
* @param array $samples
*
* @throws PreprocessorException
*/
public function fit(array $samples)
{
throw PreprocessorException::fitNotAllowed();
// intentionally not implemented
}
/**

View File

@ -27,7 +27,7 @@ class ImputerTest extends \PHPUnit_Framework_TestCase
[8, 7, 4, 5],
];
$imputer = new Imputer(null, new MeanStrategy(), Imputer::AXIS_COLUMN);
$imputer = new Imputer(null, new MeanStrategy(), Imputer::AXIS_COLUMN, $data);
$imputer->transform($data);
$this->assertEquals($imputeData, $data, '', $delta = 0.01);
@ -49,7 +49,7 @@ class ImputerTest extends \PHPUnit_Framework_TestCase
[8, 7, 6.66, 5],
];
$imputer = new Imputer(null, new MeanStrategy(), Imputer::AXIS_ROW);
$imputer = new Imputer(null, new MeanStrategy(), Imputer::AXIS_ROW, $data);
$imputer->transform($data);
$this->assertEquals($imputeData, $data, '', $delta = 0.01);
@ -71,7 +71,7 @@ class ImputerTest extends \PHPUnit_Framework_TestCase
[8, 7, 3, 5],
];
$imputer = new Imputer(null, new MedianStrategy(), Imputer::AXIS_COLUMN);
$imputer = new Imputer(null, new MedianStrategy(), Imputer::AXIS_COLUMN, $data);
$imputer->transform($data);
$this->assertEquals($imputeData, $data, '', $delta = 0.01);
@ -93,7 +93,7 @@ class ImputerTest extends \PHPUnit_Framework_TestCase
[8, 7, 7, 5],
];
$imputer = new Imputer(null, new MedianStrategy(), Imputer::AXIS_ROW);
$imputer = new Imputer(null, new MedianStrategy(), Imputer::AXIS_ROW, $data);
$imputer->transform($data);
$this->assertEquals($imputeData, $data, '', $delta = 0.01);
@ -117,7 +117,7 @@ class ImputerTest extends \PHPUnit_Framework_TestCase
[8, 3, 2, 5],
];
$imputer = new Imputer(null, new MostFrequentStrategy(), Imputer::AXIS_COLUMN);
$imputer = new Imputer(null, new MostFrequentStrategy(), Imputer::AXIS_COLUMN, $data);
$imputer->transform($data);
$this->assertEquals($imputeData, $data);
@ -141,7 +141,7 @@ class ImputerTest extends \PHPUnit_Framework_TestCase
[8, 3, 2, 5, 4],
];
$imputer = new Imputer(null, new MostFrequentStrategy(), Imputer::AXIS_ROW);
$imputer = new Imputer(null, new MostFrequentStrategy(), Imputer::AXIS_ROW, $data);
$imputer->transform($data);
$this->assertEquals($imputeData, $data);