Skip to content

Commit

Permalink
RandomForest::getFeatureImportances() method (#47)
Browse files Browse the repository at this point in the history
* RandomForest::getFeatureImportances() method

* CsvDataset update for column names
  • Loading branch information
MustafaKarabulut authored and akondas committed Feb 13, 2017
1 parent 240a227 commit a33d5fe
Show file tree
Hide file tree
Showing 5 changed files with 273 additions and 22 deletions.
121 changes: 120 additions & 1 deletion src/Phpml/Classification/DecisionTree.php
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,17 @@ class DecisionTree implements Classifier
*/
private $numUsableFeatures = 0;

/**
* @var array
*/
private $featureImportances = null;

/**
*
* @var array
*/
private $columnNames = null;

/**
* @param int $maxDepth
*/
Expand All @@ -76,6 +87,21 @@ public function train(array $samples, array $targets)
$this->columnTypes = $this->getColumnTypes($this->samples);
$this->labels = array_keys(array_count_values($this->targets));
$this->tree = $this->getSplitLeaf(range(0, count($this->samples) - 1));

// Each time the tree is trained, feature importances are reset so that
// we will have to compute it again depending on the new data
$this->featureImportances = null;

// If column names are given or computed before, then there is no
// need to init it and accidentally remove the previous given names
if ($this->columnNames === null) {
$this->columnNames = range(0, $this->featureCount - 1);
} elseif (count($this->columnNames) > $this->featureCount) {
$this->columnNames = array_slice($this->columnNames, 0, $this->featureCount);
} elseif (count($this->columnNames) < $this->featureCount) {
$this->columnNames = array_merge($this->columnNames,
range(count($this->columnNames), $this->featureCount - 1));
}
}

protected function getColumnTypes(array $samples)
Expand Down Expand Up @@ -164,6 +190,7 @@ protected function getBestSplit($records)
$split->value = $baseValue;
$split->giniIndex = $gini;
$split->columnIndex = $i;
$split->isContinuous = $this->columnTypes[$i] == self::CONTINUOS;
$split->records = $records;
$bestSplit = $split;
$bestGiniVal = $gini;
Expand Down Expand Up @@ -292,6 +319,25 @@ public function setNumFeatures(int $numFeatures)
}

$this->numUsableFeatures = $numFeatures;

return $this;
}

/**
* A string array to represent columns. Useful when HTML output or
* column importances are desired to be inspected.
*
* @param array $names
* @return $this
*/
public function setColumnNames(array $names)
{
if ($this->featureCount != 0 && count($names) != $this->featureCount) {
throw new \Exception("Length of the given array should be equal to feature count ($this->featureCount)");
}

$this->columnNames = $names;

return $this;
}

Expand All @@ -300,7 +346,80 @@ public function setNumFeatures(int $numFeatures)
*/
public function getHtml()
{
return $this->tree->__toString();
return $this->tree->getHTML($this->columnNames);
}

/**
* This will return an array including an importance value for
* each column in the given dataset. The importance values are
* normalized and their total makes 1.<br/>
*
* @param array $labels
* @return array
*/
public function getFeatureImportances()
{
if ($this->featureImportances !== null) {
return $this->featureImportances;
}

$sampleCount = count($this->samples);
$this->featureImportances = [];
foreach ($this->columnNames as $column => $columnName) {
$nodes = $this->getSplitNodesByColumn($column, $this->tree);

$importance = 0;
foreach ($nodes as $node) {
$importance += $node->getNodeImpurityDecrease($sampleCount);
}

$this->featureImportances[$columnName] = $importance;
}

// Normalize & sort the importances
$total = array_sum($this->featureImportances);
if ($total > 0) {
foreach ($this->featureImportances as &$importance) {
$importance /= $total;
}
arsort($this->featureImportances);
}

return $this->featureImportances;
}

/**
* Collects and returns an array of internal nodes that use the given
* column as a split criteron
*
* @param int $column
* @param DecisionTreeLeaf
* @param array $collected
*
* @return array
*/
protected function getSplitNodesByColumn($column, DecisionTreeLeaf $node)
{
if (!$node || $node->isTerminal) {
return [];
}

$nodes = [];
if ($node->columnIndex == $column) {
$nodes[] = $node;
}

$lNodes = [];
$rNodes = [];
if ($node->leftLeaf) {
$lNodes = $this->getSplitNodesByColumn($column, $node->leftLeaf);
}
if ($node->rightLeaf) {
$rNodes = $this->getSplitNodesByColumn($column, $node->rightLeaf);
}
$nodes = array_merge($nodes, $lNodes, $rNodes);

return $nodes;
}

/**
Expand Down
64 changes: 58 additions & 6 deletions src/Phpml/Classification/DecisionTree/DecisionTreeLeaf.php
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

class DecisionTreeLeaf
{
const OPERATOR_EQ = '=';
/**
* @var string
*/
Expand Down Expand Up @@ -45,6 +44,11 @@ class DecisionTreeLeaf
*/
public $isTerminal = false;

/**
* @var bool
*/
public $isContinuous = false;

/**
* @var float
*/
Expand All @@ -62,7 +66,7 @@ class DecisionTreeLeaf
public function evaluate($record)
{
$recordField = $record[$this->columnIndex];
if (is_string($this->value) && preg_match("/^([<>=]{1,2})\s*(.*)/", $this->value, $matches)) {
if ($this->isContinuous && preg_match("/^([<>=]{1,2})\s*(.*)/", strval($this->value), $matches)) {
$op = $matches[1];
$value= floatval($matches[2]);
$recordField = strval($recordField);
Expand All @@ -72,13 +76,51 @@ public function evaluate($record)
return $recordField == $this->value;
}

public function __toString()
/**
* Returns Mean Decrease Impurity (MDI) in the node.
* For terminal nodes, this value is equal to 0
*
* @return float
*/
public function getNodeImpurityDecrease(int $parentRecordCount)
{
if ($this->isTerminal) {
return 0.0;
}

$nodeSampleCount = (float)count($this->records);
$iT = $this->giniIndex;

if ($this->leftLeaf) {
$pL = count($this->leftLeaf->records)/$nodeSampleCount;
$iT -= $pL * $this->leftLeaf->giniIndex;
}

if ($this->rightLeaf) {
$pR = count($this->rightLeaf->records)/$nodeSampleCount;
$iT -= $pR * $this->rightLeaf->giniIndex;
}

return $iT * $nodeSampleCount / $parentRecordCount;
}

/**
* Returns HTML representation of the node including children nodes
*
* @param $columnNames
* @return string
*/
public function getHTML($columnNames = null)
{
if ($this->isTerminal) {
$value = "<b>$this->classValue</b>";
} else {
$value = $this->value;
$col = "col_$this->columnIndex";
if ($columnNames !== null) {
$col = $columnNames[$this->columnIndex];
} else {
$col = "col_$this->columnIndex";
}
if (! preg_match("/^[<>=]{1,2}/", $value)) {
$value = "=$value";
}
Expand All @@ -89,13 +131,13 @@ public function __toString()
if ($this->leftLeaf || $this->rightLeaf) {
$str .='<tr>';
if ($this->leftLeaf) {
$str .="<td valign=top><b>| Yes</b><br>$this->leftLeaf</td>";
$str .="<td valign=top><b>| Yes</b><br>" . $this->leftLeaf->getHTML($columnNames) . "</td>";
} else {
$str .='<td></td>';
}
$str .='<td>&nbsp;</td>';
if ($this->rightLeaf) {
$str .="<td valign=top align=right><b>No |</b><br>$this->rightLeaf</td>";
$str .="<td valign=top align=right><b>No |</b><br>" . $this->rightLeaf->getHTML($columnNames) . "</td>";
} else {
$str .='<td></td>';
}
Expand All @@ -104,4 +146,14 @@ public function __toString()
$str .= '</table>';
return $str;
}

/**
* HTML representation of the tree without column names
*
* @return string
*/
public function __toString()
{
return $this->getHTML();
}
}
16 changes: 7 additions & 9 deletions src/Phpml/Classification/Ensemble/Bagging.php
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class Bagging implements Classifier
/**
* @var float
*/
protected $subsetRatio = 0.5;
protected $subsetRatio = 0.7;

/**
* @var array
Expand Down Expand Up @@ -120,7 +120,7 @@ public function train(array $samples, array $targets)
$this->featureCount = count($samples[0]);
$this->numSamples = count($this->samples);

// Init classifiers and train them with random sub-samples
// Init classifiers and train them with bootstrap samples
$this->classifiers = $this->initClassifiers();
$index = 0;
foreach ($this->classifiers as $classifier) {
Expand All @@ -134,16 +134,14 @@ public function train(array $samples, array $targets)
* @param int $index
* @return array
*/
protected function getRandomSubset($index)
protected function getRandomSubset(int $index)
{
$subsetLength = (int)ceil(sqrt($this->numSamples));
$denom = $this->subsetRatio / 2;
$subsetLength = $this->numSamples / (1 / $denom);
$index = $index * $subsetLength % $this->numSamples;
$samples = [];
$targets = [];
for ($i=0; $i<$subsetLength * 2; $i++) {
$rand = rand($index, $this->numSamples - 1);
srand($index);
$bootstrapSize = $this->subsetRatio * $this->numSamples;
for ($i=0; $i < $bootstrapSize; $i++) {
$rand = rand(0, $this->numSamples - 1);
$samples[] = $this->samples[$rand];
$targets[] = $this->targets[$rand];
}
Expand Down
Loading

0 comments on commit a33d5fe

Please sign in to comment.