Skip to content

Commit

Permalink
AdaBoost improvements (#53)
Browse files Browse the repository at this point in the history
* AdaBoost improvements

* AdaBoost improvements & test case resolved

* Some coding style fixes
  • Loading branch information
MustafaKarabulut authored and akondas committed Feb 28, 2017
1 parent e8c6005 commit c028a73
Show file tree
Hide file tree
Showing 7 changed files with 385 additions and 99 deletions.
11 changes: 6 additions & 5 deletions src/Phpml/Classification/DecisionTree.php
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,13 @@ public function train(array $samples, array $targets)
}
}

protected function getColumnTypes(array $samples)
public static function getColumnTypes(array $samples)
{
$types = [];
for ($i=0; $i<$this->featureCount; $i++) {
$featureCount = count($samples[0]);
for ($i=0; $i < $featureCount; $i++) {
$values = array_column($samples, $i);
$isCategorical = $this->isCategoricalColumn($values);
$isCategorical = self::isCategoricalColumn($values);
$types[] = $isCategorical ? self::NOMINAL : self::CONTINUOS;
}
return $types;
Expand Down Expand Up @@ -327,13 +328,13 @@ protected function preprocess(array $samples)
* @param array $columnValues
* @return bool
*/
protected function isCategoricalColumn(array $columnValues)
protected static function isCategoricalColumn(array $columnValues)
{
$count = count($columnValues);

// There are two main indicators that *may* show whether a
// column is composed of discrete set of values:
// 1- Column may contain string values and not float values
// 1- Column may contain string values and non-float values
// 2- Number of unique values in the column is only a small fraction of
// all values in that column (Lower than or equal to %20 of all values)
$numericValues = array_filter($columnValues, 'is_numeric');
Expand Down
119 changes: 97 additions & 22 deletions src/Phpml/Classification/Ensemble/AdaBoost.php
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
namespace Phpml\Classification\Ensemble;

use Phpml\Classification\Linear\DecisionStump;
use Phpml\Classification\WeightedClassifier;
use Phpml\Math\Statistic\Mean;
use Phpml\Math\Statistic\StandardDeviation;
use Phpml\Classification\Classifier;
use Phpml\Helper\Predictable;
use Phpml\Helper\Trainable;
Expand Down Expand Up @@ -44,7 +47,7 @@ class AdaBoost implements Classifier
protected $weights = [];

/**
* Base classifiers
* List of selected 'weak' classifiers
*
* @var array
*/
Expand All @@ -57,17 +60,39 @@ class AdaBoost implements Classifier
*/
protected $alpha = [];

/**
* @var string
*/
protected $baseClassifier = DecisionStump::class;

/**
* @var array
*/
protected $classifierOptions = [];

/**
* ADAptive BOOSTing (AdaBoost) is an ensemble algorithm to
* improve classification performance of 'weak' classifiers such as
* DecisionStump (default base classifier of AdaBoost).
*
*/
public function __construct(int $maxIterations = 30)
public function __construct(int $maxIterations = 50)
{
$this->maxIterations = $maxIterations;
}

/**
* Sets the base classifier that will be used for boosting (default = DecisionStump)
*
* @param string $baseClassifier
* @param array $classifierOptions
*/
public function setBaseClassifier(string $baseClassifier = DecisionStump::class, array $classifierOptions = [])
{
$this->baseClassifier = $baseClassifier;
$this->classifierOptions = $classifierOptions;
}

/**
* @param array $samples
* @param array $targets
Expand All @@ -77,7 +102,7 @@ public function train(array $samples, array $targets)
// Initialize usual variables
$this->labels = array_keys(array_count_values($targets));
if (count($this->labels) != 2) {
throw new \Exception("AdaBoost is a binary classifier and can only classify between two classes");
throw new \Exception("AdaBoost is a binary classifier and can classify between two classes only");
}

// Set all target values to either -1 or 1
Expand All @@ -98,9 +123,12 @@ public function train(array $samples, array $targets)
// Execute the algorithm for a maximum number of iterations
$currIter = 0;
while ($this->maxIterations > $currIter++) {

// Determine the best 'weak' classifier based on current weights
// and update alpha & weight values at each iteration
list($classifier, $errorRate) = $this->getBestClassifier();
$classifier = $this->getBestClassifier();
$errorRate = $this->evaluateClassifier($classifier);

// Update alpha & weight values at each iteration
$alpha = $this->calculateAlpha($errorRate);
$this->updateWeights($classifier, $alpha);

Expand All @@ -117,24 +145,71 @@ public function train(array $samples, array $targets)
*/
protected function getBestClassifier()
{
// This method works only for "DecisionStump" classifier, for now.
// As a future task, it will be generalized enough to work with other
// classifiers as well
$minErrorRate = 1.0;
$bestClassifier = null;
for ($i=0; $i < $this->featureCount; $i++) {
$stump = new DecisionStump($i);
$stump->setSampleWeights($this->weights);
$stump->train($this->samples, $this->targets);

$errorRate = $stump->getTrainingErrorRate();
if ($errorRate < $minErrorRate) {
$bestClassifier = $stump;
$minErrorRate = $errorRate;
$ref = new \ReflectionClass($this->baseClassifier);
if ($this->classifierOptions) {
$classifier = $ref->newInstanceArgs($this->classifierOptions);
} else {
$classifier = $ref->newInstance();
}

if (is_subclass_of($classifier, WeightedClassifier::class)) {
$classifier->setSampleWeights($this->weights);
$classifier->train($this->samples, $this->targets);
} else {
list($samples, $targets) = $this->resample();
$classifier->train($samples, $targets);
}

return $classifier;
}

/**
* Resamples the dataset in accordance with the weights and
* returns the new dataset
*
* @return array
*/
protected function resample()
{
$weights = $this->weights;
$std = StandardDeviation::population($weights);
$mean= Mean::arithmetic($weights);
$min = min($weights);
$minZ= (int)round(($min - $mean) / $std);

$samples = [];
$targets = [];
foreach ($weights as $index => $weight) {
$z = (int)round(($weight - $mean) / $std) - $minZ + 1;
for ($i=0; $i < $z; $i++) {
if (rand(0, 1) == 0) {
continue;
}
$samples[] = $this->samples[$index];
$targets[] = $this->targets[$index];
}
}

return [$samples, $targets];
}

/**
* Evaluates the classifier and returns the classification error rate
*
* @param Classifier $classifier
*/
protected function evaluateClassifier(Classifier $classifier)
{
$total = (float) array_sum($this->weights);
$wrong = 0;
foreach ($this->samples as $index => $sample) {
$predicted = $classifier->predict($sample);
if ($predicted != $this->targets[$index]) {
$wrong += $this->weights[$index];
}
}

return [$bestClassifier, $minErrorRate];
return $wrong / $total;
}

/**
Expand All @@ -154,10 +229,10 @@ protected function calculateAlpha(float $errorRate)
/**
* Updates the sample weights
*
* @param DecisionStump $classifier
* @param Classifier $classifier
* @param float $alpha
*/
protected function updateWeights(DecisionStump $classifier, float $alpha)
protected function updateWeights(Classifier $classifier, float $alpha)
{
$sumOfWeights = array_sum($this->weights);
$weightsT1 = [];
Expand Down
6 changes: 6 additions & 0 deletions src/Phpml/Classification/Linear/Adaline.php
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,16 @@ protected function runTraining()
// Batch learning is executed:
$currIter = 0;
while ($this->maxIterations > $currIter++) {
$weights = $this->weights;

$outputs = array_map([$this, 'output'], $this->samples);
$updates = array_map([$this, 'gradient'], $this->targets, $outputs);

$this->updateWeights($updates);

if ($this->earlyStop($weights)) {
break;
}
}
}

Expand Down
Loading

0 comments on commit c028a73

Please sign in to comment.