diff --git a/README.md b/README.md index 874d0d2..595714d 100644 --- a/README.md +++ b/README.md @@ -87,6 +87,8 @@ Example scripts are available in a separate repository [php-ai/php-ml-examples]( * Cross Validation * [Random Split](http://php-ml.readthedocs.io/en/latest/machine-learning/cross-validation/random-split/) * [Stratified Random Split](http://php-ml.readthedocs.io/en/latest/machine-learning/cross-validation/stratified-random-split/) +* Feature Selection + * [Variance Threshold](http://php-ml.readthedocs.io/en/latest/machine-learning/feature-selection/variance-threshold/) * Preprocessing * [Normalization](http://php-ml.readthedocs.io/en/latest/machine-learning/preprocessing/normalization/) * [Imputation missing values](http://php-ml.readthedocs.io/en/latest/machine-learning/preprocessing/imputation-missing-values/) diff --git a/docs/index.md b/docs/index.md index f817b0a..eb50563 100644 --- a/docs/index.md +++ b/docs/index.md @@ -76,6 +76,8 @@ Example scripts are available in a separate repository [php-ai/php-ml-examples]( * Cross Validation * [Random Split](machine-learning/cross-validation/random-split.md) * [Stratified Random Split](machine-learning/cross-validation/stratified-random-split.md) +* Feature Selection + * [Variance Threshold](machine-learning/feature-selection/variance-threshold.md) * Preprocessing * [Normalization](machine-learning/preprocessing/normalization.md) * [Imputation missing values](machine-learning/preprocessing/imputation-missing-values.md) diff --git a/docs/machine-learning/feature-selection/variance-threshold.md b/docs/machine-learning/feature-selection/variance-threshold.md new file mode 100644 index 0000000..9c942e7 --- /dev/null +++ b/docs/machine-learning/feature-selection/variance-threshold.md @@ -0,0 +1,60 @@ +# Variance Threshold + +`VarianceThreshold` is a simple baseline approach to feature selection. +It removes all features whose variance doesn’t meet some threshold. +By default, it removes all zero-variance features, i.e. features that have the same value in all samples. + +## Constructor Parameters + +* $threshold (float) - features with a variance lower than this threshold will be removed (default 0.0) + +```php +use Phpml\FeatureSelection\VarianceThreshold; + +$transformer = new VarianceThreshold(0.15); +``` + +## Example of use + +As an example, suppose that we have a dataset with boolean features and +we want to remove all features that are either one or zero (on or off) +in more than 80% of the samples. +Boolean features are Bernoulli random variables, and the variance of such +variables is given by +``` +Var[X] = p(1 - p) +``` +so we can select using the threshold .8 * (1 - .8): + +```php +use Phpml\FeatureSelection\VarianceThreshold; + +$samples = [[0, 0, 1], [0, 1, 0], [1, 0, 0], [0, 1, 1], [0, 1, 0], [0, 1, 1]]; +$transformer = new VarianceThreshold(0.8 * (1 - 0.8)); + +$transformer->fit($samples); +$transformer->transform($samples); + +/* +$samples = [[0, 1], [1, 0], [0, 0], [1, 1], [1, 0], [1, 1]]; +*/ +``` + +## Pipeline + +`VarianceThreshold` implements `Transformer` interface so it can be used as part of pipeline: + +```php +use Phpml\FeatureSelection\VarianceThreshold; +use Phpml\Classification\SVC; +use Phpml\FeatureExtraction\TfIdfTransformer; +use Phpml\Pipeline; + +$transformers = [ + new TfIdfTransformer(), + new VarianceThreshold(0.1) +]; +$estimator = new SVC(); + +$pipeline = new Pipeline($transformers, $estimator); +``` diff --git a/mkdocs.yml b/mkdocs.yml index 8c9c10c..f794320 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -25,6 +25,8 @@ pages: - Cross Validation: - RandomSplit: machine-learning/cross-validation/random-split.md - Stratified Random Split: machine-learning/cross-validation/stratified-random-split.md + - Feature Selection: + - VarianceThreshold: machine-learning/feature-selection/variance-threshold.md - Preprocessing: - Normalization: machine-learning/preprocessing/normalization.md - Imputation missing values: machine-learning/preprocessing/imputation-missing-values.md diff --git a/src/FeatureSelection/VarianceThreshold.php b/src/FeatureSelection/VarianceThreshold.php new file mode 100644 index 0000000..6a3d639 --- /dev/null +++ b/src/FeatureSelection/VarianceThreshold.php @@ -0,0 +1,59 @@ +threshold = $threshold; + $this->variances = []; + $this->keepColumns = []; + } + + public function fit(array $samples): void + { + $this->variances = array_map(function (array $column) { + return Variance::population($column); + }, Matrix::transposeArray($samples)); + + foreach ($this->variances as $column => $variance) { + if ($variance > $this->threshold) { + $this->keepColumns[$column] = true; + } + } + } + + public function transform(array &$samples): void + { + foreach ($samples as &$sample) { + $sample = array_values(array_intersect_key($sample, $this->keepColumns)); + } + } +} diff --git a/src/Math/Statistic/StandardDeviation.php b/src/Math/Statistic/StandardDeviation.php index 8a0d241..426e4fd 100644 --- a/src/Math/Statistic/StandardDeviation.php +++ b/src/Math/Statistic/StandardDeviation.php @@ -9,27 +9,24 @@ use Phpml\Exception\InvalidArgumentException; class StandardDeviation { /** - * @param array|float[] $a - * - * @throws InvalidArgumentException + * @param array|float[]|int[] $numbers */ - public static function population(array $a, bool $sample = true): float + public static function population(array $numbers, bool $sample = true): float { - if (empty($a)) { + if (empty($numbers)) { throw InvalidArgumentException::arrayCantBeEmpty(); } - $n = count($a); + $n = count($numbers); if ($sample && $n === 1) { throw InvalidArgumentException::arraySizeToSmall(2); } - $mean = Mean::arithmetic($a); + $mean = Mean::arithmetic($numbers); $carry = 0.0; - foreach ($a as $val) { - $d = $val - $mean; - $carry += $d * $d; + foreach ($numbers as $val) { + $carry += ($val - $mean) ** 2; } if ($sample) { @@ -38,4 +35,26 @@ class StandardDeviation return sqrt((float) ($carry / $n)); } + + /** + * Sum of squares deviations + * ∑⟮xᵢ - μ⟯² + * + * @param array|float[]|int[] $numbers + */ + public static function sumOfSquares(array $numbers): float + { + if (empty($numbers)) { + throw InvalidArgumentException::arrayCantBeEmpty(); + } + + $mean = Mean::arithmetic($numbers); + + return array_sum(array_map( + function ($val) use ($mean) { + return ($val - $mean) ** 2; + }, + $numbers + )); + } } diff --git a/src/Math/Statistic/Variance.php b/src/Math/Statistic/Variance.php new file mode 100644 index 0000000..641cf00 --- /dev/null +++ b/src/Math/Statistic/Variance.php @@ -0,0 +1,27 @@ +fit($samples); + $transformer->transform($samples); + + // expecting to remove first column + self::assertEquals([[0, 1], [1, 0], [0, 0], [1, 1], [1, 0], [1, 1]], $samples); + } + + public function testVarianceThresholdWithZeroThreshold(): void + { + $samples = [[0, 2, 0, 3], [0, 1, 4, 3], [0, 1, 1, 3]]; + $transformer = new VarianceThreshold(); + $transformer->fit($samples); + $transformer->transform($samples); + + self::assertEquals([[2, 0], [1, 4], [1, 1]], $samples); + } + + public function testThrowExceptionWhenThresholdBelowZero(): void + { + $this->expectException(InvalidArgumentException::class); + new VarianceThreshold(-0.1); + } +} diff --git a/tests/Math/Statistic/StandardDeviationTest.php b/tests/Math/Statistic/StandardDeviationTest.php index 8333740..51c2770 100644 --- a/tests/Math/Statistic/StandardDeviationTest.php +++ b/tests/Math/Statistic/StandardDeviationTest.php @@ -37,4 +37,29 @@ class StandardDeviationTest extends TestCase $this->expectException(InvalidArgumentException::class); StandardDeviation::population([1]); } + + /** + * @dataProvider dataProviderForSumOfSquaresDeviations + */ + public function testSumOfSquares(array $numbers, float $sum): void + { + self::assertEquals($sum, StandardDeviation::sumOfSquares($numbers), '', 0.0001); + } + + public function dataProviderForSumOfSquaresDeviations(): array + { + return [ + [[3, 6, 7, 11, 12, 13, 17], 136.8571], + [[6, 11, 12, 14, 15, 20, 21], 162.8571], + [[1, 2, 3, 6, 7, 11, 12], 112], + [[1, 2, 3, 4, 5, 6, 7, 8, 9, 0], 82.5], + [[34, 253, 754, 2342, 75, 23, 876, 4, 1, -34, -345, 754, -377, 3, 0], 6453975.7333], + ]; + } + + public function testThrowExceptionOnEmptyArraySumOfSquares(): void + { + $this->expectException(InvalidArgumentException::class); + StandardDeviation::sumOfSquares([]); + } } diff --git a/tests/Math/Statistic/VarianceTest.php b/tests/Math/Statistic/VarianceTest.php new file mode 100644 index 0000000..19b2cd8 --- /dev/null +++ b/tests/Math/Statistic/VarianceTest.php @@ -0,0 +1,34 @@ +