php-ml/src/Dataset/MnistDataset.php

102 lines
2.8 KiB
PHP

<?php
declare(strict_types=1);
namespace Phpml\Dataset;
use Phpml\Exception\InvalidArgumentException;
/**
* MNIST dataset: http://yann.lecun.com/exdb/mnist/
* original mnist dataset reader: https://github.com/AndrewCarterUK/mnist-neural-network-plain-php
*/
final class MnistDataset extends ArrayDataset
{
private const MAGIC_IMAGE = 0x00000803;
private const MAGIC_LABEL = 0x00000801;
private const IMAGE_ROWS = 28;
private const IMAGE_COLS = 28;
public function __construct(string $imagePath, string $labelPath)
{
$this->samples = $this->readImages($imagePath);
$this->targets = $this->readLabels($labelPath);
if (count($this->samples) !== count($this->targets)) {
throw new InvalidArgumentException('Must have the same number of images and labels');
}
}
private function readImages(string $imagePath): array
{
$stream = fopen($imagePath, 'rb');
if ($stream === false) {
throw new InvalidArgumentException('Could not open file: '.$imagePath);
}
$images = [];
try {
$header = fread($stream, 16);
$fields = unpack('Nmagic/Nsize/Nrows/Ncols', (string) $header);
if ($fields['magic'] !== self::MAGIC_IMAGE) {
throw new InvalidArgumentException('Invalid magic number: '.$imagePath);
}
if ($fields['rows'] != self::IMAGE_ROWS) {
throw new InvalidArgumentException('Invalid number of image rows: '.$imagePath);
}
if ($fields['cols'] != self::IMAGE_COLS) {
throw new InvalidArgumentException('Invalid number of image cols: '.$imagePath);
}
for ($i = 0; $i < $fields['size']; $i++) {
$imageBytes = fread($stream, $fields['rows'] * $fields['cols']);
// Convert to float between 0 and 1
$images[] = array_map(function ($b) {
return $b / 255;
}, array_values(unpack('C*', (string) $imageBytes)));
}
} finally {
fclose($stream);
}
return $images;
}
private function readLabels(string $labelPath): array
{
$stream = fopen($labelPath, 'rb');
if ($stream === false) {
throw new InvalidArgumentException('Could not open file: '.$labelPath);
}
$labels = [];
try {
$header = fread($stream, 8);
$fields = unpack('Nmagic/Nsize', (string) $header);
if ($fields['magic'] !== self::MAGIC_LABEL) {
throw new InvalidArgumentException('Invalid magic number: '.$labelPath);
}
$labels = fread($stream, $fields['size']);
} finally {
fclose($stream);
}
return array_values(unpack('C*', (string) $labels));
}
}