mirror of
https://github.com/Llewellynvdm/php-ml.git
synced 2025-01-09 16:36:34 +00:00
Comparison - replace eval (#130)
* Replace eval with strategy * Use Factory Pattern, add tests * Add missing dockblocks * Replace strategy with simple object
This commit is contained in:
parent
dda9e16b4c
commit
11d05ce89d
@ -4,6 +4,8 @@ declare(strict_types=1);
|
|||||||
|
|
||||||
namespace Phpml\Classification\DecisionTree;
|
namespace Phpml\Classification\DecisionTree;
|
||||||
|
|
||||||
|
use Phpml\Math\Comparison;
|
||||||
|
|
||||||
class DecisionTreeLeaf
|
class DecisionTreeLeaf
|
||||||
{
|
{
|
||||||
/**
|
/**
|
||||||
@ -79,12 +81,7 @@ class DecisionTreeLeaf
|
|||||||
$recordField = $record[$this->columnIndex];
|
$recordField = $record[$this->columnIndex];
|
||||||
|
|
||||||
if ($this->isContinuous) {
|
if ($this->isContinuous) {
|
||||||
$op = $this->operator;
|
return Comparison::compare((string) $recordField, $this->numericValue, $this->operator);
|
||||||
$value = $this->numericValue;
|
|
||||||
$recordField = (string) $recordField;
|
|
||||||
eval("\$result = $recordField $op $value;");
|
|
||||||
|
|
||||||
return $result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return $recordField == $this->value;
|
return $recordField == $this->value;
|
||||||
|
@ -8,6 +8,7 @@ use Phpml\Helper\Predictable;
|
|||||||
use Phpml\Helper\OneVsRest;
|
use Phpml\Helper\OneVsRest;
|
||||||
use Phpml\Classification\WeightedClassifier;
|
use Phpml\Classification\WeightedClassifier;
|
||||||
use Phpml\Classification\DecisionTree;
|
use Phpml\Classification\DecisionTree;
|
||||||
|
use Phpml\Math\Comparison;
|
||||||
|
|
||||||
class DecisionStump extends WeightedClassifier
|
class DecisionStump extends WeightedClassifier
|
||||||
{
|
{
|
||||||
@ -236,29 +237,6 @@ class DecisionStump extends WeightedClassifier
|
|||||||
return $split;
|
return $split;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param mixed $leftValue
|
|
||||||
* @param string $operator
|
|
||||||
* @param mixed $rightValue
|
|
||||||
*
|
|
||||||
* @return boolean
|
|
||||||
*/
|
|
||||||
protected function evaluate($leftValue, string $operator, $rightValue)
|
|
||||||
{
|
|
||||||
switch ($operator) {
|
|
||||||
case '>': return $leftValue > $rightValue;
|
|
||||||
case '>=': return $leftValue >= $rightValue;
|
|
||||||
case '<': return $leftValue < $rightValue;
|
|
||||||
case '<=': return $leftValue <= $rightValue;
|
|
||||||
case '=': return $leftValue === $rightValue;
|
|
||||||
case '!=':
|
|
||||||
case '<>': return $leftValue !== $rightValue;
|
|
||||||
}
|
|
||||||
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Calculates the ratio of wrong predictions based on the new threshold
|
* Calculates the ratio of wrong predictions based on the new threshold
|
||||||
* value given as the parameter
|
* value given as the parameter
|
||||||
@ -278,7 +256,7 @@ class DecisionStump extends WeightedClassifier
|
|||||||
$rightLabel = $this->binaryLabels[1];
|
$rightLabel = $this->binaryLabels[1];
|
||||||
|
|
||||||
foreach ($values as $index => $value) {
|
foreach ($values as $index => $value) {
|
||||||
if ($this->evaluate($value, $operator, $threshold)) {
|
if (Comparison::compare($value, $threshold, $operator)) {
|
||||||
$predicted = $leftLabel;
|
$predicted = $leftLabel;
|
||||||
} else {
|
} else {
|
||||||
$predicted = $rightLabel;
|
$predicted = $rightLabel;
|
||||||
@ -337,7 +315,7 @@ class DecisionStump extends WeightedClassifier
|
|||||||
*/
|
*/
|
||||||
protected function predictSampleBinary(array $sample)
|
protected function predictSampleBinary(array $sample)
|
||||||
{
|
{
|
||||||
if ($this->evaluate($sample[$this->column], $this->operator, $this->value)) {
|
if (Comparison::compare($sample[$this->column], $this->value, $this->operator)) {
|
||||||
return $this->binaryLabels[0];
|
return $this->binaryLabels[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -157,4 +157,14 @@ class InvalidArgumentException extends \Exception
|
|||||||
{
|
{
|
||||||
return new self(sprintf('The specified path "%s" is not writable', $path));
|
return new self(sprintf('The specified path "%s" is not writable', $path));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param string $operator
|
||||||
|
*
|
||||||
|
* @return InvalidArgumentException
|
||||||
|
*/
|
||||||
|
public static function invalidOperator(string $operator)
|
||||||
|
{
|
||||||
|
return new self(sprintf('Invalid operator "%s" provided', $operator));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -4,6 +4,8 @@ declare(strict_types=1);
|
|||||||
|
|
||||||
namespace Phpml\Helper;
|
namespace Phpml\Helper;
|
||||||
|
|
||||||
|
use Phpml\Classification\Classifier;
|
||||||
|
|
||||||
trait OneVsRest
|
trait OneVsRest
|
||||||
{
|
{
|
||||||
/**
|
/**
|
||||||
@ -100,7 +102,7 @@ trait OneVsRest
|
|||||||
/**
|
/**
|
||||||
* Returns an instance of the current class after cleaning up OneVsRest stuff.
|
* Returns an instance of the current class after cleaning up OneVsRest stuff.
|
||||||
*
|
*
|
||||||
* @return \Phpml\Estimator
|
* @return Classifier|OneVsRest
|
||||||
*/
|
*/
|
||||||
protected function getClassifierCopy()
|
protected function getClassifierCopy()
|
||||||
{
|
{
|
||||||
|
45
src/Phpml/Math/Comparison.php
Normal file
45
src/Phpml/Math/Comparison.php
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
<?php
|
||||||
|
|
||||||
|
declare(strict_types=1);
|
||||||
|
|
||||||
|
namespace Phpml\Math;
|
||||||
|
|
||||||
|
use Phpml\Exception\InvalidArgumentException;
|
||||||
|
|
||||||
|
class Comparison
|
||||||
|
{
|
||||||
|
/**
|
||||||
|
* @param mixed $a
|
||||||
|
* @param mixed $b
|
||||||
|
* @param string $operator
|
||||||
|
*
|
||||||
|
* @return bool
|
||||||
|
*
|
||||||
|
* @throws InvalidArgumentException
|
||||||
|
*/
|
||||||
|
public static function compare($a, $b, string $operator): bool
|
||||||
|
{
|
||||||
|
switch ($operator) {
|
||||||
|
case '>':
|
||||||
|
return $a > $b;
|
||||||
|
case '>=':
|
||||||
|
return $a >= $b;
|
||||||
|
case '=':
|
||||||
|
case '==':
|
||||||
|
return $a == $b;
|
||||||
|
case '===':
|
||||||
|
return $a === $b;
|
||||||
|
case '<=':
|
||||||
|
return $a <= $b;
|
||||||
|
case '<':
|
||||||
|
return $a < $b;
|
||||||
|
case '!=':
|
||||||
|
case '<>':
|
||||||
|
return $a != $b;
|
||||||
|
case '!==':
|
||||||
|
return $a !== $b;
|
||||||
|
default:
|
||||||
|
throw InvalidArgumentException::invalidOperator($operator);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
80
tests/Phpml/Math/ComparisonTest.php
Normal file
80
tests/Phpml/Math/ComparisonTest.php
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
<?php
|
||||||
|
|
||||||
|
declare(strict_types=1);
|
||||||
|
|
||||||
|
namespace tests\Phpml\Math;
|
||||||
|
|
||||||
|
use Phpml\Math\Comparison;
|
||||||
|
use PHPUnit\Framework\TestCase;
|
||||||
|
|
||||||
|
class ComparisonTest extends TestCase
|
||||||
|
{
|
||||||
|
/**
|
||||||
|
* @param mixed $a
|
||||||
|
* @param mixed $b
|
||||||
|
* @param string $operator
|
||||||
|
* @param bool $expected
|
||||||
|
*
|
||||||
|
* @dataProvider provideData
|
||||||
|
*/
|
||||||
|
public function testResult($a, $b, string $operator, bool $expected)
|
||||||
|
{
|
||||||
|
$result = Comparison::compare($a, $b, $operator);
|
||||||
|
|
||||||
|
$this->assertEquals($expected, $result);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @expectedException \Phpml\Exception\InvalidArgumentException
|
||||||
|
* @expectedExceptionMessage Invalid operator "~=" provided
|
||||||
|
*/
|
||||||
|
public function testThrowExceptionWhenOperatorIsInvalid()
|
||||||
|
{
|
||||||
|
Comparison::compare(1, 1, '~=');
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return array
|
||||||
|
*/
|
||||||
|
public function provideData()
|
||||||
|
{
|
||||||
|
return [
|
||||||
|
// Greater
|
||||||
|
[1, 0, '>', true],
|
||||||
|
[1, 1, '>', false],
|
||||||
|
[0, 1, '>', false],
|
||||||
|
// Greater or equal
|
||||||
|
[1, 0, '>=', true],
|
||||||
|
[1, 1, '>=', true],
|
||||||
|
[0, 1, '>=', false],
|
||||||
|
// Equal
|
||||||
|
[1, 0, '=', false],
|
||||||
|
[1, 1, '==', true],
|
||||||
|
[1, '1', '=', true],
|
||||||
|
[1, '0', '==', false],
|
||||||
|
// Identical
|
||||||
|
[1, 0, '===', false],
|
||||||
|
[1, 1, '===', true],
|
||||||
|
[1, '1', '===', false],
|
||||||
|
['a', 'a', '===', true],
|
||||||
|
// Not equal
|
||||||
|
[1, 0, '!=', true],
|
||||||
|
[1, 1, '<>', false],
|
||||||
|
[1, '1', '!=', false],
|
||||||
|
[1, '0', '<>', true],
|
||||||
|
// Not identical
|
||||||
|
[1, 0, '!==', true],
|
||||||
|
[1, 1, '!==', false],
|
||||||
|
[1, '1', '!==', true],
|
||||||
|
[1, '0', '!==', true],
|
||||||
|
// Less or equal
|
||||||
|
[1, 0, '<=', false],
|
||||||
|
[1, 1, '<=', true],
|
||||||
|
[0, 1, '<=', true],
|
||||||
|
// Less
|
||||||
|
[1, 0, '<', false],
|
||||||
|
[1, 1, '<', false],
|
||||||
|
[0, 1, '<', true],
|
||||||
|
];
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user