mirror of
https://github.com/Llewellynvdm/php-ml.git
synced 2024-11-21 20:45:10 +00:00
Fix DecisionTreeRegressor for big dataset (#376)
This commit is contained in:
parent
91812f4c4a
commit
1e1d794655
@ -79,6 +79,7 @@ Public datasets are available in a separate repository [php-ai/php-ml-datasets](
|
||||
* Regression
|
||||
* [Least Squares](http://php-ml.readthedocs.io/en/latest/machine-learning/regression/least-squares/)
|
||||
* [SVR](http://php-ml.readthedocs.io/en/latest/machine-learning/regression/svr/)
|
||||
* DecisionTreeRegressor
|
||||
* Clustering
|
||||
* [k-Means](http://php-ml.readthedocs.io/en/latest/machine-learning/clustering/k-means/)
|
||||
* [DBSCAN](http://php-ml.readthedocs.io/en/latest/machine-learning/clustering/dbscan/)
|
||||
@ -87,6 +88,7 @@ Public datasets are available in a separate repository [php-ai/php-ml-datasets](
|
||||
* [Accuracy](http://php-ml.readthedocs.io/en/latest/machine-learning/metric/accuracy/)
|
||||
* [Confusion Matrix](http://php-ml.readthedocs.io/en/latest/machine-learning/metric/confusion-matrix/)
|
||||
* [Classification Report](http://php-ml.readthedocs.io/en/latest/machine-learning/metric/classification-report/)
|
||||
* Regression
|
||||
* Workflow
|
||||
* [Pipeline](http://php-ml.readthedocs.io/en/latest/machine-learning/workflow/pipeline)
|
||||
* Neural Network
|
||||
|
@ -4,6 +4,7 @@ declare(strict_types=1);
|
||||
|
||||
namespace Phpml\Regression;
|
||||
|
||||
use Phpml\Exception\InvalidArgumentException;
|
||||
use Phpml\Exception\InvalidOperationException;
|
||||
use Phpml\Math\Statistic\Mean;
|
||||
use Phpml\Math\Statistic\Variance;
|
||||
@ -29,6 +30,27 @@ final class DecisionTreeRegressor extends CART implements Regression
|
||||
*/
|
||||
protected $columns = [];
|
||||
|
||||
public function __construct(
|
||||
int $maxDepth = PHP_INT_MAX,
|
||||
int $maxLeafSize = 3,
|
||||
float $minPurityIncrease = 0.,
|
||||
?int $maxFeatures = null,
|
||||
float $tolerance = 1e-4
|
||||
) {
|
||||
if ($maxFeatures !== null && $maxFeatures < 1) {
|
||||
throw new InvalidArgumentException('Max features must be greater than 0');
|
||||
}
|
||||
|
||||
if ($tolerance < 0.) {
|
||||
throw new InvalidArgumentException('Tolerance must be equal or greater than 0');
|
||||
}
|
||||
|
||||
$this->maxFeatures = $maxFeatures;
|
||||
$this->tolerance = $tolerance;
|
||||
|
||||
parent::__construct($maxDepth, $maxLeafSize, $minPurityIncrease);
|
||||
}
|
||||
|
||||
public function train(array $samples, array $targets): void
|
||||
{
|
||||
$features = count($samples[0]);
|
||||
|
@ -91,7 +91,7 @@ abstract class CART
|
||||
|
||||
$depth++;
|
||||
|
||||
if ($left === [] || $right === []) {
|
||||
if ($left[1] === [] || $right[1] === []) {
|
||||
$node = $this->terminate(array_merge($left[1], $right[1]));
|
||||
|
||||
$current->attachLeft($node);
|
||||
|
@ -4,6 +4,7 @@ declare(strict_types=1);
|
||||
|
||||
namespace Phpml\Tests\Regression;
|
||||
|
||||
use Phpml\Exception\InvalidArgumentException;
|
||||
use Phpml\Exception\InvalidOperationException;
|
||||
use Phpml\ModelManager;
|
||||
use Phpml\Regression\DecisionTreeRegressor;
|
||||
@ -45,6 +46,20 @@ class DecisionTreeRegressorTest extends TestCase
|
||||
$regression->predict([[1]]);
|
||||
}
|
||||
|
||||
public function testMaxFeaturesLowerThanOne(): void
|
||||
{
|
||||
$this->expectException(InvalidArgumentException::class);
|
||||
|
||||
new DecisionTreeRegressor(5, 3, 0.0, 0);
|
||||
}
|
||||
|
||||
public function testToleranceSmallerThanZero(): void
|
||||
{
|
||||
$this->expectException(InvalidArgumentException::class);
|
||||
|
||||
new DecisionTreeRegressor(5, 3, 0.0, 20, -1);
|
||||
}
|
||||
|
||||
public function testSaveAndRestore(): void
|
||||
{
|
||||
$samples = [[60], [61], [62], [63], [65]];
|
||||
|
Loading…
Reference in New Issue
Block a user