Fix DecisionTreeRegressor for big dataset (#376)

This commit is contained in:
Arkadiusz Kondas 2019-05-12 21:27:21 +02:00 committed by GitHub
parent 91812f4c4a
commit 1e1d794655
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 40 additions and 1 deletions

View File

@ -79,6 +79,7 @@ Public datasets are available in a separate repository [php-ai/php-ml-datasets](
* Regression * Regression
* [Least Squares](http://php-ml.readthedocs.io/en/latest/machine-learning/regression/least-squares/) * [Least Squares](http://php-ml.readthedocs.io/en/latest/machine-learning/regression/least-squares/)
* [SVR](http://php-ml.readthedocs.io/en/latest/machine-learning/regression/svr/) * [SVR](http://php-ml.readthedocs.io/en/latest/machine-learning/regression/svr/)
* DecisionTreeRegressor
* Clustering * Clustering
* [k-Means](http://php-ml.readthedocs.io/en/latest/machine-learning/clustering/k-means/) * [k-Means](http://php-ml.readthedocs.io/en/latest/machine-learning/clustering/k-means/)
* [DBSCAN](http://php-ml.readthedocs.io/en/latest/machine-learning/clustering/dbscan/) * [DBSCAN](http://php-ml.readthedocs.io/en/latest/machine-learning/clustering/dbscan/)
@ -87,6 +88,7 @@ Public datasets are available in a separate repository [php-ai/php-ml-datasets](
* [Accuracy](http://php-ml.readthedocs.io/en/latest/machine-learning/metric/accuracy/) * [Accuracy](http://php-ml.readthedocs.io/en/latest/machine-learning/metric/accuracy/)
* [Confusion Matrix](http://php-ml.readthedocs.io/en/latest/machine-learning/metric/confusion-matrix/) * [Confusion Matrix](http://php-ml.readthedocs.io/en/latest/machine-learning/metric/confusion-matrix/)
* [Classification Report](http://php-ml.readthedocs.io/en/latest/machine-learning/metric/classification-report/) * [Classification Report](http://php-ml.readthedocs.io/en/latest/machine-learning/metric/classification-report/)
* Regression
* Workflow * Workflow
* [Pipeline](http://php-ml.readthedocs.io/en/latest/machine-learning/workflow/pipeline) * [Pipeline](http://php-ml.readthedocs.io/en/latest/machine-learning/workflow/pipeline)
* Neural Network * Neural Network

View File

@ -4,6 +4,7 @@ declare(strict_types=1);
namespace Phpml\Regression; namespace Phpml\Regression;
use Phpml\Exception\InvalidArgumentException;
use Phpml\Exception\InvalidOperationException; use Phpml\Exception\InvalidOperationException;
use Phpml\Math\Statistic\Mean; use Phpml\Math\Statistic\Mean;
use Phpml\Math\Statistic\Variance; use Phpml\Math\Statistic\Variance;
@ -29,6 +30,27 @@ final class DecisionTreeRegressor extends CART implements Regression
*/ */
protected $columns = []; protected $columns = [];
public function __construct(
int $maxDepth = PHP_INT_MAX,
int $maxLeafSize = 3,
float $minPurityIncrease = 0.,
?int $maxFeatures = null,
float $tolerance = 1e-4
) {
if ($maxFeatures !== null && $maxFeatures < 1) {
throw new InvalidArgumentException('Max features must be greater than 0');
}
if ($tolerance < 0.) {
throw new InvalidArgumentException('Tolerance must be equal or greater than 0');
}
$this->maxFeatures = $maxFeatures;
$this->tolerance = $tolerance;
parent::__construct($maxDepth, $maxLeafSize, $minPurityIncrease);
}
public function train(array $samples, array $targets): void public function train(array $samples, array $targets): void
{ {
$features = count($samples[0]); $features = count($samples[0]);

View File

@ -91,7 +91,7 @@ abstract class CART
$depth++; $depth++;
if ($left === [] || $right === []) { if ($left[1] === [] || $right[1] === []) {
$node = $this->terminate(array_merge($left[1], $right[1])); $node = $this->terminate(array_merge($left[1], $right[1]));
$current->attachLeft($node); $current->attachLeft($node);

View File

@ -4,6 +4,7 @@ declare(strict_types=1);
namespace Phpml\Tests\Regression; namespace Phpml\Tests\Regression;
use Phpml\Exception\InvalidArgumentException;
use Phpml\Exception\InvalidOperationException; use Phpml\Exception\InvalidOperationException;
use Phpml\ModelManager; use Phpml\ModelManager;
use Phpml\Regression\DecisionTreeRegressor; use Phpml\Regression\DecisionTreeRegressor;
@ -45,6 +46,20 @@ class DecisionTreeRegressorTest extends TestCase
$regression->predict([[1]]); $regression->predict([[1]]);
} }
public function testMaxFeaturesLowerThanOne(): void
{
$this->expectException(InvalidArgumentException::class);
new DecisionTreeRegressor(5, 3, 0.0, 0);
}
public function testToleranceSmallerThanZero(): void
{
$this->expectException(InvalidArgumentException::class);
new DecisionTreeRegressor(5, 3, 0.0, 20, -1);
}
public function testSaveAndRestore(): void public function testSaveAndRestore(): void
{ {
$samples = [[60], [61], [62], [63], [65]]; $samples = [[60], [61], [62], [63], [65]];