Skip to content

Commit

Permalink
Code Style (#86)
Browse files Browse the repository at this point in the history
* Code Style

* Code Review fixes
  • Loading branch information
marmichalski authored and akondas committed May 17, 2017
1 parent 43f15d2 commit 7ab80b6
Show file tree
Hide file tree
Showing 40 changed files with 535 additions and 388 deletions.
42 changes: 31 additions & 11 deletions src/Phpml/Classification/DecisionTree.php
Original file line number Diff line number Diff line change
Expand Up @@ -102,19 +102,21 @@ public function train(array $samples, array $targets)
$this->columnNames = array_slice($this->columnNames, 0, $this->featureCount);
} elseif (count($this->columnNames) < $this->featureCount) {
$this->columnNames = array_merge($this->columnNames,
range(count($this->columnNames), $this->featureCount - 1));
range(count($this->columnNames), $this->featureCount - 1)
);
}
}

/**
* @param array $samples
*
* @return array
*/
public static function getColumnTypes(array $samples) : array
{
$types = [];
$featureCount = count($samples[0]);
for ($i=0; $i < $featureCount; $i++) {
for ($i = 0; $i < $featureCount; ++$i) {
$values = array_column($samples, $i);
$isCategorical = self::isCategoricalColumn($values);
$types[] = $isCategorical ? self::NOMINAL : self::CONTINUOUS;
Expand All @@ -125,7 +127,8 @@ public static function getColumnTypes(array $samples) : array

/**
* @param array $records
* @param int $depth
* @param int $depth
*
* @return DecisionTreeLeaf
*/
protected function getSplitLeaf(array $records, int $depth = 0) : DecisionTreeLeaf
Expand Down Expand Up @@ -163,10 +166,10 @@ protected function getSplitLeaf(array $records, int $depth = 0) : DecisionTreeLe

// Group remaining targets
$target = $this->targets[$recordNo];
if (! array_key_exists($target, $remainingTargets)) {
if (!array_key_exists($target, $remainingTargets)) {
$remainingTargets[$target] = 1;
} else {
$remainingTargets[$target]++;
++$remainingTargets[$target];
}
}

Expand All @@ -188,6 +191,7 @@ protected function getSplitLeaf(array $records, int $depth = 0) : DecisionTreeLe

/**
* @param array $records
*
* @return DecisionTreeLeaf
*/
protected function getBestSplit(array $records) : DecisionTreeLeaf
Expand Down Expand Up @@ -251,7 +255,7 @@ protected function getBestSplit(array $records) : DecisionTreeLeaf
protected function getSelectedFeatures() : array
{
$allFeatures = range(0, $this->featureCount - 1);
if ($this->numUsableFeatures === 0 && ! $this->selectedFeatures) {
if ($this->numUsableFeatures === 0 && !$this->selectedFeatures) {
return $allFeatures;
}

Expand All @@ -271,9 +275,10 @@ protected function getSelectedFeatures() : array
}

/**
* @param $baseValue
* @param mixed $baseValue
* @param array $colValues
* @param array $targets
*
* @return float
*/
public function getGiniIndex($baseValue, array $colValues, array $targets) : float
Expand All @@ -282,20 +287,23 @@ public function getGiniIndex($baseValue, array $colValues, array $targets) : flo
foreach ($this->labels as $label) {
$countMatrix[$label] = [0, 0];
}

foreach ($colValues as $index => $value) {
$label = $targets[$index];
$rowIndex = $value === $baseValue ? 0 : 1;
$countMatrix[$label][$rowIndex]++;
++$countMatrix[$label][$rowIndex];
}

$giniParts = [0, 0];
for ($i=0; $i<=1; $i++) {
for ($i = 0; $i <= 1; ++$i) {
$part = 0;
$sum = array_sum(array_column($countMatrix, $i));
if ($sum > 0) {
foreach ($this->labels as $label) {
$part += pow($countMatrix[$label][$i] / floatval($sum), 2);
}
}

$giniParts[$i] = (1 - $part) * $sum;
}

Expand All @@ -304,14 +312,15 @@ public function getGiniIndex($baseValue, array $colValues, array $targets) : flo

/**
* @param array $samples
*
* @return array
*/
protected function preprocess(array $samples) : array
{
// Detect and convert continuous data column values into
// discrete values by using the median as a threshold value
$columns = [];
for ($i=0; $i<$this->featureCount; $i++) {
for ($i = 0; $i < $this->featureCount; ++$i) {
$values = array_column($samples, $i);
if ($this->columnTypes[$i] == self::CONTINUOUS) {
$median = Mean::median($values);
Expand All @@ -332,6 +341,7 @@ protected function preprocess(array $samples) : array

/**
* @param array $columnValues
*
* @return bool
*/
protected static function isCategoricalColumn(array $columnValues) : bool
Expand All @@ -348,6 +358,7 @@ protected static function isCategoricalColumn(array $columnValues) : bool
if ($floatValues) {
return false;
}

if (count($numericValues) !== $count) {
return true;
}
Expand All @@ -365,7 +376,9 @@ protected static function isCategoricalColumn(array $columnValues) : bool
* randomly selected for each split operation.
*
* @param int $numFeatures
*
* @return $this
*
* @throws InvalidArgumentException
*/
public function setNumFeatures(int $numFeatures)
Expand Down Expand Up @@ -394,7 +407,9 @@ protected function setSelectedFeatures(array $selectedFeatures)
* column importances are desired to be inspected.
*
* @param array $names
*
* @return $this
*
* @throws InvalidArgumentException
*/
public function setColumnNames(array $names)
Expand Down Expand Up @@ -458,8 +473,9 @@ public function getFeatureImportances()
* Collects and returns an array of internal nodes that use the given
* column as a split criterion
*
* @param int $column
* @param int $column
* @param DecisionTreeLeaf $node
*
* @return array
*/
protected function getSplitNodesByColumn(int $column, DecisionTreeLeaf $node) : array
Expand All @@ -478,16 +494,19 @@ protected function getSplitNodesByColumn(int $column, DecisionTreeLeaf $node) :
if ($node->leftLeaf) {
$lNodes = $this->getSplitNodesByColumn($column, $node->leftLeaf);
}

if ($node->rightLeaf) {
$rNodes = $this->getSplitNodesByColumn($column, $node->rightLeaf);
}

$nodes = array_merge($nodes, $lNodes, $rNodes);

return $nodes;
}

/**
* @param array $sample
*
* @return mixed
*/
protected function predictSample(array $sample)
Expand All @@ -497,6 +516,7 @@ protected function predictSample(array $sample)
if ($node->isTerminal) {
break;
}

if ($node->evaluate($sample)) {
$node = $node->leftLeaf;
} else {
Expand Down
4 changes: 3 additions & 1 deletion src/Phpml/Classification/DecisionTree/DecisionTreeLeaf.php
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ public function evaluate($record)
* Returns Mean Decrease Impurity (MDI) in the node.
* For terminal nodes, this value is equal to 0
*
* @param int $parentRecordCount
*
* @return float
*/
public function getNodeImpurityDecrease(int $parentRecordCount)
Expand Down Expand Up @@ -133,7 +135,7 @@ public function getHTML($columnNames = null)
} else {
$col = "col_$this->columnIndex";
}
if (! preg_match("/^[<>=]{1,2}/", $value)) {
if (!preg_match("/^[<>=]{1,2}/", $value)) {
$value = "=$value";
}
$value = "<b>$col $value</b><br>Gini: ". number_format($this->giniIndex, 2);
Expand Down
8 changes: 6 additions & 2 deletions src/Phpml/Classification/Ensemble/AdaBoost.php
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class AdaBoost implements Classifier
* improve classification performance of 'weak' classifiers such as
* DecisionStump (default base classifier of AdaBoost).
*
* @param int $maxIterations
*/
public function __construct(int $maxIterations = 50)
{
Expand All @@ -96,6 +97,8 @@ public function setBaseClassifier(string $baseClassifier = DecisionStump::class,
/**
* @param array $samples
* @param array $targets
*
* @throws \Exception
*/
public function train(array $samples, array $targets)
{
Expand Down Expand Up @@ -123,7 +126,6 @@ public function train(array $samples, array $targets)
// Execute the algorithm for a maximum number of iterations
$currIter = 0;
while ($this->maxIterations > $currIter++) {

// Determine the best 'weak' classifier based on current weights
$classifier = $this->getBestClassifier();
$errorRate = $this->evaluateClassifier($classifier);
Expand Down Expand Up @@ -181,7 +183,7 @@ protected function resample()
$targets = [];
foreach ($weights as $index => $weight) {
$z = (int)round(($weight - $mean) / $std) - $minZ + 1;
for ($i=0; $i < $z; $i++) {
for ($i = 0; $i < $z; ++$i) {
if (rand(0, 1) == 0) {
continue;
}
Expand All @@ -197,6 +199,8 @@ protected function resample()
* Evaluates the classifier and returns the classification error rate
*
* @param Classifier $classifier
*
* @return float
*/
protected function evaluateClassifier(Classifier $classifier)
{
Expand Down
25 changes: 16 additions & 9 deletions src/Phpml/Classification/Ensemble/Bagging.php
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,13 @@ class Bagging implements Classifier
private $samples = [];

/**
* Creates an ensemble classifier with given number of base classifiers<br>
* Default number of base classifiers is 100.
* Creates an ensemble classifier with given number of base classifiers
* Default number of base classifiers is 50.
* The more number of base classifiers, the better performance but at the cost of procesing time
*
* @param int $numClassifier
*/
public function __construct($numClassifier = 50)
public function __construct(int $numClassifier = 50)
{
$this->numClassifier = $numClassifier;
}
Expand All @@ -76,14 +76,17 @@ public function __construct($numClassifier = 50)
* to train each base classifier.
*
* @param float $ratio
*
* @return $this
* @throws Exception
*
* @throws \Exception
*/
public function setSubsetRatio(float $ratio)
{
if ($ratio < 0.1 || $ratio > 1.0) {
throw new \Exception("Subset ratio should be between 0.1 and 1.0");
}

$this->subsetRatio = $ratio;
return $this;
}
Expand All @@ -98,12 +101,14 @@ public function setSubsetRatio(float $ratio)
*
* @param string $classifier
* @param array $classifierOptions
*
* @return $this
*/
public function setClassifer(string $classifier, array $classifierOptions = [])
{
$this->classifier = $classifier;
$this->classifierOptions = $classifierOptions;

return $this;
}

Expand Down Expand Up @@ -138,11 +143,12 @@ protected function getRandomSubset(int $index)
$targets = [];
srand($index);
$bootstrapSize = $this->subsetRatio * $this->numSamples;
for ($i=0; $i < $bootstrapSize; $i++) {
for ($i = 0; $i < $bootstrapSize; ++$i) {
$rand = rand(0, $this->numSamples - 1);
$samples[] = $this->samples[$rand];
$targets[] = $this->targets[$rand];
}

return [$samples, $targets];
}

Expand All @@ -152,24 +158,25 @@ protected function getRandomSubset(int $index)
protected function initClassifiers()
{
$classifiers = [];
for ($i=0; $i<$this->numClassifier; $i++) {
for ($i = 0; $i < $this->numClassifier; ++$i) {
$ref = new \ReflectionClass($this->classifier);
if ($this->classifierOptions) {
$obj = $ref->newInstanceArgs($this->classifierOptions);
} else {
$obj = $ref->newInstance();
}
$classifiers[] = $this->initSingleClassifier($obj, $i);

$classifiers[] = $this->initSingleClassifier($obj);
}
return $classifiers;
}

/**
* @param Classifier $classifier
* @param int $index
*
* @return Classifier
*/
protected function initSingleClassifier($classifier, $index)
protected function initSingleClassifier($classifier)
{
return $classifier;
}
Expand Down
Loading

0 comments on commit 7ab80b6

Please sign in to comment.