Additional training for SVR (#59)

* additional training SVR

* additional training SVR, missed old labels reference

* SVM labels parameter now targets

* SVM member labels now targets

* SVM init targets empty array
This commit is contained in:
Kyle Warren 2017-03-17 06:44:45 -04:00 committed by Arkadiusz Kondas
parent 8be19567a2
commit c44f3b2730

View File

@ -4,8 +4,13 @@ declare(strict_types=1);
namespace Phpml\SupportVectorMachine;
use Phpml\Helper\Trainable;
class SupportVectorMachine
{
use Trainable;
/**
* @var int
*/
@ -84,7 +89,7 @@ class SupportVectorMachine
/**
* @var array
*/
private $labels;
private $targets = [];
/**
* @param int $type
@ -126,12 +131,14 @@ class SupportVectorMachine
/**
* @param array $samples
* @param array $labels
* @param array $targets
*/
public function train(array $samples, array $labels)
public function train(array $samples, array $targets)
{
$this->labels = $labels;
$trainingSet = DataTransformer::trainingSet($samples, $labels, in_array($this->type, [Type::EPSILON_SVR, Type::NU_SVR]));
$this->samples = array_merge($this->samples, $samples);
$this->targets = array_merge($this->targets, $targets);
$trainingSet = DataTransformer::trainingSet($this->samples, $this->targets, in_array($this->type, [Type::EPSILON_SVR, Type::NU_SVR]));
file_put_contents($trainingSetFileName = $this->varPath.uniqid('phpml', true), $trainingSet);
$modelFileName = $trainingSetFileName.'-model';
@ -176,7 +183,7 @@ class SupportVectorMachine
unlink($outputFileName);
if (in_array($this->type, [Type::C_SVC, Type::NU_SVC])) {
$predictions = DataTransformer::predictions($predictions, $this->labels);
$predictions = DataTransformer::predictions($predictions, $this->targets);
} else {
$predictions = explode(PHP_EOL, trim($predictions));
}