mirror of
https://github.com/Llewellynvdm/php-ml.git
synced 2025-04-09 11:41:50 +00:00
Add RandomForest exception tests (#251)
This commit is contained in:
parent
8976047cbc
commit
941d240ab6
@ -29,12 +29,12 @@ class DecisionTreeLeaf
|
|||||||
public $columnIndex;
|
public $columnIndex;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @var ?DecisionTreeLeaf
|
* @var DecisionTreeLeaf|null
|
||||||
*/
|
*/
|
||||||
public $leftLeaf;
|
public $leftLeaf;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @var ?DecisionTreeLeaf
|
* @var DecisionTreeLeaf|null
|
||||||
*/
|
*/
|
||||||
public $rightLeaf;
|
public $rightLeaf;
|
||||||
|
|
||||||
|
@ -4,9 +4,9 @@ declare(strict_types=1);
|
|||||||
|
|
||||||
namespace Phpml\Classification\Ensemble;
|
namespace Phpml\Classification\Ensemble;
|
||||||
|
|
||||||
use Exception;
|
|
||||||
use Phpml\Classification\Classifier;
|
use Phpml\Classification\Classifier;
|
||||||
use Phpml\Classification\DecisionTree;
|
use Phpml\Classification\DecisionTree;
|
||||||
|
use Phpml\Exception\InvalidArgumentException;
|
||||||
|
|
||||||
class RandomForest extends Bagging
|
class RandomForest extends Bagging
|
||||||
{
|
{
|
||||||
@ -41,20 +41,20 @@ class RandomForest extends Bagging
|
|||||||
* Default value for the ratio is 'log' which results in log(numFeatures, 2) + 1
|
* Default value for the ratio is 'log' which results in log(numFeatures, 2) + 1
|
||||||
* features to be taken into consideration while selecting subspace of features
|
* features to be taken into consideration while selecting subspace of features
|
||||||
*
|
*
|
||||||
* @param mixed $ratio string or float should be given
|
* @param string|float $ratio
|
||||||
*
|
|
||||||
* @return $this
|
|
||||||
*
|
|
||||||
* @throws \Exception
|
|
||||||
*/
|
*/
|
||||||
public function setFeatureSubsetRatio($ratio)
|
public function setFeatureSubsetRatio($ratio): self
|
||||||
{
|
{
|
||||||
|
if (!is_string($ratio) && !is_float($ratio)) {
|
||||||
|
throw new InvalidArgumentException('Feature subset ratio must be a string or a float');
|
||||||
|
}
|
||||||
|
|
||||||
if (is_float($ratio) && ($ratio < 0.1 || $ratio > 1.0)) {
|
if (is_float($ratio) && ($ratio < 0.1 || $ratio > 1.0)) {
|
||||||
throw new Exception('When a float given, feature subset ratio should be between 0.1 and 1.0');
|
throw new InvalidArgumentException('When a float is given, feature subset ratio should be between 0.1 and 1.0');
|
||||||
}
|
}
|
||||||
|
|
||||||
if (is_string($ratio) && $ratio != 'sqrt' && $ratio != 'log') {
|
if (is_string($ratio) && $ratio != 'sqrt' && $ratio != 'log') {
|
||||||
throw new Exception("When a string given, feature subset ratio can only be 'sqrt' or 'log' ");
|
throw new InvalidArgumentException("When a string is given, feature subset ratio can only be 'sqrt' or 'log'");
|
||||||
}
|
}
|
||||||
|
|
||||||
$this->featureSubsetRatio = $ratio;
|
$this->featureSubsetRatio = $ratio;
|
||||||
@ -66,13 +66,11 @@ class RandomForest extends Bagging
|
|||||||
* RandomForest algorithm is usable *only* with DecisionTree
|
* RandomForest algorithm is usable *only* with DecisionTree
|
||||||
*
|
*
|
||||||
* @return $this
|
* @return $this
|
||||||
*
|
|
||||||
* @throws \Exception
|
|
||||||
*/
|
*/
|
||||||
public function setClassifer(string $classifier, array $classifierOptions = [])
|
public function setClassifer(string $classifier, array $classifierOptions = [])
|
||||||
{
|
{
|
||||||
if ($classifier != DecisionTree::class) {
|
if ($classifier != DecisionTree::class) {
|
||||||
throw new Exception('RandomForest can only use DecisionTree as base classifier');
|
throw new InvalidArgumentException('RandomForest can only use DecisionTree as base classifier');
|
||||||
}
|
}
|
||||||
|
|
||||||
return parent::setClassifer($classifier, $classifierOptions);
|
return parent::setClassifer($classifier, $classifierOptions);
|
||||||
@ -133,7 +131,7 @@ class RandomForest extends Bagging
|
|||||||
{
|
{
|
||||||
if (is_float($this->featureSubsetRatio)) {
|
if (is_float($this->featureSubsetRatio)) {
|
||||||
$featureCount = (int) ($this->featureSubsetRatio * $this->featureCount);
|
$featureCount = (int) ($this->featureSubsetRatio * $this->featureCount);
|
||||||
} elseif ($this->featureCount == 'sqrt') {
|
} elseif ($this->featureSubsetRatio == 'sqrt') {
|
||||||
$featureCount = (int) sqrt($this->featureCount) + 1;
|
$featureCount = (int) sqrt($this->featureCount) + 1;
|
||||||
} else {
|
} else {
|
||||||
$featureCount = (int) log($this->featureCount, 2) + 1;
|
$featureCount = (int) log($this->featureCount, 2) + 1;
|
||||||
|
@ -7,19 +7,44 @@ namespace Phpml\Tests\Classification\Ensemble;
|
|||||||
use Phpml\Classification\DecisionTree;
|
use Phpml\Classification\DecisionTree;
|
||||||
use Phpml\Classification\Ensemble\RandomForest;
|
use Phpml\Classification\Ensemble\RandomForest;
|
||||||
use Phpml\Classification\NaiveBayes;
|
use Phpml\Classification\NaiveBayes;
|
||||||
use Throwable;
|
use Phpml\Exception\InvalidArgumentException;
|
||||||
|
|
||||||
class RandomForestTest extends BaggingTest
|
class RandomForestTest extends BaggingTest
|
||||||
{
|
{
|
||||||
public function testOtherBaseClassifier(): void
|
public function testThrowExceptionWithInvalidClassifier(): void
|
||||||
{
|
{
|
||||||
try {
|
$this->expectException(InvalidArgumentException::class);
|
||||||
|
$this->expectExceptionMessage('RandomForest can only use DecisionTree as base classifier');
|
||||||
|
|
||||||
$classifier = new RandomForest();
|
$classifier = new RandomForest();
|
||||||
$classifier->setClassifer(NaiveBayes::class);
|
$classifier->setClassifer(NaiveBayes::class);
|
||||||
$this->assertEquals(0, 1);
|
|
||||||
} catch (Throwable $ex) {
|
|
||||||
$this->assertEquals(1, 1);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public function testThrowExceptionWithInvalidFeatureSubsetRatioType(): void
|
||||||
|
{
|
||||||
|
$this->expectException(InvalidArgumentException::class);
|
||||||
|
$this->expectExceptionMessage('Feature subset ratio must be a string or a float');
|
||||||
|
|
||||||
|
$classifier = new RandomForest();
|
||||||
|
$classifier->setFeatureSubsetRatio(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
public function testThrowExceptionWithInvalidFeatureSubsetRatioFloat(): void
|
||||||
|
{
|
||||||
|
$this->expectException(InvalidArgumentException::class);
|
||||||
|
$this->expectExceptionMessage('When a float is given, feature subset ratio should be between 0.1 and 1.0');
|
||||||
|
|
||||||
|
$classifier = new RandomForest();
|
||||||
|
$classifier->setFeatureSubsetRatio(1.1);
|
||||||
|
}
|
||||||
|
|
||||||
|
public function testThrowExceptionWithInvalidFeatureSubsetRatioString(): void
|
||||||
|
{
|
||||||
|
$this->expectException(InvalidArgumentException::class);
|
||||||
|
$this->expectExceptionMessage("When a string is given, feature subset ratio can only be 'sqrt' or 'log'");
|
||||||
|
|
||||||
|
$classifier = new RandomForest();
|
||||||
|
$classifier->setFeatureSubsetRatio('pow');
|
||||||
}
|
}
|
||||||
|
|
||||||
protected function getClassifier($numBaseClassifiers = 50)
|
protected function getClassifier($numBaseClassifiers = 50)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user