mirror of
https://github.com/Llewellynvdm/php-ml.git
synced 2024-11-21 20:45:10 +00:00
Implement VarianceThreshold - simple baseline approach to feature selection. (#228)
* Add sum of squares deviations * Calculate population variance * Add VarianceThreshold - feature selection transformer * Add docs about VarianceThreshold * Add missing code for pipeline usage
This commit is contained in:
parent
4b5d57fd6f
commit
3ba35918a3
@ -87,6 +87,8 @@ Example scripts are available in a separate repository [php-ai/php-ml-examples](
|
||||
* Cross Validation
|
||||
* [Random Split](http://php-ml.readthedocs.io/en/latest/machine-learning/cross-validation/random-split/)
|
||||
* [Stratified Random Split](http://php-ml.readthedocs.io/en/latest/machine-learning/cross-validation/stratified-random-split/)
|
||||
* Feature Selection
|
||||
* [Variance Threshold](http://php-ml.readthedocs.io/en/latest/machine-learning/feature-selection/variance-threshold/)
|
||||
* Preprocessing
|
||||
* [Normalization](http://php-ml.readthedocs.io/en/latest/machine-learning/preprocessing/normalization/)
|
||||
* [Imputation missing values](http://php-ml.readthedocs.io/en/latest/machine-learning/preprocessing/imputation-missing-values/)
|
||||
|
@ -76,6 +76,8 @@ Example scripts are available in a separate repository [php-ai/php-ml-examples](
|
||||
* Cross Validation
|
||||
* [Random Split](machine-learning/cross-validation/random-split.md)
|
||||
* [Stratified Random Split](machine-learning/cross-validation/stratified-random-split.md)
|
||||
* Feature Selection
|
||||
* [Variance Threshold](machine-learning/feature-selection/variance-threshold.md)
|
||||
* Preprocessing
|
||||
* [Normalization](machine-learning/preprocessing/normalization.md)
|
||||
* [Imputation missing values](machine-learning/preprocessing/imputation-missing-values.md)
|
||||
|
@ -0,0 +1,60 @@
|
||||
# Variance Threshold
|
||||
|
||||
`VarianceThreshold` is a simple baseline approach to feature selection.
|
||||
It removes all features whose variance doesn’t meet some threshold.
|
||||
By default, it removes all zero-variance features, i.e. features that have the same value in all samples.
|
||||
|
||||
## Constructor Parameters
|
||||
|
||||
* $threshold (float) - features with a variance lower than this threshold will be removed (default 0.0)
|
||||
|
||||
```php
|
||||
use Phpml\FeatureSelection\VarianceThreshold;
|
||||
|
||||
$transformer = new VarianceThreshold(0.15);
|
||||
```
|
||||
|
||||
## Example of use
|
||||
|
||||
As an example, suppose that we have a dataset with boolean features and
|
||||
we want to remove all features that are either one or zero (on or off)
|
||||
in more than 80% of the samples.
|
||||
Boolean features are Bernoulli random variables, and the variance of such
|
||||
variables is given by
|
||||
```
|
||||
Var[X] = p(1 - p)
|
||||
```
|
||||
so we can select using the threshold .8 * (1 - .8):
|
||||
|
||||
```php
|
||||
use Phpml\FeatureSelection\VarianceThreshold;
|
||||
|
||||
$samples = [[0, 0, 1], [0, 1, 0], [1, 0, 0], [0, 1, 1], [0, 1, 0], [0, 1, 1]];
|
||||
$transformer = new VarianceThreshold(0.8 * (1 - 0.8));
|
||||
|
||||
$transformer->fit($samples);
|
||||
$transformer->transform($samples);
|
||||
|
||||
/*
|
||||
$samples = [[0, 1], [1, 0], [0, 0], [1, 1], [1, 0], [1, 1]];
|
||||
*/
|
||||
```
|
||||
|
||||
## Pipeline
|
||||
|
||||
`VarianceThreshold` implements `Transformer` interface so it can be used as part of pipeline:
|
||||
|
||||
```php
|
||||
use Phpml\FeatureSelection\VarianceThreshold;
|
||||
use Phpml\Classification\SVC;
|
||||
use Phpml\FeatureExtraction\TfIdfTransformer;
|
||||
use Phpml\Pipeline;
|
||||
|
||||
$transformers = [
|
||||
new TfIdfTransformer(),
|
||||
new VarianceThreshold(0.1)
|
||||
];
|
||||
$estimator = new SVC();
|
||||
|
||||
$pipeline = new Pipeline($transformers, $estimator);
|
||||
```
|
@ -25,6 +25,8 @@ pages:
|
||||
- Cross Validation:
|
||||
- RandomSplit: machine-learning/cross-validation/random-split.md
|
||||
- Stratified Random Split: machine-learning/cross-validation/stratified-random-split.md
|
||||
- Feature Selection:
|
||||
- VarianceThreshold: machine-learning/feature-selection/variance-threshold.md
|
||||
- Preprocessing:
|
||||
- Normalization: machine-learning/preprocessing/normalization.md
|
||||
- Imputation missing values: machine-learning/preprocessing/imputation-missing-values.md
|
||||
|
59
src/FeatureSelection/VarianceThreshold.php
Normal file
59
src/FeatureSelection/VarianceThreshold.php
Normal file
@ -0,0 +1,59 @@
|
||||
<?php
|
||||
|
||||
declare(strict_types=1);
|
||||
|
||||
namespace Phpml\FeatureSelection;
|
||||
|
||||
use Phpml\Exception\InvalidArgumentException;
|
||||
use Phpml\Math\Matrix;
|
||||
use Phpml\Math\Statistic\Variance;
|
||||
use Phpml\Transformer;
|
||||
|
||||
final class VarianceThreshold implements Transformer
|
||||
{
|
||||
/**
|
||||
* @var float
|
||||
*/
|
||||
private $threshold;
|
||||
|
||||
/**
|
||||
* @var array
|
||||
*/
|
||||
private $variances = [];
|
||||
|
||||
/**
|
||||
* @var array
|
||||
*/
|
||||
private $keepColumns = [];
|
||||
|
||||
public function __construct(float $threshold = 0.0)
|
||||
{
|
||||
if ($threshold < 0) {
|
||||
throw new InvalidArgumentException('Threshold can\'t be lower than zero');
|
||||
}
|
||||
|
||||
$this->threshold = $threshold;
|
||||
$this->variances = [];
|
||||
$this->keepColumns = [];
|
||||
}
|
||||
|
||||
public function fit(array $samples): void
|
||||
{
|
||||
$this->variances = array_map(function (array $column) {
|
||||
return Variance::population($column);
|
||||
}, Matrix::transposeArray($samples));
|
||||
|
||||
foreach ($this->variances as $column => $variance) {
|
||||
if ($variance > $this->threshold) {
|
||||
$this->keepColumns[$column] = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public function transform(array &$samples): void
|
||||
{
|
||||
foreach ($samples as &$sample) {
|
||||
$sample = array_values(array_intersect_key($sample, $this->keepColumns));
|
||||
}
|
||||
}
|
||||
}
|
@ -9,27 +9,24 @@ use Phpml\Exception\InvalidArgumentException;
|
||||
class StandardDeviation
|
||||
{
|
||||
/**
|
||||
* @param array|float[] $a
|
||||
*
|
||||
* @throws InvalidArgumentException
|
||||
* @param array|float[]|int[] $numbers
|
||||
*/
|
||||
public static function population(array $a, bool $sample = true): float
|
||||
public static function population(array $numbers, bool $sample = true): float
|
||||
{
|
||||
if (empty($a)) {
|
||||
if (empty($numbers)) {
|
||||
throw InvalidArgumentException::arrayCantBeEmpty();
|
||||
}
|
||||
|
||||
$n = count($a);
|
||||
$n = count($numbers);
|
||||
|
||||
if ($sample && $n === 1) {
|
||||
throw InvalidArgumentException::arraySizeToSmall(2);
|
||||
}
|
||||
|
||||
$mean = Mean::arithmetic($a);
|
||||
$mean = Mean::arithmetic($numbers);
|
||||
$carry = 0.0;
|
||||
foreach ($a as $val) {
|
||||
$d = $val - $mean;
|
||||
$carry += $d * $d;
|
||||
foreach ($numbers as $val) {
|
||||
$carry += ($val - $mean) ** 2;
|
||||
}
|
||||
|
||||
if ($sample) {
|
||||
@ -38,4 +35,26 @@ class StandardDeviation
|
||||
|
||||
return sqrt((float) ($carry / $n));
|
||||
}
|
||||
|
||||
/**
|
||||
* Sum of squares deviations
|
||||
* ∑⟮xᵢ - μ⟯²
|
||||
*
|
||||
* @param array|float[]|int[] $numbers
|
||||
*/
|
||||
public static function sumOfSquares(array $numbers): float
|
||||
{
|
||||
if (empty($numbers)) {
|
||||
throw InvalidArgumentException::arrayCantBeEmpty();
|
||||
}
|
||||
|
||||
$mean = Mean::arithmetic($numbers);
|
||||
|
||||
return array_sum(array_map(
|
||||
function ($val) use ($mean) {
|
||||
return ($val - $mean) ** 2;
|
||||
},
|
||||
$numbers
|
||||
));
|
||||
}
|
||||
}
|
||||
|
27
src/Math/Statistic/Variance.php
Normal file
27
src/Math/Statistic/Variance.php
Normal file
@ -0,0 +1,27 @@
|
||||
<?php
|
||||
|
||||
declare(strict_types=1);
|
||||
|
||||
namespace Phpml\Math\Statistic;
|
||||
|
||||
/**
|
||||
* In probability theory and statistics, variance is the expectation of the squared deviation of a random variable from its mean.
|
||||
* Informally, it measures how far a set of (random) numbers are spread out from their average value
|
||||
* https://en.wikipedia.org/wiki/Variance
|
||||
*/
|
||||
final class Variance
|
||||
{
|
||||
/**
|
||||
* Population variance
|
||||
* Use when all possible observations of the system are present.
|
||||
* If used with a subset of data (sample variance), it will be a biased variance.
|
||||
*
|
||||
* ∑⟮xᵢ - μ⟯²
|
||||
* σ² = ----------
|
||||
* N
|
||||
*/
|
||||
public static function population(array $population): float
|
||||
{
|
||||
return StandardDeviation::sumOfSquares($population) / count($population);
|
||||
}
|
||||
}
|
39
tests/FeatureSelection/VarianceThresholdTest.php
Normal file
39
tests/FeatureSelection/VarianceThresholdTest.php
Normal file
@ -0,0 +1,39 @@
|
||||
<?php
|
||||
|
||||
declare(strict_types=1);
|
||||
|
||||
namespace Phpml\Tests\FeatureSelection;
|
||||
|
||||
use Phpml\Exception\InvalidArgumentException;
|
||||
use Phpml\FeatureSelection\VarianceThreshold;
|
||||
use PHPUnit\Framework\TestCase;
|
||||
|
||||
final class VarianceThresholdTest extends TestCase
|
||||
{
|
||||
public function testVarianceThreshold(): void
|
||||
{
|
||||
$samples = [[0, 0, 1], [0, 1, 0], [1, 0, 0], [0, 1, 1], [0, 1, 0], [0, 1, 1]];
|
||||
$transformer = new VarianceThreshold(0.8 * (1 - 0.8)); // 80% of samples - boolean features are Bernoulli random variables
|
||||
$transformer->fit($samples);
|
||||
$transformer->transform($samples);
|
||||
|
||||
// expecting to remove first column
|
||||
self::assertEquals([[0, 1], [1, 0], [0, 0], [1, 1], [1, 0], [1, 1]], $samples);
|
||||
}
|
||||
|
||||
public function testVarianceThresholdWithZeroThreshold(): void
|
||||
{
|
||||
$samples = [[0, 2, 0, 3], [0, 1, 4, 3], [0, 1, 1, 3]];
|
||||
$transformer = new VarianceThreshold();
|
||||
$transformer->fit($samples);
|
||||
$transformer->transform($samples);
|
||||
|
||||
self::assertEquals([[2, 0], [1, 4], [1, 1]], $samples);
|
||||
}
|
||||
|
||||
public function testThrowExceptionWhenThresholdBelowZero(): void
|
||||
{
|
||||
$this->expectException(InvalidArgumentException::class);
|
||||
new VarianceThreshold(-0.1);
|
||||
}
|
||||
}
|
@ -37,4 +37,29 @@ class StandardDeviationTest extends TestCase
|
||||
$this->expectException(InvalidArgumentException::class);
|
||||
StandardDeviation::population([1]);
|
||||
}
|
||||
|
||||
/**
|
||||
* @dataProvider dataProviderForSumOfSquaresDeviations
|
||||
*/
|
||||
public function testSumOfSquares(array $numbers, float $sum): void
|
||||
{
|
||||
self::assertEquals($sum, StandardDeviation::sumOfSquares($numbers), '', 0.0001);
|
||||
}
|
||||
|
||||
public function dataProviderForSumOfSquaresDeviations(): array
|
||||
{
|
||||
return [
|
||||
[[3, 6, 7, 11, 12, 13, 17], 136.8571],
|
||||
[[6, 11, 12, 14, 15, 20, 21], 162.8571],
|
||||
[[1, 2, 3, 6, 7, 11, 12], 112],
|
||||
[[1, 2, 3, 4, 5, 6, 7, 8, 9, 0], 82.5],
|
||||
[[34, 253, 754, 2342, 75, 23, 876, 4, 1, -34, -345, 754, -377, 3, 0], 6453975.7333],
|
||||
];
|
||||
}
|
||||
|
||||
public function testThrowExceptionOnEmptyArraySumOfSquares(): void
|
||||
{
|
||||
$this->expectException(InvalidArgumentException::class);
|
||||
StandardDeviation::sumOfSquares([]);
|
||||
}
|
||||
}
|
||||
|
34
tests/Math/Statistic/VarianceTest.php
Normal file
34
tests/Math/Statistic/VarianceTest.php
Normal file
@ -0,0 +1,34 @@
|
||||
<?php
|
||||
|
||||
declare(strict_types=1);
|
||||
|
||||
namespace Phpml\Tests\Math\Statistic;
|
||||
|
||||
use Phpml\Math\Statistic\Variance;
|
||||
use PHPUnit\Framework\TestCase;
|
||||
|
||||
final class VarianceTest extends TestCase
|
||||
{
|
||||
/**
|
||||
* @dataProvider dataProviderForPopulationVariance
|
||||
*/
|
||||
public function testVarianceFromInt(array $numbers, float $variance): void
|
||||
{
|
||||
self::assertEquals($variance, Variance::population($numbers), '', 0.001);
|
||||
}
|
||||
|
||||
public function dataProviderForPopulationVariance()
|
||||
{
|
||||
return [
|
||||
[[0, 0, 0, 0, 0, 1], 0.138],
|
||||
[[-11, 0, 10, 20, 30], 208.16],
|
||||
[[7, 8, 9, 10, 11, 12, 13], 4.0],
|
||||
[[300, 570, 170, 730, 300], 41944],
|
||||
[[-4, 2, 7, 8, 3], 18.16],
|
||||
[[3, 7, 34, 25, 46, 7754, 3, 6], 6546331.937],
|
||||
[[4, 6, 1, 1, 1, 1, 2, 2, 1, 3], 2.56],
|
||||
[[-3732, 5, 27, 9248, -174], 18741676.56],
|
||||
[[-554, -555, -554, -554, -555, -555, -556], 0.4897],
|
||||
];
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user