implement data Normalizer with L1 and L2 norm

This commit is contained in:
Arkadiusz Kondas 2016-05-08 20:35:01 +02:00
parent 65cdfe64b2
commit fb04b57853
3 changed files with 157 additions and 0 deletions

View File

@ -0,0 +1,16 @@
<?php
declare (strict_types = 1);
namespace Phpml\Exception;
class NormalizerException extends \Exception
{
/**
* @return NormalizerException
*/
public static function unknownNorm()
{
return new self('Unknown norm supplied.');
}
}

View File

@ -0,0 +1,83 @@
<?php
declare (strict_types = 1);
namespace Phpml\Preprocessing;
use Phpml\Exception\NormalizerException;
class Normalizer implements Preprocessor
{
const NORM_L1 = 1;
const NORM_L2 = 2;
/**
* @var int
*/
private $norm;
/**
* @param int $norm
*
* @throws NormalizerException
*/
public function __construct(int $norm = self::NORM_L2)
{
if (!in_array($norm, [self::NORM_L1, self::NORM_L2])) {
throw NormalizerException::unknownNorm();
}
$this->norm = $norm;
}
/**
* @param array $samples
*/
public function preprocess(array &$samples)
{
$method = sprintf('normalizeL%s', $this->norm);
foreach ($samples as &$sample) {
$this->$method($sample);
}
}
/**
* @param array $sample
*/
private function normalizeL1(array &$sample)
{
$norm1 = 0;
foreach ($sample as $feature) {
$norm1 += abs($feature);
}
if (0 == $norm1) {
$count = count($sample);
$sample = array_fill(0, $count, 1.0 / $count);
} else {
foreach ($sample as &$feature) {
$feature = $feature / $norm1;
}
}
}
/**
* @param array $sample
*/
private function normalizeL2(array &$sample)
{
$norm2 = 0;
foreach ($sample as $feature) {
$norm2 += $feature * $feature;
}
$norm2 = sqrt($norm2);
if (0 == $norm2) {
$sample = array_fill(0, count($sample), 1);
} else {
foreach ($sample as &$feature) {
$feature = $feature / $norm2;
}
}
}
}

View File

@ -0,0 +1,58 @@
<?php
declare (strict_types = 1);
namespace tests\Preprocessing;
use Phpml\Preprocessing\Normalizer;
class NormalizerTest extends \PHPUnit_Framework_TestCase
{
/**
* @expectedException \Phpml\Exception\NormalizerException
*/
public function testThrowExceptionOnInvalidNorm()
{
new Normalizer(99);
}
public function testNormalizeSamplesWithL2Norm()
{
$samples = [
[1, -1, 2],
[2, 0, 0],
[0, 1, -1],
];
$normalized = [
[0.4, -0.4, 0.81],
[1.0, 0.0, 0.0],
[0.0, 0.7, -0.7],
];
$normalizer = new Normalizer();
$normalizer->preprocess($samples);
$this->assertEquals($normalized, $samples, '', $delta = 0.01);
}
public function testNormalizeSamplesWithL1Norm()
{
$samples = [
[1, -1, 2],
[2, 0, 0],
[0, 1, -1],
];
$normalized = [
[0.25, -0.25, 0.5],
[1.0, 0.0, 0.0],
[0.0, 0.5, -0.5],
];
$normalizer = new Normalizer(Normalizer::NORM_L1);
$normalizer->preprocess($samples);
$this->assertEquals($normalized, $samples, '', $delta = 0.01);
}
}