diff --git a/src/Phpml/Classification/KNearestNeighbors.php b/src/Phpml/Classification/KNearestNeighbors.php index f1a87cf..95ebeaf 100644 --- a/src/Phpml/Classification/KNearestNeighbors.php +++ b/src/Phpml/Classification/KNearestNeighbors.php @@ -35,7 +35,7 @@ class KNearestNeighbors implements Classifier $this->k = $k; $this->samples = []; - $this->labels = []; + $this->targets = []; $this->distanceMetric = $distanceMetric; } @@ -48,10 +48,10 @@ class KNearestNeighbors implements Classifier { $distances = $this->kNeighborsDistances($sample); - $predictions = array_combine(array_values($this->labels), array_fill(0, count($this->labels), 0)); + $predictions = array_combine(array_values($this->targets), array_fill(0, count($this->targets), 0)); foreach ($distances as $index => $distance) { - ++$predictions[$this->labels[$index]]; + ++$predictions[$this->targets[$index]]; } arsort($predictions); diff --git a/src/Phpml/Classification/NaiveBayes.php b/src/Phpml/Classification/NaiveBayes.php index 9726b40..3e7d819 100644 --- a/src/Phpml/Classification/NaiveBayes.php +++ b/src/Phpml/Classification/NaiveBayes.php @@ -19,7 +19,7 @@ class NaiveBayes implements Classifier protected function predictSample(array $sample) { $predictions = []; - foreach ($this->labels as $index => $label) { + foreach ($this->targets as $index => $label) { $predictions[$label] = 0; foreach ($sample as $token => $count) { if (array_key_exists($token, $this->samples[$index])) { diff --git a/src/Phpml/Helper/Trainable.php b/src/Phpml/Helper/Trainable.php index 36b8993..fda27d5 100644 --- a/src/Phpml/Helper/Trainable.php +++ b/src/Phpml/Helper/Trainable.php @@ -14,15 +14,15 @@ trait Trainable /** * @var array */ - private $labels; + private $targets; /** * @param array $samples - * @param array $labels + * @param array $targets */ - public function train(array $samples, array $labels) + public function train(array $samples, array $targets) { $this->samples = $samples; - $this->labels = $labels; + $this->targets = $targets; } } diff --git a/src/Phpml/Pipeline.php b/src/Phpml/Pipeline.php index 2ac0ed0..f0017fe 100644 --- a/src/Phpml/Pipeline.php +++ b/src/Phpml/Pipeline.php @@ -2,28 +2,85 @@ declare (strict_types = 1); -namespace Phpml\Pipeline; +namespace Phpml; -class Pipeline +class Pipeline implements Estimator { /** - * @var array + * @var array|Transformer[] */ - private $stages; + private $transformers; /** - * @param array $stages + * @var Estimator */ - public function __construct(array $stages) + private $estimator; + + /** + * @param array|Transformer[] $transformers + * @param Estimator $estimator + */ + public function __construct(array $transformers = [], Estimator $estimator) { - $this->stages = $stages; + foreach ($transformers as $transformer) { + $this->addTransformer($transformer); + } + + $this->estimator = $estimator; } /** - * @param mixed $stage + * @param Transformer $transformer */ - public function addStage($stage) + public function addTransformer(Transformer $transformer) { - $this->stages[] = $stage; + $this->transformers[] = $transformer; } + + /** + * @param Estimator $estimator + */ + public function setEstimator(Estimator $estimator) + { + $this->estimator = $estimator; + } + + /** + * @return array|Transformer[] + */ + public function getTransformers() + { + return $this->transformers; + } + + /** + * @return Estimator + */ + public function getEstimator() + { + return $this->estimator; + } + + /** + * @param array $samples + * @param array $targets + */ + public function train(array $samples, array $targets) + { + foreach ($this->transformers as $transformer) { + $samples = $transformer->transform($samples); + } + + $this->estimator->train($samples, $targets); + } + + /** + * @param array $samples + * @return mixed + */ + public function predict(array $samples) + { + return $this->estimator->predict($samples); + } + }