mirror of
https://github.com/Llewellynvdm/php-ml.git
synced 2024-11-21 20:45:10 +00:00
Fix DecisionTreeRegressor for big dataset (#376)
This commit is contained in:
parent
91812f4c4a
commit
1e1d794655
@ -79,6 +79,7 @@ Public datasets are available in a separate repository [php-ai/php-ml-datasets](
|
|||||||
* Regression
|
* Regression
|
||||||
* [Least Squares](http://php-ml.readthedocs.io/en/latest/machine-learning/regression/least-squares/)
|
* [Least Squares](http://php-ml.readthedocs.io/en/latest/machine-learning/regression/least-squares/)
|
||||||
* [SVR](http://php-ml.readthedocs.io/en/latest/machine-learning/regression/svr/)
|
* [SVR](http://php-ml.readthedocs.io/en/latest/machine-learning/regression/svr/)
|
||||||
|
* DecisionTreeRegressor
|
||||||
* Clustering
|
* Clustering
|
||||||
* [k-Means](http://php-ml.readthedocs.io/en/latest/machine-learning/clustering/k-means/)
|
* [k-Means](http://php-ml.readthedocs.io/en/latest/machine-learning/clustering/k-means/)
|
||||||
* [DBSCAN](http://php-ml.readthedocs.io/en/latest/machine-learning/clustering/dbscan/)
|
* [DBSCAN](http://php-ml.readthedocs.io/en/latest/machine-learning/clustering/dbscan/)
|
||||||
@ -87,6 +88,7 @@ Public datasets are available in a separate repository [php-ai/php-ml-datasets](
|
|||||||
* [Accuracy](http://php-ml.readthedocs.io/en/latest/machine-learning/metric/accuracy/)
|
* [Accuracy](http://php-ml.readthedocs.io/en/latest/machine-learning/metric/accuracy/)
|
||||||
* [Confusion Matrix](http://php-ml.readthedocs.io/en/latest/machine-learning/metric/confusion-matrix/)
|
* [Confusion Matrix](http://php-ml.readthedocs.io/en/latest/machine-learning/metric/confusion-matrix/)
|
||||||
* [Classification Report](http://php-ml.readthedocs.io/en/latest/machine-learning/metric/classification-report/)
|
* [Classification Report](http://php-ml.readthedocs.io/en/latest/machine-learning/metric/classification-report/)
|
||||||
|
* Regression
|
||||||
* Workflow
|
* Workflow
|
||||||
* [Pipeline](http://php-ml.readthedocs.io/en/latest/machine-learning/workflow/pipeline)
|
* [Pipeline](http://php-ml.readthedocs.io/en/latest/machine-learning/workflow/pipeline)
|
||||||
* Neural Network
|
* Neural Network
|
||||||
|
@ -4,6 +4,7 @@ declare(strict_types=1);
|
|||||||
|
|
||||||
namespace Phpml\Regression;
|
namespace Phpml\Regression;
|
||||||
|
|
||||||
|
use Phpml\Exception\InvalidArgumentException;
|
||||||
use Phpml\Exception\InvalidOperationException;
|
use Phpml\Exception\InvalidOperationException;
|
||||||
use Phpml\Math\Statistic\Mean;
|
use Phpml\Math\Statistic\Mean;
|
||||||
use Phpml\Math\Statistic\Variance;
|
use Phpml\Math\Statistic\Variance;
|
||||||
@ -29,6 +30,27 @@ final class DecisionTreeRegressor extends CART implements Regression
|
|||||||
*/
|
*/
|
||||||
protected $columns = [];
|
protected $columns = [];
|
||||||
|
|
||||||
|
public function __construct(
|
||||||
|
int $maxDepth = PHP_INT_MAX,
|
||||||
|
int $maxLeafSize = 3,
|
||||||
|
float $minPurityIncrease = 0.,
|
||||||
|
?int $maxFeatures = null,
|
||||||
|
float $tolerance = 1e-4
|
||||||
|
) {
|
||||||
|
if ($maxFeatures !== null && $maxFeatures < 1) {
|
||||||
|
throw new InvalidArgumentException('Max features must be greater than 0');
|
||||||
|
}
|
||||||
|
|
||||||
|
if ($tolerance < 0.) {
|
||||||
|
throw new InvalidArgumentException('Tolerance must be equal or greater than 0');
|
||||||
|
}
|
||||||
|
|
||||||
|
$this->maxFeatures = $maxFeatures;
|
||||||
|
$this->tolerance = $tolerance;
|
||||||
|
|
||||||
|
parent::__construct($maxDepth, $maxLeafSize, $minPurityIncrease);
|
||||||
|
}
|
||||||
|
|
||||||
public function train(array $samples, array $targets): void
|
public function train(array $samples, array $targets): void
|
||||||
{
|
{
|
||||||
$features = count($samples[0]);
|
$features = count($samples[0]);
|
||||||
|
@ -91,7 +91,7 @@ abstract class CART
|
|||||||
|
|
||||||
$depth++;
|
$depth++;
|
||||||
|
|
||||||
if ($left === [] || $right === []) {
|
if ($left[1] === [] || $right[1] === []) {
|
||||||
$node = $this->terminate(array_merge($left[1], $right[1]));
|
$node = $this->terminate(array_merge($left[1], $right[1]));
|
||||||
|
|
||||||
$current->attachLeft($node);
|
$current->attachLeft($node);
|
||||||
|
@ -4,6 +4,7 @@ declare(strict_types=1);
|
|||||||
|
|
||||||
namespace Phpml\Tests\Regression;
|
namespace Phpml\Tests\Regression;
|
||||||
|
|
||||||
|
use Phpml\Exception\InvalidArgumentException;
|
||||||
use Phpml\Exception\InvalidOperationException;
|
use Phpml\Exception\InvalidOperationException;
|
||||||
use Phpml\ModelManager;
|
use Phpml\ModelManager;
|
||||||
use Phpml\Regression\DecisionTreeRegressor;
|
use Phpml\Regression\DecisionTreeRegressor;
|
||||||
@ -45,6 +46,20 @@ class DecisionTreeRegressorTest extends TestCase
|
|||||||
$regression->predict([[1]]);
|
$regression->predict([[1]]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public function testMaxFeaturesLowerThanOne(): void
|
||||||
|
{
|
||||||
|
$this->expectException(InvalidArgumentException::class);
|
||||||
|
|
||||||
|
new DecisionTreeRegressor(5, 3, 0.0, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
public function testToleranceSmallerThanZero(): void
|
||||||
|
{
|
||||||
|
$this->expectException(InvalidArgumentException::class);
|
||||||
|
|
||||||
|
new DecisionTreeRegressor(5, 3, 0.0, 20, -1);
|
||||||
|
}
|
||||||
|
|
||||||
public function testSaveAndRestore(): void
|
public function testSaveAndRestore(): void
|
||||||
{
|
{
|
||||||
$samples = [[60], [61], [62], [63], [65]];
|
$samples = [[60], [61], [62], [63], [65]];
|
||||||
|
Loading…
Reference in New Issue
Block a user