| # Copyright 2018 The IREE Authors |
| # |
| # Licensed under the Apache License v2.0 with LLVM Exceptions. |
| # See https://llvm.org/LICENSE.txt for license information. |
| # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| """Datasets used in examples.""" |
| |
| import array |
| import gzip |
| import os |
| from os import path |
| import struct |
| import urllib.request |
| |
| import numpy as np |
| |
| _DATA = "/tmp/jax_example_data/" |
| |
| |
| def _download(url, filename): |
| """Download a url to a file in the JAX data temp directory.""" |
| if not path.exists(_DATA): |
| os.makedirs(_DATA) |
| out_file = path.join(_DATA, filename) |
| if not path.isfile(out_file): |
| urllib.request.urlretrieve(url, out_file) |
| print("downloaded {} to {}".format(url, _DATA)) |
| |
| |
| def _partial_flatten(x): |
| """Flatten all but the first dimension of an ndarray.""" |
| return np.reshape(x, (x.shape[0], -1)) |
| |
| |
| def _one_hot(x, k, dtype=np.float32): |
| """Create a one-hot encoding of x of size k.""" |
| return np.array(x[:, None] == np.arange(k), dtype) |
| |
| |
| def mnist_raw(): |
| """Download and parse the raw MNIST dataset.""" |
| # CVDF mirror of http://yann.lecun.com/exdb/mnist/ |
| base_url = "https://storage.googleapis.com/cvdf-datasets/mnist/" |
| |
| def parse_labels(filename): |
| with gzip.open(filename, "rb") as fh: |
| _ = struct.unpack(">II", fh.read(8)) |
| return np.array(array.array("B", fh.read()), dtype=np.uint8) |
| |
| def parse_images(filename): |
| with gzip.open(filename, "rb") as fh: |
| _, num_data, rows, cols = struct.unpack(">IIII", fh.read(16)) |
| return np.array(array.array("B", fh.read()), dtype=np.uint8).reshape( |
| num_data, rows, cols |
| ) |
| |
| for filename in [ |
| "train-images-idx3-ubyte.gz", |
| "train-labels-idx1-ubyte.gz", |
| "t10k-images-idx3-ubyte.gz", |
| "t10k-labels-idx1-ubyte.gz", |
| ]: |
| _download(base_url + filename, filename) |
| |
| train_images = parse_images(path.join(_DATA, "train-images-idx3-ubyte.gz")) |
| train_labels = parse_labels(path.join(_DATA, "train-labels-idx1-ubyte.gz")) |
| test_images = parse_images(path.join(_DATA, "t10k-images-idx3-ubyte.gz")) |
| test_labels = parse_labels(path.join(_DATA, "t10k-labels-idx1-ubyte.gz")) |
| |
| return train_images, train_labels, test_images, test_labels |
| |
| |
| def mnist(permute_train=False): |
| """Download, parse and process MNIST data to unit scale and one-hot labels.""" |
| train_images, train_labels, test_images, test_labels = mnist_raw() |
| |
| train_images = _partial_flatten(train_images) / np.float32(255.0) |
| test_images = _partial_flatten(test_images) / np.float32(255.0) |
| train_labels = _one_hot(train_labels, 10) |
| test_labels = _one_hot(test_labels, 10) |
| |
| if permute_train: |
| perm = np.random.RandomState(0).permutation(train_images.shape[0]) |
| train_images = train_images[perm] |
| train_labels = train_labels[perm] |
| |
| return train_images, train_labels, test_images, test_labels |