Return labels in MultilayerPerceptron output (#315)

This commit is contained in:
Marcin Michalski 2018-10-15 19:47:42 +02:00 committed by Arkadiusz Kondas
parent e255369636
commit d29c5906df
7 changed files with 61 additions and 12 deletions

View File

@ -41,7 +41,7 @@ class MLPClassifier extends MultilayerPerceptron implements Classifier
} }
} }
return $this->classes[$predictedClass]; return $predictedClass;
} }
/** /**
@ -49,9 +49,8 @@ class MLPClassifier extends MultilayerPerceptron implements Classifier
*/ */
protected function trainSample(array $sample, $target): void protected function trainSample(array $sample, $target): void
{ {
// Feed-forward. // Feed-forward.
$this->setInput($sample)->getOutput(); $this->setInput($sample);
// Back-propagate. // Back-propagate.
$this->backpropagation->backpropagate($this->getLayers(), $this->getTargetClass($target)); $this->backpropagation->backpropagate($this->getLayers(), $this->getTargetClass($target));

View File

@ -41,12 +41,9 @@ class Layer
return $this->nodes; return $this->nodes;
} }
/**
* @return Neuron
*/
private function createNode(string $nodeClass, ?ActivationFunction $activationFunction = null): Node private function createNode(string $nodeClass, ?ActivationFunction $activationFunction = null): Node
{ {
if ($nodeClass == Neuron::class) { if ($nodeClass === Neuron::class) {
return new Neuron($activationFunction); return new Neuron($activationFunction);
} }

View File

@ -51,8 +51,6 @@ abstract class LayeredNetwork implements Network
/** /**
* @param mixed $input * @param mixed $input
*
* @return $this
*/ */
public function setInput($input): Network public function setInput($input): Network
{ {

View File

@ -69,6 +69,10 @@ abstract class MultilayerPerceptron extends LayeredNetwork implements Estimator,
throw new InvalidArgumentException('Provide at least 2 different classes'); throw new InvalidArgumentException('Provide at least 2 different classes');
} }
if (count($classes) !== count(array_unique($classes))) {
throw new InvalidArgumentException('Classes must be unique');
}
$this->classes = array_values($classes); $this->classes = array_values($classes);
$this->iterations = $iterations; $this->iterations = $iterations;
$this->inputLayerFeatures = $inputLayerFeatures; $this->inputLayerFeatures = $inputLayerFeatures;
@ -109,6 +113,16 @@ abstract class MultilayerPerceptron extends LayeredNetwork implements Estimator,
$this->backpropagation->setLearningRate($this->learningRate); $this->backpropagation->setLearningRate($this->learningRate);
} }
public function getOutput(): array
{
$result = [];
foreach ($this->getOutputLayer()->getNodes() as $i => $neuron) {
$result[$this->classes[$i]] = $neuron->getOutput();
}
return $result;
}
/** /**
* @param mixed $target * @param mixed $target
*/ */

View File

@ -44,7 +44,7 @@ class Neuron implements Node
/** /**
* @return Synapse[] * @return Synapse[]
*/ */
public function getSynapses() public function getSynapses(): array
{ {
return $this->synapses; return $this->synapses;
} }

View File

@ -183,7 +183,7 @@ class MLPClassifierTest extends TestCase
$testSamples = [[0, 0], [1, 0], [0, 1], [1, 1]]; $testSamples = [[0, 0], [1, 0], [0, 1], [1, 1]];
$predicted = $classifier->predict($testSamples); $predicted = $classifier->predict($testSamples);
$filename = 'perceptron-test-'.random_int(100, 999).'-'.uniqid(); $filename = 'perceptron-test-'.random_int(100, 999).'-'.uniqid('', false);
$filepath = tempnam(sys_get_temp_dir(), $filename); $filepath = tempnam(sys_get_temp_dir(), $filename);
$modelManager = new ModelManager(); $modelManager = new ModelManager();
$modelManager->saveToFile($classifier, $filepath); $modelManager->saveToFile($classifier, $filepath);
@ -204,7 +204,7 @@ class MLPClassifierTest extends TestCase
$this->assertEquals('a', $network->predict([1, 0])); $this->assertEquals('a', $network->predict([1, 0]));
$this->assertEquals('b', $network->predict([0, 1])); $this->assertEquals('b', $network->predict([0, 1]));
$filename = 'perceptron-test-'.random_int(100, 999).'-'.uniqid(); $filename = 'perceptron-test-'.random_int(100, 999).'-'.uniqid('', false);
$filepath = tempnam(sys_get_temp_dir(), $filename); $filepath = tempnam(sys_get_temp_dir(), $filename);
$modelManager = new ModelManager(); $modelManager = new ModelManager();
$modelManager->saveToFile($network, $filepath); $modelManager->saveToFile($network, $filepath);
@ -245,6 +245,13 @@ class MLPClassifierTest extends TestCase
new MLPClassifier(2, [2], [0]); new MLPClassifier(2, [2], [0]);
} }
public function testOutputWithLabels(): void
{
$output = (new MLPClassifier(2, [2, 2], ['T', 'F']))->getOutput();
$this->assertEquals(['T', 'F'], array_keys($output));
}
private function getSynapsesNodes(array $synapses): array private function getSynapsesNodes(array $synapses): array
{ {
$nodes = []; $nodes = [];

View File

@ -4,6 +4,7 @@ declare(strict_types=1);
namespace Phpml\Tests\NeuralNetwork\Network; namespace Phpml\Tests\NeuralNetwork\Network;
use Phpml\Exception\InvalidArgumentException;
use Phpml\NeuralNetwork\ActivationFunction; use Phpml\NeuralNetwork\ActivationFunction;
use Phpml\NeuralNetwork\Layer; use Phpml\NeuralNetwork\Layer;
use Phpml\NeuralNetwork\Network\MultilayerPerceptron; use Phpml\NeuralNetwork\Network\MultilayerPerceptron;
@ -13,6 +14,39 @@ use PHPUnit_Framework_MockObject_MockObject;
class MultilayerPerceptronTest extends TestCase class MultilayerPerceptronTest extends TestCase
{ {
public function testThrowExceptionWhenHiddenLayersAreEmpty(): void
{
$this->expectException(InvalidArgumentException::class);
$this->expectExceptionMessage('Provide at least 1 hidden layer');
$this->getMockForAbstractClass(
MultilayerPerceptron::class,
[5, [], [0, 1], 1000, null, 0.42]
);
}
public function testThrowExceptionWhenThereIsOnlyOneClass(): void
{
$this->expectException(InvalidArgumentException::class);
$this->expectExceptionMessage('Provide at least 2 different classes');
$this->getMockForAbstractClass(
MultilayerPerceptron::class,
[5, [3], [0], 1000, null, 0.42]
);
}
public function testThrowExceptionWhenClassesAreNotUnique(): void
{
$this->expectException(InvalidArgumentException::class);
$this->expectExceptionMessage('Classes must be unique');
$this->getMockForAbstractClass(
MultilayerPerceptron::class,
[5, [3], [0, 1, 2, 3, 1], 1000, null, 0.42]
);
}
public function testLearningRateSetter(): void public function testLearningRateSetter(): void
{ {
/** @var MultilayerPerceptron $mlp */ /** @var MultilayerPerceptron $mlp */