Skip to content

Commit

Permalink
AdaBoost algorithm along with some improvements (#51)
Browse files Browse the repository at this point in the history
  • Loading branch information
MustafaKarabulut authored and akondas committed Feb 21, 2017
1 parent cf222bc commit 4daa0a2
Show file tree
Hide file tree
Showing 7 changed files with 463 additions and 61 deletions.
30 changes: 24 additions & 6 deletions src/Phpml/Classification/DecisionTree.php
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class DecisionTree implements Classifier
/**
* @var array
*/
private $columnTypes;
protected $columnTypes;

/**
* @var array
Expand All @@ -39,12 +39,12 @@ class DecisionTree implements Classifier
/**
* @var DecisionTreeLeaf
*/
private $tree = null;
protected $tree = null;

/**
* @var int
*/
private $maxDepth;
protected $maxDepth;

/**
* @var int
Expand Down Expand Up @@ -79,6 +79,7 @@ public function __construct($maxDepth = 10)
{
$this->maxDepth = $maxDepth;
}

/**
* @param array $samples
* @param array $targets
Expand Down Expand Up @@ -209,6 +210,17 @@ protected function getBestSplit($records)
$split->columnIndex = $i;
$split->isContinuous = $this->columnTypes[$i] == self::CONTINUOS;
$split->records = $records;

// If a numeric column is to be selected, then
// the original numeric value and the selected operator
// will also be saved into the leaf for future access
if ($this->columnTypes[$i] == self::CONTINUOS) {
$matches = [];
preg_match("/^([<>=]{1,2})\s*(.*)/", strval($split->value), $matches);
$split->operator = $matches[1];
$split->numericValue = floatval($matches[2]);
}

$bestSplit = $split;
$bestGiniVal = $gini;
}
Expand Down Expand Up @@ -318,15 +330,21 @@ protected function preprocess(array $samples)
protected function isCategoricalColumn(array $columnValues)
{
$count = count($columnValues);

// There are two main indicators that *may* show whether a
// column is composed of discrete set of values:
// 1- Column may contain string values
// 1- Column may contain string values and not float values
// 2- Number of unique values in the column is only a small fraction of
// all values in that column (Lower than or equal to %20 of all values)
$numericValues = array_filter($columnValues, 'is_numeric');
$floatValues = array_filter($columnValues, 'is_float');
if ($floatValues) {
return false;
}
if (count($numericValues) != $count) {
return true;
}

$distinctValues = array_count_values($columnValues);
if (count($distinctValues) <= $count / 5) {
return true;
Expand Down Expand Up @@ -357,9 +375,9 @@ public function setNumFeatures(int $numFeatures)
}

/**
* Used to set predefined features to consider while deciding which column to use for a split,
* Used to set predefined features to consider while deciding which column to use for a split
*
* @param array $features
* @param array $selectedFeatures
*/
protected function setSelectedFeatures(array $selectedFeatures)
{
Expand Down
18 changes: 15 additions & 3 deletions src/Phpml/Classification/DecisionTree/DecisionTreeLeaf.php
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,16 @@ class DecisionTreeLeaf
*/
public $value;

/**
* @var float
*/
public $numericValue;

/**
* @var string
*/
public $operator;

/**
* @var int
*/
Expand Down Expand Up @@ -66,13 +76,15 @@ class DecisionTreeLeaf
public function evaluate($record)
{
$recordField = $record[$this->columnIndex];
if ($this->isContinuous && preg_match("/^([<>=]{1,2})\s*(.*)/", strval($this->value), $matches)) {
$op = $matches[1];
$value= floatval($matches[2]);

if ($this->isContinuous) {
$op = $this->operator;
$value= $this->numericValue;
$recordField = strval($recordField);
eval("\$result = $recordField $op $value;");
return $result;
}

return $recordField == $this->value;
}

Expand Down
190 changes: 190 additions & 0 deletions src/Phpml/Classification/Ensemble/AdaBoost.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
<?php

declare(strict_types=1);

namespace Phpml\Classification\Ensemble;

use Phpml\Classification\Linear\DecisionStump;
use Phpml\Classification\Classifier;
use Phpml\Helper\Predictable;
use Phpml\Helper\Trainable;

class AdaBoost implements Classifier
{
use Predictable, Trainable;

/**
* Actual labels given in the targets array
* @var array
*/
protected $labels = [];

/**
* @var int
*/
protected $sampleCount;

/**
* @var int
*/
protected $featureCount;

/**
* Number of maximum iterations to be done
*
* @var int
*/
protected $maxIterations;

/**
* Sample weights
*
* @var array
*/
protected $weights = [];

/**
* Base classifiers
*
* @var array
*/
protected $classifiers = [];

/**
* Base classifier weights
*
* @var array
*/
protected $alpha = [];

/**
* ADAptive BOOSTing (AdaBoost) is an ensemble algorithm to
* improve classification performance of 'weak' classifiers such as
* DecisionStump (default base classifier of AdaBoost).
*
*/
public function __construct(int $maxIterations = 30)
{
$this->maxIterations = $maxIterations;
}

/**
* @param array $samples
* @param array $targets
*/
public function train(array $samples, array $targets)
{
// Initialize usual variables
$this->labels = array_keys(array_count_values($targets));
if (count($this->labels) != 2) {
throw new \Exception("AdaBoost is a binary classifier and can only classify between two classes");
}

// Set all target values to either -1 or 1
$this->labels = [1 => $this->labels[0], -1 => $this->labels[1]];
foreach ($targets as $target) {
$this->targets[] = $target == $this->labels[1] ? 1 : -1;
}

$this->samples = array_merge($this->samples, $samples);
$this->featureCount = count($samples[0]);
$this->sampleCount = count($this->samples);

// Initialize AdaBoost parameters
$this->weights = array_fill(0, $this->sampleCount, 1.0 / $this->sampleCount);
$this->classifiers = [];
$this->alpha = [];

// Execute the algorithm for a maximum number of iterations
$currIter = 0;
while ($this->maxIterations > $currIter++) {
// Determine the best 'weak' classifier based on current weights
// and update alpha & weight values at each iteration
list($classifier, $errorRate) = $this->getBestClassifier();
$alpha = $this->calculateAlpha($errorRate);
$this->updateWeights($classifier, $alpha);

$this->classifiers[] = $classifier;
$this->alpha[] = $alpha;
}
}

/**
* Returns the classifier with the lowest error rate with the
* consideration of current sample weights
*
* @return Classifier
*/
protected function getBestClassifier()
{
// This method works only for "DecisionStump" classifier, for now.
// As a future task, it will be generalized enough to work with other
// classifiers as well
$minErrorRate = 1.0;
$bestClassifier = null;
for ($i=0; $i < $this->featureCount; $i++) {
$stump = new DecisionStump($i);
$stump->setSampleWeights($this->weights);
$stump->train($this->samples, $this->targets);

$errorRate = $stump->getTrainingErrorRate();
if ($errorRate < $minErrorRate) {
$bestClassifier = $stump;
$minErrorRate = $errorRate;
}
}

return [$bestClassifier, $minErrorRate];
}

/**
* Calculates alpha of a classifier
*
* @param float $errorRate
* @return float
*/
protected function calculateAlpha(float $errorRate)
{
if ($errorRate == 0) {
$errorRate = 1e-10;
}
return 0.5 * log((1 - $errorRate) / $errorRate);
}

/**
* Updates the sample weights
*
* @param DecisionStump $classifier
* @param float $alpha
*/
protected function updateWeights(DecisionStump $classifier, float $alpha)
{
$sumOfWeights = array_sum($this->weights);
$weightsT1 = [];
foreach ($this->weights as $index => $weight) {
$desired = $this->targets[$index];
$output = $classifier->predict($this->samples[$index]);

$weight *= exp(-$alpha * $desired * $output) / $sumOfWeights;

$weightsT1[] = $weight;
}

$this->weights = $weightsT1;
}

/**
* @param array $sample
* @return mixed
*/
public function predictSample(array $sample)
{
$sum = 0;
foreach ($this->alpha as $index => $alpha) {
$h = $this->classifiers[$index]->predict($sample);
$sum += $h * $alpha;
}

return $this->labels[ $sum > 0 ? 1 : -1];
}
}
Loading

0 comments on commit 4daa0a2

Please sign in to comment.