Implement NumberConverter (#377)

This commit is contained in:
Arkadiusz Kondas 2019-05-12 22:25:17 +02:00 committed by GitHub
parent 1e1d794655
commit 717f236ca9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 102 additions and 8 deletions

View File

@ -30,7 +30,7 @@ class TfIdfTransformer implements Transformer
}
}
public function transform(array &$samples): void
public function transform(array &$samples, ?array &$targets = null): void
{
foreach ($samples as &$sample) {
foreach ($sample as $index => &$feature) {

View File

@ -46,7 +46,7 @@ class TokenCountVectorizer implements Transformer
$this->buildVocabulary($samples);
}
public function transform(array &$samples): void
public function transform(array &$samples, ?array &$targets = null): void
{
array_walk($samples, function (string &$sample): void {
$this->transformSample($sample);

View File

@ -56,7 +56,7 @@ final class SelectKBest implements Transformer
$this->keepColumns = array_slice($sorted, 0, $this->k, true);
}
public function transform(array &$samples): void
public function transform(array &$samples, ?array &$targets = null): void
{
if ($this->keepColumns === null) {
return;

View File

@ -48,7 +48,7 @@ final class VarianceThreshold implements Transformer
}
}
public function transform(array &$samples): void
public function transform(array &$samples, ?array &$targets = null): void
{
foreach ($samples as &$sample) {
$sample = array_values(array_intersect_key($sample, $this->keepColumns));

View File

@ -49,7 +49,7 @@ class Imputer implements Preprocessor
$this->samples = $samples;
}
public function transform(array &$samples): void
public function transform(array &$samples, ?array &$targets = null): void
{
if ($this->samples === []) {
throw new InvalidOperationException('Missing training samples for Imputer.');

View File

@ -22,7 +22,7 @@ final class LabelEncoder implements Preprocessor
}
}
public function transform(array &$samples): void
public function transform(array &$samples, ?array &$targets = null): void
{
foreach ($samples as &$sample) {
$sample = $this->classes[(string) $sample];

View File

@ -66,7 +66,7 @@ class Normalizer implements Preprocessor
$this->fitted = true;
}
public function transform(array &$samples): void
public function transform(array &$samples, ?array &$targets = null): void
{
$methods = [
self::NORM_L1 => 'normalizeL1',

View File

@ -0,0 +1,47 @@
<?php
declare(strict_types=1);
namespace Phpml\Preprocessing;
final class NumberConverter implements Preprocessor
{
/**
* @var bool
*/
private $transformTargets;
/**
* @var mixed
*/
private $nonNumericPlaceholder;
/**
* @param mixed $nonNumericPlaceholder
*/
public function __construct(bool $transformTargets = false, $nonNumericPlaceholder = null)
{
$this->transformTargets = $transformTargets;
$this->nonNumericPlaceholder = $nonNumericPlaceholder;
}
public function fit(array $samples, ?array $targets = null): void
{
//nothing to do
}
public function transform(array &$samples, ?array &$targets = null): void
{
foreach ($samples as &$sample) {
foreach ($sample as &$feature) {
$feature = is_numeric($feature) ? (float) $feature : $this->nonNumericPlaceholder;
}
}
if ($this->transformTargets && is_array($targets)) {
foreach ($targets as &$target) {
$target = is_numeric($target) ? (float) $target : $this->nonNumericPlaceholder;
}
}
}
}

View File

@ -11,5 +11,5 @@ interface Transformer
*/
public function fit(array $samples, ?array $targets = null): void;
public function transform(array &$samples): void;
public function transform(array &$samples, ?array &$targets = null): void;
}

View File

@ -0,0 +1,47 @@
<?php
declare(strict_types=1);
namespace Phpml\Tests\Preprocessing;
use Phpml\Preprocessing\NumberConverter;
use PHPUnit\Framework\TestCase;
final class NumberConverterTest extends TestCase
{
public function testConvertSamples(): void
{
$samples = [['1', '-4'], ['2.0', 3.0], ['3', '112.5'], ['5', '0.0004']];
$targets = ['1', '1', '2', '2'];
$converter = new NumberConverter();
$converter->transform($samples, $targets);
self::assertEquals([[1.0, -4.0], [2.0, 3.0], [3.0, 112.5], [5.0, 0.0004]], $samples);
self::assertEquals(['1', '1', '2', '2'], $targets);
}
public function testConvertTargets(): void
{
$samples = [['1', '-4'], ['2.0', 3.0], ['3', '112.5'], ['5', '0.0004']];
$targets = ['1', '1', '2', 'not'];
$converter = new NumberConverter(true);
$converter->transform($samples, $targets);
self::assertEquals([[1.0, -4.0], [2.0, 3.0], [3.0, 112.5], [5.0, 0.0004]], $samples);
self::assertEquals([1.0, 1.0, 2.0, null], $targets);
}
public function testConvertWithPlaceholder(): void
{
$samples = [['invalid'], ['13.5']];
$targets = ['invalid', '2'];
$converter = new NumberConverter(true, 'missing');
$converter->transform($samples, $targets);
self::assertEquals([['missing'], [13.5]], $samples);
self::assertEquals(['missing', 2.0], $targets);
}
}