Implement VarianceThreshold - simple baseline approach to feature selection. (#228)

* Add sum of squares deviations

* Calculate population variance

* Add VarianceThreshold - feature selection transformer

* Add docs about VarianceThreshold

* Add missing code for pipeline usage
This commit is contained in:
Arkadiusz Kondas 2018-02-10 18:07:09 +01:00 committed by GitHub
parent 4b5d57fd6f
commit 3ba35918a3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 279 additions and 10 deletions

View File

@ -87,6 +87,8 @@ Example scripts are available in a separate repository [php-ai/php-ml-examples](
* Cross Validation
* [Random Split](http://php-ml.readthedocs.io/en/latest/machine-learning/cross-validation/random-split/)
* [Stratified Random Split](http://php-ml.readthedocs.io/en/latest/machine-learning/cross-validation/stratified-random-split/)
* Feature Selection
* [Variance Threshold](http://php-ml.readthedocs.io/en/latest/machine-learning/feature-selection/variance-threshold/)
* Preprocessing
* [Normalization](http://php-ml.readthedocs.io/en/latest/machine-learning/preprocessing/normalization/)
* [Imputation missing values](http://php-ml.readthedocs.io/en/latest/machine-learning/preprocessing/imputation-missing-values/)

View File

@ -76,6 +76,8 @@ Example scripts are available in a separate repository [php-ai/php-ml-examples](
* Cross Validation
* [Random Split](machine-learning/cross-validation/random-split.md)
* [Stratified Random Split](machine-learning/cross-validation/stratified-random-split.md)
* Feature Selection
* [Variance Threshold](machine-learning/feature-selection/variance-threshold.md)
* Preprocessing
* [Normalization](machine-learning/preprocessing/normalization.md)
* [Imputation missing values](machine-learning/preprocessing/imputation-missing-values.md)

View File

@ -0,0 +1,60 @@
# Variance Threshold
`VarianceThreshold` is a simple baseline approach to feature selection.
It removes all features whose variance doesnt meet some threshold.
By default, it removes all zero-variance features, i.e. features that have the same value in all samples.
## Constructor Parameters
* $threshold (float) - features with a variance lower than this threshold will be removed (default 0.0)
```php
use Phpml\FeatureSelection\VarianceThreshold;
$transformer = new VarianceThreshold(0.15);
```
## Example of use
As an example, suppose that we have a dataset with boolean features and
we want to remove all features that are either one or zero (on or off)
in more than 80% of the samples.
Boolean features are Bernoulli random variables, and the variance of such
variables is given by
```
Var[X] = p(1 - p)
```
so we can select using the threshold .8 * (1 - .8):
```php
use Phpml\FeatureSelection\VarianceThreshold;
$samples = [[0, 0, 1], [0, 1, 0], [1, 0, 0], [0, 1, 1], [0, 1, 0], [0, 1, 1]];
$transformer = new VarianceThreshold(0.8 * (1 - 0.8));
$transformer->fit($samples);
$transformer->transform($samples);
/*
$samples = [[0, 1], [1, 0], [0, 0], [1, 1], [1, 0], [1, 1]];
*/
```
## Pipeline
`VarianceThreshold` implements `Transformer` interface so it can be used as part of pipeline:
```php
use Phpml\FeatureSelection\VarianceThreshold;
use Phpml\Classification\SVC;
use Phpml\FeatureExtraction\TfIdfTransformer;
use Phpml\Pipeline;
$transformers = [
new TfIdfTransformer(),
new VarianceThreshold(0.1)
];
$estimator = new SVC();
$pipeline = new Pipeline($transformers, $estimator);
```

View File

@ -25,6 +25,8 @@ pages:
- Cross Validation:
- RandomSplit: machine-learning/cross-validation/random-split.md
- Stratified Random Split: machine-learning/cross-validation/stratified-random-split.md
- Feature Selection:
- VarianceThreshold: machine-learning/feature-selection/variance-threshold.md
- Preprocessing:
- Normalization: machine-learning/preprocessing/normalization.md
- Imputation missing values: machine-learning/preprocessing/imputation-missing-values.md

View File

@ -0,0 +1,59 @@
<?php
declare(strict_types=1);
namespace Phpml\FeatureSelection;
use Phpml\Exception\InvalidArgumentException;
use Phpml\Math\Matrix;
use Phpml\Math\Statistic\Variance;
use Phpml\Transformer;
final class VarianceThreshold implements Transformer
{
/**
* @var float
*/
private $threshold;
/**
* @var array
*/
private $variances = [];
/**
* @var array
*/
private $keepColumns = [];
public function __construct(float $threshold = 0.0)
{
if ($threshold < 0) {
throw new InvalidArgumentException('Threshold can\'t be lower than zero');
}
$this->threshold = $threshold;
$this->variances = [];
$this->keepColumns = [];
}
public function fit(array $samples): void
{
$this->variances = array_map(function (array $column) {
return Variance::population($column);
}, Matrix::transposeArray($samples));
foreach ($this->variances as $column => $variance) {
if ($variance > $this->threshold) {
$this->keepColumns[$column] = true;
}
}
}
public function transform(array &$samples): void
{
foreach ($samples as &$sample) {
$sample = array_values(array_intersect_key($sample, $this->keepColumns));
}
}
}

View File

@ -9,27 +9,24 @@ use Phpml\Exception\InvalidArgumentException;
class StandardDeviation
{
/**
* @param array|float[] $a
*
* @throws InvalidArgumentException
* @param array|float[]|int[] $numbers
*/
public static function population(array $a, bool $sample = true): float
public static function population(array $numbers, bool $sample = true): float
{
if (empty($a)) {
if (empty($numbers)) {
throw InvalidArgumentException::arrayCantBeEmpty();
}
$n = count($a);
$n = count($numbers);
if ($sample && $n === 1) {
throw InvalidArgumentException::arraySizeToSmall(2);
}
$mean = Mean::arithmetic($a);
$mean = Mean::arithmetic($numbers);
$carry = 0.0;
foreach ($a as $val) {
$d = $val - $mean;
$carry += $d * $d;
foreach ($numbers as $val) {
$carry += ($val - $mean) ** 2;
}
if ($sample) {
@ -38,4 +35,26 @@ class StandardDeviation
return sqrt((float) ($carry / $n));
}
/**
* Sum of squares deviations
* ∑⟮xᵢ - μ⟯²
*
* @param array|float[]|int[] $numbers
*/
public static function sumOfSquares(array $numbers): float
{
if (empty($numbers)) {
throw InvalidArgumentException::arrayCantBeEmpty();
}
$mean = Mean::arithmetic($numbers);
return array_sum(array_map(
function ($val) use ($mean) {
return ($val - $mean) ** 2;
},
$numbers
));
}
}

View File

@ -0,0 +1,27 @@
<?php
declare(strict_types=1);
namespace Phpml\Math\Statistic;
/**
* In probability theory and statistics, variance is the expectation of the squared deviation of a random variable from its mean.
* Informally, it measures how far a set of (random) numbers are spread out from their average value
* https://en.wikipedia.org/wiki/Variance
*/
final class Variance
{
/**
* Population variance
* Use when all possible observations of the system are present.
* If used with a subset of data (sample variance), it will be a biased variance.
*
* ∑⟮xᵢ - μ⟯²
* σ² = ----------
* N
*/
public static function population(array $population): float
{
return StandardDeviation::sumOfSquares($population) / count($population);
}
}

View File

@ -0,0 +1,39 @@
<?php
declare(strict_types=1);
namespace Phpml\Tests\FeatureSelection;
use Phpml\Exception\InvalidArgumentException;
use Phpml\FeatureSelection\VarianceThreshold;
use PHPUnit\Framework\TestCase;
final class VarianceThresholdTest extends TestCase
{
public function testVarianceThreshold(): void
{
$samples = [[0, 0, 1], [0, 1, 0], [1, 0, 0], [0, 1, 1], [0, 1, 0], [0, 1, 1]];
$transformer = new VarianceThreshold(0.8 * (1 - 0.8)); // 80% of samples - boolean features are Bernoulli random variables
$transformer->fit($samples);
$transformer->transform($samples);
// expecting to remove first column
self::assertEquals([[0, 1], [1, 0], [0, 0], [1, 1], [1, 0], [1, 1]], $samples);
}
public function testVarianceThresholdWithZeroThreshold(): void
{
$samples = [[0, 2, 0, 3], [0, 1, 4, 3], [0, 1, 1, 3]];
$transformer = new VarianceThreshold();
$transformer->fit($samples);
$transformer->transform($samples);
self::assertEquals([[2, 0], [1, 4], [1, 1]], $samples);
}
public function testThrowExceptionWhenThresholdBelowZero(): void
{
$this->expectException(InvalidArgumentException::class);
new VarianceThreshold(-0.1);
}
}

View File

@ -37,4 +37,29 @@ class StandardDeviationTest extends TestCase
$this->expectException(InvalidArgumentException::class);
StandardDeviation::population([1]);
}
/**
* @dataProvider dataProviderForSumOfSquaresDeviations
*/
public function testSumOfSquares(array $numbers, float $sum): void
{
self::assertEquals($sum, StandardDeviation::sumOfSquares($numbers), '', 0.0001);
}
public function dataProviderForSumOfSquaresDeviations(): array
{
return [
[[3, 6, 7, 11, 12, 13, 17], 136.8571],
[[6, 11, 12, 14, 15, 20, 21], 162.8571],
[[1, 2, 3, 6, 7, 11, 12], 112],
[[1, 2, 3, 4, 5, 6, 7, 8, 9, 0], 82.5],
[[34, 253, 754, 2342, 75, 23, 876, 4, 1, -34, -345, 754, -377, 3, 0], 6453975.7333],
];
}
public function testThrowExceptionOnEmptyArraySumOfSquares(): void
{
$this->expectException(InvalidArgumentException::class);
StandardDeviation::sumOfSquares([]);
}
}

View File

@ -0,0 +1,34 @@
<?php
declare(strict_types=1);
namespace Phpml\Tests\Math\Statistic;
use Phpml\Math\Statistic\Variance;
use PHPUnit\Framework\TestCase;
final class VarianceTest extends TestCase
{
/**
* @dataProvider dataProviderForPopulationVariance
*/
public function testVarianceFromInt(array $numbers, float $variance): void
{
self::assertEquals($variance, Variance::population($numbers), '', 0.001);
}
public function dataProviderForPopulationVariance()
{
return [
[[0, 0, 0, 0, 0, 1], 0.138],
[[-11, 0, 10, 20, 30], 208.16],
[[7, 8, 9, 10, 11, 12, 13], 4.0],
[[300, 570, 170, 730, 300], 41944],
[[-4, 2, 7, 8, 3], 18.16],
[[3, 7, 34, 25, 46, 7754, 3, 6], 6546331.937],
[[4, 6, 1, 1, 1, 1, 2, 2, 1, 3], 2.56],
[[-3732, 5, 27, 9248, -174], 18741676.56],
[[-554, -555, -554, -554, -555, -555, -556], 0.4897],
];
}
}