forked from jorgecasas/php-ml
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Implement MnistDataset * Add MNIST dataset documentation
- Loading branch information
Showing
10 changed files
with
164 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# MnistDataset | ||
|
||
Helper class that load data from MNIST dataset: [http://yann.lecun.com/exdb/mnist/](http://yann.lecun.com/exdb/mnist/) | ||
|
||
> The MNIST database of handwritten digits, available from this page, has a training set of 60,000 examples, and a test set of 10,000 examples. It is a subset of a larger set available from NIST. The digits have been size-normalized and centered in a fixed-size image. | ||
It is a good database for people who want to try learning techniques and pattern recognition methods on real-world data while spending minimal efforts on preprocessing and formatting. | ||
|
||
### Constructors Parameters | ||
|
||
* $imagePath - (string) path to image file | ||
* $labelPath - (string) path to label file | ||
|
||
``` | ||
use Phpml\Dataset\MnistDataset; | ||
$trainDataset = new MnistDataset('train-images-idx3-ubyte', 'train-labels-idx1-ubyte'); | ||
``` | ||
|
||
### Samples and labels | ||
|
||
To get samples or labels you can use getters: | ||
|
||
``` | ||
$dataset->getSamples(); | ||
$dataset->getTargets(); | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
<?php | ||
|
||
declare(strict_types=1); | ||
|
||
namespace Phpml\Dataset; | ||
|
||
use Phpml\Exception\InvalidArgumentException; | ||
|
||
/** | ||
* MNIST dataset: http://yann.lecun.com/exdb/mnist/ | ||
* original mnist dataset reader: https://github.com/AndrewCarterUK/mnist-neural-network-plain-php | ||
*/ | ||
final class MnistDataset extends ArrayDataset | ||
{ | ||
private const MAGIC_IMAGE = 0x00000803; | ||
|
||
private const MAGIC_LABEL = 0x00000801; | ||
|
||
private const IMAGE_ROWS = 28; | ||
|
||
private const IMAGE_COLS = 28; | ||
|
||
public function __construct(string $imagePath, string $labelPath) | ||
{ | ||
$this->samples = $this->readImages($imagePath); | ||
$this->targets = $this->readLabels($labelPath); | ||
|
||
if (count($this->samples) !== count($this->targets)) { | ||
throw new InvalidArgumentException('Must have the same number of images and labels'); | ||
} | ||
} | ||
|
||
private function readImages(string $imagePath): array | ||
{ | ||
$stream = fopen($imagePath, 'rb'); | ||
|
||
if ($stream === false) { | ||
throw new InvalidArgumentException('Could not open file: '.$imagePath); | ||
} | ||
|
||
$images = []; | ||
|
||
try { | ||
$header = fread($stream, 16); | ||
|
||
$fields = unpack('Nmagic/Nsize/Nrows/Ncols', (string) $header); | ||
|
||
if ($fields['magic'] !== self::MAGIC_IMAGE) { | ||
throw new InvalidArgumentException('Invalid magic number: '.$imagePath); | ||
} | ||
|
||
if ($fields['rows'] != self::IMAGE_ROWS) { | ||
throw new InvalidArgumentException('Invalid number of image rows: '.$imagePath); | ||
} | ||
|
||
if ($fields['cols'] != self::IMAGE_COLS) { | ||
throw new InvalidArgumentException('Invalid number of image cols: '.$imagePath); | ||
} | ||
|
||
for ($i = 0; $i < $fields['size']; $i++) { | ||
$imageBytes = fread($stream, $fields['rows'] * $fields['cols']); | ||
|
||
// Convert to float between 0 and 1 | ||
$images[] = array_map(function ($b) { | ||
return $b / 255; | ||
}, array_values(unpack('C*', (string) $imageBytes))); | ||
} | ||
} finally { | ||
fclose($stream); | ||
} | ||
|
||
return $images; | ||
} | ||
|
||
private function readLabels(string $labelPath): array | ||
{ | ||
$stream = fopen($labelPath, 'rb'); | ||
|
||
if ($stream === false) { | ||
throw new InvalidArgumentException('Could not open file: '.$labelPath); | ||
} | ||
|
||
$labels = []; | ||
|
||
try { | ||
$header = fread($stream, 8); | ||
|
||
$fields = unpack('Nmagic/Nsize', (string) $header); | ||
|
||
if ($fields['magic'] !== self::MAGIC_LABEL) { | ||
throw new InvalidArgumentException('Invalid magic number: '.$labelPath); | ||
} | ||
|
||
$labels = fread($stream, $fields['size']); | ||
} finally { | ||
fclose($stream); | ||
} | ||
|
||
return array_values(unpack('C*', (string) $labels)); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
<?php | ||
|
||
declare(strict_types=1); | ||
|
||
namespace Phpml\Tests\Dataset; | ||
|
||
use Phpml\Dataset\MnistDataset; | ||
use Phpml\Exception\InvalidArgumentException; | ||
use PHPUnit\Framework\TestCase; | ||
|
||
class MnistDatasetTest extends TestCase | ||
{ | ||
public function testSimpleMnistDataset(): void | ||
{ | ||
$dataset = new MnistDataset( | ||
__DIR__.'/Resources/mnist/images-idx-ubyte', | ||
__DIR__.'/Resources/mnist/labels-idx-ubyte' | ||
); | ||
|
||
self::assertCount(10, $dataset->getSamples()); | ||
self::assertCount(10, $dataset->getTargets()); | ||
} | ||
|
||
public function testCheckSamplesAndTargetsCountMatch(): void | ||
{ | ||
$this->expectException(InvalidArgumentException::class); | ||
|
||
new MnistDataset( | ||
__DIR__.'/Resources/mnist/images-idx-ubyte', | ||
__DIR__.'/Resources/mnist/labels-11-idx-ubyte' | ||
); | ||
} | ||
} |
Binary file not shown.
Binary file not shown.
Binary file not shown.