diff --git a/src/Phpml/SupportVectorMachine/SupportVectorMachine.php b/src/Phpml/SupportVectorMachine/SupportVectorMachine.php index 55a5305..27b9e5a 100644 --- a/src/Phpml/SupportVectorMachine/SupportVectorMachine.php +++ b/src/Phpml/SupportVectorMachine/SupportVectorMachine.php @@ -4,9 +4,14 @@ declare(strict_types=1); namespace Phpml\SupportVectorMachine; +use Phpml\Helper\Trainable; + + class SupportVectorMachine { - /** + use Trainable; + + /** * @var int */ private $type; @@ -84,7 +89,7 @@ class SupportVectorMachine /** * @var array */ - private $labels; + private $targets = []; /** * @param int $type @@ -126,12 +131,14 @@ class SupportVectorMachine /** * @param array $samples - * @param array $labels + * @param array $targets */ - public function train(array $samples, array $labels) + public function train(array $samples, array $targets) { - $this->labels = $labels; - $trainingSet = DataTransformer::trainingSet($samples, $labels, in_array($this->type, [Type::EPSILON_SVR, Type::NU_SVR])); + $this->samples = array_merge($this->samples, $samples); + $this->targets = array_merge($this->targets, $targets); + + $trainingSet = DataTransformer::trainingSet($this->samples, $this->targets, in_array($this->type, [Type::EPSILON_SVR, Type::NU_SVR])); file_put_contents($trainingSetFileName = $this->varPath.uniqid('phpml', true), $trainingSet); $modelFileName = $trainingSetFileName.'-model'; @@ -176,7 +183,7 @@ class SupportVectorMachine unlink($outputFileName); if (in_array($this->type, [Type::C_SVC, Type::NU_SVC])) { - $predictions = DataTransformer::predictions($predictions, $this->labels); + $predictions = DataTransformer::predictions($predictions, $this->targets); } else { $predictions = explode(PHP_EOL, trim($predictions)); }