From e1854d44a26c062a4f1538f9eb8135a12793314f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20Monlla=C3=B3?= Date: Thu, 20 Apr 2017 04:26:31 +0800 Subject: [PATCH] Partial training base (#78) * Cost values for multiclass OneVsRest uses * Partial training interface * Reduce linear classifiers memory usage * Testing partial training and isolated training * Partial trainer naming switched to incremental estimator Other changes according to review's feedback. * Clean optimization data once optimize is finished * Abstract resetBinary --- src/Phpml/Classification/Linear/Adaline.php | 7 +- .../Classification/Linear/DecisionStump.php | 54 ++++--- .../Linear/LogisticRegression.php | 23 +-- .../Classification/Linear/Perceptron.php | 80 ++++++---- src/Phpml/Helper/OneVsRest.php | 137 +++++++++++++----- .../Helper/Optimizer/ConjugateGradient.php | 2 + src/Phpml/Helper/Optimizer/GD.php | 15 +- src/Phpml/Helper/Optimizer/StochasticGD.php | 20 ++- src/Phpml/IncrementalEstimator.php | 16 ++ .../Classification/Linear/AdalineTest.php | 18 ++- .../Classification/Linear/PerceptronTest.php | 18 ++- 11 files changed, 283 insertions(+), 107 deletions(-) create mode 100644 src/Phpml/IncrementalEstimator.php diff --git a/src/Phpml/Classification/Linear/Adaline.php b/src/Phpml/Classification/Linear/Adaline.php index 8d94be41..f34dc5c4 100644 --- a/src/Phpml/Classification/Linear/Adaline.php +++ b/src/Phpml/Classification/Linear/Adaline.php @@ -53,8 +53,11 @@ public function __construct(float $learningRate = 0.001, int $maxIterations = 10 /** * Adapts the weights with respect to given samples and targets * by use of gradient descent learning rule + * + * @param array $samples + * @param array $targets */ - protected function runTraining() + protected function runTraining(array $samples, array $targets) { // The cost function is the sum of squares $callback = function ($weights, $sample, $target) { @@ -69,6 +72,6 @@ protected function runTraining() $isBatch = $this->trainingType == self::BATCH_TRAINING; - return parent::runGradientDescent($callback, $isBatch); + return parent::runGradientDescent($samples, $targets, $callback, $isBatch); } } diff --git a/src/Phpml/Classification/Linear/DecisionStump.php b/src/Phpml/Classification/Linear/DecisionStump.php index 13bc4a5a..99f982ff 100644 --- a/src/Phpml/Classification/Linear/DecisionStump.php +++ b/src/Phpml/Classification/Linear/DecisionStump.php @@ -89,15 +89,13 @@ public function __construct(int $columnIndex = self::AUTO_SELECT) * @param array $targets * @throws \Exception */ - protected function trainBinary(array $samples, array $targets) + protected function trainBinary(array $samples, array $targets, array $labels) { - $this->samples = array_merge($this->samples, $samples); - $this->targets = array_merge($this->targets, $targets); - $this->binaryLabels = array_keys(array_count_values($this->targets)); - $this->featureCount = count($this->samples[0]); + $this->binaryLabels = $labels; + $this->featureCount = count($samples[0]); // If a column index is given, it should be among the existing columns - if ($this->givenColumnIndex > count($this->samples[0]) - 1) { + if ($this->givenColumnIndex > count($samples[0]) - 1) { $this->givenColumnIndex = self::AUTO_SELECT; } @@ -105,19 +103,19 @@ protected function trainBinary(array $samples, array $targets) // If none given, then assign 1 as a weight to each sample if ($this->weights) { $numWeights = count($this->weights); - if ($numWeights != count($this->samples)) { + if ($numWeights != count($samples)) { throw new \Exception("Number of sample weights does not match with number of samples"); } } else { - $this->weights = array_fill(0, count($this->samples), 1); + $this->weights = array_fill(0, count($samples), 1); } // Determine type of each column as either "continuous" or "nominal" - $this->columnTypes = DecisionTree::getColumnTypes($this->samples); + $this->columnTypes = DecisionTree::getColumnTypes($samples); // Try to find the best split in the columns of the dataset // by calculating error rate for each split point in each column - $columns = range(0, count($this->samples[0]) - 1); + $columns = range(0, count($samples[0]) - 1); if ($this->givenColumnIndex != self::AUTO_SELECT) { $columns = [$this->givenColumnIndex]; } @@ -128,9 +126,9 @@ protected function trainBinary(array $samples, array $targets) 'trainingErrorRate' => 1.0]; foreach ($columns as $col) { if ($this->columnTypes[$col] == DecisionTree::CONTINUOUS) { - $split = $this->getBestNumericalSplit($col); + $split = $this->getBestNumericalSplit($samples, $targets, $col); } else { - $split = $this->getBestNominalSplit($col); + $split = $this->getBestNominalSplit($samples, $targets, $col); } if ($split['trainingErrorRate'] < $bestSplit['trainingErrorRate']) { @@ -161,13 +159,15 @@ public function setNumericalSplitCount(float $count) /** * Determines best split point for the given column * + * @param array $samples + * @param array $targets * @param int $col * * @return array */ - protected function getBestNumericalSplit(int $col) + protected function getBestNumericalSplit(array $samples, array $targets, int $col) { - $values = array_column($this->samples, $col); + $values = array_column($samples, $col); // Trying all possible points may be accomplished in two general ways: // 1- Try all values in the $samples array ($values) // 2- Artificially split the range of values into several parts and try them @@ -182,7 +182,7 @@ protected function getBestNumericalSplit(int $col) // Before trying all possible split points, let's first try // the average value for the cut point $threshold = array_sum($values) / (float) count($values); - list($errorRate, $prob) = $this->calculateErrorRate($threshold, $operator, $values); + list($errorRate, $prob) = $this->calculateErrorRate($targets, $threshold, $operator, $values); if ($split == null || $errorRate < $split['trainingErrorRate']) { $split = ['value' => $threshold, 'operator' => $operator, 'prob' => $prob, 'column' => $col, @@ -192,7 +192,7 @@ protected function getBestNumericalSplit(int $col) // Try other possible points one by one for ($step = $minValue; $step <= $maxValue; $step+= $stepSize) { $threshold = (float)$step; - list($errorRate, $prob) = $this->calculateErrorRate($threshold, $operator, $values); + list($errorRate, $prob) = $this->calculateErrorRate($targets, $threshold, $operator, $values); if ($errorRate < $split['trainingErrorRate']) { $split = ['value' => $threshold, 'operator' => $operator, 'prob' => $prob, 'column' => $col, @@ -205,13 +205,15 @@ protected function getBestNumericalSplit(int $col) } /** + * @param array $samples + * @param array $targets * @param int $col * * @return array */ - protected function getBestNominalSplit(int $col) : array + protected function getBestNominalSplit(array $samples, array $targets, int $col) : array { - $values = array_column($this->samples, $col); + $values = array_column($samples, $col); $valueCounts = array_count_values($values); $distinctVals= array_keys($valueCounts); @@ -219,7 +221,7 @@ protected function getBestNominalSplit(int $col) : array foreach (['=', '!='] as $operator) { foreach ($distinctVals as $val) { - list($errorRate, $prob) = $this->calculateErrorRate($val, $operator, $values); + list($errorRate, $prob) = $this->calculateErrorRate($targets, $val, $operator, $values); if ($split == null || $split['trainingErrorRate'] < $errorRate) { $split = ['value' => $val, 'operator' => $operator, @@ -260,13 +262,14 @@ protected function evaluate($leftValue, $operator, $rightValue) * Calculates the ratio of wrong predictions based on the new threshold * value given as the parameter * + * @param array $targets * @param float $threshold * @param string $operator * @param array $values * * @return array */ - protected function calculateErrorRate(float $threshold, string $operator, array $values) : array + protected function calculateErrorRate(array $targets, float $threshold, string $operator, array $values) : array { $wrong = 0.0; $prob = []; @@ -280,8 +283,8 @@ protected function calculateErrorRate(float $threshold, string $operator, array $predicted = $rightLabel; } - $target = $this->targets[$index]; - if (strval($predicted) != strval($this->targets[$index])) { + $target = $targets[$index]; + if (strval($predicted) != strval($targets[$index])) { $wrong += $this->weights[$index]; } @@ -340,6 +343,13 @@ protected function predictSampleBinary(array $sample) return $this->binaryLabels[1]; } + /** + * @return void + */ + protected function resetBinary() + { + } + /** * @return string */ diff --git a/src/Phpml/Classification/Linear/LogisticRegression.php b/src/Phpml/Classification/Linear/LogisticRegression.php index a0ec2900..bd56d347 100644 --- a/src/Phpml/Classification/Linear/LogisticRegression.php +++ b/src/Phpml/Classification/Linear/LogisticRegression.php @@ -123,20 +123,23 @@ public function setLambda(float $lambda) /** * Adapts the weights with respect to given samples and targets * by use of selected solver + * + * @param array $samples + * @param array $targets */ - protected function runTraining() + protected function runTraining(array $samples, array $targets) { $callback = $this->getCostFunction(); switch ($this->trainingType) { case self::BATCH_TRAINING: - return $this->runGradientDescent($callback, true); + return $this->runGradientDescent($samples, $targets, $callback, true); case self::ONLINE_TRAINING: - return $this->runGradientDescent($callback, false); + return $this->runGradientDescent($samples, $targets, $callback, false); case self::CONJUGATE_GRAD_TRAINING: - return $this->runConjugateGradient($callback); + return $this->runConjugateGradient($samples, $targets, $callback); } } @@ -144,13 +147,15 @@ protected function runTraining() * Executes Conjugate Gradient method to optimize the * weights of the LogReg model */ - protected function runConjugateGradient(\Closure $gradientFunc) + protected function runConjugateGradient(array $samples, array $targets, \Closure $gradientFunc) { - $optimizer = (new ConjugateGradient($this->featureCount)) - ->setMaxIterations($this->maxIterations); + if (empty($this->optimizer)) { + $this->optimizer = (new ConjugateGradient($this->featureCount)) + ->setMaxIterations($this->maxIterations); + } - $this->weights = $optimizer->runOptimization($this->samples, $this->targets, $gradientFunc); - $this->costValues = $optimizer->getCostValues(); + $this->weights = $this->optimizer->runOptimization($samples, $targets, $gradientFunc); + $this->costValues = $this->optimizer->getCostValues(); } /** diff --git a/src/Phpml/Classification/Linear/Perceptron.php b/src/Phpml/Classification/Linear/Perceptron.php index 8280bcbf..2cf96cc6 100644 --- a/src/Phpml/Classification/Linear/Perceptron.php +++ b/src/Phpml/Classification/Linear/Perceptron.php @@ -10,20 +10,17 @@ use Phpml\Helper\Optimizer\GD; use Phpml\Classification\Classifier; use Phpml\Preprocessing\Normalizer; +use Phpml\IncrementalEstimator; +use Phpml\Helper\PartiallyTrainable; -class Perceptron implements Classifier +class Perceptron implements Classifier, IncrementalEstimator { use Predictable, OneVsRest; - /** - * @var array - */ - protected $samples = []; - /** - * @var array + * @var \Phpml\Helper\Optimizer\Optimizer */ - protected $targets = []; + protected $optimizer; /** * @var array @@ -93,32 +90,47 @@ public function __construct(float $learningRate = 0.001, int $maxIterations = 10 $this->maxIterations = $maxIterations; } + /** + * @param array $samples + * @param array $targets + * @param array $labels + */ + public function partialTrain(array $samples, array $targets, array $labels = array()) + { + return $this->trainByLabel($samples, $targets, $labels); + } + /** * @param array $samples * @param array $targets + * @param array $labels */ - public function trainBinary(array $samples, array $targets) + public function trainBinary(array $samples, array $targets, array $labels) { - $this->labels = array_keys(array_count_values($targets)); - if (count($this->labels) > 2) { - throw new \Exception("Perceptron is for binary (two-class) classification only"); - } if ($this->normalizer) { $this->normalizer->transform($samples); } // Set all target values to either -1 or 1 - $this->labels = [1 => $this->labels[0], -1 => $this->labels[1]]; - foreach ($targets as $target) { - $this->targets[] = strval($target) == strval($this->labels[1]) ? 1 : -1; + $this->labels = [1 => $labels[0], -1 => $labels[1]]; + foreach ($targets as $key => $target) { + $targets[$key] = strval($target) == strval($this->labels[1]) ? 1 : -1; } // Set samples and feature count vars - $this->samples = array_merge($this->samples, $samples); - $this->featureCount = count($this->samples[0]); + $this->featureCount = count($samples[0]); + + $this->runTraining($samples, $targets); + } - $this->runTraining(); + protected function resetBinary() + { + $this->labels = []; + $this->optimizer = null; + $this->featureCount = 0; + $this->weights = null; + $this->costValues = []; } /** @@ -151,8 +163,11 @@ public function getCostValues() /** * Trains the perceptron model with Stochastic Gradient Descent optimization * to get the correct set of weights + * + * @param array $samples + * @param array $targets */ - protected function runTraining() + protected function runTraining(array $samples, array $targets) { // The cost function is the sum of squares $callback = function ($weights, $sample, $target) { @@ -165,25 +180,30 @@ protected function runTraining() return [$error, $gradient]; }; - $this->runGradientDescent($callback); + $this->runGradientDescent($samples, $targets, $callback); } /** - * Executes Stochastic Gradient Descent algorithm for + * Executes a Gradient Descent algorithm for * the given cost function + * + * @param array $samples + * @param array $targets */ - protected function runGradientDescent(\Closure $gradientFunc, bool $isBatch = false) + protected function runGradientDescent(array $samples, array $targets, \Closure $gradientFunc, bool $isBatch = false) { $class = $isBatch ? GD::class : StochasticGD::class; - $optimizer = (new $class($this->featureCount)) - ->setLearningRate($this->learningRate) - ->setMaxIterations($this->maxIterations) - ->setChangeThreshold(1e-6) - ->setEarlyStop($this->enableEarlyStop); + if (empty($this->optimizer)) { + $this->optimizer = (new $class($this->featureCount)) + ->setLearningRate($this->learningRate) + ->setMaxIterations($this->maxIterations) + ->setChangeThreshold(1e-6) + ->setEarlyStop($this->enableEarlyStop); + } - $this->weights = $optimizer->runOptimization($this->samples, $this->targets, $gradientFunc); - $this->costValues = $optimizer->getCostValues(); + $this->weights = $this->optimizer->runOptimization($samples, $targets, $gradientFunc); + $this->costValues = $this->optimizer->getCostValues(); } /** diff --git a/src/Phpml/Helper/OneVsRest.php b/src/Phpml/Helper/OneVsRest.php index 9e7bc82c..98269cdb 100644 --- a/src/Phpml/Helper/OneVsRest.php +++ b/src/Phpml/Helper/OneVsRest.php @@ -6,30 +6,23 @@ trait OneVsRest { - /** - * @var array - */ - protected $samples = []; - - /** - * @var array - */ - protected $targets = []; /** * @var array */ - protected $classifiers; + protected $classifiers = []; /** + * All provided training targets' labels. + * * @var array */ - protected $labels; + protected $allLabels = []; /** * @var array */ - protected $costValues; + protected $costValues = []; /** * Train a binary classifier in the OvR style @@ -39,51 +32,111 @@ trait OneVsRest */ public function train(array $samples, array $targets) { - // Clone the current classifier, so that - // we don't mess up its variables while training - // multiple instances of this classifier - $classifier = clone $this; - $this->classifiers = []; + // Clears previous stuff. + $this->reset(); + + return $this->trainBylabel($samples, $targets); + } + + /** + * @param array $samples + * @param array $targets + * @param array $allLabels All training set labels + * @return void + */ + protected function trainByLabel(array $samples, array $targets, array $allLabels = array()) + { + + // Overwrites the current value if it exist. $allLabels must be provided for each partialTrain run. + if (!empty($allLabels)) { + $this->allLabels = $allLabels; + } else { + $this->allLabels = array_keys(array_count_values($targets)); + } + sort($this->allLabels, SORT_STRING); // If there are only two targets, then there is no need to perform OvR - $this->labels = array_keys(array_count_values($targets)); - if (count($this->labels) == 2) { - $classifier->trainBinary($samples, $targets); - $this->classifiers[] = $classifier; + if (count($this->allLabels) == 2) { + + // Init classifier if required. + if (empty($this->classifiers)) { + $this->classifiers[0] = $this->getClassifierCopy(); + } + + $this->classifiers[0]->trainBinary($samples, $targets, $this->allLabels); } else { // Train a separate classifier for each label and memorize them - $this->samples = $samples; - $this->targets = $targets; - foreach ($this->labels as $label) { - $predictor = clone $classifier; - $targets = $this->binarizeTargets($label); - $predictor->trainBinary($samples, $targets); - $this->classifiers[$label] = $predictor; + + foreach ($this->allLabels as $label) { + + // Init classifier if required. + if (empty($this->classifiers[$label])) { + $this->classifiers[$label] = $this->getClassifierCopy(); + } + + list($binarizedTargets, $classifierLabels) = $this->binarizeTargets($targets, $label); + $this->classifiers[$label]->trainBinary($samples, $binarizedTargets, $classifierLabels); } } - + // If the underlying classifier is capable of giving the cost values // during the training, then assign it to the relevant variable - if (method_exists($this->classifiers[0], 'getCostValues')) { - $this->costValues = $this->classifiers[0]->getCostValues(); + // Adding just the first classifier cost values to avoid complex average calculations. + $classifierref = reset($this->classifiers); + if (method_exists($classifierref, 'getCostValues')) { + $this->costValues = $classifierref->getCostValues(); } } + /** + * Resets the classifier and the vars internally used by OneVsRest to create multiple classifiers. + */ + public function reset() + { + $this->classifiers = []; + $this->allLabels = []; + $this->costValues = []; + + $this->resetBinary(); + } + + /** + * Returns an instance of the current class after cleaning up OneVsRest stuff. + * + * @return \Phpml\Estimator + */ + protected function getClassifierCopy() + { + + // Clone the current classifier, so that + // we don't mess up its variables while training + // multiple instances of this classifier + $classifier = clone $this; + $classifier->reset(); + return $classifier; + } + /** * Groups all targets into two groups: Targets equal to * the given label and the others * + * $targets is not passed by reference nor contains objects so this method + * changes will not affect the caller $targets array. + * + * @param array $targets * @param mixed $label + * @return array Binarized targets and target's labels */ - private function binarizeTargets($label) + private function binarizeTargets($targets, $label) { - $targets = []; - foreach ($this->targets as $target) { - $targets[] = $target == $label ? $label : "not_$label"; + $notLabel = "not_$label"; + foreach ($targets as $key => $target) { + $targets[$key] = $target == $label ? $label : $notLabel; } - return $targets; + $labels = array($label, $notLabel); + return array($targets, $labels); } @@ -94,7 +147,7 @@ private function binarizeTargets($label) */ protected function predictSample(array $sample) { - if (count($this->labels) == 2) { + if (count($this->allLabels) == 2) { return $this->classifiers[0]->predictSampleBinary($sample); } @@ -113,8 +166,16 @@ protected function predictSample(array $sample) * * @param array $samples * @param array $targets + * @param array $labels + */ + abstract protected function trainBinary(array $samples, array $targets, array $labels); + + /** + * To be overwritten by OneVsRest classifiers. + * + * @return void */ - abstract protected function trainBinary(array $samples, array $targets); + abstract protected function resetBinary(); /** * Each classifier that make use of OvR approach should be able to diff --git a/src/Phpml/Helper/Optimizer/ConjugateGradient.php b/src/Phpml/Helper/Optimizer/ConjugateGradient.php index 9bcb338d..18ae89a0 100644 --- a/src/Phpml/Helper/Optimizer/ConjugateGradient.php +++ b/src/Phpml/Helper/Optimizer/ConjugateGradient.php @@ -57,6 +57,8 @@ public function runOptimization(array $samples, array $targets, \Closure $gradie } } + $this->clear(); + return $this->theta; } diff --git a/src/Phpml/Helper/Optimizer/GD.php b/src/Phpml/Helper/Optimizer/GD.php index 14029309..8974c8e7 100644 --- a/src/Phpml/Helper/Optimizer/GD.php +++ b/src/Phpml/Helper/Optimizer/GD.php @@ -15,7 +15,7 @@ class GD extends StochasticGD * * @var int */ - protected $sampleCount; + protected $sampleCount = null; /** * @param array $samples @@ -49,6 +49,8 @@ public function runOptimization(array $samples, array $targets, \Closure $gradie } } + $this->clear(); + return $this->theta; } @@ -105,4 +107,15 @@ protected function updateWeightsWithUpdates(array $updates, float $penalty) } } } + + /** + * Clears the optimizer internal vars after the optimization process. + * + * @return void + */ + protected function clear() + { + $this->sampleCount = null; + parent::clear(); + } } diff --git a/src/Phpml/Helper/Optimizer/StochasticGD.php b/src/Phpml/Helper/Optimizer/StochasticGD.php index 5379a283..e9e318a8 100644 --- a/src/Phpml/Helper/Optimizer/StochasticGD.php +++ b/src/Phpml/Helper/Optimizer/StochasticGD.php @@ -16,14 +16,14 @@ class StochasticGD extends Optimizer * * @var array */ - protected $samples; + protected $samples = []; /** * y (targets) * * @var array */ - protected $targets; + protected $targets = []; /** * Callback function to get the gradient and cost value @@ -31,7 +31,7 @@ class StochasticGD extends Optimizer * * @var \Closure */ - protected $gradientCb; + protected $gradientCb = null; /** * Maximum number of iterations used to train the model @@ -192,6 +192,8 @@ public function runOptimization(array $samples, array $targets, \Closure $gradie } } + $this->clear(); + // Solution in the pocket is better than or equal to the last state // so, we use this solution return $this->theta = $bestTheta; @@ -268,4 +270,16 @@ public function getCostValues() { return $this->costValues; } + + /** + * Clears the optimizer internal vars after the optimization process. + * + * @return void + */ + protected function clear() + { + $this->samples = []; + $this->targets = []; + $this->gradientCb = null; + } } diff --git a/src/Phpml/IncrementalEstimator.php b/src/Phpml/IncrementalEstimator.php new file mode 100644 index 00000000..df188728 --- /dev/null +++ b/src/Phpml/IncrementalEstimator.php @@ -0,0 +1,16 @@ +assertEquals(1, $classifier->predict([6.0, 5.0])); $this->assertEquals(2, $classifier->predict([3.0, 9.5])); - return $classifier; + // Extra partial training should lead to the same results. + $classifier->partialTrain([[0, 1], [1, 0]], [0, 0], [0, 1, 2]); + $this->assertEquals(0, $classifier->predict([0.5, 0.5])); + $this->assertEquals(1, $classifier->predict([6.0, 5.0])); + $this->assertEquals(2, $classifier->predict([3.0, 9.5])); + + // Train should clear previous data. + $samples = [ + [0, 0], [0, 1], [1, 0], [1, 1], // First group : a cluster at bottom-left corner in 2D + [5, 5], [6, 5], [5, 6], [7, 5], // Second group: another cluster at the middle-right + [3, 10],[3, 10],[3, 8], [3, 9] // Third group : cluster at the top-middle + ]; + $targets = [2, 2, 2, 2, 0, 0, 0, 0, 1, 1, 1, 1]; + $classifier->train($samples, $targets); + $this->assertEquals(2, $classifier->predict([0.5, 0.5])); + $this->assertEquals(0, $classifier->predict([6.0, 5.0])); + $this->assertEquals(1, $classifier->predict([3.0, 9.5])); } public function testSaveAndRestore() diff --git a/tests/Phpml/Classification/Linear/PerceptronTest.php b/tests/Phpml/Classification/Linear/PerceptronTest.php index 1f40c461..132a6d79 100644 --- a/tests/Phpml/Classification/Linear/PerceptronTest.php +++ b/tests/Phpml/Classification/Linear/PerceptronTest.php @@ -48,7 +48,23 @@ public function testPredictSingleSample() $this->assertEquals(1, $classifier->predict([6.0, 5.0])); $this->assertEquals(2, $classifier->predict([3.0, 9.5])); - return $classifier; + // Extra partial training should lead to the same results. + $classifier->partialTrain([[0, 1], [1, 0]], [0, 0], [0, 1, 2]); + $this->assertEquals(0, $classifier->predict([0.5, 0.5])); + $this->assertEquals(1, $classifier->predict([6.0, 5.0])); + $this->assertEquals(2, $classifier->predict([3.0, 9.5])); + + // Train should clear previous data. + $samples = [ + [0, 0], [0, 1], [1, 0], [1, 1], // First group : a cluster at bottom-left corner in 2D + [5, 5], [6, 5], [5, 6], [7, 5], // Second group: another cluster at the middle-right + [3, 10],[3, 10],[3, 8], [3, 9] // Third group : cluster at the top-middle + ]; + $targets = [2, 2, 2, 2, 0, 0, 0, 0, 1, 1, 1, 1]; + $classifier->train($samples, $targets); + $this->assertEquals(2, $classifier->predict([0.5, 0.5])); + $this->assertEquals(0, $classifier->predict([6.0, 5.0])); + $this->assertEquals(1, $classifier->predict([3.0, 9.5])); } public function testSaveAndRestore()