blob: 2e17e8253451cbc155ddebd69787404556aa2c26 [file] [log] [blame]
# Copyright 2022 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import argparse
import os
import functorch
from functorch._src.compile_utils import strip_overloads
import iree_torch
import torch
import torch_mlir
def _get_argparse():
parser = argparse.ArgumentParser(
description="Train and run a regression model.")
parser.add_argument("output_file",
default="/tmp/native_training.vmfb",
help="The path to output the vmfb file to.")
parser.add_argument(
"--iree-backend",
default="llvm-cpu",
help="See https://iree-org.github.io/iree/deployment-configurations/ "
"for the full list of options.")
return parser
def _suppress_warnings():
import warnings
warnings.simplefilter("ignore")
import os
def forward(w, b, X):
return torch.matmul(X, w) + b
def mse(y_pred, y):
err = y_pred - y
return torch.mean(torch.square(err))
def loss_fn(w, b, X, y):
y_pred = forward(w, b, X)
return mse(y_pred, y)
grad_fn = functorch.grad(loss_fn, argnums=(0, 1))
def update(w, b, grad_w, grad_b):
learning_rate = 0.05
new_w = w - grad_w * learning_rate
new_b = b - grad_b * learning_rate
return new_w, new_b
def train(w, b, X, y):
grad_w, grad_b = grad_fn(w, b, X, y)
loss = loss_fn(w, b, X, y)
return update(w, b, grad_w, grad_b) + (loss,)
def main():
global w, b, X_test, y_test
_suppress_warnings()
args = _get_argparse().parse_args()
#
# Training
#
#
# We use placeholder dummy values for tracing the model, since the training
# functions themselves are stateless. The real data will be fed in at call
# time.
w = torch.tensor([1.0, 1.0, 1.0], dtype=torch.float32)
b = torch.tensor(1.0, dtype=torch.float32)
X_test = torch.tensor([[1.0, 1.0, 1.0]], dtype=torch.float32)
y_test = torch.tensor([1.0], dtype=torch.float32)
train_args = (w, b, X_test, y_test)
graph = functorch.make_fx(train)(*train_args)
# TODO: Remove once https://github.com/llvm/torch-mlir/issues/1495
# is resolved.
strip_overloads(graph)
mlir = torch_mlir.compile(graph,
train_args,
output_type=torch_mlir.OutputType.LINALG_ON_TENSORS)
vmfb = iree_torch.compile_to_vmfb(mlir, args.iree_backend)
with open(args.output_file, "wb") as f:
f.write(vmfb)
if __name__ == "__main__":
main()