mirror of
https://github.com/Llewellynvdm/php-ml.git
synced 2024-11-21 12:35:10 +00:00
Implement OneHotEncoder (#384)
This commit is contained in:
parent
3baf1520e3
commit
4590d5cc32
@ -4,9 +4,16 @@ All notable changes to this project will be documented in this file.
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [Unreleased]
|
||||
## [0.9.0] - Unreleased
|
||||
### Added
|
||||
- [Preprocessing] Implement LabelEncoder
|
||||
- [Preprocessing] Implement ColumnFilter
|
||||
- [Preprocessing] Implement LambdaTransformer
|
||||
- [Preprocessing] Implement NumberConverter
|
||||
- [Preprocessing] Implement OneHotEncoder
|
||||
- [Workflow] Implement FeatureUnion
|
||||
- [Metric] Add Regression metrics: meanSquaredError, meanSquaredLogarithmicError, meanAbsoluteError, medianAbsoluteError, r2Score, maxError
|
||||
- [Regression] Implement DecisionTreeRegressor
|
||||
|
||||
## [0.8.0] - 2019-03-20
|
||||
### Added
|
||||
|
@ -107,6 +107,7 @@ Public datasets are available in a separate repository [php-ai/php-ml-datasets](
|
||||
* LambdaTransformer
|
||||
* NumberConverter
|
||||
* ColumnFilter
|
||||
* OneHotEncoder
|
||||
* Feature Extraction
|
||||
* [Token Count Vectorizer](http://php-ml.readthedocs.io/en/latest/machine-learning/feature-extraction/token-count-vectorizer/)
|
||||
* NGramTokenizer
|
||||
|
66
src/Preprocessing/OneHotEncoder.php
Normal file
66
src/Preprocessing/OneHotEncoder.php
Normal file
@ -0,0 +1,66 @@
|
||||
<?php
|
||||
|
||||
declare(strict_types=1);
|
||||
|
||||
namespace Phpml\Preprocessing;
|
||||
|
||||
use Phpml\Exception\InvalidArgumentException;
|
||||
|
||||
final class OneHotEncoder implements Preprocessor
|
||||
{
|
||||
/**
|
||||
* @var bool
|
||||
*/
|
||||
private $ignoreUnknown;
|
||||
|
||||
/**
|
||||
* @var array
|
||||
*/
|
||||
private $categories = [];
|
||||
|
||||
public function __construct(bool $ignoreUnknown = false)
|
||||
{
|
||||
$this->ignoreUnknown = $ignoreUnknown;
|
||||
}
|
||||
|
||||
public function fit(array $samples, ?array $targets = null): void
|
||||
{
|
||||
foreach (array_keys(array_values(current($samples))) as $column) {
|
||||
$this->fitColumn($column, array_values(array_unique(array_column($samples, $column))));
|
||||
}
|
||||
}
|
||||
|
||||
public function transform(array &$samples, ?array &$targets = null): void
|
||||
{
|
||||
foreach ($samples as &$sample) {
|
||||
$sample = $this->transformSample(array_values($sample));
|
||||
}
|
||||
}
|
||||
|
||||
private function fitColumn(int $column, array $values): void
|
||||
{
|
||||
$count = count($values);
|
||||
foreach ($values as $index => $value) {
|
||||
$map = array_fill(0, $count, 0);
|
||||
$map[$index] = 1;
|
||||
$this->categories[$column][$value] = $map;
|
||||
}
|
||||
}
|
||||
|
||||
private function transformSample(array $sample): array
|
||||
{
|
||||
$encoded = [];
|
||||
foreach ($sample as $column => $feature) {
|
||||
if (!isset($this->categories[$column][$feature]) && !$this->ignoreUnknown) {
|
||||
throw new InvalidArgumentException(sprintf('Missing category "%s" for column %s in trained encoder', $feature, $column));
|
||||
}
|
||||
|
||||
$encoded = array_merge(
|
||||
$encoded,
|
||||
$this->categories[$column][$feature] ?? array_fill(0, count($this->categories[$column]), 0)
|
||||
);
|
||||
}
|
||||
|
||||
return $encoded;
|
||||
}
|
||||
}
|
66
tests/Preprocessing/OneHotEncoderTest.php
Normal file
66
tests/Preprocessing/OneHotEncoderTest.php
Normal file
@ -0,0 +1,66 @@
|
||||
<?php
|
||||
|
||||
declare(strict_types=1);
|
||||
|
||||
namespace Phpml\Tests\Preprocessing;
|
||||
|
||||
use Phpml\Exception\InvalidArgumentException;
|
||||
use Phpml\Preprocessing\OneHotEncoder;
|
||||
use PHPUnit\Framework\TestCase;
|
||||
|
||||
final class OneHotEncoderTest extends TestCase
|
||||
{
|
||||
public function testOneHotEncodingWithoutIgnoreUnknown(): void
|
||||
{
|
||||
$samples = [
|
||||
['fish', 'New York', 'regression'],
|
||||
['dog', 'New York', 'regression'],
|
||||
['fish', 'Vancouver', 'classification'],
|
||||
['dog', 'Vancouver', 'regression'],
|
||||
];
|
||||
|
||||
$encoder = new OneHotEncoder();
|
||||
$encoder->fit($samples);
|
||||
$encoder->transform($samples);
|
||||
|
||||
self::assertEquals([
|
||||
[1, 0, 1, 0, 1, 0],
|
||||
[0, 1, 1, 0, 1, 0],
|
||||
[1, 0, 0, 1, 0, 1],
|
||||
[0, 1, 0, 1, 1, 0],
|
||||
], $samples);
|
||||
}
|
||||
|
||||
public function testThrowExceptionWhenUnknownCategory(): void
|
||||
{
|
||||
$encoder = new OneHotEncoder();
|
||||
$encoder->fit([
|
||||
['fish', 'New York', 'regression'],
|
||||
['dog', 'New York', 'regression'],
|
||||
['fish', 'Vancouver', 'classification'],
|
||||
['dog', 'Vancouver', 'regression'],
|
||||
]);
|
||||
$samples = [['fish', 'New York', 'ka boom']];
|
||||
|
||||
$this->expectException(InvalidArgumentException::class);
|
||||
|
||||
$encoder->transform($samples);
|
||||
}
|
||||
|
||||
public function testIgnoreMissingCategory(): void
|
||||
{
|
||||
$encoder = new OneHotEncoder(true);
|
||||
$encoder->fit([
|
||||
['fish', 'New York', 'regression'],
|
||||
['dog', 'New York', 'regression'],
|
||||
['fish', 'Vancouver', 'classification'],
|
||||
['dog', 'Vancouver', 'regression'],
|
||||
]);
|
||||
$samples = [['ka', 'boom', 'riko']];
|
||||
$encoder->transform($samples);
|
||||
|
||||
self::assertEquals([
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
], $samples);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user