create traits for reduce complexity

This commit is contained in:
Arkadiusz Kondas 2016-04-16 21:24:40 +02:00
parent 6f5f190600
commit a4ab370a48
5 changed files with 79 additions and 78 deletions

View File

@ -4,28 +4,25 @@ declare (strict_types = 1);
namespace Phpml\Classifier; namespace Phpml\Classifier;
use Phpml\Classifier\Traits\Predictable;
use Phpml\Classifier\Traits\Trainable;
use Phpml\Metric\Distance; use Phpml\Metric\Distance;
use Phpml\Metric\Distance\Euclidean; use Phpml\Metric\Distance\Euclidean;
class KNearestNeighbors implements Classifier class KNearestNeighbors implements Classifier
{ {
use Trainable, Predictable;
/** /**
* @var int * @var int
*/ */
private $k; private $k;
/**
* @var Distance
*/
private $distanceMetric; private $distanceMetric;
/**
* @var array
*/
private $samples;
/**
* @var array
*/
private $labels;
/** /**
* @param int $k * @param int $k
* @param Distance|null $distanceMetric (if null then Euclidean distance as default) * @param Distance|null $distanceMetric (if null then Euclidean distance as default)
@ -42,35 +39,6 @@ class KNearestNeighbors implements Classifier
$this->distanceMetric = $distanceMetric; $this->distanceMetric = $distanceMetric;
} }
/**
* @param array $samples
* @param array $labels
*/
public function train(array $samples, array $labels)
{
$this->samples = $samples;
$this->labels = $labels;
}
/**
* @param array $samples
*
* @return mixed
*/
public function predict(array $samples)
{
if (!is_array($samples[0])) {
$predicted = $this->predictSample($samples);
} else {
$predicted = [];
foreach ($samples as $index => $sample) {
$predicted[$index] = $this->predictSample($sample);
}
}
return $predicted;
}
/** /**
* @param array $sample * @param array $sample
* *

View File

@ -4,46 +4,12 @@ declare (strict_types = 1);
namespace Phpml\Classifier; namespace Phpml\Classifier;
use Phpml\Classifier\Traits\Predictable;
use Phpml\Classifier\Traits\Trainable;
class NaiveBayes implements Classifier class NaiveBayes implements Classifier
{ {
/** use Trainable, Predictable;
* @var array
*/
private $samples;
/**
* @var array
*/
private $labels;
/**
* @param array $samples
* @param array $labels
*/
public function train(array $samples, array $labels)
{
$this->samples = $samples;
$this->labels = $labels;
}
/**
* @param array $samples
*
* @return mixed
*/
public function predict(array $samples)
{
if (!is_array($samples[0])) {
$predicted = $this->predictSample($samples);
} else {
$predicted = [];
foreach ($samples as $index => $sample) {
$predicted[$index] = $this->predictSample($sample);
}
}
return $predicted;
}
/** /**
* @param array $sample * @param array $sample
@ -67,4 +33,5 @@ class NaiveBayes implements Classifier
return key($predictions); return key($predictions);
} }
} }

View File

@ -0,0 +1,27 @@
<?php
declare(strict_types = 1);
namespace Phpml\Classifier\Traits;
trait Predictable
{
/**
* @param array $samples
*
* @return mixed
*/
public function predict(array $samples)
{
if (!is_array($samples[0])) {
$predicted = $this->predictSample($samples);
} else {
$predicted = [];
foreach ($samples as $index => $sample) {
$predicted[$index] = $this->predictSample($sample);
}
}
return $predicted;
}
}

View File

@ -0,0 +1,29 @@
<?php
declare(strict_types = 1);
namespace Phpml\Classifier\Traits;
trait Trainable
{
/**
* @var array
*/
private $samples;
/**
* @var array
*/
private $labels;
/**
* @param array $samples
* @param array $labels
*/
public function train(array $samples, array $labels)
{
$this->samples = $samples;
$this->labels = $labels;
}
}

View File

@ -16,7 +16,7 @@ class CsvDatasetTest extends \PHPUnit_Framework_TestCase
new CsvDataset('missingFile', 3); new CsvDataset('missingFile', 3);
} }
public function testSampleCsvDataset() public function testSampleCsvDatasetWithHeaderRow()
{ {
$filePath = dirname(__FILE__).'/Resources/dataset.csv'; $filePath = dirname(__FILE__).'/Resources/dataset.csv';
@ -25,4 +25,14 @@ class CsvDatasetTest extends \PHPUnit_Framework_TestCase
$this->assertEquals(10, count($dataset->getSamples())); $this->assertEquals(10, count($dataset->getSamples()));
$this->assertEquals(10, count($dataset->getLabels())); $this->assertEquals(10, count($dataset->getLabels()));
} }
public function testSampleCsvDatasetWithoutHeaderRow()
{
$filePath = dirname(__FILE__).'/Resources/dataset.csv';
$dataset = new CsvDataset($filePath, 2, false);
$this->assertEquals(11, count($dataset->getSamples()));
$this->assertEquals(11, count($dataset->getLabels()));
}
} }