Add RandomForest exception tests (#251)

This commit is contained in:
Marcin Michalski 2018-03-04 17:02:36 +01:00 committed by Arkadiusz Kondas
parent 8976047cbc
commit 941d240ab6
3 changed files with 47 additions and 24 deletions

View File

@ -29,12 +29,12 @@ class DecisionTreeLeaf
public $columnIndex;
/**
* @var ?DecisionTreeLeaf
* @var DecisionTreeLeaf|null
*/
public $leftLeaf;
/**
* @var ?DecisionTreeLeaf
* @var DecisionTreeLeaf|null
*/
public $rightLeaf;

View File

@ -4,9 +4,9 @@ declare(strict_types=1);
namespace Phpml\Classification\Ensemble;
use Exception;
use Phpml\Classification\Classifier;
use Phpml\Classification\DecisionTree;
use Phpml\Exception\InvalidArgumentException;
class RandomForest extends Bagging
{
@ -41,20 +41,20 @@ class RandomForest extends Bagging
* Default value for the ratio is 'log' which results in log(numFeatures, 2) + 1
* features to be taken into consideration while selecting subspace of features
*
* @param mixed $ratio string or float should be given
*
* @return $this
*
* @throws \Exception
* @param string|float $ratio
*/
public function setFeatureSubsetRatio($ratio)
public function setFeatureSubsetRatio($ratio): self
{
if (!is_string($ratio) && !is_float($ratio)) {
throw new InvalidArgumentException('Feature subset ratio must be a string or a float');
}
if (is_float($ratio) && ($ratio < 0.1 || $ratio > 1.0)) {
throw new Exception('When a float given, feature subset ratio should be between 0.1 and 1.0');
throw new InvalidArgumentException('When a float is given, feature subset ratio should be between 0.1 and 1.0');
}
if (is_string($ratio) && $ratio != 'sqrt' && $ratio != 'log') {
throw new Exception("When a string given, feature subset ratio can only be 'sqrt' or 'log' ");
throw new InvalidArgumentException("When a string is given, feature subset ratio can only be 'sqrt' or 'log'");
}
$this->featureSubsetRatio = $ratio;
@ -66,13 +66,11 @@ class RandomForest extends Bagging
* RandomForest algorithm is usable *only* with DecisionTree
*
* @return $this
*
* @throws \Exception
*/
public function setClassifer(string $classifier, array $classifierOptions = [])
{
if ($classifier != DecisionTree::class) {
throw new Exception('RandomForest can only use DecisionTree as base classifier');
throw new InvalidArgumentException('RandomForest can only use DecisionTree as base classifier');
}
return parent::setClassifer($classifier, $classifierOptions);
@ -133,7 +131,7 @@ class RandomForest extends Bagging
{
if (is_float($this->featureSubsetRatio)) {
$featureCount = (int) ($this->featureSubsetRatio * $this->featureCount);
} elseif ($this->featureCount == 'sqrt') {
} elseif ($this->featureSubsetRatio == 'sqrt') {
$featureCount = (int) sqrt($this->featureCount) + 1;
} else {
$featureCount = (int) log($this->featureCount, 2) + 1;

View File

@ -7,19 +7,44 @@ namespace Phpml\Tests\Classification\Ensemble;
use Phpml\Classification\DecisionTree;
use Phpml\Classification\Ensemble\RandomForest;
use Phpml\Classification\NaiveBayes;
use Throwable;
use Phpml\Exception\InvalidArgumentException;
class RandomForestTest extends BaggingTest
{
public function testOtherBaseClassifier(): void
public function testThrowExceptionWithInvalidClassifier(): void
{
try {
$this->expectException(InvalidArgumentException::class);
$this->expectExceptionMessage('RandomForest can only use DecisionTree as base classifier');
$classifier = new RandomForest();
$classifier->setClassifer(NaiveBayes::class);
$this->assertEquals(0, 1);
} catch (Throwable $ex) {
$this->assertEquals(1, 1);
}
public function testThrowExceptionWithInvalidFeatureSubsetRatioType(): void
{
$this->expectException(InvalidArgumentException::class);
$this->expectExceptionMessage('Feature subset ratio must be a string or a float');
$classifier = new RandomForest();
$classifier->setFeatureSubsetRatio(1);
}
public function testThrowExceptionWithInvalidFeatureSubsetRatioFloat(): void
{
$this->expectException(InvalidArgumentException::class);
$this->expectExceptionMessage('When a float is given, feature subset ratio should be between 0.1 and 1.0');
$classifier = new RandomForest();
$classifier->setFeatureSubsetRatio(1.1);
}
public function testThrowExceptionWithInvalidFeatureSubsetRatioString(): void
{
$this->expectException(InvalidArgumentException::class);
$this->expectExceptionMessage("When a string is given, feature subset ratio can only be 'sqrt' or 'log'");
$classifier = new RandomForest();
$classifier->setFeatureSubsetRatio('pow');
}
protected function getClassifier($numBaseClassifiers = 50)