diff --git a/src/Phpml/SupportVectorMachine/SupportVectorMachine.php b/src/Phpml/SupportVectorMachine/SupportVectorMachine.php index a812065..e9bd8c8 100644 --- a/src/Phpml/SupportVectorMachine/SupportVectorMachine.php +++ b/src/Phpml/SupportVectorMachine/SupportVectorMachine.php @@ -124,5 +124,4 @@ class SupportVectorMachine return ''; } - } diff --git a/tests/Phpml/SupportVectorMachine/SupportVectorMachineTest.php b/tests/Phpml/SupportVectorMachine/SupportVectorMachineTest.php index 330f7f0..d14f777 100644 --- a/tests/Phpml/SupportVectorMachine/SupportVectorMachineTest.php +++ b/tests/Phpml/SupportVectorMachine/SupportVectorMachineTest.php @@ -34,7 +34,7 @@ SV $this->assertEquals($model, $svm->getModel()); } - public function testPredictCSVCModelWithLinearKernel() + public function testPredictSampleWithLinearKernel() { $samples = [[1, 3], [1, 4], [2, 4], [3, 1], [4, 1], [4, 2]]; $labels = ['a', 'a', 'a', 'b', 'b', 'b']; @@ -52,4 +52,31 @@ SV $this->assertEquals('a', $predictions[1]); $this->assertEquals('b', $predictions[2]); } + + public function testPredictSampleFromMultipleClassWithRbfKernel() + { + $samples = [ + [1, 3], [1, 4], [1, 4], + [3, 1], [4, 1], [4, 2], + [-3, -1], [-4, -1], [-4, -2], + ]; + $labels = [ + 'a', 'a', 'a', + 'b', 'b', 'b', + 'c', 'c', 'c', + ]; + + $svm = new SupportVectorMachine(Type::C_SVC, Kernel::RBF, 100.0); + $svm->train($samples, $labels); + + $predictions = $svm->predict([ + [1, 5], + [4, 3], + [-4, -3], + ]); + + $this->assertEquals('a', $predictions[0]); + $this->assertEquals('b', $predictions[1]); + $this->assertEquals('c', $predictions[2]); + } }