Add RandomForest exception tests (#251)

This commit is contained in:
Marcin Michalski 2018-03-04 17:02:36 +01:00 committed by Arkadiusz Kondas
parent 8976047cbc
commit 941d240ab6
3 changed files with 47 additions and 24 deletions

View File

@ -29,12 +29,12 @@ class DecisionTreeLeaf
public $columnIndex; public $columnIndex;
/** /**
* @var ?DecisionTreeLeaf * @var DecisionTreeLeaf|null
*/ */
public $leftLeaf; public $leftLeaf;
/** /**
* @var ?DecisionTreeLeaf * @var DecisionTreeLeaf|null
*/ */
public $rightLeaf; public $rightLeaf;

View File

@ -4,9 +4,9 @@ declare(strict_types=1);
namespace Phpml\Classification\Ensemble; namespace Phpml\Classification\Ensemble;
use Exception;
use Phpml\Classification\Classifier; use Phpml\Classification\Classifier;
use Phpml\Classification\DecisionTree; use Phpml\Classification\DecisionTree;
use Phpml\Exception\InvalidArgumentException;
class RandomForest extends Bagging class RandomForest extends Bagging
{ {
@ -41,20 +41,20 @@ class RandomForest extends Bagging
* Default value for the ratio is 'log' which results in log(numFeatures, 2) + 1 * Default value for the ratio is 'log' which results in log(numFeatures, 2) + 1
* features to be taken into consideration while selecting subspace of features * features to be taken into consideration while selecting subspace of features
* *
* @param mixed $ratio string or float should be given * @param string|float $ratio
*
* @return $this
*
* @throws \Exception
*/ */
public function setFeatureSubsetRatio($ratio) public function setFeatureSubsetRatio($ratio): self
{ {
if (!is_string($ratio) && !is_float($ratio)) {
throw new InvalidArgumentException('Feature subset ratio must be a string or a float');
}
if (is_float($ratio) && ($ratio < 0.1 || $ratio > 1.0)) { if (is_float($ratio) && ($ratio < 0.1 || $ratio > 1.0)) {
throw new Exception('When a float given, feature subset ratio should be between 0.1 and 1.0'); throw new InvalidArgumentException('When a float is given, feature subset ratio should be between 0.1 and 1.0');
} }
if (is_string($ratio) && $ratio != 'sqrt' && $ratio != 'log') { if (is_string($ratio) && $ratio != 'sqrt' && $ratio != 'log') {
throw new Exception("When a string given, feature subset ratio can only be 'sqrt' or 'log' "); throw new InvalidArgumentException("When a string is given, feature subset ratio can only be 'sqrt' or 'log'");
} }
$this->featureSubsetRatio = $ratio; $this->featureSubsetRatio = $ratio;
@ -66,13 +66,11 @@ class RandomForest extends Bagging
* RandomForest algorithm is usable *only* with DecisionTree * RandomForest algorithm is usable *only* with DecisionTree
* *
* @return $this * @return $this
*
* @throws \Exception
*/ */
public function setClassifer(string $classifier, array $classifierOptions = []) public function setClassifer(string $classifier, array $classifierOptions = [])
{ {
if ($classifier != DecisionTree::class) { if ($classifier != DecisionTree::class) {
throw new Exception('RandomForest can only use DecisionTree as base classifier'); throw new InvalidArgumentException('RandomForest can only use DecisionTree as base classifier');
} }
return parent::setClassifer($classifier, $classifierOptions); return parent::setClassifer($classifier, $classifierOptions);
@ -133,7 +131,7 @@ class RandomForest extends Bagging
{ {
if (is_float($this->featureSubsetRatio)) { if (is_float($this->featureSubsetRatio)) {
$featureCount = (int) ($this->featureSubsetRatio * $this->featureCount); $featureCount = (int) ($this->featureSubsetRatio * $this->featureCount);
} elseif ($this->featureCount == 'sqrt') { } elseif ($this->featureSubsetRatio == 'sqrt') {
$featureCount = (int) sqrt($this->featureCount) + 1; $featureCount = (int) sqrt($this->featureCount) + 1;
} else { } else {
$featureCount = (int) log($this->featureCount, 2) + 1; $featureCount = (int) log($this->featureCount, 2) + 1;

View File

@ -7,19 +7,44 @@ namespace Phpml\Tests\Classification\Ensemble;
use Phpml\Classification\DecisionTree; use Phpml\Classification\DecisionTree;
use Phpml\Classification\Ensemble\RandomForest; use Phpml\Classification\Ensemble\RandomForest;
use Phpml\Classification\NaiveBayes; use Phpml\Classification\NaiveBayes;
use Throwable; use Phpml\Exception\InvalidArgumentException;
class RandomForestTest extends BaggingTest class RandomForestTest extends BaggingTest
{ {
public function testOtherBaseClassifier(): void public function testThrowExceptionWithInvalidClassifier(): void
{ {
try { $this->expectException(InvalidArgumentException::class);
$this->expectExceptionMessage('RandomForest can only use DecisionTree as base classifier');
$classifier = new RandomForest(); $classifier = new RandomForest();
$classifier->setClassifer(NaiveBayes::class); $classifier->setClassifer(NaiveBayes::class);
$this->assertEquals(0, 1);
} catch (Throwable $ex) {
$this->assertEquals(1, 1);
} }
public function testThrowExceptionWithInvalidFeatureSubsetRatioType(): void
{
$this->expectException(InvalidArgumentException::class);
$this->expectExceptionMessage('Feature subset ratio must be a string or a float');
$classifier = new RandomForest();
$classifier->setFeatureSubsetRatio(1);
}
public function testThrowExceptionWithInvalidFeatureSubsetRatioFloat(): void
{
$this->expectException(InvalidArgumentException::class);
$this->expectExceptionMessage('When a float is given, feature subset ratio should be between 0.1 and 1.0');
$classifier = new RandomForest();
$classifier->setFeatureSubsetRatio(1.1);
}
public function testThrowExceptionWithInvalidFeatureSubsetRatioString(): void
{
$this->expectException(InvalidArgumentException::class);
$this->expectExceptionMessage("When a string is given, feature subset ratio can only be 'sqrt' or 'log'");
$classifier = new RandomForest();
$classifier->setFeatureSubsetRatio('pow');
} }
protected function getClassifier($numBaseClassifiers = 50) protected function getClassifier($numBaseClassifiers = 50)