mirror of
https://github.com/Llewellynvdm/php-ml.git
synced 2024-11-22 04:55:10 +00:00
Add activationFunction parameter for Perceptron and Layer
This commit is contained in:
parent
c506a84164
commit
2412f15923
@ -17,20 +17,36 @@ class Layer
|
||||
/**
|
||||
* @param int $nodesNumber
|
||||
* @param string $nodeClass
|
||||
* @param ActivationFunction|null $activationFunction
|
||||
*
|
||||
* @throws InvalidArgumentException
|
||||
*/
|
||||
public function __construct(int $nodesNumber = 0, string $nodeClass = Neuron::class)
|
||||
public function __construct(int $nodesNumber = 0, string $nodeClass = Neuron::class, ActivationFunction $activationFunction = null)
|
||||
{
|
||||
if (!in_array(Node::class, class_implements($nodeClass))) {
|
||||
throw InvalidArgumentException::invalidLayerNodeClass();
|
||||
}
|
||||
|
||||
for ($i = 0; $i < $nodesNumber; ++$i) {
|
||||
$this->nodes[] = new $nodeClass();
|
||||
$this->nodes[] = $this->createNode($nodeClass, $activationFunction);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @param string $nodeClass
|
||||
* @param ActivationFunction|null $activationFunction
|
||||
*
|
||||
* @return Neuron
|
||||
*/
|
||||
private function createNode(string $nodeClass, ActivationFunction $activationFunction = null)
|
||||
{
|
||||
if (Neuron::class == $nodeClass) {
|
||||
return new Neuron($activationFunction);
|
||||
}
|
||||
|
||||
return new $nodeClass();
|
||||
}
|
||||
|
||||
/**
|
||||
* @param Node $node
|
||||
*/
|
||||
|
@ -5,6 +5,7 @@ declare (strict_types = 1);
|
||||
namespace Phpml\NeuralNetwork\Network;
|
||||
|
||||
use Phpml\Exception\InvalidArgumentException;
|
||||
use Phpml\NeuralNetwork\ActivationFunction;
|
||||
use Phpml\NeuralNetwork\Layer;
|
||||
use Phpml\NeuralNetwork\Node\Bias;
|
||||
use Phpml\NeuralNetwork\Node\Input;
|
||||
@ -15,17 +16,18 @@ class MultilayerPerceptron extends LayeredNetwork
|
||||
{
|
||||
/**
|
||||
* @param array $layers
|
||||
* @param ActivationFunction|null $activationFunction
|
||||
*
|
||||
* @throws InvalidArgumentException
|
||||
*/
|
||||
public function __construct(array $layers)
|
||||
public function __construct(array $layers, ActivationFunction $activationFunction = null)
|
||||
{
|
||||
if (count($layers) < 2) {
|
||||
throw InvalidArgumentException::invalidLayersNumber();
|
||||
}
|
||||
|
||||
$this->addInputLayer(array_shift($layers));
|
||||
$this->addNeuronLayers($layers);
|
||||
$this->addNeuronLayers($layers, $activationFunction);
|
||||
$this->addBiasNodes();
|
||||
$this->generateSynapses();
|
||||
}
|
||||
@ -40,11 +42,12 @@ class MultilayerPerceptron extends LayeredNetwork
|
||||
|
||||
/**
|
||||
* @param array $layers
|
||||
* @param ActivationFunction|null $activationFunction
|
||||
*/
|
||||
private function addNeuronLayers(array $layers)
|
||||
private function addNeuronLayers(array $layers, ActivationFunction $activationFunction = null)
|
||||
{
|
||||
foreach ($layers as $neurons) {
|
||||
$this->addLayer(new Layer($neurons, Neuron::class));
|
||||
$this->addLayer(new Layer($neurons, Neuron::class, $activationFunction));
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user