Skip to content

Commit

Permalink
DecisionTree and Fuzzy C Means classifiers (#35)
Browse files Browse the repository at this point in the history
* Fuzzy C-Means implementation

* Update FuzzyCMeans

* Rename FuzzyCMeans to FuzzyCMeans.php

* Update NaiveBayes.php

* Small fix applied to improve training performance

array_unique is replaced with array_count_values+array_keys which is way
faster

* Revert "Small fix applied to improve training performance"

This reverts commit c20253f.

* Revert "Revert "Small fix applied to improve training performance""

This reverts commit ea10e13.

* Revert "Small fix applied to improve training performance"

This reverts commit c20253f.

* DecisionTree

* FCM Test

* FCM Test

* DecisionTree Test
  • Loading branch information
MustafaKarabulut authored and akondas committed Jan 31, 2017
1 parent 95fc139 commit 87396eb
Show file tree
Hide file tree
Showing 6 changed files with 740 additions and 27 deletions.
274 changes: 274 additions & 0 deletions src/Phpml/Classification/DecisionTree.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,274 @@
<?php

declare(strict_types=1);

namespace Phpml\Classification;

use Phpml\Helper\Predictable;
use Phpml\Helper\Trainable;
use Phpml\Math\Statistic\Mean;
use Phpml\Classification\DecisionTree\DecisionTreeLeaf;

class DecisionTree implements Classifier
{
use Trainable, Predictable;

const CONTINUOS = 1;
const NOMINAL = 2;

/**
* @var array
*/
private $samples = array();

/**
* @var array
*/
private $columnTypes;
/**
* @var array
*/
private $labels = array();
/**
* @var int
*/
private $featureCount = 0;
/**
* @var DecisionTreeLeaf
*/
private $tree = null;

/**
* @var int
*/
private $maxDepth;

/**
* @var int
*/
public $actualDepth = 0;

/**
* @param int $maxDepth
*/
public function __construct($maxDepth = 10)
{
$this->maxDepth = $maxDepth;
}
/**
* @param array $samples
* @param array $targets
*/
public function train(array $samples, array $targets)
{
$this->featureCount = count($samples[0]);
$this->columnTypes = $this->getColumnTypes($samples);
$this->samples = $samples;
$this->targets = $targets;
$this->labels = array_keys(array_count_values($targets));
$this->tree = $this->getSplitLeaf(range(0, count($samples) - 1));
}

protected function getColumnTypes(array $samples)
{
$types = [];
for ($i=0; $i<$this->featureCount; $i++) {
$values = array_column($samples, $i);
$isCategorical = $this->isCategoricalColumn($values);
$types[] = $isCategorical ? self::NOMINAL : self::CONTINUOS;
}
return $types;
}

/**
* @param null|array $records
* @return DecisionTreeLeaf
*/
protected function getSplitLeaf($records, $depth = 0)
{
$split = $this->getBestSplit($records);
$split->level = $depth;
if ($this->actualDepth < $depth) {
$this->actualDepth = $depth;
}
$leftRecords = [];
$rightRecords= [];
$remainingTargets = [];
$prevRecord = null;
$allSame = true;
foreach ($records as $recordNo) {
$record = $this->samples[$recordNo];
if ($prevRecord && $prevRecord != $record) {
$allSame = false;
}
$prevRecord = $record;
if ($split->evaluate($record)) {
$leftRecords[] = $recordNo;
} else {
$rightRecords[]= $recordNo;
}
$target = $this->targets[$recordNo];
if (! in_array($target, $remainingTargets)) {
$remainingTargets[] = $target;
}
}

if (count($remainingTargets) == 1 || $allSame || $depth >= $this->maxDepth) {
$split->isTerminal = 1;
$classes = array_count_values($remainingTargets);
arsort($classes);
$split->classValue = key($classes);
} else {
if ($leftRecords) {
$split->leftLeaf = $this->getSplitLeaf($leftRecords, $depth + 1);
}
if ($rightRecords) {
$split->rightLeaf= $this->getSplitLeaf($rightRecords, $depth + 1);
}
}
return $split;
}

/**
* @param array $records
* @return DecisionTreeLeaf[]
*/
protected function getBestSplit($records)
{
$targets = array_intersect_key($this->targets, array_flip($records));
$samples = array_intersect_key($this->samples, array_flip($records));
$samples = array_combine($records, $this->preprocess($samples));
$bestGiniVal = 1;
$bestSplit = null;
for ($i=0; $i<$this->featureCount; $i++) {
$colValues = [];
$baseValue = null;
foreach ($samples as $index => $row) {
$colValues[$index] = $row[$i];
if ($baseValue === null) {
$baseValue = $row[$i];
}
}
$gini = $this->getGiniIndex($baseValue, $colValues, $targets);
if ($bestSplit == null || $bestGiniVal > $gini) {
$split = new DecisionTreeLeaf();
$split->value = $baseValue;
$split->giniIndex = $gini;
$split->columnIndex = $i;
$split->records = $records;
$bestSplit = $split;
$bestGiniVal = $gini;
}
}
return $bestSplit;
}

/**
* @param string $baseValue
* @param array $colValues
* @param array $targets
*/
public function getGiniIndex($baseValue, $colValues, $targets)
{
$countMatrix = [];
foreach ($this->labels as $label) {
$countMatrix[$label] = [0, 0];
}
foreach ($colValues as $index => $value) {
$label = $targets[$index];
$rowIndex = $value == $baseValue ? 0 : 1;
$countMatrix[$label][$rowIndex]++;
}
$giniParts = [0, 0];
for ($i=0; $i<=1; $i++) {
$part = 0;
$sum = array_sum(array_column($countMatrix, $i));
if ($sum > 0) {
foreach ($this->labels as $label) {
$part += pow($countMatrix[$label][$i] / floatval($sum), 2);
}
}
$giniParts[$i] = (1 - $part) * $sum;
}
return array_sum($giniParts) / count($colValues);
}

/**
* @param array $samples
* @return array
*/
protected function preprocess(array $samples)
{
// Detect and convert continuous data column values into
// discrete values by using the median as a threshold value
$columns = array();
for ($i=0; $i<$this->featureCount; $i++) {
$values = array_column($samples, $i);
if ($this->columnTypes[$i] == self::CONTINUOS) {
$median = Mean::median($values);
foreach ($values as &$value) {
if ($value <= $median) {
$value = "<= $median";
} else {
$value = "> $median";
}
}
}
$columns[] = $values;
}
// Below method is a strange yet very simple & efficient method
// to get the transpose of a 2D array
return array_map(null, ...$columns);
}

/**
* @param array $columnValues
* @return bool
*/
protected function isCategoricalColumn(array $columnValues)
{
$count = count($columnValues);
// There are two main indicators that *may* show whether a
// column is composed of discrete set of values:
// 1- Column may contain string values
// 2- Number of unique values in the column is only a small fraction of
// all values in that column (Lower than or equal to %20 of all values)
$numericValues = array_filter($columnValues, 'is_numeric');
if (count($numericValues) != $count) {
return true;
}
$distinctValues = array_count_values($columnValues);
if (count($distinctValues) <= $count / 5) {
return true;
}
return false;
}

/**
* @return string
*/
public function getHtml()
{
return $this->tree->__toString();
}

/**
* @param array $sample
* @return mixed
*/
protected function predictSample(array $sample)
{
$node = $this->tree;
do {
if ($node->isTerminal) {
break;
}
if ($node->evaluate($sample)) {
$node = $node->leftLeaf;
} else {
$node = $node->rightLeaf;
}
} while ($node);
return $node->classValue;
}
}
106 changes: 106 additions & 0 deletions src/Phpml/Classification/DecisionTree/DecisionTreeLeaf.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
<?php
declare(strict_types=1);

namespace Phpml\Classification\DecisionTree;

class DecisionTreeLeaf
{
const OPERATOR_EQ = '=';
/**
* @var string
*/
public $value;

/**
* @var int
*/
public $columnIndex;

/**
* @var DecisionTreeLeaf
*/
public $leftLeaf = null;

/**
* @var DecisionTreeLeaf
*/
public $rightLeaf= null;

/**
* @var array
*/
public $records = [];

/**
* Class value represented by the leaf, this value is non-empty
* only for terminal leaves
*
* @var string
*/
public $classValue = '';

/**
* @var bool
*/
public $isTerminal = false;

/**
* @var float
*/
public $giniIndex = 0;

/**
* @var int
*/
public $level = 0;

/**
* @param array $record
* @return bool
*/
public function evaluate($record)
{
$recordField = $record[$this->columnIndex];
if (preg_match("/^([<>=]{1,2})\s*(.*)/", $this->value, $matches)) {
$op = $matches[1];
$value= floatval($matches[2]);
$recordField = strval($recordField);
eval("\$result = $recordField $op $value;");
return $result;
}
return $recordField == $this->value;
}

public function __toString()
{
if ($this->isTerminal) {
$value = "<b>$this->classValue</b>";
} else {
$value = $this->value;
$col = "col_$this->columnIndex";
if (! preg_match("/^[<>=]{1,2}/", $value)) {
$value = "=$value";
}
$value = "<b>$col $value</b><br>Gini: ". number_format($this->giniIndex, 2);
}
$str = "<table ><tr><td colspan=3 align=center style='border:1px solid;'>
$value</td></tr>";
if ($this->leftLeaf || $this->rightLeaf) {
$str .='<tr>';
if ($this->leftLeaf) {
$str .="<td valign=top><b>| Yes</b><br>$this->leftLeaf</td>";
} else {
$str .='<td></td>';
}
$str .='<td>&nbsp;</td>';
if ($this->rightLeaf) {
$str .="<td valign=top align=right><b>No |</b><br>$this->rightLeaf</td>";
} else {
$str .='<td></td>';
}
$str .= '</tr>';
}
$str .= '</table>';
return $str;
}
}
Loading

0 comments on commit 87396eb

Please sign in to comment.