mirror of
https://github.com/Llewellynvdm/php-ml.git
synced 2024-11-25 06:17:34 +00:00
Create MLP Regressor draft
This commit is contained in:
parent
2412f15923
commit
f0bd5ae424
81
src/Phpml/Regression/MLPRegressor.php
Normal file
81
src/Phpml/Regression/MLPRegressor.php
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
<?php
|
||||||
|
|
||||||
|
declare (strict_types = 1);
|
||||||
|
|
||||||
|
namespace Phpml\Regression;
|
||||||
|
|
||||||
|
|
||||||
|
use Phpml\Helper\Predictable;
|
||||||
|
use Phpml\NeuralNetwork\ActivationFunction;
|
||||||
|
use Phpml\NeuralNetwork\Network\MultilayerPerceptron;
|
||||||
|
use Phpml\NeuralNetwork\Training\Backpropagation;
|
||||||
|
|
||||||
|
class MLPRegressor implements Regression
|
||||||
|
{
|
||||||
|
use Predictable;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @var MultilayerPerceptron
|
||||||
|
*/
|
||||||
|
private $perceptron;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @var array
|
||||||
|
*/
|
||||||
|
private $hiddenLayers;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @var float
|
||||||
|
*/
|
||||||
|
private $desiredError;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @var int
|
||||||
|
*/
|
||||||
|
private $maxIterations;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @var ActivationFunction
|
||||||
|
*/
|
||||||
|
private $activationFunction;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param array $hiddenLayers
|
||||||
|
* @param float $desiredError
|
||||||
|
* @param int $maxIterations
|
||||||
|
* @param ActivationFunction $activationFunction
|
||||||
|
*/
|
||||||
|
public function __construct(array $hiddenLayers = [100], float $desiredError, int $maxIterations, ActivationFunction $activationFunction = null)
|
||||||
|
{
|
||||||
|
$this->hiddenLayers = $hiddenLayers;
|
||||||
|
$this->desiredError = $desiredError;
|
||||||
|
$this->maxIterations = $maxIterations;
|
||||||
|
$this->activationFunction = $activationFunction;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param array $samples
|
||||||
|
* @param array $targets
|
||||||
|
*/
|
||||||
|
public function train(array $samples, array $targets)
|
||||||
|
{
|
||||||
|
$layers = [count($samples[0])] + $this->hiddenLayers + [count($targets[0])];
|
||||||
|
|
||||||
|
$this->perceptron = new MultilayerPerceptron($layers, $this->activationFunction);
|
||||||
|
|
||||||
|
$trainer = new Backpropagation($this->perceptron);
|
||||||
|
$trainer->train($samples, $targets, $this->desiredError, $this->maxIterations);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param array $sample
|
||||||
|
*
|
||||||
|
* @return array
|
||||||
|
*/
|
||||||
|
protected function predictSample(array $sample)
|
||||||
|
{
|
||||||
|
return $this->perceptron->setInput($sample)->getOutput();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user