Change from theta to learning rate var name in NN (#159)

This commit is contained in:
David Monllaó 2017-11-20 23:39:50 +01:00 committed by Arkadiusz Kondas
parent 333598b472
commit b1d40bfa30
3 changed files with 11 additions and 11 deletions

View File

@ -8,7 +8,7 @@ A multilayer perceptron (MLP) is a feedforward artificial neural network model t
* $hiddenLayers (array) - array with the hidden layers configuration, each value represent number of neurons in each layers * $hiddenLayers (array) - array with the hidden layers configuration, each value represent number of neurons in each layers
* $classes (array) - array with the different training set classes (array keys are ignored) * $classes (array) - array with the different training set classes (array keys are ignored)
* $iterations (int) - number of training iterations * $iterations (int) - number of training iterations
* $theta (int) - network theta parameter * $learningRate (float) - the learning rate
* $activationFunction (ActivationFunction) - neuron activation function * $activationFunction (ActivationFunction) - neuron activation function
``` ```

View File

@ -46,9 +46,9 @@ abstract class MultilayerPerceptron extends LayeredNetwork implements Estimator,
protected $activationFunction; protected $activationFunction;
/** /**
* @var int * @var float
*/ */
private $theta; private $learningRate;
/** /**
* @var Backpropagation * @var Backpropagation
@ -58,7 +58,7 @@ abstract class MultilayerPerceptron extends LayeredNetwork implements Estimator,
/** /**
* @throws InvalidArgumentException * @throws InvalidArgumentException
*/ */
public function __construct(int $inputLayerFeatures, array $hiddenLayers, array $classes, int $iterations = 10000, ?ActivationFunction $activationFunction = null, int $theta = 1) public function __construct(int $inputLayerFeatures, array $hiddenLayers, array $classes, int $iterations = 10000, ?ActivationFunction $activationFunction = null, float $learningRate = 1)
{ {
if (empty($hiddenLayers)) { if (empty($hiddenLayers)) {
throw InvalidArgumentException::invalidLayersNumber(); throw InvalidArgumentException::invalidLayersNumber();
@ -73,7 +73,7 @@ abstract class MultilayerPerceptron extends LayeredNetwork implements Estimator,
$this->inputLayerFeatures = $inputLayerFeatures; $this->inputLayerFeatures = $inputLayerFeatures;
$this->hiddenLayers = $hiddenLayers; $this->hiddenLayers = $hiddenLayers;
$this->activationFunction = $activationFunction; $this->activationFunction = $activationFunction;
$this->theta = $theta; $this->learningRate = $learningRate;
$this->initNetwork(); $this->initNetwork();
} }
@ -87,7 +87,7 @@ abstract class MultilayerPerceptron extends LayeredNetwork implements Estimator,
$this->addBiasNodes(); $this->addBiasNodes();
$this->generateSynapses(); $this->generateSynapses();
$this->backpropagation = new Backpropagation($this->theta); $this->backpropagation = new Backpropagation($this->learningRate);
} }
public function train(array $samples, array $targets): void public function train(array $samples, array $targets): void

View File

@ -10,9 +10,9 @@ use Phpml\NeuralNetwork\Training\Backpropagation\Sigma;
class Backpropagation class Backpropagation
{ {
/** /**
* @var int * @var float
*/ */
private $theta; private $learningRate;
/** /**
* @var array * @var array
@ -24,9 +24,9 @@ class Backpropagation
*/ */
private $prevSigmas = null; private $prevSigmas = null;
public function __construct(int $theta) public function __construct(float $learningRate)
{ {
$this->theta = $theta; $this->learningRate = $learningRate;
} }
/** /**
@ -43,7 +43,7 @@ class Backpropagation
if ($neuron instanceof Neuron) { if ($neuron instanceof Neuron) {
$sigma = $this->getSigma($neuron, $targetClass, $key, $i == $layersNumber); $sigma = $this->getSigma($neuron, $targetClass, $key, $i == $layersNumber);
foreach ($neuron->getSynapses() as $synapse) { foreach ($neuron->getSynapses() as $synapse) {
$synapse->changeWeight($this->theta * $sigma * $synapse->getNode()->getOutput()); $synapse->changeWeight($this->learningRate * $sigma * $synapse->getNode()->getOutput());
} }
} }
} }