diff --git a/CHANGELOG.md b/CHANGELOG.md index 63507f0..662a086 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,10 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] +### Added +- [Preprocessing] Implement LabelEncoder + ## [0.8.0] - 2019-03-20 ### Added - [Tokenization] Added NGramTokenizer (#350) diff --git a/README.md b/README.md index 4df5730..999b48b 100644 --- a/README.md +++ b/README.md @@ -100,6 +100,7 @@ Public datasets are available in a separate repository [php-ai/php-ml-datasets]( * Preprocessing * [Normalization](http://php-ml.readthedocs.io/en/latest/machine-learning/preprocessing/normalization/) * [Imputation missing values](http://php-ml.readthedocs.io/en/latest/machine-learning/preprocessing/imputation-missing-values/) + * LabelEncoder * Feature Extraction * [Token Count Vectorizer](http://php-ml.readthedocs.io/en/latest/machine-learning/feature-extraction/token-count-vectorizer/) * NGramTokenizer diff --git a/docs/index.md b/docs/index.md index 3c6ede2..25ad6b0 100644 --- a/docs/index.md +++ b/docs/index.md @@ -85,6 +85,7 @@ Example scripts are available in a separate repository [php-ai/php-ml-examples]( * Preprocessing * [Normalization](machine-learning/preprocessing/normalization.md) * [Imputation missing values](machine-learning/preprocessing/imputation-missing-values.md) + * LabelEncoder * Feature Extraction * [Token Count Vectorizer](machine-learning/feature-extraction/token-count-vectorizer.md) * [Tf-idf Transformer](machine-learning/feature-extraction/tf-idf-transformer.md) diff --git a/src/Preprocessing/LabelEncoder.php b/src/Preprocessing/LabelEncoder.php new file mode 100644 index 0000000..9b5df2c --- /dev/null +++ b/src/Preprocessing/LabelEncoder.php @@ -0,0 +1,47 @@ +classes = []; + + foreach ($samples as $sample) { + if (!isset($this->classes[(string) $sample])) { + $this->classes[(string) $sample] = count($this->classes); + } + } + } + + public function transform(array &$samples): void + { + foreach ($samples as &$sample) { + $sample = $this->classes[(string) $sample]; + } + } + + public function inverseTransform(array &$samples): void + { + $classes = array_flip($this->classes); + foreach ($samples as &$sample) { + $sample = $classes[$sample]; + } + } + + /** + * @return string[] + */ + public function classes(): array + { + return array_keys($this->classes); + } +} diff --git a/tests/Preprocessing/LabelEncoderTest.php b/tests/Preprocessing/LabelEncoderTest.php new file mode 100644 index 0000000..71dc87e --- /dev/null +++ b/tests/Preprocessing/LabelEncoderTest.php @@ -0,0 +1,68 @@ +fit($samples); + $le->transform($samples); + + self::assertEquals($transformed, $samples); + } + + public function labelEncoderDataProvider(): array + { + return [ + [['one', 'one', 'two', 'three'], [0, 0, 1, 2]], + [['one', 1, 'two', 'three'], [0, 1, 2, 3]], + [['one', null, 'two', 'three'], [0, 1, 2, 3]], + [['one', 'one', 'one', 'one'], [0, 0, 0, 0]], + [['one', 'one', 'one', 'one', null, null, 1, 1, 2, 'two'], [0, 0, 0, 0, 1, 1, 2, 2, 3, 4]], + ]; + } + + public function testResetClassesAfterNextFit(): void + { + $samples = ['Shanghai', 'Beijing', 'Karachi']; + + $le = new LabelEncoder(); + $le->fit($samples); + + self::assertEquals(['Shanghai', 'Beijing', 'Karachi'], $le->classes()); + + $samples = ['Istanbul', 'Dhaka', 'Tokyo']; + + $le->fit($samples); + + self::assertEquals(['Istanbul', 'Dhaka', 'Tokyo'], $le->classes()); + } + + public function testFitAndTransformFullCycle(): void + { + $samples = ['Shanghai', 'Beijing', 'Karachi', 'Beijing', 'Beijing', 'Karachi']; + $encoded = [0, 1, 2, 1, 1, 2]; + + $le = new LabelEncoder(); + $le->fit($samples); + + self::assertEquals(['Shanghai', 'Beijing', 'Karachi'], $le->classes()); + + $transformed = $samples; + $le->transform($transformed); + self::assertEquals($encoded, $transformed); + + $le->inverseTransform($transformed); + self::assertEquals($samples, $transformed); + } +}