forked from jorgecasas/php-ml
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
DecisionTree and Fuzzy C Means classifiers (#35)
* Fuzzy C-Means implementation * Update FuzzyCMeans * Rename FuzzyCMeans to FuzzyCMeans.php * Update NaiveBayes.php * Small fix applied to improve training performance array_unique is replaced with array_count_values+array_keys which is way faster * Revert "Small fix applied to improve training performance" This reverts commit c20253f. * Revert "Revert "Small fix applied to improve training performance"" This reverts commit ea10e13. * Revert "Small fix applied to improve training performance" This reverts commit c20253f. * DecisionTree * FCM Test * FCM Test * DecisionTree Test
- Loading branch information
1 parent
95fc139
commit 87396eb
Showing
6 changed files
with
740 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,274 @@ | ||
<?php | ||
|
||
declare(strict_types=1); | ||
|
||
namespace Phpml\Classification; | ||
|
||
use Phpml\Helper\Predictable; | ||
use Phpml\Helper\Trainable; | ||
use Phpml\Math\Statistic\Mean; | ||
use Phpml\Classification\DecisionTree\DecisionTreeLeaf; | ||
|
||
class DecisionTree implements Classifier | ||
{ | ||
use Trainable, Predictable; | ||
|
||
const CONTINUOS = 1; | ||
const NOMINAL = 2; | ||
|
||
/** | ||
* @var array | ||
*/ | ||
private $samples = array(); | ||
|
||
/** | ||
* @var array | ||
*/ | ||
private $columnTypes; | ||
/** | ||
* @var array | ||
*/ | ||
private $labels = array(); | ||
/** | ||
* @var int | ||
*/ | ||
private $featureCount = 0; | ||
/** | ||
* @var DecisionTreeLeaf | ||
*/ | ||
private $tree = null; | ||
|
||
/** | ||
* @var int | ||
*/ | ||
private $maxDepth; | ||
|
||
/** | ||
* @var int | ||
*/ | ||
public $actualDepth = 0; | ||
|
||
/** | ||
* @param int $maxDepth | ||
*/ | ||
public function __construct($maxDepth = 10) | ||
{ | ||
$this->maxDepth = $maxDepth; | ||
} | ||
/** | ||
* @param array $samples | ||
* @param array $targets | ||
*/ | ||
public function train(array $samples, array $targets) | ||
{ | ||
$this->featureCount = count($samples[0]); | ||
$this->columnTypes = $this->getColumnTypes($samples); | ||
$this->samples = $samples; | ||
$this->targets = $targets; | ||
$this->labels = array_keys(array_count_values($targets)); | ||
$this->tree = $this->getSplitLeaf(range(0, count($samples) - 1)); | ||
} | ||
|
||
protected function getColumnTypes(array $samples) | ||
{ | ||
$types = []; | ||
for ($i=0; $i<$this->featureCount; $i++) { | ||
$values = array_column($samples, $i); | ||
$isCategorical = $this->isCategoricalColumn($values); | ||
$types[] = $isCategorical ? self::NOMINAL : self::CONTINUOS; | ||
} | ||
return $types; | ||
} | ||
|
||
/** | ||
* @param null|array $records | ||
* @return DecisionTreeLeaf | ||
*/ | ||
protected function getSplitLeaf($records, $depth = 0) | ||
{ | ||
$split = $this->getBestSplit($records); | ||
$split->level = $depth; | ||
if ($this->actualDepth < $depth) { | ||
$this->actualDepth = $depth; | ||
} | ||
$leftRecords = []; | ||
$rightRecords= []; | ||
$remainingTargets = []; | ||
$prevRecord = null; | ||
$allSame = true; | ||
foreach ($records as $recordNo) { | ||
$record = $this->samples[$recordNo]; | ||
if ($prevRecord && $prevRecord != $record) { | ||
$allSame = false; | ||
} | ||
$prevRecord = $record; | ||
if ($split->evaluate($record)) { | ||
$leftRecords[] = $recordNo; | ||
} else { | ||
$rightRecords[]= $recordNo; | ||
} | ||
$target = $this->targets[$recordNo]; | ||
if (! in_array($target, $remainingTargets)) { | ||
$remainingTargets[] = $target; | ||
} | ||
} | ||
|
||
if (count($remainingTargets) == 1 || $allSame || $depth >= $this->maxDepth) { | ||
$split->isTerminal = 1; | ||
$classes = array_count_values($remainingTargets); | ||
arsort($classes); | ||
$split->classValue = key($classes); | ||
} else { | ||
if ($leftRecords) { | ||
$split->leftLeaf = $this->getSplitLeaf($leftRecords, $depth + 1); | ||
} | ||
if ($rightRecords) { | ||
$split->rightLeaf= $this->getSplitLeaf($rightRecords, $depth + 1); | ||
} | ||
} | ||
return $split; | ||
} | ||
|
||
/** | ||
* @param array $records | ||
* @return DecisionTreeLeaf[] | ||
*/ | ||
protected function getBestSplit($records) | ||
{ | ||
$targets = array_intersect_key($this->targets, array_flip($records)); | ||
$samples = array_intersect_key($this->samples, array_flip($records)); | ||
$samples = array_combine($records, $this->preprocess($samples)); | ||
$bestGiniVal = 1; | ||
$bestSplit = null; | ||
for ($i=0; $i<$this->featureCount; $i++) { | ||
$colValues = []; | ||
$baseValue = null; | ||
foreach ($samples as $index => $row) { | ||
$colValues[$index] = $row[$i]; | ||
if ($baseValue === null) { | ||
$baseValue = $row[$i]; | ||
} | ||
} | ||
$gini = $this->getGiniIndex($baseValue, $colValues, $targets); | ||
if ($bestSplit == null || $bestGiniVal > $gini) { | ||
$split = new DecisionTreeLeaf(); | ||
$split->value = $baseValue; | ||
$split->giniIndex = $gini; | ||
$split->columnIndex = $i; | ||
$split->records = $records; | ||
$bestSplit = $split; | ||
$bestGiniVal = $gini; | ||
} | ||
} | ||
return $bestSplit; | ||
} | ||
|
||
/** | ||
* @param string $baseValue | ||
* @param array $colValues | ||
* @param array $targets | ||
*/ | ||
public function getGiniIndex($baseValue, $colValues, $targets) | ||
{ | ||
$countMatrix = []; | ||
foreach ($this->labels as $label) { | ||
$countMatrix[$label] = [0, 0]; | ||
} | ||
foreach ($colValues as $index => $value) { | ||
$label = $targets[$index]; | ||
$rowIndex = $value == $baseValue ? 0 : 1; | ||
$countMatrix[$label][$rowIndex]++; | ||
} | ||
$giniParts = [0, 0]; | ||
for ($i=0; $i<=1; $i++) { | ||
$part = 0; | ||
$sum = array_sum(array_column($countMatrix, $i)); | ||
if ($sum > 0) { | ||
foreach ($this->labels as $label) { | ||
$part += pow($countMatrix[$label][$i] / floatval($sum), 2); | ||
} | ||
} | ||
$giniParts[$i] = (1 - $part) * $sum; | ||
} | ||
return array_sum($giniParts) / count($colValues); | ||
} | ||
|
||
/** | ||
* @param array $samples | ||
* @return array | ||
*/ | ||
protected function preprocess(array $samples) | ||
{ | ||
// Detect and convert continuous data column values into | ||
// discrete values by using the median as a threshold value | ||
$columns = array(); | ||
for ($i=0; $i<$this->featureCount; $i++) { | ||
$values = array_column($samples, $i); | ||
if ($this->columnTypes[$i] == self::CONTINUOS) { | ||
$median = Mean::median($values); | ||
foreach ($values as &$value) { | ||
if ($value <= $median) { | ||
$value = "<= $median"; | ||
} else { | ||
$value = "> $median"; | ||
} | ||
} | ||
} | ||
$columns[] = $values; | ||
} | ||
// Below method is a strange yet very simple & efficient method | ||
// to get the transpose of a 2D array | ||
return array_map(null, ...$columns); | ||
} | ||
|
||
/** | ||
* @param array $columnValues | ||
* @return bool | ||
*/ | ||
protected function isCategoricalColumn(array $columnValues) | ||
{ | ||
$count = count($columnValues); | ||
// There are two main indicators that *may* show whether a | ||
// column is composed of discrete set of values: | ||
// 1- Column may contain string values | ||
// 2- Number of unique values in the column is only a small fraction of | ||
// all values in that column (Lower than or equal to %20 of all values) | ||
$numericValues = array_filter($columnValues, 'is_numeric'); | ||
if (count($numericValues) != $count) { | ||
return true; | ||
} | ||
$distinctValues = array_count_values($columnValues); | ||
if (count($distinctValues) <= $count / 5) { | ||
return true; | ||
} | ||
return false; | ||
} | ||
|
||
/** | ||
* @return string | ||
*/ | ||
public function getHtml() | ||
{ | ||
return $this->tree->__toString(); | ||
} | ||
|
||
/** | ||
* @param array $sample | ||
* @return mixed | ||
*/ | ||
protected function predictSample(array $sample) | ||
{ | ||
$node = $this->tree; | ||
do { | ||
if ($node->isTerminal) { | ||
break; | ||
} | ||
if ($node->evaluate($sample)) { | ||
$node = $node->leftLeaf; | ||
} else { | ||
$node = $node->rightLeaf; | ||
} | ||
} while ($node); | ||
return $node->classValue; | ||
} | ||
} |
106 changes: 106 additions & 0 deletions
106
src/Phpml/Classification/DecisionTree/DecisionTreeLeaf.php
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
<?php | ||
declare(strict_types=1); | ||
|
||
namespace Phpml\Classification\DecisionTree; | ||
|
||
class DecisionTreeLeaf | ||
{ | ||
const OPERATOR_EQ = '='; | ||
/** | ||
* @var string | ||
*/ | ||
public $value; | ||
|
||
/** | ||
* @var int | ||
*/ | ||
public $columnIndex; | ||
|
||
/** | ||
* @var DecisionTreeLeaf | ||
*/ | ||
public $leftLeaf = null; | ||
|
||
/** | ||
* @var DecisionTreeLeaf | ||
*/ | ||
public $rightLeaf= null; | ||
|
||
/** | ||
* @var array | ||
*/ | ||
public $records = []; | ||
|
||
/** | ||
* Class value represented by the leaf, this value is non-empty | ||
* only for terminal leaves | ||
* | ||
* @var string | ||
*/ | ||
public $classValue = ''; | ||
|
||
/** | ||
* @var bool | ||
*/ | ||
public $isTerminal = false; | ||
|
||
/** | ||
* @var float | ||
*/ | ||
public $giniIndex = 0; | ||
|
||
/** | ||
* @var int | ||
*/ | ||
public $level = 0; | ||
|
||
/** | ||
* @param array $record | ||
* @return bool | ||
*/ | ||
public function evaluate($record) | ||
{ | ||
$recordField = $record[$this->columnIndex]; | ||
if (preg_match("/^([<>=]{1,2})\s*(.*)/", $this->value, $matches)) { | ||
$op = $matches[1]; | ||
$value= floatval($matches[2]); | ||
$recordField = strval($recordField); | ||
eval("\$result = $recordField $op $value;"); | ||
return $result; | ||
} | ||
return $recordField == $this->value; | ||
} | ||
|
||
public function __toString() | ||
{ | ||
if ($this->isTerminal) { | ||
$value = "<b>$this->classValue</b>"; | ||
} else { | ||
$value = $this->value; | ||
$col = "col_$this->columnIndex"; | ||
if (! preg_match("/^[<>=]{1,2}/", $value)) { | ||
$value = "=$value"; | ||
} | ||
$value = "<b>$col $value</b><br>Gini: ". number_format($this->giniIndex, 2); | ||
} | ||
$str = "<table ><tr><td colspan=3 align=center style='border:1px solid;'> | ||
$value</td></tr>"; | ||
if ($this->leftLeaf || $this->rightLeaf) { | ||
$str .='<tr>'; | ||
if ($this->leftLeaf) { | ||
$str .="<td valign=top><b>| Yes</b><br>$this->leftLeaf</td>"; | ||
} else { | ||
$str .='<td></td>'; | ||
} | ||
$str .='<td> </td>'; | ||
if ($this->rightLeaf) { | ||
$str .="<td valign=top align=right><b>No |</b><br>$this->rightLeaf</td>"; | ||
} else { | ||
$str .='<td></td>'; | ||
} | ||
$str .= '</tr>'; | ||
} | ||
$str .= '</table>'; | ||
return $str; | ||
} | ||
} |
Oops, something went wrong.