mirror of
https://github.com/Llewellynvdm/php-ml.git
synced 2025-01-24 15:48:24 +00:00
implement support vector regression
This commit is contained in:
parent
c409658483
commit
430c1078cf
@ -4,8 +4,8 @@ declare (strict_types = 1);
|
||||
|
||||
namespace Phpml\Classification;
|
||||
|
||||
use Phpml\Classification\Traits\Predictable;
|
||||
use Phpml\Classification\Traits\Trainable;
|
||||
use Phpml\Helper\Predictable;
|
||||
use Phpml\Helper\Trainable;
|
||||
use Phpml\Math\Distance;
|
||||
use Phpml\Math\Distance\Euclidean;
|
||||
|
||||
|
@ -4,8 +4,8 @@ declare (strict_types = 1);
|
||||
|
||||
namespace Phpml\Classification;
|
||||
|
||||
use Phpml\Classification\Traits\Predictable;
|
||||
use Phpml\Classification\Traits\Trainable;
|
||||
use Phpml\Helper\Predictable;
|
||||
use Phpml\Helper\Trainable;
|
||||
|
||||
class NaiveBayes implements Classifier
|
||||
{
|
||||
|
@ -4,6 +4,7 @@ declare (strict_types = 1);
|
||||
|
||||
namespace Phpml\Classification;
|
||||
|
||||
use Phpml\SupportVectorMachine\Kernel;
|
||||
use Phpml\SupportVectorMachine\SupportVectorMachine;
|
||||
use Phpml\SupportVectorMachine\Type;
|
||||
|
||||
@ -21,7 +22,7 @@ class SVC extends SupportVectorMachine implements Classifier
|
||||
* @param bool $probabilityEstimates
|
||||
*/
|
||||
public function __construct(
|
||||
int $kernel, float $cost = 1.0, int $degree = 3, float $gamma = null, float $coef0 = 0.0,
|
||||
int $kernel = Kernel::LINEAR, float $cost = 1.0, int $degree = 3, float $gamma = null, float $coef0 = 0.0,
|
||||
float $tolerance = 0.001, int $cacheSize = 100, bool $shrinking = true,
|
||||
bool $probabilityEstimates = false
|
||||
) {
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
declare (strict_types = 1);
|
||||
|
||||
namespace Phpml\Classification\Traits;
|
||||
namespace Phpml\Helper;
|
||||
|
||||
trait Predictable
|
||||
{
|
@ -2,7 +2,7 @@
|
||||
|
||||
declare (strict_types = 1);
|
||||
|
||||
namespace Phpml\Classification\Traits;
|
||||
namespace Phpml\Helper;
|
||||
|
||||
trait Trainable
|
||||
{
|
@ -4,10 +4,12 @@ declare (strict_types = 1);
|
||||
|
||||
namespace Phpml\Regression;
|
||||
|
||||
use Phpml\Helper\Predictable;
|
||||
use Phpml\Math\Matrix;
|
||||
|
||||
class LeastSquares implements Regression
|
||||
{
|
||||
use Predictable;
|
||||
/**
|
||||
* @var array
|
||||
*/
|
||||
@ -45,7 +47,7 @@ class LeastSquares implements Regression
|
||||
*
|
||||
* @return mixed
|
||||
*/
|
||||
public function predict($sample)
|
||||
public function predictSample(array $sample)
|
||||
{
|
||||
$result = $this->intercept;
|
||||
foreach ($this->coefficients as $index => $coefficient) {
|
||||
|
@ -13,9 +13,9 @@ interface Regression
|
||||
public function train(array $samples, array $targets);
|
||||
|
||||
/**
|
||||
* @param float $sample
|
||||
* @param array $samples
|
||||
*
|
||||
* @return mixed
|
||||
*/
|
||||
public function predict($sample);
|
||||
public function predict(array $samples);
|
||||
}
|
||||
|
31
src/Phpml/Regression/SVR.php
Normal file
31
src/Phpml/Regression/SVR.php
Normal file
@ -0,0 +1,31 @@
|
||||
<?php
|
||||
|
||||
declare (strict_types = 1);
|
||||
|
||||
namespace Phpml\Regression;
|
||||
|
||||
use Phpml\SupportVectorMachine\Kernel;
|
||||
use Phpml\SupportVectorMachine\SupportVectorMachine;
|
||||
use Phpml\SupportVectorMachine\Type;
|
||||
|
||||
class SVR extends SupportVectorMachine implements Regression
|
||||
{
|
||||
/**
|
||||
* @param int $kernel
|
||||
* @param int $degree
|
||||
* @param float $epsilon
|
||||
* @param float $cost
|
||||
* @param float|null $gamma
|
||||
* @param float $coef0
|
||||
* @param float $tolerance
|
||||
* @param int $cacheSize
|
||||
* @param bool $shrinking
|
||||
*/
|
||||
public function __construct(
|
||||
int $kernel = Kernel::RBF, int $degree = 3, float $epsilon = 0.1, float $cost = 1.0,
|
||||
float $gamma = null, float $coef0 = 0.0, float $tolerance = 0.001,
|
||||
int $cacheSize = 100, bool $shrinking = true
|
||||
) {
|
||||
parent::__construct(Type::EPSILON_SVR, $kernel, $cost, 0.5, $degree, $gamma, $coef0, $epsilon, $tolerance, $cacheSize, $shrinking, false);
|
||||
}
|
||||
}
|
@ -9,15 +9,19 @@ class DataTransformer
|
||||
/**
|
||||
* @param array $samples
|
||||
* @param array $labels
|
||||
* @param bool $targets
|
||||
*
|
||||
* @return string
|
||||
*/
|
||||
public static function trainingSet(array $samples, array $labels): string
|
||||
public static function trainingSet(array $samples, array $labels, bool $targets = false): string
|
||||
{
|
||||
$set = '';
|
||||
if (!$targets) {
|
||||
$numericLabels = self::numericLabels($labels);
|
||||
}
|
||||
|
||||
foreach ($labels as $index => $label) {
|
||||
$set .= sprintf('%s %s %s', $numericLabels[$label], self::sampleRow($samples[$index]), PHP_EOL);
|
||||
$set .= sprintf('%s %s %s', ($targets ? $label : $numericLabels[$label]), self::sampleRow($samples[$index]), PHP_EOL);
|
||||
}
|
||||
|
||||
return $set;
|
||||
|
@ -131,7 +131,7 @@ class SupportVectorMachine
|
||||
public function train(array $samples, array $labels)
|
||||
{
|
||||
$this->labels = $labels;
|
||||
$trainingSet = DataTransformer::trainingSet($samples, $labels);
|
||||
$trainingSet = DataTransformer::trainingSet($samples, $labels, in_array($this->type, [Type::EPSILON_SVR, Type::NU_SVR]));
|
||||
file_put_contents($trainingSetFileName = $this->varPath.uniqid(), $trainingSet);
|
||||
$modelFileName = $trainingSetFileName.'-model';
|
||||
|
||||
@ -169,13 +169,17 @@ class SupportVectorMachine
|
||||
$output = '';
|
||||
exec(escapeshellcmd($command), $output);
|
||||
|
||||
$rawPredictions = file_get_contents($outputFileName);
|
||||
$predictions = file_get_contents($outputFileName);
|
||||
|
||||
unlink($testSetFileName);
|
||||
unlink($modelFileName);
|
||||
unlink($outputFileName);
|
||||
|
||||
$predictions = DataTransformer::predictions($rawPredictions, $this->labels);
|
||||
if (in_array($this->type, [Type::C_SVC, Type::NU_SVC])) {
|
||||
$predictions = DataTransformer::predictions($predictions, $this->labels);
|
||||
} else {
|
||||
$predictions = explode(PHP_EOL, trim($predictions));
|
||||
}
|
||||
|
||||
if (!is_array($samples[0])) {
|
||||
return $predictions[0];
|
||||
|
50
tests/Phpml/Regression/SVRTest.php
Normal file
50
tests/Phpml/Regression/SVRTest.php
Normal file
@ -0,0 +1,50 @@
|
||||
<?php
|
||||
|
||||
declare (strict_types = 1);
|
||||
|
||||
namespace tests\Regression;
|
||||
|
||||
use Phpml\Regression\SVR;
|
||||
use Phpml\SupportVectorMachine\Kernel;
|
||||
|
||||
class SVRTest extends \PHPUnit_Framework_TestCase
|
||||
{
|
||||
public function testPredictSingleFeatureSamples()
|
||||
{
|
||||
$delta = 0.01;
|
||||
|
||||
$samples = [[60], [61], [62], [63], [65]];
|
||||
$targets = [3.1, 3.6, 3.8, 4, 4.1];
|
||||
|
||||
$regression = new SVR(Kernel::LINEAR);
|
||||
$regression->train($samples, $targets);
|
||||
|
||||
$this->assertEquals(4.03, $regression->predict([64]), '', $delta);
|
||||
|
||||
$samples = [[9300], [10565], [15000], [15000], [17764], [57000], [65940], [73676], [77006], [93739], [146088], [153260]];
|
||||
$targets = [7100, 15500, 4400, 4400, 5900, 4600, 8800, 2000, 2750, 2550, 960, 1025];
|
||||
|
||||
$regression = new SVR(Kernel::LINEAR);
|
||||
$regression->train($samples, $targets);
|
||||
|
||||
$this->assertEquals(6236.12, $regression->predict([9300]), '', $delta);
|
||||
$this->assertEquals(4718.29, $regression->predict([57000]), '', $delta);
|
||||
$this->assertEquals(4081.69, $regression->predict([77006]), '', $delta);
|
||||
$this->assertEquals(6236.12, $regression->predict([9300]), '', $delta);
|
||||
$this->assertEquals(1655.26, $regression->predict([153260]), '', $delta);
|
||||
}
|
||||
|
||||
public function testPredictMultiFeaturesSamples()
|
||||
{
|
||||
$delta = 0.01;
|
||||
|
||||
$samples = [[73676, 1996], [77006, 1998], [10565, 2000], [146088, 1995], [15000, 2001], [65940, 2000], [9300, 2000], [93739, 1996], [153260, 1994], [17764, 2002], [57000, 1998], [15000, 2000]];
|
||||
$targets = [2000, 2750, 15500, 960, 4400, 8800, 7100, 2550, 1025, 5900, 4600, 4400];
|
||||
|
||||
$regression = new SVR(Kernel::LINEAR);
|
||||
$regression->train($samples, $targets);
|
||||
|
||||
$this->assertEquals(4109.82, $regression->predict([60000, 1996]), '', $delta);
|
||||
$this->assertEquals(4112.28, $regression->predict([60000, 2000]), '', $delta);
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user