|  | # Copyright 2018 The IREE Authors | 
|  | # | 
|  | # Licensed under the Apache License v2.0 with LLVM Exceptions. | 
|  | # See https://llvm.org/LICENSE.txt for license information. | 
|  | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | 
|  | """Datasets used in examples.""" | 
|  |  | 
|  | import array | 
|  | import gzip | 
|  | import os | 
|  | from os import path | 
|  | import struct | 
|  | import urllib.request | 
|  |  | 
|  | import numpy as np | 
|  |  | 
|  | _DATA = "/tmp/jax_example_data/" | 
|  |  | 
|  |  | 
|  | def _download(url, filename): | 
|  | """Download a url to a file in the JAX data temp directory.""" | 
|  | if not path.exists(_DATA): | 
|  | os.makedirs(_DATA) | 
|  | out_file = path.join(_DATA, filename) | 
|  | if not path.isfile(out_file): | 
|  | urllib.request.urlretrieve(url, out_file) | 
|  | print("downloaded {} to {}".format(url, _DATA)) | 
|  |  | 
|  |  | 
|  | def _partial_flatten(x): | 
|  | """Flatten all but the first dimension of an ndarray.""" | 
|  | return np.reshape(x, (x.shape[0], -1)) | 
|  |  | 
|  |  | 
|  | def _one_hot(x, k, dtype=np.float32): | 
|  | """Create a one-hot encoding of x of size k.""" | 
|  | return np.array(x[:, None] == np.arange(k), dtype) | 
|  |  | 
|  |  | 
|  | def mnist_raw(): | 
|  | """Download and parse the raw MNIST dataset.""" | 
|  | # CVDF mirror of http://yann.lecun.com/exdb/mnist/ | 
|  | base_url = "https://storage.googleapis.com/cvdf-datasets/mnist/" | 
|  |  | 
|  | def parse_labels(filename): | 
|  | with gzip.open(filename, "rb") as fh: | 
|  | _ = struct.unpack(">II", fh.read(8)) | 
|  | return np.array(array.array("B", fh.read()), dtype=np.uint8) | 
|  |  | 
|  | def parse_images(filename): | 
|  | with gzip.open(filename, "rb") as fh: | 
|  | _, num_data, rows, cols = struct.unpack(">IIII", fh.read(16)) | 
|  | return np.array(array.array("B", fh.read()), | 
|  | dtype=np.uint8).reshape(num_data, rows, cols) | 
|  |  | 
|  | for filename in [ | 
|  | "train-images-idx3-ubyte.gz", "train-labels-idx1-ubyte.gz", | 
|  | "t10k-images-idx3-ubyte.gz", "t10k-labels-idx1-ubyte.gz" | 
|  | ]: | 
|  | _download(base_url + filename, filename) | 
|  |  | 
|  | train_images = parse_images(path.join(_DATA, "train-images-idx3-ubyte.gz")) | 
|  | train_labels = parse_labels(path.join(_DATA, "train-labels-idx1-ubyte.gz")) | 
|  | test_images = parse_images(path.join(_DATA, "t10k-images-idx3-ubyte.gz")) | 
|  | test_labels = parse_labels(path.join(_DATA, "t10k-labels-idx1-ubyte.gz")) | 
|  |  | 
|  | return train_images, train_labels, test_images, test_labels | 
|  |  | 
|  |  | 
|  | def mnist(permute_train=False): | 
|  | """Download, parse and process MNIST data to unit scale and one-hot labels.""" | 
|  | train_images, train_labels, test_images, test_labels = mnist_raw() | 
|  |  | 
|  | train_images = _partial_flatten(train_images) / np.float32(255.) | 
|  | test_images = _partial_flatten(test_images) / np.float32(255.) | 
|  | train_labels = _one_hot(train_labels, 10) | 
|  | test_labels = _one_hot(test_labels, 10) | 
|  |  | 
|  | if permute_train: | 
|  | perm = np.random.RandomState(0).permutation(train_images.shape[0]) | 
|  | train_images = train_images[perm] | 
|  | train_labels = train_labels[perm] | 
|  |  | 
|  | return train_images, train_labels, test_images, test_labels |