Remove native training example.
This was migrated:
https://github.com/iree-org/iree-torch/commit/6a6f60b0b0ccb9107ee40d0994c472e1126fcc48
diff --git a/build_tools/testing/test_samples.sh b/build_tools/testing/test_samples.sh
index 14eaf52..1387a6d 100755
--- a/build_tools/testing/test_samples.sh
+++ b/build_tools/testing/test_samples.sh
@@ -16,4 +16,3 @@
./samples/dynamic_shapes/test.sh
./samples/variables_and_state/test.sh
-./samples/native_training/test.sh
diff --git a/samples/native_training/Makefile b/samples/native_training/Makefile
deleted file mode 100644
index 4781108..0000000
--- a/samples/native_training/Makefile
+++ /dev/null
@@ -1,53 +0,0 @@
-# Copyright 2022 The IREE Authors
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-# This is an example showing a basic makefile that links in the IREE runtime by
-# way of the unified static library. It's recommended that IREE is added as a
-# subproject and cmake is used to add the dependencies, but when using other
-# build systems this is easier to adapt.
-#
-# Configure the runtime:
-# cmake -GNinja -B /tmp/iree-build-runtime/ .
-# Build the runtime:
-# cmake --build /tmp/iree-build-runtime/ --target iree_runtime_unified
-# Make this binary:
-# make
-#
-# Note that if IREE_SIZE_OPTIMIZED is used to build the runtime then the
-# -DNDEBUG and -DIREE_STATUS_MODE=0 are required on any binaries using it. YMMV
-# if changing any compiler options and not keeping them in sync. Prefer using
-# cmake to ensure consistency between the builds.
-#
-# If cpuinfo is not supported on your platform then configure the runtime with
-# -DIREE_ENABLE_CPUINFO=OFF.
-
-RUNTIME_SRC_DIR ?= ../../runtime/src/
-RUNTIME_BUILD_DIR ?= /tmp/iree-build-runtime/
-
-SRC_FILES := main.c
-INCLUDE_DIRS := ${RUNTIME_SRC_DIR}
-INCLUDE_FLAGS := $(addprefix -I,${INCLUDE_DIRS})
-LIBRARY_DIRS := \
- ${RUNTIME_BUILD_DIR}/build_tools/third_party/flatcc/ \
- ${RUNTIME_BUILD_DIR}/runtime/src/iree/runtime/ \
- ${RUNTIME_BUILD_DIR}/third_party/cpuinfo/ \
- ${RUNTIME_BUILD_DIR}/third_party/cpuinfo/deps/clog
-LINK_LIBRARIES := \
- iree_runtime_unified \
- flatcc_parsing \
- cpuinfo \
- clog \
- dl \
- pthread
-LIBRARY_FLAGS := $(addprefix -L,${LIBRARY_DIRS}) $(addprefix -l,${LINK_LIBRARIES})
-CXX_FLAGS := -flto -O2 ${INCLUDE_FLAGS} ${LIBRARY_FLAGS}
-
-all: native-training
-clean:
- rm -f native-training
-
-native-training: ${SRC_FILES}
- ${CXX} ${SRC_FILES} ${CXX_FLAGS} -o $@
diff --git a/samples/native_training/README.md b/samples/native_training/README.md
deleted file mode 100644
index d223976..0000000
--- a/samples/native_training/README.md
+++ /dev/null
@@ -1,69 +0,0 @@
-# Native Training Example
-
-This example shows how to
-
-1. Build a PyTorch functional model for training
-2. Import that model into IREE's compiler
-3. Compile that model to an IREE VM bytecode module
-4. Load the compiled module using IREE's high level runtime C API into a
- lightweight program
-5. Train the loaded model
-
-This example was built with the goal of allowing you to be able to build it
-outside this repo in your own project with minimal changes.
-
-The weights for the model are stored in the program itself and updated in
-memory. This can be modified to be stored however you see fit.
-
-## Running the Example
-
-Install `iree-torch` and other dependencies necessary for this example.
-[iree-torch](https://github.com/iree-org/iree-torch) provides a number of
-convenient wrappers around `torch-mlir` and `iree` compilation:
-
-> **Note**
-> We recommend installing Python packages inside a
-> [virtual environment](https://docs.python.org/3/tutorial/venv.html).
-
-```shell
-pip install -f https://iree-org.github.io/iree/pip-release-links.html iree-compiler
-pip install -f https://llvm.github.io/torch-mlir/package-index/ torch-mlir
-pip install git+https://github.com/iree-org/iree-torch.git
-```
-
-Update submodules in this repo:
-
-```shell
-(cd $(git rev-parse --show-toplevel) && git submodule update --init)
-```
-
-Build the IREE runtime:
-
-```shell
-(cd $(git rev-parse --show-toplevel) && cmake -GNinja -B /tmp/iree-build-runtime/ .)
-cmake --build /tmp/iree-build-runtime/ --target iree_runtime_unified
-```
-
-Make sure you're in this example's directory:
-
-```shell
-cd $(git rev-parse --show-toplevel)/samples/native_training
-```
-
-Build the native training example:
-
-```shell
-make
-```
-
-Generate the IREE VM bytecode for the model:
-
-```shell
-python native_training.py /tmp/native_training.vmfb
-```
-
-Run the native training model:
-
-```shell
-./native-training /tmp/native_training.vmfb
-```
diff --git a/samples/native_training/main.c b/samples/native_training/main.c
deleted file mode 100644
index 9309201..0000000
--- a/samples/native_training/main.c
+++ /dev/null
@@ -1,260 +0,0 @@
-// Copyright 2022 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include <stdio.h>
-
-#include "iree/runtime/api.h"
-
-struct State {
- float w[3];
- float b[1];
- float X[3];
- float y[1];
- float loss[1];
-};
-
-void print_state(struct State* state) {
- fprintf(stdout, "Weights:");
- for (iree_host_size_t i = 0; i < IREE_ARRAYSIZE(state->w); ++i) {
- fprintf(stdout, " %f", state->w[i]);
- }
- fprintf(stdout, ", Bias: %f", state->b[0]);
- fprintf(stdout, ", Loss: %f\n", state->loss[0]);
-}
-
-iree_status_t train(iree_runtime_session_t* session, struct State* state) {
- iree_status_t status = iree_ok_status();
-
- // Lookup the entry point function.
- iree_runtime_call_t call;
- status = iree_runtime_call_initialize_by_name(
- session, iree_make_cstring_view("module.forward"), &call);
-
- // Allocate buffers in device-local memory so that if the device has an
- // independent address space they live on the fast side of the fence.
- iree_hal_dim_t shape_w[1] = {IREE_ARRAYSIZE(state->w)};
- iree_hal_dim_t shape_b[0] = {};
- iree_hal_dim_t shape_X[2] = {1, IREE_ARRAYSIZE(state->X)};
- iree_hal_dim_t shape_y[1] = {IREE_ARRAYSIZE(state->y)};
- iree_hal_buffer_view_t* arg0 = NULL;
- iree_hal_buffer_view_t* arg1 = NULL;
- iree_hal_buffer_view_t* arg2 = NULL;
- iree_hal_buffer_view_t* arg3 = NULL;
- if (iree_status_is_ok(status)) {
- status = iree_hal_buffer_view_allocate_buffer(
- iree_runtime_session_device_allocator(session), IREE_ARRAYSIZE(shape_w),
- shape_w, IREE_HAL_ELEMENT_TYPE_FLOAT_32,
- IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR,
- (iree_hal_buffer_params_t){
- .usage = IREE_HAL_BUFFER_USAGE_DEFAULT,
- .type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL,
- },
- iree_make_const_byte_span(state->w, sizeof(state->w)), &arg0);
- }
- if (iree_status_is_ok(status)) {
- status = iree_hal_buffer_view_allocate_buffer(
- iree_runtime_session_device_allocator(session), IREE_ARRAYSIZE(shape_b),
- shape_b, IREE_HAL_ELEMENT_TYPE_FLOAT_32,
- IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR,
- (iree_hal_buffer_params_t){
- .usage = IREE_HAL_BUFFER_USAGE_DEFAULT,
- .type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL,
- },
- iree_make_const_byte_span(state->b, sizeof(state->b)), &arg1);
- }
- if (iree_status_is_ok(status)) {
- status = iree_hal_buffer_view_allocate_buffer(
- iree_runtime_session_device_allocator(session), IREE_ARRAYSIZE(shape_X),
- shape_X, IREE_HAL_ELEMENT_TYPE_FLOAT_32,
- IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR,
- (iree_hal_buffer_params_t){
- .usage = IREE_HAL_BUFFER_USAGE_DEFAULT,
- .type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL,
- },
- iree_make_const_byte_span(state->X, sizeof(state->X)), &arg2);
- }
- if (iree_status_is_ok(status)) {
- status = iree_hal_buffer_view_allocate_buffer(
- iree_runtime_session_device_allocator(session), IREE_ARRAYSIZE(shape_y),
- shape_y, IREE_HAL_ELEMENT_TYPE_FLOAT_32,
- IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR,
- (iree_hal_buffer_params_t){
- .usage = IREE_HAL_BUFFER_USAGE_DEFAULT,
- .type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL,
- },
- iree_make_const_byte_span(state->y, sizeof(state->y)), &arg3);
- }
-
- // Setup call inputs with our buffers.
- if (iree_status_is_ok(status)) {
- status = iree_runtime_call_inputs_push_back_buffer_view(&call, arg0);
- }
- iree_hal_buffer_view_release(arg0);
- if (iree_status_is_ok(status)) {
- status = iree_runtime_call_inputs_push_back_buffer_view(&call, arg1);
- }
- iree_hal_buffer_view_release(arg1);
- if (iree_status_is_ok(status)) {
- status = iree_runtime_call_inputs_push_back_buffer_view(&call, arg2);
- }
- iree_hal_buffer_view_release(arg2);
- if (iree_status_is_ok(status)) {
- status = iree_runtime_call_inputs_push_back_buffer_view(&call, arg3);
- }
- iree_hal_buffer_view_release(arg3);
-
- // Invoke the function
- IREE_RETURN_IF_ERROR(iree_runtime_call_invoke(&call, /*flags=*/0));
-
- // Update weights
- iree_hal_buffer_view_t* result = NULL;
- if (iree_status_is_ok(status)) {
- status = iree_runtime_call_outputs_pop_front_buffer_view(&call, &result);
- }
- if (iree_status_is_ok(status)) {
- status = iree_hal_device_transfer_d2h(
- iree_runtime_session_device(session),
- iree_hal_buffer_view_buffer(result), 0, &state->w, sizeof(state->w),
- IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, iree_infinite_timeout());
- }
-
- // Update bias
- if (iree_status_is_ok(status)) {
- status = iree_runtime_call_outputs_pop_front_buffer_view(&call, &result);
- }
- if (iree_status_is_ok(status)) {
- status = iree_hal_device_transfer_d2h(
- iree_runtime_session_device(session),
- iree_hal_buffer_view_buffer(result), 0, &state->b, sizeof(state->b),
- IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, iree_infinite_timeout());
- }
-
- // Update loss
- if (iree_status_is_ok(status)) {
- status = iree_runtime_call_outputs_pop_front_buffer_view(&call, &result);
- }
- if (iree_status_is_ok(status)) {
- status = iree_hal_device_transfer_d2h(iree_runtime_session_device(session),
- iree_hal_buffer_view_buffer(result),
- 0, &state->loss, sizeof(state->loss),
- IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT,
- iree_infinite_timeout());
- }
- iree_hal_buffer_view_release(result);
-
- return status;
-}
-
-iree_status_t run_sample(iree_string_view_t bytecode_module_path,
- iree_string_view_t driver_name, struct State* state) {
- iree_status_t status = iree_ok_status();
-
- //===-------------------------------------------------------------------===//
- // Instance configuration (this should be shared across sessions).
- iree_runtime_instance_options_t instance_options;
- iree_runtime_instance_options_initialize(&instance_options);
- iree_runtime_instance_options_use_all_available_drivers(&instance_options);
- iree_runtime_instance_t* instance = NULL;
- if (iree_status_is_ok(status)) {
- fprintf(stdout, "Configuring IREE runtime instance and '%s' device\n",
- driver_name.data);
- status = iree_runtime_instance_create(&instance_options,
- iree_allocator_system(), &instance);
- }
- // TODO(#5724): move device selection into the compiled modules.
- iree_hal_device_t* device = NULL;
- if (iree_status_is_ok(status)) {
- status = iree_runtime_instance_try_create_default_device(
- instance, driver_name, &device);
- }
- //===-------------------------------------------------------------------===//
-
- //===-------------------------------------------------------------------===//
- // Session configuration (one per loaded module to hold module state).
- iree_runtime_session_options_t session_options;
- iree_runtime_session_options_initialize(&session_options);
- iree_runtime_session_t* session = NULL;
- if (iree_status_is_ok(status)) {
- fprintf(stdout, "Creating IREE runtime session\n");
- status = iree_runtime_session_create_with_device(
- instance, &session_options, device,
- iree_runtime_instance_host_allocator(instance), &session);
- }
- iree_hal_device_release(device);
-
- if (iree_status_is_ok(status)) {
- fprintf(stdout, "Loading bytecode module at '%s'\n",
- bytecode_module_path.data);
- status = iree_runtime_session_append_bytecode_module_from_file(
- session, bytecode_module_path.data);
- }
- //===-------------------------------------------------------------------===//
-
- //===-------------------------------------------------------------------===//
- if (iree_status_is_ok(status)) {
- fprintf(stdout, "Training...\n");
- print_state(state);
- for (int i = 0; i < 10; i++) {
- status = train(session, state);
- print_state(state);
- if (!iree_status_is_ok(status)) {
- break;
- }
- }
- }
- //===-------------------------------------------------------------------===//
-
- //===-------------------------------------------------------------------===//
- // Cleanup.
- iree_runtime_session_release(session);
- iree_runtime_instance_release(instance);
- //===-------------------------------------------------------------------===//
-
- return status;
-}
-
-int main(int argc, char** argv) {
- // Parse args
- if (argc < 2) {
- fprintf(stderr,
- "Usage: native-training </path/to/native_training.vmfb> "
- "[<driver_name>]\n");
- fprintf(stderr, " (See the README for this sample for details)\n ");
- return -1;
- }
- iree_string_view_t bytecode_module_path = iree_make_cstring_view(argv[1]);
- iree_string_view_t driver_name;
- if (argc >= 3) {
- driver_name = iree_make_cstring_view(argv[2]);
- } else {
- driver_name = iree_make_cstring_view("local-sync");
- }
-
- // Run training
- struct State state = {
- {4.0f, 4.0f, 5.0f}, // w
- {2.0f}, // b
- {1.0f, 1.0f, 1.0f}, // X
- {14.0f}, // y
- {1.0f}, // loss
- };
- iree_status_t result = run_sample(bytecode_module_path, driver_name, &state);
- if (!iree_status_is_ok(result)) {
- fprintf(stdout, "Failed!\n");
- iree_status_fprint(stderr, result);
- iree_status_ignore(result);
- return -1;
- }
-
- // Validate result
- if (*state.loss > 0.1f) {
- fprintf(stdout, "Loss unexpectedly high\n");
- return -1;
- }
-
- fprintf(stdout, "\nSuccess!\n");
- return 0;
-}
diff --git a/samples/native_training/native_training.py b/samples/native_training/native_training.py
deleted file mode 100644
index 3020a7b..0000000
--- a/samples/native_training/native_training.py
+++ /dev/null
@@ -1,97 +0,0 @@
-# Copyright 2022 The IREE Authors
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-import argparse
-import os
-
-import functorch
-import iree_torch
-import torch
-import torch_mlir
-
-
-def _get_argparse():
- parser = argparse.ArgumentParser(
- description="Train and run a regression model.")
- parser.add_argument("output_file",
- default="/tmp/native_training.vmfb",
- help="The path to output the vmfb file to.")
- parser.add_argument(
- "--iree-backend",
- default="llvm-cpu",
- help="See https://iree-org.github.io/iree/deployment-configurations/ "
- "for the full list of options.")
- return parser
-
-
-def _suppress_warnings():
- import warnings
- warnings.simplefilter("ignore")
- import os
-
-
-def forward(w, b, X):
- return torch.matmul(X, w) + b
-
-
-def mse(y_pred, y):
- err = y_pred - y
- return torch.mean(torch.square(err))
-
-
-def loss_fn(w, b, X, y):
- y_pred = forward(w, b, X)
- return mse(y_pred, y)
-
-
-grad_fn = functorch.grad(loss_fn, argnums=(0, 1))
-
-
-def update(w, b, grad_w, grad_b):
- learning_rate = 0.05
- new_w = w - grad_w * learning_rate
- new_b = b - grad_b * learning_rate
- return new_w, new_b
-
-
-def train(w, b, X, y):
- grad_w, grad_b = grad_fn(w, b, X, y)
- loss = loss_fn(w, b, X, y)
- return update(w, b, grad_w, grad_b) + (loss,)
-
-
-def main():
- global w, b, X_test, y_test
- _suppress_warnings()
- args = _get_argparse().parse_args()
-
- #
- # Training
- #
- #
-
- # We use placeholder dummy values for tracing the model, since the training
- # functions themselves are stateless. The real data will be fed in at call
- # time.
- w = torch.tensor([1.0, 1.0, 1.0], dtype=torch.float32)
- b = torch.tensor(1.0, dtype=torch.float32)
- X_test = torch.tensor([[1.0, 1.0, 1.0]], dtype=torch.float32)
- y_test = torch.tensor([1.0], dtype=torch.float32)
-
- train_args = (w, b, X_test, y_test)
- graph = functorch.make_fx(train)(*train_args)
-
- mlir = torch_mlir.compile(graph,
- train_args,
- output_type=torch_mlir.OutputType.LINALG_ON_TENSORS)
-
- vmfb = iree_torch.compile_to_vmfb(mlir, args.iree_backend)
- with open(args.output_file, "wb") as f:
- f.write(vmfb)
-
-
-if __name__ == "__main__":
- main()
diff --git a/samples/native_training/test.sh b/samples/native_training/test.sh
deleted file mode 100755
index ecf4168..0000000
--- a/samples/native_training/test.sh
+++ /dev/null
@@ -1,48 +0,0 @@
-#!/bin/bash
-# Copyright 2022 The IREE Authors
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-# This script runs the steps laid out in the README for this sample. It is
-# intended for use on continuous integration servers and as a reference for
-# users, but can also be run manually.
-
-set -x
-set -e
-
-# Run under a virtual environment to isolate Python packages.
-#
-# For more information, see `build_tools/testing/run_python_notebook.sh`
-python3 -m venv .script.venv --system-site-packages --clear
-source .script.venv/bin/activate 2> /dev/null
-trap "deactivate 2> /dev/null" EXIT
-
-# Update pip within the venv (you'd think this wouldn't be needed, but it is).
-python3 -m pip install --quiet --upgrade pip
-
-# Install script requirements, reusing system versions if possible.
-python3 -m pip install --quiet \
- -f https://iree-org.github.io/iree/pip-release-links.html iree-compiler
-python3 -m pip install --quiet \
- -f https://llvm.github.io/torch-mlir/package-index/ torch-mlir
-python3 -m pip install --quiet \
- git+https://github.com/iree-org/iree-torch.git
-
-# Update submodules in this repo.
-(cd $(git rev-parse --show-toplevel) && git submodule update --init)
-
-# Build the IREE runtime.
-(cd $(git rev-parse --show-toplevel) && cmake -GNinja -B /tmp/iree-build-runtime/ .)
-cmake --build /tmp/iree-build-runtime/ --target iree_runtime_unified
-
-# Build the example.
-cd $(git rev-parse --show-toplevel)/samples/native_training
-make
-
-# Generate the VM bytecode.
-python native_training.py /tmp/native_training.vmfb
-
-# Run the native training model.
-./native-training /tmp/native_training.vmfb
\ No newline at end of file