diff --git a/src/FeatureExtraction/TfIdfTransformer.php b/src/FeatureExtraction/TfIdfTransformer.php index d1ac35d..34f7533 100644 --- a/src/FeatureExtraction/TfIdfTransformer.php +++ b/src/FeatureExtraction/TfIdfTransformer.php @@ -30,7 +30,7 @@ class TfIdfTransformer implements Transformer } } - public function transform(array &$samples): void + public function transform(array &$samples, ?array &$targets = null): void { foreach ($samples as &$sample) { foreach ($sample as $index => &$feature) { diff --git a/src/FeatureExtraction/TokenCountVectorizer.php b/src/FeatureExtraction/TokenCountVectorizer.php index afd5f33..5cc5e8d 100644 --- a/src/FeatureExtraction/TokenCountVectorizer.php +++ b/src/FeatureExtraction/TokenCountVectorizer.php @@ -46,7 +46,7 @@ class TokenCountVectorizer implements Transformer $this->buildVocabulary($samples); } - public function transform(array &$samples): void + public function transform(array &$samples, ?array &$targets = null): void { array_walk($samples, function (string &$sample): void { $this->transformSample($sample); diff --git a/src/FeatureSelection/SelectKBest.php b/src/FeatureSelection/SelectKBest.php index 36b4245..16e5278 100644 --- a/src/FeatureSelection/SelectKBest.php +++ b/src/FeatureSelection/SelectKBest.php @@ -56,7 +56,7 @@ final class SelectKBest implements Transformer $this->keepColumns = array_slice($sorted, 0, $this->k, true); } - public function transform(array &$samples): void + public function transform(array &$samples, ?array &$targets = null): void { if ($this->keepColumns === null) { return; diff --git a/src/FeatureSelection/VarianceThreshold.php b/src/FeatureSelection/VarianceThreshold.php index 5ca2332..3bbc29d 100644 --- a/src/FeatureSelection/VarianceThreshold.php +++ b/src/FeatureSelection/VarianceThreshold.php @@ -48,7 +48,7 @@ final class VarianceThreshold implements Transformer } } - public function transform(array &$samples): void + public function transform(array &$samples, ?array &$targets = null): void { foreach ($samples as &$sample) { $sample = array_values(array_intersect_key($sample, $this->keepColumns)); diff --git a/src/Preprocessing/Imputer.php b/src/Preprocessing/Imputer.php index e5b5af8..88ee2dd 100644 --- a/src/Preprocessing/Imputer.php +++ b/src/Preprocessing/Imputer.php @@ -49,7 +49,7 @@ class Imputer implements Preprocessor $this->samples = $samples; } - public function transform(array &$samples): void + public function transform(array &$samples, ?array &$targets = null): void { if ($this->samples === []) { throw new InvalidOperationException('Missing training samples for Imputer.'); diff --git a/src/Preprocessing/LabelEncoder.php b/src/Preprocessing/LabelEncoder.php index 9b5df2c..1e612a1 100644 --- a/src/Preprocessing/LabelEncoder.php +++ b/src/Preprocessing/LabelEncoder.php @@ -22,7 +22,7 @@ final class LabelEncoder implements Preprocessor } } - public function transform(array &$samples): void + public function transform(array &$samples, ?array &$targets = null): void { foreach ($samples as &$sample) { $sample = $this->classes[(string) $sample]; diff --git a/src/Preprocessing/Normalizer.php b/src/Preprocessing/Normalizer.php index 9888e0e..5ba43e6 100644 --- a/src/Preprocessing/Normalizer.php +++ b/src/Preprocessing/Normalizer.php @@ -66,7 +66,7 @@ class Normalizer implements Preprocessor $this->fitted = true; } - public function transform(array &$samples): void + public function transform(array &$samples, ?array &$targets = null): void { $methods = [ self::NORM_L1 => 'normalizeL1', diff --git a/src/Preprocessing/NumberConverter.php b/src/Preprocessing/NumberConverter.php new file mode 100644 index 0000000..68247b1 --- /dev/null +++ b/src/Preprocessing/NumberConverter.php @@ -0,0 +1,47 @@ +transformTargets = $transformTargets; + $this->nonNumericPlaceholder = $nonNumericPlaceholder; + } + + public function fit(array $samples, ?array $targets = null): void + { + //nothing to do + } + + public function transform(array &$samples, ?array &$targets = null): void + { + foreach ($samples as &$sample) { + foreach ($sample as &$feature) { + $feature = is_numeric($feature) ? (float) $feature : $this->nonNumericPlaceholder; + } + } + + if ($this->transformTargets && is_array($targets)) { + foreach ($targets as &$target) { + $target = is_numeric($target) ? (float) $target : $this->nonNumericPlaceholder; + } + } + } +} diff --git a/src/Transformer.php b/src/Transformer.php index 7350e2c..3a9b91d 100644 --- a/src/Transformer.php +++ b/src/Transformer.php @@ -11,5 +11,5 @@ interface Transformer */ public function fit(array $samples, ?array $targets = null): void; - public function transform(array &$samples): void; + public function transform(array &$samples, ?array &$targets = null): void; } diff --git a/tests/Preprocessing/NumberConverterTest.php b/tests/Preprocessing/NumberConverterTest.php new file mode 100644 index 0000000..287b739 --- /dev/null +++ b/tests/Preprocessing/NumberConverterTest.php @@ -0,0 +1,47 @@ +transform($samples, $targets); + + self::assertEquals([[1.0, -4.0], [2.0, 3.0], [3.0, 112.5], [5.0, 0.0004]], $samples); + self::assertEquals(['1', '1', '2', '2'], $targets); + } + + public function testConvertTargets(): void + { + $samples = [['1', '-4'], ['2.0', 3.0], ['3', '112.5'], ['5', '0.0004']]; + $targets = ['1', '1', '2', 'not']; + + $converter = new NumberConverter(true); + $converter->transform($samples, $targets); + + self::assertEquals([[1.0, -4.0], [2.0, 3.0], [3.0, 112.5], [5.0, 0.0004]], $samples); + self::assertEquals([1.0, 1.0, 2.0, null], $targets); + } + + public function testConvertWithPlaceholder(): void + { + $samples = [['invalid'], ['13.5']]; + $targets = ['invalid', '2']; + + $converter = new NumberConverter(true, 'missing'); + $converter->transform($samples, $targets); + + self::assertEquals([['missing'], [13.5]], $samples); + self::assertEquals(['missing', 2.0], $targets); + } +}