refactor Backpropagation methods and simplify things

This commit is contained in:
Arkadiusz Kondas 2016-08-10 23:03:02 +02:00
parent 66d029e94f
commit c506a84164
3 changed files with 58 additions and 14 deletions

View File

@ -21,6 +21,11 @@ class Backpropagation implements Training
*/
private $theta;
/**
* @var array
*/
private $sigmas;
/**
* @param Network $network
* @param int $theta
@ -71,20 +76,22 @@ class Backpropagation implements Training
return $resultsWithinError;
}
/**
* @param array $sample
* @param array $target
*/
private function trainSample(array $sample, array $target)
{
$this->network->setInput($sample)->getOutput();
$this->sigmas = [];
$sigmas = [];
$layers = $this->network->getLayers();
$layersNumber = count($layers);
for ($i = $layersNumber; $i > 1; --$i) {
foreach ($layers[$i - 1]->getNodes() as $key => $neuron) {
if ($neuron instanceof Neuron) {
$neuronOutput = $neuron->getOutput();
$sigma = $neuronOutput * (1 - $neuronOutput) * ($i == $layersNumber ? ($target[$key] - $neuronOutput) : $this->getPrevSigma($sigmas, $neuron));
$sigmas[] = new Sigma($neuron, $sigma);
$sigma = $this->getSigma($neuron, $target, $key, $i == $layersNumber);
foreach ($neuron->getSynapses() as $synapse) {
$synapse->changeWeight($this->theta * $sigma * $synapse->getNode()->getOutput());
}
@ -94,21 +101,40 @@ class Backpropagation implements Training
}
/**
* @param Sigma[] $sigmas
* @param Neuron $forNeuron
* @param Neuron $neuron
* @param array $target
* @param int $key
* @param bool $lastLayer
*
* @return float
*/
private function getPrevSigma(array $sigmas, Neuron $forNeuron): float
private function getSigma(Neuron $neuron, array $target, int $key, bool $lastLayer): float
{
$neuronOutput = $neuron->getOutput();
$sigma = $neuronOutput * (1 - $neuronOutput);
if ($lastLayer) {
$sigma *= ($target[$key] - $neuronOutput);
} else {
$sigma *= $this->getPrevSigma($neuron);
}
$this->sigmas[] = new Sigma($neuron, $sigma);
return $sigma;
}
/**
* @param Neuron $neuron
*
* @return float
*/
private function getPrevSigma(Neuron $neuron): float
{
$sigma = 0.0;
foreach ($sigmas as $neuronSigma) {
foreach ($neuronSigma->getNeuron()->getSynapses() as $synapse) {
if ($synapse->getNode() == $forNeuron) {
$sigma += $synapse->getWeight() * $neuronSigma->getSigma();
}
}
foreach ($this->sigmas as $neuronSigma) {
$sigma += $neuronSigma->getSigmaForNeuron($neuron);
}
return $sigma;

View File

@ -43,4 +43,22 @@ class Sigma
{
return $this->sigma;
}
/**
* @param Neuron $neuron
*
* @return float
*/
public function getSigmaForNeuron(Neuron $neuron): float
{
$sigma = 0.0;
foreach ($this->neuron->getSynapses() as $synapse) {
if ($synapse->getNode() == $neuron) {
$sigma += $synapse->getWeight() * $this->getSigma();
}
}
return $sigma;
}
}

View File

@ -18,7 +18,7 @@ class BackpropagationTest extends \PHPUnit_Framework_TestCase
[[1, 0], [0, 1], [1, 1], [0, 0]],
[[1], [1], [0], [0]],
$desiredError = 0.2,
10000
30000
);
$this->assertEquals(0, $network->setInput([1, 1])->getOutput()[0], '', $desiredError);