Skip to content

Commit

Permalink
Implement DecisionTreeRegressor (#375)
Browse files Browse the repository at this point in the history
  • Loading branch information
akondas authored May 12, 2019
1 parent 8544cf7 commit 91812f4
Show file tree
Hide file tree
Showing 11 changed files with 759 additions and 0 deletions.
144 changes: 144 additions & 0 deletions src/Regression/DecisionTreeRegressor.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
<?php

declare(strict_types=1);

namespace Phpml\Regression;

use Phpml\Exception\InvalidOperationException;
use Phpml\Math\Statistic\Mean;
use Phpml\Math\Statistic\Variance;
use Phpml\Tree\CART;
use Phpml\Tree\Node\AverageNode;
use Phpml\Tree\Node\BinaryNode;
use Phpml\Tree\Node\DecisionNode;

final class DecisionTreeRegressor extends CART implements Regression
{
/**
* @var int|null
*/
protected $maxFeatures;

/**
* @var float
*/
protected $tolerance;

/**
* @var array
*/
protected $columns = [];

public function train(array $samples, array $targets): void
{
$features = count($samples[0]);

$this->columns = range(0, $features - 1);
$this->maxFeatures = $this->maxFeatures ?? (int) round(sqrt($features));

$this->grow($samples, $targets);

$this->columns = [];
}

public function predict(array $samples)
{
if ($this->bare()) {
throw new InvalidOperationException('Regressor must be trained first');
}

$predictions = [];

foreach ($samples as $sample) {
$node = $this->search($sample);

$predictions[] = $node instanceof AverageNode
? $node->outcome()
: null;
}

return $predictions;
}

protected function split(array $samples, array $targets): DecisionNode
{
$bestVariance = INF;
$bestColumn = $bestValue = null;
$bestGroups = [];

shuffle($this->columns);

foreach (array_slice($this->columns, 0, $this->maxFeatures) as $column) {
$values = array_unique(array_column($samples, $column));

foreach ($values as $value) {
$groups = $this->partition($column, $value, $samples, $targets);

$variance = $this->splitImpurity($groups);

if ($variance < $bestVariance) {
$bestColumn = $column;
$bestValue = $value;
$bestGroups = $groups;
$bestVariance = $variance;
}

if ($variance <= $this->tolerance) {
break 2;
}
}
}

return new DecisionNode($bestColumn, $bestValue, $bestGroups, $bestVariance);
}

protected function terminate(array $targets): BinaryNode
{
return new AverageNode(Mean::arithmetic($targets), Variance::population($targets), count($targets));
}

protected function splitImpurity(array $groups): float
{
$samplesCount = (int) array_sum(array_map(static function (array $group) {
return count($group[0]);
}, $groups));

$impurity = 0.;

foreach ($groups as $group) {
$k = count($group[1]);

if ($k < 2) {
continue 1;
}

$variance = Variance::population($group[1]);

$impurity += ($k / $samplesCount) * $variance;
}

return $impurity;
}

/**
* @param int|float $value
*/
private function partition(int $column, $value, array $samples, array $targets): array
{
$leftSamples = $leftTargets = $rightSamples = $rightTargets = [];
foreach ($samples as $index => $sample) {
if ($sample[$column] < $value) {
$leftSamples[] = $sample;
$leftTargets[] = $targets[$index];
} else {
$rightSamples[] = $sample;
$rightTargets[] = $targets[$index];
}
}

return [
[$leftSamples, $leftTargets],
[$rightSamples, $rightTargets],
];
}
}
176 changes: 176 additions & 0 deletions src/Tree/CART.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
<?php

declare(strict_types=1);

namespace Phpml\Tree;

use Phpml\Exception\InvalidArgumentException;
use Phpml\Tree\Node\BinaryNode;
use Phpml\Tree\Node\DecisionNode;
use Phpml\Tree\Node\LeafNode;

abstract class CART
{
/**
* @var DecisionNode|null
*/
protected $root;

/**
* @var int
*/
protected $maxDepth;

/**
* @var int
*/
protected $maxLeafSize;

/**
* @var float
*/
protected $minPurityIncrease;

/**
* @var int
*/
protected $featureCount;

public function __construct(int $maxDepth = PHP_INT_MAX, int $maxLeafSize = 3, float $minPurityIncrease = 0.)
{
if ($maxDepth < 1) {
throw new InvalidArgumentException('Max depth must be greater than 0');
}

if ($maxLeafSize < 1) {
throw new InvalidArgumentException('Max leaf size must be greater than 0');
}

if ($minPurityIncrease < 0.) {
throw new InvalidArgumentException('Min purity increase must be equal or greater than 0');
}

$this->maxDepth = $maxDepth;
$this->maxLeafSize = $maxLeafSize;
$this->minPurityIncrease = $minPurityIncrease;
}

public function root(): ?DecisionNode
{
return $this->root;
}

public function height(): int
{
return $this->root !== null ? $this->root->height() : 0;
}

public function balance(): int
{
return $this->root !== null ? $this->root->balance() : 0;
}

public function bare(): bool
{
return $this->root === null;
}

public function grow(array $samples, array $targets): void
{
$this->featureCount = count($samples[0]);
$depth = 1;
$this->root = $this->split($samples, $targets);
$stack = [[$this->root, $depth]];

while ($stack) {
[$current, $depth] = array_pop($stack) ?? [];

[$left, $right] = $current->groups();

$current->cleanup();

$depth++;

if ($left === [] || $right === []) {
$node = $this->terminate(array_merge($left[1], $right[1]));

$current->attachLeft($node);
$current->attachRight($node);

continue 1;
}

if ($depth >= $this->maxDepth) {
$current->attachLeft($this->terminate($left[1]));
$current->attachRight($this->terminate($right[1]));

continue 1;
}

if (count($left[1]) > $this->maxLeafSize) {
$node = $this->split($left[0], $left[1]);

if ($node->purityIncrease() + 1e-8 > $this->minPurityIncrease) {
$current->attachLeft($node);

$stack[] = [$node, $depth];
} else {
$current->attachLeft($this->terminate($left[1]));
}
} else {
$current->attachLeft($this->terminate($left[1]));
}

if (count($right[1]) > $this->maxLeafSize) {
$node = $this->split($right[0], $right[1]);

if ($node->purityIncrease() + 1e-8 > $this->minPurityIncrease) {
$current->attachRight($node);

$stack[] = [$node, $depth];
} else {
$current->attachRight($this->terminate($right[1]));
}
} else {
$current->attachRight($this->terminate($right[1]));
}
}
}

public function search(array $sample): ?BinaryNode
{
$current = $this->root;

while ($current) {
if ($current instanceof DecisionNode) {
$value = $current->value();

if (is_string($value)) {
if ($sample[$current->column()] === $value) {
$current = $current->left();
} else {
$current = $current->right();
}
} else {
if ($sample[$current->column()] < $value) {
$current = $current->left();
} else {
$current = $current->right();
}
}

continue 1;
}

if ($current instanceof LeafNode) {
break 1;
}
}

return $current;
}

abstract protected function split(array $samples, array $targets): DecisionNode;

abstract protected function terminate(array $targets): BinaryNode;
}
9 changes: 9 additions & 0 deletions src/Tree/Node.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
<?php

declare(strict_types=1);

namespace Phpml\Tree;

interface Node
{
}
45 changes: 45 additions & 0 deletions src/Tree/Node/AverageNode.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
<?php

declare(strict_types=1);

namespace Phpml\Tree\Node;

class AverageNode extends BinaryNode implements PurityNode, LeafNode
{
/**
* @var float
*/
private $outcome;

/**
* @var float
*/
private $impurity;

/**
* @var int
*/
private $samplesCount;

public function __construct(float $outcome, float $impurity, int $samplesCount)
{
$this->outcome = $outcome;
$this->impurity = $impurity;
$this->samplesCount = $samplesCount;
}

public function outcome(): float
{
return $this->outcome;
}

public function impurity(): float
{
return $this->impurity;
}

public function samplesCount(): int
{
return $this->samplesCount;
}
}
Loading

0 comments on commit 91812f4

Please sign in to comment.