diff --git a/README.md b/README.md index d93996f..f518fd0 100644 --- a/README.md +++ b/README.md @@ -112,6 +112,7 @@ Public datasets are available in a separate repository [php-ai/php-ml-datasets]( * [CSV](http://php-ml.readthedocs.io/en/latest/machine-learning/datasets/csv-dataset/) * [Files](http://php-ml.readthedocs.io/en/latest/machine-learning/datasets/files-dataset/) * [SVM](http://php-ml.readthedocs.io/en/latest/machine-learning/datasets/svm-dataset/) + * [MNIST](http://php-ml.readthedocs.io/en/latest/machine-learning/datasets/mnist-dataset.md) * Ready to use: * [Iris](http://php-ml.readthedocs.io/en/latest/machine-learning/datasets/demo/iris/) * [Wine](http://php-ml.readthedocs.io/en/latest/machine-learning/datasets/demo/wine/) diff --git a/docs/index.md b/docs/index.md index 12cbbd5..3c6ede2 100644 --- a/docs/index.md +++ b/docs/index.md @@ -93,6 +93,7 @@ Example scripts are available in a separate repository [php-ai/php-ml-examples]( * [CSV](machine-learning/datasets/csv-dataset.md) * [Files](machine-learning/datasets/files-dataset.md) * [SVM](machine-learning/datasets/svm-dataset.md) + * [MNIST](machine-learning/datasets/mnist-dataset.md) * Ready to use: * [Iris](machine-learning/datasets/demo/iris.md) * [Wine](machine-learning/datasets/demo/wine.md) diff --git a/docs/machine-learning/datasets/mnist-dataset.md b/docs/machine-learning/datasets/mnist-dataset.md new file mode 100644 index 0000000..1ed5081 --- /dev/null +++ b/docs/machine-learning/datasets/mnist-dataset.md @@ -0,0 +1,26 @@ +# MnistDataset + +Helper class that load data from MNIST dataset: [http://yann.lecun.com/exdb/mnist/](http://yann.lecun.com/exdb/mnist/) + +> The MNIST database of handwritten digits, available from this page, has a training set of 60,000 examples, and a test set of 10,000 examples. It is a subset of a larger set available from NIST. The digits have been size-normalized and centered in a fixed-size image. + It is a good database for people who want to try learning techniques and pattern recognition methods on real-world data while spending minimal efforts on preprocessing and formatting. + +### Constructors Parameters + +* $imagePath - (string) path to image file +* $labelPath - (string) path to label file + +``` +use Phpml\Dataset\MnistDataset; + +$trainDataset = new MnistDataset('train-images-idx3-ubyte', 'train-labels-idx1-ubyte'); +``` + +### Samples and labels + +To get samples or labels you can use getters: + +``` +$dataset->getSamples(); +$dataset->getTargets(); +``` diff --git a/mkdocs.yml b/mkdocs.yml index 490e5dc..451d6e9 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -39,6 +39,7 @@ pages: - CSV Dataset: machine-learning/datasets/csv-dataset.md - Files Dataset: machine-learning/datasets/files-dataset.md - SVM Dataset: machine-learning/datasets/svm-dataset.md + - MNIST Dataset: machine-learning/datasets/mnist-dataset.md - Ready to use datasets: - Iris: machine-learning/datasets/demo/iris.md - Wine: machine-learning/datasets/demo/wine.md diff --git a/phpstan.neon b/phpstan.neon index 7a676fa..0ee43c4 100644 --- a/phpstan.neon +++ b/phpstan.neon @@ -6,7 +6,7 @@ includes: parameters: ignoreErrors: - '#Property Phpml\\Clustering\\KMeans\\Cluster\:\:\$points \(iterable\\&SplObjectStorage\) does not accept SplObjectStorage#' - - '#Phpml\\Dataset\\FilesDataset::__construct\(\) does not call parent constructor from Phpml\\Dataset\\ArrayDataset#' + - '#Phpml\\Dataset\\(.*)Dataset::__construct\(\) does not call parent constructor from Phpml\\Dataset\\ArrayDataset#' # wide range cases - '#Parameter \#1 \$coordinates of class Phpml\\Clustering\\KMeans\\Point constructor expects array, array\|Phpml\\Clustering\\KMeans\\Point given#' diff --git a/src/Dataset/MnistDataset.php b/src/Dataset/MnistDataset.php new file mode 100644 index 0000000..59a3a26 --- /dev/null +++ b/src/Dataset/MnistDataset.php @@ -0,0 +1,101 @@ +samples = $this->readImages($imagePath); + $this->targets = $this->readLabels($labelPath); + + if (count($this->samples) !== count($this->targets)) { + throw new InvalidArgumentException('Must have the same number of images and labels'); + } + } + + private function readImages(string $imagePath): array + { + $stream = fopen($imagePath, 'rb'); + + if ($stream === false) { + throw new InvalidArgumentException('Could not open file: '.$imagePath); + } + + $images = []; + + try { + $header = fread($stream, 16); + + $fields = unpack('Nmagic/Nsize/Nrows/Ncols', (string) $header); + + if ($fields['magic'] !== self::MAGIC_IMAGE) { + throw new InvalidArgumentException('Invalid magic number: '.$imagePath); + } + + if ($fields['rows'] != self::IMAGE_ROWS) { + throw new InvalidArgumentException('Invalid number of image rows: '.$imagePath); + } + + if ($fields['cols'] != self::IMAGE_COLS) { + throw new InvalidArgumentException('Invalid number of image cols: '.$imagePath); + } + + for ($i = 0; $i < $fields['size']; $i++) { + $imageBytes = fread($stream, $fields['rows'] * $fields['cols']); + + // Convert to float between 0 and 1 + $images[] = array_map(function ($b) { + return $b / 255; + }, array_values(unpack('C*', (string) $imageBytes))); + } + } finally { + fclose($stream); + } + + return $images; + } + + private function readLabels(string $labelPath): array + { + $stream = fopen($labelPath, 'rb'); + + if ($stream === false) { + throw new InvalidArgumentException('Could not open file: '.$labelPath); + } + + $labels = []; + + try { + $header = fread($stream, 8); + + $fields = unpack('Nmagic/Nsize', (string) $header); + + if ($fields['magic'] !== self::MAGIC_LABEL) { + throw new InvalidArgumentException('Invalid magic number: '.$labelPath); + } + + $labels = fread($stream, $fields['size']); + } finally { + fclose($stream); + } + + return array_values(unpack('C*', (string) $labels)); + } +} diff --git a/tests/Dataset/MnistDatasetTest.php b/tests/Dataset/MnistDatasetTest.php new file mode 100644 index 0000000..5fc7374 --- /dev/null +++ b/tests/Dataset/MnistDatasetTest.php @@ -0,0 +1,33 @@ +getSamples()); + self::assertCount(10, $dataset->getTargets()); + } + + public function testCheckSamplesAndTargetsCountMatch(): void + { + $this->expectException(InvalidArgumentException::class); + + new MnistDataset( + __DIR__.'/Resources/mnist/images-idx-ubyte', + __DIR__.'/Resources/mnist/labels-11-idx-ubyte' + ); + } +} diff --git a/tests/Dataset/Resources/mnist/images-idx-ubyte b/tests/Dataset/Resources/mnist/images-idx-ubyte new file mode 100644 index 0000000..40b870a Binary files /dev/null and b/tests/Dataset/Resources/mnist/images-idx-ubyte differ diff --git a/tests/Dataset/Resources/mnist/labels-11-idx-ubyte b/tests/Dataset/Resources/mnist/labels-11-idx-ubyte new file mode 100644 index 0000000..db9362d Binary files /dev/null and b/tests/Dataset/Resources/mnist/labels-11-idx-ubyte differ diff --git a/tests/Dataset/Resources/mnist/labels-idx-ubyte b/tests/Dataset/Resources/mnist/labels-idx-ubyte new file mode 100644 index 0000000..eca5265 Binary files /dev/null and b/tests/Dataset/Resources/mnist/labels-idx-ubyte differ