diff --git a/src/Helper/Optimizer/Optimizer.php b/src/Helper/Optimizer/Optimizer.php index 9bac3be..dba0cd0 100644 --- a/src/Helper/Optimizer/Optimizer.php +++ b/src/Helper/Optimizer/Optimizer.php @@ -9,8 +9,6 @@ use Phpml\Exception\InvalidArgumentException; abstract class Optimizer { - public $initialTheta; - /** * Unknown variables to be found * @@ -37,11 +35,9 @@ abstract class Optimizer for ($i = 0; $i < $this->dimensions; ++$i) { $this->theta[] = (random_int(0, PHP_INT_MAX) / PHP_INT_MAX) + 0.1; } - - $this->initialTheta = $this->theta; } - public function setInitialTheta(array $theta) + public function setTheta(array $theta) { if (count($theta) != $this->dimensions) { throw new InvalidArgumentException(sprintf('Number of values in the weights array should be %s', $this->dimensions)); diff --git a/src/Helper/Optimizer/StochasticGD.php b/src/Helper/Optimizer/StochasticGD.php index e1cbeea..c4fabd3 100644 --- a/src/Helper/Optimizer/StochasticGD.php +++ b/src/Helper/Optimizer/StochasticGD.php @@ -89,7 +89,7 @@ class StochasticGD extends Optimizer $this->dimensions = $dimensions; } - public function setInitialTheta(array $theta) + public function setTheta(array $theta) { if (count($theta) != $this->dimensions + 1) { throw new InvalidArgumentException(sprintf('Number of values in the weights array should be %s', $this->dimensions + 1)); diff --git a/tests/Helper/Optimizer/ConjugateGradientTest.php b/tests/Helper/Optimizer/ConjugateGradientTest.php index 78eb718..09c250c 100644 --- a/tests/Helper/Optimizer/ConjugateGradientTest.php +++ b/tests/Helper/Optimizer/ConjugateGradientTest.php @@ -57,7 +57,7 @@ class ConjugateGradientTest extends TestCase $optimizer = new ConjugateGradient(1); // set very weak theta to trigger very bad result - $optimizer->setInitialTheta([0.0000001, 0.0000001]); + $optimizer->setTheta([0.0000001, 0.0000001]); $theta = $optimizer->runOptimization($samples, $targets, $callback); @@ -97,6 +97,6 @@ class ConjugateGradientTest extends TestCase $opimizer = new ConjugateGradient(2); $this->expectException(InvalidArgumentException::class); - $opimizer->setInitialTheta([0.15]); + $opimizer->setTheta([0.15]); } } diff --git a/tests/Helper/Optimizer/OptimizerTest.php b/tests/Helper/Optimizer/OptimizerTest.php new file mode 100644 index 0000000..22efdd2 --- /dev/null +++ b/tests/Helper/Optimizer/OptimizerTest.php @@ -0,0 +1,34 @@ +expectException(InvalidArgumentException::class); + $this->expectExceptionMessage('Number of values in the weights array should be 3'); + /** @var Optimizer $optimizer */ + $optimizer = $this->getMockForAbstractClass(Optimizer::class, [3]); + + $optimizer->setTheta([]); + } + + public function testSetTheta(): void + { + /** @var Optimizer $optimizer */ + $optimizer = $this->getMockForAbstractClass(Optimizer::class, [2]); + $object = $optimizer->setTheta([0.3, 1]); + + $theta = $this->getObjectAttribute($optimizer, 'theta'); + + $this->assertSame($object, $optimizer); + $this->assertSame([0.3, 1], $theta); + } +}