diff --git a/src/Phpml/NeuralNetwork/Layer.php b/src/Phpml/NeuralNetwork/Layer.php index 6700164..b94da21 100644 --- a/src/Phpml/NeuralNetwork/Layer.php +++ b/src/Phpml/NeuralNetwork/Layer.php @@ -15,22 +15,38 @@ class Layer private $nodes = []; /** - * @param int $nodesNumber - * @param string $nodeClass + * @param int $nodesNumber + * @param string $nodeClass + * @param ActivationFunction|null $activationFunction * * @throws InvalidArgumentException */ - public function __construct(int $nodesNumber = 0, string $nodeClass = Neuron::class) + public function __construct(int $nodesNumber = 0, string $nodeClass = Neuron::class, ActivationFunction $activationFunction = null) { if (!in_array(Node::class, class_implements($nodeClass))) { throw InvalidArgumentException::invalidLayerNodeClass(); } for ($i = 0; $i < $nodesNumber; ++$i) { - $this->nodes[] = new $nodeClass(); + $this->nodes[] = $this->createNode($nodeClass, $activationFunction); } } + /** + * @param string $nodeClass + * @param ActivationFunction|null $activationFunction + * + * @return Neuron + */ + private function createNode(string $nodeClass, ActivationFunction $activationFunction = null) + { + if (Neuron::class == $nodeClass) { + return new Neuron($activationFunction); + } + + return new $nodeClass(); + } + /** * @param Node $node */ diff --git a/src/Phpml/NeuralNetwork/Network/MultilayerPerceptron.php b/src/Phpml/NeuralNetwork/Network/MultilayerPerceptron.php index 4079822..e97e045 100644 --- a/src/Phpml/NeuralNetwork/Network/MultilayerPerceptron.php +++ b/src/Phpml/NeuralNetwork/Network/MultilayerPerceptron.php @@ -5,6 +5,7 @@ declare (strict_types = 1); namespace Phpml\NeuralNetwork\Network; use Phpml\Exception\InvalidArgumentException; +use Phpml\NeuralNetwork\ActivationFunction; use Phpml\NeuralNetwork\Layer; use Phpml\NeuralNetwork\Node\Bias; use Phpml\NeuralNetwork\Node\Input; @@ -14,18 +15,19 @@ use Phpml\NeuralNetwork\Node\Neuron\Synapse; class MultilayerPerceptron extends LayeredNetwork { /** - * @param array $layers + * @param array $layers + * @param ActivationFunction|null $activationFunction * * @throws InvalidArgumentException */ - public function __construct(array $layers) + public function __construct(array $layers, ActivationFunction $activationFunction = null) { if (count($layers) < 2) { throw InvalidArgumentException::invalidLayersNumber(); } $this->addInputLayer(array_shift($layers)); - $this->addNeuronLayers($layers); + $this->addNeuronLayers($layers, $activationFunction); $this->addBiasNodes(); $this->generateSynapses(); } @@ -39,12 +41,13 @@ class MultilayerPerceptron extends LayeredNetwork } /** - * @param array $layers + * @param array $layers + * @param ActivationFunction|null $activationFunction */ - private function addNeuronLayers(array $layers) + private function addNeuronLayers(array $layers, ActivationFunction $activationFunction = null) { foreach ($layers as $neurons) { - $this->addLayer(new Layer($neurons, Neuron::class)); + $this->addLayer(new Layer($neurons, Neuron::class, $activationFunction)); } }