blob: 16b8c510e0f8cb6b1c36fac6100b0c2cf025214b [file] [log] [blame] [view]
<!--
Copyright 2019 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
# MNIST IR Example
This shows the MNIST MLP model as it is compiled from Keras, lowered to XLA HLO,
and then lowered to an IREE module with SPIR-V. Several steps are omitted for
brevity.
## TensorFlow Keras Model
```python
def simple_mnist_model(input_shape):
"""Creates a simple (multi-layer perceptron) MNIST model."""
model = tf.keras.models.Sequential()
# Flatten to a 1d array (e.g. 28x28 -> 784)
model.add(tf.keras.layers.Flatten(input_shape=input_shape))
# Fully-connected neural layer with 128 neurons, RELU activation
model.add(tf.keras.layers.Dense(128, activation='relu'))
# Fully-connected neural layer returning probability scores for each class
model.add(tf.keras.layers.Dense(10, activation='softmax'))
return model
```
## XLA HLO
**NOTE**: this uses placeholder weights to keep the page from being a few
thousand lines of floats.
```mlir
module {
func @main(%arg0: tensor<1x28x28x1xf32>) -> tuple<tensor<1x10xf32>>
attributes {iree.module.export} {
%cst = constant {name = "constant.9"} dense<0.5> : tensor<f32>
%0 = "xla_hlo.broadcast_in_dim"(%cst) {name = "broadcast.10"} : (tensor<f32>) -> tensor<1x128xf32>
%1 = "xla_hlo.copy"(%arg0) {name = "copy.1"} : (tensor<1x28x28x1xf32>) -> tensor<1x28x28x1xf32>
%2 = "xla_hlo.reshape"(%1) {name = "reshape.2"} : (tensor<1x28x28x1xf32>) -> tensor<1x28x28x1xf32>
%3 = "xla_hlo.reshape"(%2) {name = "reshape.3"} : (tensor<1x28x28x1xf32>) -> tensor<1x784xf32>
%cst_0 = constant {name = "constant.4"} dense<0.5> : tensor<784x128xf32>
%4 = "xla_hlo.dot"(%3, %cst_0) {name = "dot.5", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x784xf32>, tensor<784x128xf32>) -> tensor<1x128xf32>
%cst_1 = constant {name = "constant.6"} dense<0.5> : tensor<128xf32>
%5 = "xla_hlo.broadcast_in_dim"(%cst_1) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "broadcast.7"} : (tensor<128xf32>) -> tensor<1x128xf32>
%6 = "xla_hlo.add"(%4, %5) {name = "add.8"} : (tensor<1x128xf32>, tensor<1x128xf32>) -> tensor<1x128xf32>
%7 = "xla_hlo.max"(%0, %6) {name = "maximum.11"} : (tensor<1x128xf32>, tensor<1x128xf32>) -> tensor<1x128xf32>
%cst_2 = constant {name = "constant.12"} dense<0.5> : tensor<128x10xf32>
%8 = "xla_hlo.dot"(%7, %cst_2) {name = "dot.13", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x128xf32>, tensor<128x10xf32>) -> tensor<1x10xf32>
%cst_3 = constant {name = "constant.14"} dense<0.5> : tensor<10xf32>
%9 = "xla_hlo.broadcast_in_dim"(%cst_3) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "broadcast.15"} : (tensor<10xf32>) -> tensor<1x10xf32>
%10 = "xla_hlo.add"(%8, %9) {name = "add.16"} : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
%cst_4 = constant {name = "constant.17"} dense<0xFF800000> : tensor<f32>
%11 = "xla_hlo.reduce"(%10, %cst_4) ( {
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>): // no predecessors
%20 = "xla_hlo.max"(%arg1, %arg2) {name = "maximum.21"} : (tensor<f32>, tensor<f32>) -> tensor<f32>
"xla_hlo.return"(%20) : (tensor<f32>) -> ()
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x10xf32>, tensor<f32>) -> tensor<1xf32>
%12 = "xla_hlo.broadcast_in_dim"(%11) {broadcast_dimensions = dense<0> : tensor<1xi64>, name = "broadcast.23"} : (tensor<1xf32>) -> tensor<1x10xf32>
%13 = "xla_hlo.sub"(%10, %12) {name = "subtract.24"} : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
%14 = "xla_hlo.exp"(%13) {name = "exponential.25"} : (tensor<1x10xf32>) -> tensor<1x10xf32>
%cst_5 = constant {name = "constant.27"} dense<0.5> : tensor<f32>
%15 = "xla_hlo.reduce"(%14, %cst_5) ( {
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): // no predecessors
%21 = "xla_hlo.add"(%arg3, %arg4) {name = "add.31"} : (tensor<f32>, tensor<f32>) -> tensor<f32>
"xla_hlo.return"(%21) : (tensor<f32>) -> ()
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x10xf32>, tensor<f32>) -> tensor<1xf32>
%16 = "xla_hlo.broadcast_in_dim"(%15) {broadcast_dimensions = dense<0> : tensor<1xi64>, name = "broadcast.34"} : (tensor<1xf32>) -> tensor<1x10xf32>
%17 = "xla_hlo.div"(%14, %16) {name = "divide.35"} : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
%18 = "xla_hlo.reshape"(%17) {name = "reshape.36"} : (tensor<1x10xf32>) -> tensor<1x10xf32>
%19 = "xla_hlo.tuple"(%18) {name = "tuple.37"} : (tensor<1x10xf32>) -> tuple<tensor<1x10xf32>>
return %19 : tuple<tensor<1x10xf32>>
}
}
```
## IREE IR (pre-backend lowering)
Here's the lowered, outlined, and compiler-annotated version of the above in the
IREE sequencer dialect.
```mlir
module {
iree.multi_arch_executable @main_ex_dispatch_0[0]() {
iree.executable[0](Unspecified) {
module {
func @main_entry_dispatch_0(%arg0: memref<1x28x28x1xf32>, %arg1: memref<1x784xf32>)
attributes {iree.executable.export, iree.executable.workload = dense<[784, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
%0 = iree.load_input(%arg0 : memref<1x28x28x1xf32>) : tensor<1x28x28x1xf32>
%1 = "xla_hlo.copy"(%0) {name = "copy.1"} : (tensor<1x28x28x1xf32>) -> tensor<1x28x28x1xf32>
%2 = "xla_hlo.reshape"(%1) {name = "reshape.3"} : (tensor<1x28x28x1xf32>) -> tensor<1x784xf32>
iree.store_output(%2 : tensor<1x784xf32>, %arg1 : memref<1x784xf32>)
iree.return
}
}
}
}
iree.multi_arch_executable @main_ex_dispatch_1[1]() {
iree.executable[1](Unspecified) {
module {
func @main_entry_dispatch_1(%arg0: memref<1x784xf32>, %arg1: memref<784x128xf32>, %arg2: memref<1x128xf32>)
attributes {iree.executable.export, iree.executable.workload = dense<[128, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
%0 = iree.load_input(%arg0 : memref<1x784xf32>) : tensor<1x784xf32>
%1 = iree.load_input(%arg1 : memref<784x128xf32>) : tensor<784x128xf32>
%2 = "xla_hlo.dot"(%0, %1) {name = "dot.5", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x784xf32>, tensor<784x128xf32>) -> tensor<1x128xf32>
iree.store_output(%2 : tensor<1x128xf32>, %arg2 : memref<1x128xf32>)
iree.return
}
}
}
}
iree.multi_arch_executable @main_ex_dispatch_2[2]() {
iree.executable[2](Unspecified) {
module {
func @main_entry_dispatch_2(%arg0: memref<1x128xf32>, %arg1: memref<1x128xf32>)
attributes {iree.executable.export, iree.executable.workload = dense<[128, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
%0 = iree.load_input(%arg0 : memref<1x128xf32>) : tensor<1x128xf32>
%cst = constant dense<5.000000e-01> : tensor<128xf32>
%cst_0 = constant dense<5.000000e-01> : tensor<f32>
%1 = "xla_hlo.broadcast_in_dim"(%cst_0) {name = "broadcast.10"} : (tensor<f32>) -> tensor<1x128xf32>
%2 = "xla_hlo.broadcast_in_dim"(%cst) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "broadcast.7"} : (tensor<128xf32>) -> tensor<1x128xf32>
%3 = addf %0, %2 : tensor<1x128xf32>
%4 = xla_hlo.max %1, %3 {name = "maximum.11"} : tensor<1x128xf32>
iree.store_output(%4 : tensor<1x128xf32>, %arg1 : memref<1x128xf32>)
iree.return
}
}
}
}
iree.multi_arch_executable @main_ex_dispatch_3[3]() {
iree.executable[3](Unspecified) {
module {
func @main_entry_dispatch_3(%arg0: memref<1x128xf32>, %arg1: memref<128x10xf32>, %arg2: memref<1x10xf32>)
attributes {iree.executable.export, iree.executable.workload = dense<[10, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
%0 = iree.load_input(%arg0 : memref<1x128xf32>) : tensor<1x128xf32>
%1 = iree.load_input(%arg1 : memref<128x10xf32>) : tensor<128x10xf32>
%2 = "xla_hlo.dot"(%0, %1) {name = "dot.13", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x128xf32>, tensor<128x10xf32>) -> tensor<1x10xf32>
iree.store_output(%2 : tensor<1x10xf32>, %arg2 : memref<1x10xf32>)
iree.return
}
}
}
}
iree.multi_arch_executable @main_ex_dispatch_4[4]() {
iree.executable[4](Unspecified) {
module {
func @main_entry_dispatch_4(%arg0: memref<1x10xf32>, %arg1: memref<1x10xf32>)
attributes {iree.executable.export, iree.executable.workload = dense<[10, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
%0 = iree.load_input(%arg0 : memref<1x10xf32>) : tensor<1x10xf32>
%cst = constant dense<5.000000e-01> : tensor<10xf32>
%1 = "xla_hlo.broadcast_in_dim"(%cst) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "broadcast.15"} : (tensor<10xf32>) -> tensor<1x10xf32>
%2 = addf %0, %1 : tensor<1x10xf32>
iree.store_output(%2 : tensor<1x10xf32>, %arg1 : memref<1x10xf32>)
iree.return
}
}
}
}
iree.multi_arch_executable @main_ex_dispatch_5[5]() {
iree.executable[5](Unspecified) {
module {
func @main_entry_dispatch_5(%arg0: memref<1x10xf32>, %arg1: memref<1xf32>)
attributes {iree.executable.export, iree.executable.workload = dense<1> : tensor<3xi32>, iree.ordinal = 0 : i32} {
%0 = iree.load_input(%arg0 : memref<1x10xf32>) : tensor<1x10xf32>
%cst = constant dense<0xFF800000> : tensor<f32>
%1 = "xla_hlo.reduce"(%0, %cst) ( {
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>): // no predecessors
%2 = xla_hlo.max %arg2, %arg3 {name = "maximum.21"} : tensor<f32>
"xla_hlo.return"(%2) : (tensor<f32>) -> ()
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x10xf32>, tensor<f32>) -> tensor<1xf32>
iree.store_output(%1 : tensor<1xf32>, %arg1 : memref<1xf32>)
iree.return
}
}
}
}
iree.multi_arch_executable @main_ex_dispatch_6[6]() {
iree.executable[6](Unspecified) {
module {
func @main_entry_dispatch_6(%arg0: memref<1x10xf32>, %arg1: memref<1xf32>, %arg2: memref<1x10xf32>)
attributes {iree.executable.export, iree.executable.workload = dense<[10, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
%0 = iree.load_input(%arg0 : memref<1x10xf32>) : tensor<1x10xf32>
%1 = iree.load_input(%arg1 : memref<1xf32>) : tensor<1xf32>
%2 = "xla_hlo.broadcast_in_dim"(%1) {broadcast_dimensions = dense<0> : tensor<1xi64>, name = "broadcast.23"} : (tensor<1xf32>) -> tensor<1x10xf32>
%3 = subf %0, %2 : tensor<1x10xf32>
%4 = "xla_hlo.exp"(%3) {name = "exponential.25"} : (tensor<1x10xf32>) -> tensor<1x10xf32>
iree.store_output(%4 : tensor<1x10xf32>, %arg2 : memref<1x10xf32>)
iree.return
}
}
}
}
iree.multi_arch_executable @main_ex_dispatch_7[7]() {
iree.executable[7](Unspecified) {
module {
func @main_entry_dispatch_7(%arg0: memref<1x10xf32>, %arg1: memref<1xf32>)
attributes {iree.executable.export, iree.executable.workload = dense<1> : tensor<3xi32>, iree.ordinal = 0 : i32} {
%0 = iree.load_input(%arg0 : memref<1x10xf32>) : tensor<1x10xf32>
%cst = constant dense<5.000000e-01> : tensor<f32>
%1 = "xla_hlo.reduce"(%0, %cst) ( {
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>): // no predecessors
%2 = addf %arg2, %arg3 : tensor<f32>
"xla_hlo.return"(%2) : (tensor<f32>) -> ()
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x10xf32>, tensor<f32>) -> tensor<1xf32>
iree.store_output(%1 : tensor<1xf32>, %arg1 : memref<1xf32>)
iree.return
}
}
}
}
iree.multi_arch_executable @main_ex_dispatch_8[8]() {
iree.executable[8](Unspecified) {
module {
func @main_entry_dispatch_8(%arg0: memref<1xf32>, %arg1: memref<1x10xf32>, %arg2: memref<1x10xf32>)
attributes {iree.executable.export, iree.executable.workload = dense<[10, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
%0 = iree.load_input(%arg0 : memref<1xf32>) : tensor<1xf32>
%1 = iree.load_input(%arg1 : memref<1x10xf32>) : tensor<1x10xf32>
%2 = "xla_hlo.broadcast_in_dim"(%0) {broadcast_dimensions = dense<0> : tensor<1xi64>, name = "broadcast.34"} : (tensor<1xf32>) -> tensor<1x10xf32>
%3 = divf %1, %2 : tensor<1x10xf32>
iree.store_output(%3 : tensor<1x10xf32>, %arg2 : memref<1x10xf32>)
iree.return
}
}
}
}
func @main(%arg0: memref<1x28x28x1xf32>) -> memref<1x10xf32>
attributes {iree.module.export} {
%0 = "iree_ll_seq.constant"() {value = dense<5.000000e-01> : tensor<784x128xf32>} : () -> memref<784x128xf32>
%1 = "iree_ll_seq.constant"() {value = dense<5.000000e-01> : tensor<128x10xf32>} : () -> memref<128x10xf32>
%2 = "iree_ll_seq.alloc_heap"() : () -> memref<1x784xf32>
iree_ll_seq.static_dispatch main_ex_dispatch_0::main_entry_dispatch_0[dense<[784, 1, 1]> : tensor<3xi32>](%arg0, %2) : (memref<1x28x28x1xf32>, memref<1x784xf32>) -> ()
%3 = "iree_ll_seq.alloc_heap"() : () -> memref<1x128xf32>
iree_ll_seq.static_dispatch main_ex_dispatch_1::main_entry_dispatch_1[dense<[128, 1, 1]> : tensor<3xi32>](%2, %0, %3) : (memref<1x784xf32>, memref<784x128xf32>, memref<1x128xf32>) -> ()
%4 = "iree_ll_seq.alloc_heap"() : () -> memref<1x128xf32>
iree_ll_seq.static_dispatch main_ex_dispatch_2::main_entry_dispatch_2[dense<[128, 1, 1]> : tensor<3xi32>](%3, %4) : (memref<1x128xf32>, memref<1x128xf32>) -> ()
%5 = "iree_ll_seq.alloc_heap"() : () -> memref<1x10xf32>
iree_ll_seq.static_dispatch main_ex_dispatch_3::main_entry_dispatch_3[dense<[10, 1, 1]> : tensor<3xi32>](%4, %1, %5) : (memref<1x128xf32>, memref<128x10xf32>, memref<1x10xf32>) -> ()
%6 = "iree_ll_seq.alloc_heap"() : () -> memref<1x10xf32>
iree_ll_seq.static_dispatch main_ex_dispatch_4::main_entry_dispatch_4[dense<[10, 1, 1]> : tensor<3xi32>](%5, %6) : (memref<1x10xf32>, memref<1x10xf32>) -> ()
%7 = "iree_ll_seq.alloc_heap"() : () -> memref<1xf32>
iree_ll_seq.static_dispatch main_ex_dispatch_5::main_entry_dispatch_5[dense<1> : tensor<3xi32>](%6, %7) : (memref<1x10xf32>, memref<1xf32>) -> ()
%8 = "iree_ll_seq.alloc_heap"() : () -> memref<1x10xf32>
iree_ll_seq.static_dispatch main_ex_dispatch_6::main_entry_dispatch_6[dense<[10, 1, 1]> : tensor<3xi32>](%6, %7, %8) : (memref<1x10xf32>, memref<1xf32>, memref<1x10xf32>) -> ()
%9 = "iree_ll_seq.alloc_heap"() : () -> memref<1xf32>
iree_ll_seq.static_dispatch main_ex_dispatch_7::main_entry_dispatch_7[dense<1> : tensor<3xi32>](%8, %9) : (memref<1x10xf32>, memref<1xf32>) -> ()
%10 = "iree_ll_seq.alloc_heap"() : () -> memref<1x10xf32>
iree_ll_seq.static_dispatch main_ex_dispatch_8::main_entry_dispatch_8[dense<[10, 1, 1]> : tensor<3xi32>](%9, %8, %10) : (memref<1xf32>, memref<1x10xf32>, memref<1x10xf32>) -> ()
iree_ll_seq.return %10 : memref<1x10xf32>
}
}
```
**NOTE**: this is effectively compiling in -O0, which is why the buffers are not
aliased and some dispatch region fusing is not performed. As we get things going
we'll be adding simple optimizations that can operate on this IR to elide almost
all copies and externalize allocations to transient pooled memory.
## Final IREE Module with SPIR-V
TODO(benvanik): once reductions are done.