Add activationFunction parameter for Perceptron and Layer

This commit is contained in:
Arkadiusz Kondas 2016-08-11 13:21:22 +02:00
parent c506a84164
commit 2412f15923
2 changed files with 29 additions and 10 deletions

View File

@ -17,20 +17,36 @@ class Layer
/** /**
* @param int $nodesNumber * @param int $nodesNumber
* @param string $nodeClass * @param string $nodeClass
* @param ActivationFunction|null $activationFunction
* *
* @throws InvalidArgumentException * @throws InvalidArgumentException
*/ */
public function __construct(int $nodesNumber = 0, string $nodeClass = Neuron::class) public function __construct(int $nodesNumber = 0, string $nodeClass = Neuron::class, ActivationFunction $activationFunction = null)
{ {
if (!in_array(Node::class, class_implements($nodeClass))) { if (!in_array(Node::class, class_implements($nodeClass))) {
throw InvalidArgumentException::invalidLayerNodeClass(); throw InvalidArgumentException::invalidLayerNodeClass();
} }
for ($i = 0; $i < $nodesNumber; ++$i) { for ($i = 0; $i < $nodesNumber; ++$i) {
$this->nodes[] = new $nodeClass(); $this->nodes[] = $this->createNode($nodeClass, $activationFunction);
} }
} }
/**
* @param string $nodeClass
* @param ActivationFunction|null $activationFunction
*
* @return Neuron
*/
private function createNode(string $nodeClass, ActivationFunction $activationFunction = null)
{
if (Neuron::class == $nodeClass) {
return new Neuron($activationFunction);
}
return new $nodeClass();
}
/** /**
* @param Node $node * @param Node $node
*/ */

View File

@ -5,6 +5,7 @@ declare (strict_types = 1);
namespace Phpml\NeuralNetwork\Network; namespace Phpml\NeuralNetwork\Network;
use Phpml\Exception\InvalidArgumentException; use Phpml\Exception\InvalidArgumentException;
use Phpml\NeuralNetwork\ActivationFunction;
use Phpml\NeuralNetwork\Layer; use Phpml\NeuralNetwork\Layer;
use Phpml\NeuralNetwork\Node\Bias; use Phpml\NeuralNetwork\Node\Bias;
use Phpml\NeuralNetwork\Node\Input; use Phpml\NeuralNetwork\Node\Input;
@ -15,17 +16,18 @@ class MultilayerPerceptron extends LayeredNetwork
{ {
/** /**
* @param array $layers * @param array $layers
* @param ActivationFunction|null $activationFunction
* *
* @throws InvalidArgumentException * @throws InvalidArgumentException
*/ */
public function __construct(array $layers) public function __construct(array $layers, ActivationFunction $activationFunction = null)
{ {
if (count($layers) < 2) { if (count($layers) < 2) {
throw InvalidArgumentException::invalidLayersNumber(); throw InvalidArgumentException::invalidLayersNumber();
} }
$this->addInputLayer(array_shift($layers)); $this->addInputLayer(array_shift($layers));
$this->addNeuronLayers($layers); $this->addNeuronLayers($layers, $activationFunction);
$this->addBiasNodes(); $this->addBiasNodes();
$this->generateSynapses(); $this->generateSynapses();
} }
@ -40,11 +42,12 @@ class MultilayerPerceptron extends LayeredNetwork
/** /**
* @param array $layers * @param array $layers
* @param ActivationFunction|null $activationFunction
*/ */
private function addNeuronLayers(array $layers) private function addNeuronLayers(array $layers, ActivationFunction $activationFunction = null)
{ {
foreach ($layers as $neurons) { foreach ($layers as $neurons) {
$this->addLayer(new Layer($neurons, Neuron::class)); $this->addLayer(new Layer($neurons, Neuron::class, $activationFunction));
} }
} }