diff --git a/CHANGELOG.md b/CHANGELOG.md index 13bae10..5990242 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ This changelog references the relevant changes done in PHP-ML library. * fix ensure DataTransformer::testSet samples array is not empty (#204) * fix optimizer initial theta randomization (#239) * fix travis build on osx (#281) + * fix SVM locale (non-locale aware) (#288) * typo, tests, code styles and documentation fixes (#265, #261, #254, #253, #251, #250, #248, #245, #243) * 0.6.2 (2018-02-22) diff --git a/src/SupportVectorMachine/DataTransformer.php b/src/SupportVectorMachine/DataTransformer.php index 06272e2..fcc18f6 100644 --- a/src/SupportVectorMachine/DataTransformer.php +++ b/src/SupportVectorMachine/DataTransformer.php @@ -104,7 +104,7 @@ class DataTransformer { $row = []; foreach ($sample as $index => $feature) { - $row[] = sprintf('%s:%s', $index + 1, $feature); + $row[] = sprintf('%s:%F', $index + 1, $feature); } return implode(' ', $row); diff --git a/src/SupportVectorMachine/SupportVectorMachine.php b/src/SupportVectorMachine/SupportVectorMachine.php index be16ff4..4c2f87b 100644 --- a/src/SupportVectorMachine/SupportVectorMachine.php +++ b/src/SupportVectorMachine/SupportVectorMachine.php @@ -269,7 +269,7 @@ class SupportVectorMachine private function buildTrainCommand(string $trainingSetFileName, string $modelFileName): string { return sprintf( - '%ssvm-train%s -s %s -t %s -c %s -n %s -d %s%s -r %s -p %s -m %s -e %s -h %d -b %d %s %s', + '%ssvm-train%s -s %s -t %s -c %s -n %F -d %s%s -r %s -p %F -m %F -e %F -h %d -b %d %s %s', $this->binPath, $this->getOSExtension(), $this->type, diff --git a/tests/Classification/SVCTest.php b/tests/Classification/SVCTest.php index 0709e04..1ec1541 100644 --- a/tests/Classification/SVCTest.php +++ b/tests/Classification/SVCTest.php @@ -57,13 +57,32 @@ class SVCTest extends TestCase $classifier->train($trainSamples, $trainLabels); $predicted = $classifier->predict($testSamples); - $filename = 'svc-test-'.random_int(100, 999).'-'.uniqid(); - $filepath = tempnam(sys_get_temp_dir(), $filename); + $filepath = tempnam(sys_get_temp_dir(), uniqid('svc-test', true)); $modelManager = new ModelManager(); $modelManager->saveToFile($classifier, $filepath); $restoredClassifier = $modelManager->restoreFromFile($filepath); $this->assertEquals($classifier, $restoredClassifier); $this->assertEquals($predicted, $restoredClassifier->predict($testSamples)); + $this->assertEquals($predicted, $testLabels); + } + + public function testWithNonDotDecimalLocale(): void + { + $currentLocale = setlocale(LC_NUMERIC, '0'); + setlocale(LC_NUMERIC, 'pl_PL.utf8'); + + $trainSamples = [[1, 3], [1, 4], [2, 4], [3, 1], [4, 1], [4, 2]]; + $trainLabels = ['a', 'a', 'a', 'b', 'b', 'b']; + + $testSamples = [[3, 2], [5, 1], [4, 3]]; + $testLabels = ['b', 'b', 'b']; + + $classifier = new SVC(Kernel::LINEAR, $cost = 1000); + $classifier->train($trainSamples, $trainLabels); + + $this->assertEquals($classifier->predict($testSamples), $testLabels); + + setlocale(LC_NUMERIC, $currentLocale); } } diff --git a/tests/SupportVectorMachine/DataTransformerTest.php b/tests/SupportVectorMachine/DataTransformerTest.php index 75c23d6..df29806 100644 --- a/tests/SupportVectorMachine/DataTransformerTest.php +++ b/tests/SupportVectorMachine/DataTransformerTest.php @@ -16,10 +16,10 @@ class DataTransformerTest extends TestCase $labels = ['a', 'a', 'b', 'b']; $trainingSet = - '0 1:1 2:1 '.PHP_EOL. - '0 1:2 2:1 '.PHP_EOL. - '1 1:3 2:2 '.PHP_EOL. - '1 1:4 2:5 '.PHP_EOL + '0 1:1.000000 2:1.000000 '.PHP_EOL. + '0 1:2.000000 2:1.000000 '.PHP_EOL. + '1 1:3.000000 2:2.000000 '.PHP_EOL. + '1 1:4.000000 2:5.000000 '.PHP_EOL ; $this->assertEquals($trainingSet, DataTransformer::trainingSet($samples, $labels)); @@ -30,10 +30,10 @@ class DataTransformerTest extends TestCase $samples = [[1, 1], [2, 1], [3, 2], [4, 5]]; $testSet = - '0 1:1 2:1 '.PHP_EOL. - '0 1:2 2:1 '.PHP_EOL. - '0 1:3 2:2 '.PHP_EOL. - '0 1:4 2:5 '.PHP_EOL + '0 1:1.000000 2:1.000000 '.PHP_EOL. + '0 1:2.000000 2:1.000000 '.PHP_EOL. + '0 1:3.000000 2:2.000000 '.PHP_EOL. + '0 1:4.000000 2:5.000000 '.PHP_EOL ; $this->assertEquals($testSet, DataTransformer::testSet($samples));