mirror of
https://github.com/Llewellynvdm/php-ml.git
synced 2024-11-21 20:45:10 +00:00
Implement ColumnFilter preprocessor (#378)
This commit is contained in:
parent
717f236ca9
commit
417174d143
42
src/Preprocessing/ColumnFilter.php
Normal file
42
src/Preprocessing/ColumnFilter.php
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
<?php
|
||||||
|
|
||||||
|
declare(strict_types=1);
|
||||||
|
|
||||||
|
namespace Phpml\Preprocessing;
|
||||||
|
|
||||||
|
final class ColumnFilter implements Preprocessor
|
||||||
|
{
|
||||||
|
/**
|
||||||
|
* @var string[]
|
||||||
|
*/
|
||||||
|
private $datasetColumns = [];
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @var string[]
|
||||||
|
*/
|
||||||
|
private $filterColumns = [];
|
||||||
|
|
||||||
|
public function __construct(array $datasetColumns, array $filterColumns)
|
||||||
|
{
|
||||||
|
$this->datasetColumns = array_map(static function (string $column): string {
|
||||||
|
return $column;
|
||||||
|
}, $datasetColumns);
|
||||||
|
$this->filterColumns = array_map(static function (string $column): string {
|
||||||
|
return $column;
|
||||||
|
}, $filterColumns);
|
||||||
|
}
|
||||||
|
|
||||||
|
public function fit(array $samples, ?array $targets = null): void
|
||||||
|
{
|
||||||
|
//nothing to do
|
||||||
|
}
|
||||||
|
|
||||||
|
public function transform(array &$samples, ?array &$targets = null): void
|
||||||
|
{
|
||||||
|
$keys = array_intersect($this->datasetColumns, $this->filterColumns);
|
||||||
|
|
||||||
|
foreach ($samples as &$sample) {
|
||||||
|
$sample = array_values(array_intersect_key($sample, $keys));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
27
tests/Preprocessing/ColumnFilterTest.php
Normal file
27
tests/Preprocessing/ColumnFilterTest.php
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
<?php
|
||||||
|
|
||||||
|
declare(strict_types=1);
|
||||||
|
|
||||||
|
namespace Phpml\Tests\Preprocessing;
|
||||||
|
|
||||||
|
use Phpml\Preprocessing\ColumnFilter;
|
||||||
|
use PHPUnit\Framework\TestCase;
|
||||||
|
|
||||||
|
final class ColumnFilterTest extends TestCase
|
||||||
|
{
|
||||||
|
public function testFilterColumns(): void
|
||||||
|
{
|
||||||
|
$datasetColumns = ['age', 'income', 'kids', 'beersPerWeek'];
|
||||||
|
$filterColumns = ['income', 'beersPerWeek'];
|
||||||
|
$samples = [
|
||||||
|
[21, 100000, 1, 4],
|
||||||
|
[35, 120000, 0, 12],
|
||||||
|
[33, 200000, 4, 0],
|
||||||
|
];
|
||||||
|
|
||||||
|
$filter = new ColumnFilter($datasetColumns, $filterColumns);
|
||||||
|
$filter->transform($samples);
|
||||||
|
|
||||||
|
self::assertEquals([[100000, 4], [120000, 12], [200000, 0]], $samples);
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user