mirror of
https://github.com/Llewellynvdm/php-ml.git
synced 2024-11-25 06:17:34 +00:00
implement fit on Imputer
This commit is contained in:
parent
557f344018
commit
3e9e70810d
@ -13,4 +13,13 @@ class NormalizerException extends \Exception
|
||||
{
|
||||
return new self('Unknown norm supplied.');
|
||||
}
|
||||
|
||||
/**
|
||||
* @return NormalizerException
|
||||
*/
|
||||
public static function fitNotAllowed()
|
||||
{
|
||||
return new self('Fit is not allowed for this preprocessor.');
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -1,17 +0,0 @@
|
||||
<?php
|
||||
declare(strict_types = 1);
|
||||
|
||||
namespace Phpml\Exception;
|
||||
|
||||
class PreprocessorException extends \Exception
|
||||
{
|
||||
|
||||
/**
|
||||
* @return PreprocessorException
|
||||
*/
|
||||
public static function fitNotAllowed()
|
||||
{
|
||||
return new self('Fit is not allowed for this preprocessor.');
|
||||
}
|
||||
|
||||
}
|
@ -67,6 +67,7 @@ class Pipeline implements Estimator
|
||||
*/
|
||||
public function train(array $samples, array $targets)
|
||||
{
|
||||
$this->fitTransformers($samples);
|
||||
$this->transformSamples($samples);
|
||||
$this->estimator->train($samples, $targets);
|
||||
}
|
||||
@ -83,6 +84,16 @@ class Pipeline implements Estimator
|
||||
return $this->estimator->predict($samples);
|
||||
}
|
||||
|
||||
/**
|
||||
* @param array $samples
|
||||
*/
|
||||
private function fitTransformers(array &$samples)
|
||||
{
|
||||
foreach ($this->transformers as $transformer) {
|
||||
$transformer->fit($samples);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @param array $samples
|
||||
*/
|
||||
|
@ -26,16 +26,23 @@ class Imputer implements Preprocessor
|
||||
*/
|
||||
private $axis;
|
||||
|
||||
/**
|
||||
* @var $samples
|
||||
*/
|
||||
private $samples;
|
||||
|
||||
/**
|
||||
* @param mixed $missingValue
|
||||
* @param Strategy $strategy
|
||||
* @param int $axis
|
||||
* @param array|null $samples
|
||||
*/
|
||||
public function __construct($missingValue = null, Strategy $strategy, int $axis = self::AXIS_COLUMN)
|
||||
public function __construct($missingValue = null, Strategy $strategy, int $axis = self::AXIS_COLUMN, array $samples = [])
|
||||
{
|
||||
$this->missingValue = $missingValue;
|
||||
$this->strategy = $strategy;
|
||||
$this->axis = $axis;
|
||||
$this->samples = $samples;
|
||||
}
|
||||
|
||||
/**
|
||||
@ -43,7 +50,7 @@ class Imputer implements Preprocessor
|
||||
*/
|
||||
public function fit(array $samples)
|
||||
{
|
||||
// TODO: Implement fit() method.
|
||||
$this->samples = $samples;
|
||||
}
|
||||
|
||||
/**
|
||||
@ -52,19 +59,18 @@ class Imputer implements Preprocessor
|
||||
public function transform(array &$samples)
|
||||
{
|
||||
foreach ($samples as &$sample) {
|
||||
$this->preprocessSample($sample, $samples);
|
||||
$this->preprocessSample($sample);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @param array $sample
|
||||
* @param array $samples
|
||||
*/
|
||||
private function preprocessSample(array &$sample, array $samples)
|
||||
private function preprocessSample(array &$sample)
|
||||
{
|
||||
foreach ($sample as $column => &$value) {
|
||||
if ($value === $this->missingValue) {
|
||||
$value = $this->strategy->replaceValue($this->getAxis($column, $sample, $samples));
|
||||
$value = $this->strategy->replaceValue($this->getAxis($column, $sample));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -72,18 +78,17 @@ class Imputer implements Preprocessor
|
||||
/**
|
||||
* @param int $column
|
||||
* @param array $currentSample
|
||||
* @param array $samples
|
||||
*
|
||||
* @return array
|
||||
*/
|
||||
private function getAxis(int $column, array $currentSample, array $samples): array
|
||||
private function getAxis(int $column, array $currentSample): array
|
||||
{
|
||||
if (self::AXIS_ROW === $this->axis) {
|
||||
return array_diff($currentSample, [$this->missingValue]);
|
||||
}
|
||||
|
||||
$axis = [];
|
||||
foreach ($samples as $sample) {
|
||||
foreach ($this->samples as $sample) {
|
||||
if ($sample[$column] !== $this->missingValue) {
|
||||
$axis[] = $sample[$column];
|
||||
}
|
||||
|
@ -5,7 +5,6 @@ declare (strict_types = 1);
|
||||
namespace Phpml\Preprocessing;
|
||||
|
||||
use Phpml\Exception\NormalizerException;
|
||||
use Phpml\Exception\PreprocessorException;
|
||||
|
||||
class Normalizer implements Preprocessor
|
||||
{
|
||||
@ -33,12 +32,10 @@ class Normalizer implements Preprocessor
|
||||
|
||||
/**
|
||||
* @param array $samples
|
||||
*
|
||||
* @throws PreprocessorException
|
||||
*/
|
||||
public function fit(array $samples)
|
||||
{
|
||||
throw PreprocessorException::fitNotAllowed();
|
||||
// intentionally not implemented
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -27,7 +27,7 @@ class ImputerTest extends \PHPUnit_Framework_TestCase
|
||||
[8, 7, 4, 5],
|
||||
];
|
||||
|
||||
$imputer = new Imputer(null, new MeanStrategy(), Imputer::AXIS_COLUMN);
|
||||
$imputer = new Imputer(null, new MeanStrategy(), Imputer::AXIS_COLUMN, $data);
|
||||
$imputer->transform($data);
|
||||
|
||||
$this->assertEquals($imputeData, $data, '', $delta = 0.01);
|
||||
@ -49,7 +49,7 @@ class ImputerTest extends \PHPUnit_Framework_TestCase
|
||||
[8, 7, 6.66, 5],
|
||||
];
|
||||
|
||||
$imputer = new Imputer(null, new MeanStrategy(), Imputer::AXIS_ROW);
|
||||
$imputer = new Imputer(null, new MeanStrategy(), Imputer::AXIS_ROW, $data);
|
||||
$imputer->transform($data);
|
||||
|
||||
$this->assertEquals($imputeData, $data, '', $delta = 0.01);
|
||||
@ -71,7 +71,7 @@ class ImputerTest extends \PHPUnit_Framework_TestCase
|
||||
[8, 7, 3, 5],
|
||||
];
|
||||
|
||||
$imputer = new Imputer(null, new MedianStrategy(), Imputer::AXIS_COLUMN);
|
||||
$imputer = new Imputer(null, new MedianStrategy(), Imputer::AXIS_COLUMN, $data);
|
||||
$imputer->transform($data);
|
||||
|
||||
$this->assertEquals($imputeData, $data, '', $delta = 0.01);
|
||||
@ -93,7 +93,7 @@ class ImputerTest extends \PHPUnit_Framework_TestCase
|
||||
[8, 7, 7, 5],
|
||||
];
|
||||
|
||||
$imputer = new Imputer(null, new MedianStrategy(), Imputer::AXIS_ROW);
|
||||
$imputer = new Imputer(null, new MedianStrategy(), Imputer::AXIS_ROW, $data);
|
||||
$imputer->transform($data);
|
||||
|
||||
$this->assertEquals($imputeData, $data, '', $delta = 0.01);
|
||||
@ -117,7 +117,7 @@ class ImputerTest extends \PHPUnit_Framework_TestCase
|
||||
[8, 3, 2, 5],
|
||||
];
|
||||
|
||||
$imputer = new Imputer(null, new MostFrequentStrategy(), Imputer::AXIS_COLUMN);
|
||||
$imputer = new Imputer(null, new MostFrequentStrategy(), Imputer::AXIS_COLUMN, $data);
|
||||
$imputer->transform($data);
|
||||
|
||||
$this->assertEquals($imputeData, $data);
|
||||
@ -141,7 +141,7 @@ class ImputerTest extends \PHPUnit_Framework_TestCase
|
||||
[8, 3, 2, 5, 4],
|
||||
];
|
||||
|
||||
$imputer = new Imputer(null, new MostFrequentStrategy(), Imputer::AXIS_ROW);
|
||||
$imputer = new Imputer(null, new MostFrequentStrategy(), Imputer::AXIS_ROW, $data);
|
||||
$imputer->transform($data);
|
||||
|
||||
$this->assertEquals($imputeData, $data);
|
||||
|
Loading…
Reference in New Issue
Block a user