Merge pull request #2650 from rsuderman:main-to-google
PiperOrigin-RevId: 323056877
diff --git a/.github/workflows/update_tf.yml b/.github/workflows/update_tf.yml
index 86c55d9..b116f99 100644
--- a/.github/workflows/update_tf.yml
+++ b/.github/workflows/update_tf.yml
@@ -54,6 +54,4 @@
Automated submodule bump from .github/workflows/update_tf.yml
committer: "Submodule Update Action <iree-github-actions-bot@google.com>"
- # TODO(gcmn): Figure out a way to assign this to someone dynamically.
- reviewers: gmngeoffrey
branch: "auto_submodule_update"
diff --git a/CMakeLists.txt b/CMakeLists.txt
index aae95bf..155f08d 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -93,6 +93,10 @@
if( IREE_HAL_DRIVERS_TO_BUILD STREQUAL "all" )
set( IREE_HAL_DRIVERS_TO_BUILD ${IREE_ALL_HAL_DRIVERS} )
+ # For cross compilation towords Android, we don't want LLVM JIT HAL driver.
+ if(ANDROID)
+ list(REMOVE_ITEM IREE_HAL_DRIVERS_TO_BUILD LLVM)
+ endif()
endif()
message(STATUS "Building HAL drivers ${IREE_HAL_DRIVERS_TO_BUILD}")
@@ -112,8 +116,8 @@
# List of all target backends to be built by default:
set(IREE_ALL_TARGET_BACKENDS
# TODO(scotttodd): LLVMAOT
- LLVMIR
- Vulkan_SPIRV
+ LLVM-IR
+ Vulkan-SPIRV
VMLA
)
diff --git a/build_tools/cmake/iree_check_test.cmake b/build_tools/cmake/iree_check_test.cmake
index aa794bd..93fd934 100644
--- a/build_tools/cmake/iree_check_test.cmake
+++ b/build_tools/cmake/iree_check_test.cmake
@@ -173,6 +173,17 @@
${ARGN}
)
+
+ string(TOUPPER ${_RULE_DRIVER} _UPPERCASE_DRIVER)
+ if(NOT IREE_HAL_DRIVER_${_UPPERCASE_DRIVER})
+ return()
+ endif()
+
+ string(TOUPPER ${_RULE_TARGET_BACKEND} _UPPERCASE_TARGET_BACKEND)
+ if(NOT IREE_TARGET_BACKEND_${_UPPERCASE_TARGET_BACKEND})
+ return()
+ endif()
+
foreach(_SRC IN LISTS _RULE_SRCS)
set(_TEST_NAME "${_RULE_NAME}_${_SRC}")
iree_check_test(
diff --git a/build_tools/docker/build_and_update_gcr.py b/build_tools/docker/build_and_update_gcr.py
index 8ca90ad..531088f 100755
--- a/build_tools/docker/build_and_update_gcr.py
+++ b/build_tools/docker/build_and_update_gcr.py
@@ -122,7 +122,11 @@
# Ensure the user has the correct authorization if they try to push to GCR.
if args.push:
- subprocess.check_output(['gcloud', 'auth', 'configure-docker'])
+ if run_command(['which', 'gcloud']) != 0:
+ print('gcloud not found.'
+ ' See https://cloud.google.com/sdk/install for installation.')
+ sys.exit(1)
+ check_command(['gcloud', 'auth', 'configure-docker'])
# Check if any images depend on `args.images` and update them if they do.
images_to_update_set = set()
diff --git a/build_tools/kokoro/gcp_ubuntu/docker_common.sh b/build_tools/kokoro/gcp_ubuntu/docker_common.sh
index 81dc0b3..fbc896f 100644
--- a/build_tools/kokoro/gcp_ubuntu/docker_common.sh
+++ b/build_tools/kokoro/gcp_ubuntu/docker_common.sh
@@ -32,13 +32,10 @@
mkdir -p "${fake_etc_dir?}"
local fake_group="${fake_etc_dir?}/group"
- local fake_passwd="${fake_etc_dir?}/group"
+ local fake_passwd="${fake_etc_dir?}/passwd"
- cp /etc/passwd "${fake_group?}"
- cp /etc/group "${fake_passwd?}"
- getent group "$(id -g)" >> "${fake_group?}"
- getent passwd "$(id -u)" >> "${fake_passwd?}"
-
+ getent group > "${fake_group?}"
+ getent passwd > "${fake_passwd?}"
local workdir="${KOKORO_ARTIFACTS_DIR?}/github/iree"
diff --git a/docs/design_docs/codegen_passes.md b/docs/design_docs/codegen_passes.md
new file mode 100644
index 0000000..83e37fc
--- /dev/null
+++ b/docs/design_docs/codegen_passes.md
@@ -0,0 +1,640 @@
+# IREE CPU/GPU Code Generation Pipeline
+
+This document is intended to provide an overview of the codegen pipeline within
+IREE used to generate CPU/GPU code. It intends to give an overview of the main
+passes used, the objective of the pass, the current implementation, and what it
+is expected to achieve in the long term.
+
+Note that while the code generation pipeline supports dynamic shapes, this work
+is very preliminary. The description of this is not covered here.
+
+## Input to the codegen pipeline
+
+The input to the code generation pipeline is the module within the
+`hal.executable.target` operation. Functions within this module that do __not__
+have `Visibility::Private` are the *entry point* functions of the dispatch
+region. These are the functions that are *invoked* by the IREE runtime. In
+addition, each dispatch region also contains a `hal.interface` operation that
+describes the ABI to use for the dispatch region. Two examples of the input to
+the code generation pipeline are shown below. In both of these, a single
+dispatch function contains a sequence of MHLO operations that the dispatch
+region creation has grouped into a single region. Ideally the grouped operations
+are fused into a single kernel.
+
+```mlir
+hal.executable.target "vulkan*" {
+ module attributes {spv.target_env = ...} {
+ func @main_ex_dispatch() {
+ %c0 = constant 0 : index
+ %0 = hal.interface.load.tensor @legacy_io::@arg0,
+ offset = %c0 : tensor<4x5xf32>
+ %1 = hal.interface.load.tensor @legacy_io::@arg1,
+ offset = %c0 : tensor<5x10xf32>
+ %2 = "mhlo.dot"(%0, %1) {precision_config = ["DEFAULT", "DEFAULT"]} :
+ (tensor<4x5xf32>, tensor<5x10xf32>) -> tensor<4x10xf32>
+ hal.interface.store.tensor %2, @legacy_io::@ret0,
+ offset = %c0 : tensor<4x10xf32>
+ return
+ }
+ hal.interface @legacy_io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0,
+ type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1,
+ type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2,
+ type="StorageBuffer", access="Write|Discard"
+ }
+ }
+}
+```
+
+<a name="snippet1"></a> Snippet 1 : Dispatch region with matrix-matrix multiply
+operation.
+
+```mlir
+hal.executable.target "vulkan*" {
+ module attributes {spv.target_env = ...} {
+ func @main_ex_dispatch() {
+ %c0 = constant 0 : index
+ %0 = hal.interface.load.tensor @legacy_io::@arg0,
+ offset = %c0 : tensor<10x5xf32>
+ %1 = hal.interface.load.tensor @legacy_io::@arg1,
+ offset = %c0 : tensor<10x5xf32>
+ %2 = hal.interface.load.tensor @legacy_io::@arg2,
+ offset = %c0 : tensor<10x5xf32>
+ %3 = "mhlo.add"(%0, %1) :
+ (tensor<10x5xf32>, tensor<10x5xf32>) -> tensor<10x5xf32>
+ %4 = "mhlo.multiply"(%3, %2) :
+ (tensor<10x5xf32>, tensor<10x5xf32>) -> tensor<10x5xf32>
+ hal.interface.store.tensor %4, @legacy_io::@ret0,
+ offset = %c0 : tensor<10x5xf32>
+ return
+ }
+ hal.interface @legacy_io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0,
+ type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1,
+ type="StorageBuffer", access="Read"
+ hal.interface.binding @arg2, set=0, binding=2,
+ type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=3,
+ type="StorageBuffer", access="Write|Discard"
+ }
+ }
+}
+```
+
+<a name="snippet2"></a> Snippet 2 : Dispatch region with element-wise
+operations.
+
+__Roadmap Note__: The current implementation might not actually fuse the
+operations grouped into a dispatch region into a single kernel. It is possible
+to end up with multiple kernels per dispatch region. Over time we plan to
+address this by using fusion at different levels (see below).
+
+The inputs to the dispatch region are materialized within the entry point
+function using the `hal.interface.load.tensor` operation, This operation returns
+a `tensor` view of the buffer used to store the inputs. Similarly the result of
+the dispatch region are *written* out using the `hal.interface.store.tensor`
+operation.
+
+The main constraint that the code generation operates under is that it should
+not require additional (temporary) buffers to execute the operations grouped
+together within a dispatch region. The rationale behind this constraint is that
+buffer allocation/synchronization in IREE happens at the granularity of dispatch
+regions, allowing the scheduler to make better decision about where to insert
+appropriate synchronizations.
+
+The IR after all the passes used in the lowering from MHLO to SPIR-V for the
+above two examples can be found here ([matrix-matrix multiply op][DotAfterAll],
+[elementwise ops][PwAfterAll]). Below is a description of the major passes used.
+
+## Conversion from MHLO dialect to Linalg on buffers
+
+The code generation pipeline heavily relies on use of
+[Structured Operations][LinalgRationale], specifically the
+[Linalg Dialect][LinalgDialect]. Both, the Linalg operations on `tensor`s and on
+`memref`s are central to the progressive lowering approach followed here. The
+first part of the code generation pipeline is to convert the MHLO operations on
+`tensor`s to Linalg operation on `memref`s. This part of the pipeline is common
+to both CPU and GPU code generation.
+
+The steps involved in this conversion is shown below. Each of the arrows
+represents a pass in the pipeline:
+
+
+
+The next sections describe each of these passes in more detail.
+
+### MHLO to Linalg on tensors
+
+The first step is to convert MHLO operations to Linalg on tensors. This is done
+using the [HLOToLinalgPass][HLOToLinalgPass] from Tensorflow. An example of the
+conversion is shown below, where each of the `mhlo.add` and `mhlo.multiply`
+operations are converted to `linalg.generic` operations on tensors.
+
+```mlir
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+%3 = linalg.generic
+ {args_in = 2 : i64, args_out = 1 : i64,
+ indexing_maps = [#map0, #map0, #map0],
+ iterator_types = ["parallel", "parallel"]} %0, %1 {
+ ^bb0(%arg0: f32, %arg1: f32): // no predecessors
+ %5 = addf %arg0, %arg1 : f32
+ linalg.yield %5 : f32
+ } : tensor<10x5xf32>, tensor<10x5xf32> -> tensor<10x5xf32>
+%4 = linalg.generic
+ {args_in = 2 : i64, args_out = 1 : i64,
+ indexing_maps = [#map0, #map0, #map0],
+ iterator_types = ["parallel", "parallel"]} %3, %2 {
+ ^bb0(%arg0: f32, %arg1: f32): // no predecessors
+ %5 = mulf %arg0, %arg1 : f32
+ linalg.yield %5 : f32
+ }: tensor<10x5xf32>, tensor<10x5xf32> -> tensor<10x5xf32>
+```
+
+<a name="snippet3"></a> Snippet 3 : MHLO to Linalg conversion for
+[element-wise operations](#snippet2)
+
+At the time of writing the representation of Linalg on `tensor`s does not model
+reduction iterator types completely. Specifically, the reduction in Linalg is
+modeled using read-modify-write approach, i.e. each iteration of the reduction
+loop reads the value stored in the output, adds its contribution, and writes
+back to the same location. This means the output has to be *initialized* to the
+null element of the reduction operator (i.e. 0 if the reduction is done using
+addition). This works for operations on buffers. Since tensors are SSA values
+they cannot be updated in-place. As a result, the reduction semantics does not
+map as well to `tensor`s. For now it is treated as a convention that when the
+Linalg operation is converted to use `memref`s it has to be initialized
+appropriately before performing the reduction. Due to this, the conversion from
+MHLO op to Linalg op is only done for operations which do not need a *reduction*
+iterator type in the converted Linalg op. Consequently, only element-wise
+operations, broadcast operations and data movement operations (like copy and
+transpose) are converted to Linalg operations at this stage.
+
+__Roadmap note__: One long term solution for the above is to have operations on
+tensors that have *reduction* iterator type to take an additional argument that
+contains the initial value of the result tensor. When the operation is converted
+to use `memref`s, the buffer for the initial value operand can be reused for the
+result. The details involved have not been fully worked out yet.
+
+### Fusion of Linalg on tensor operations
+
+The Linalg on `tensor` operations generated at the previous step are fused using
+the [LinalgFusionOfTensorOps][LinalgFusionOfTensorOps] from MLIR. Since
+`tensor`s are SSA values, fusion at this stage can be done without using alias
+analysis or dependence analysis based on reads and writes. Instead the use-def
+chains for the `tensor` values can be used to implement producer-consumer
+fusion. This stage fuses most elementwise operations, broadcast operations and
+data movement operations. An example of the fused op is shown below.
+
+```mlir
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+%3 = linalg.generic
+ {args_in = 3 : i64, args_out = 1 : i64,
+ indexing_maps = [#map0, #map0, #map0, #map0],
+ iterator_types = ["parallel", "parallel"]} %0, %1, %2 {
+ ^bb0(%arg0: f32, %arg1: f32, %arg2: f32): // no predecessors
+ %4 = addf %arg0, %arg1 : f32
+ %5 = mulf %4, %arg2 : f32
+ linalg.yield %5 : f32
+ }: tensor<?x5xf32>, tensor<?x5xf32>, tensor<?x5xf32> -> tensor<?x5xf32>
+```
+
+<a name="snippet4"></a> Snippet 4: Fusion of Linalg operation on tensors for
+element-wise operations shown in [Snippet 3](#snippet3)
+
+### Conversion of Linalg on tensors to Linalg on buffers
+
+Post fusion all the operation on `tensor`s are converted to analogous operations
+on `memref`s. In general, this requires a buffer allocation pass. In IREE,
+buffer allocation happens at the granularity of dispatch region, and as
+mentioned [earlier](#input-to-the-codegen-pipeline), the dispatch region is not
+expected to use any additional temporary buffers. So instead of having another
+buffer allocation pass within the code generation pipeline, a simpler approach
+is used within IREE:
+
+- For each `hal.interface.store.tensor` an `iree.placeholder` operation is
+ created. The latter uses the same `hal.interface.binding` as the former, but
+ returns a `memref` view of the output of the dispatch region instead of a
+ `tensor` view. This `iree.placeholder` operation is added to start of the
+ entry point function.
+
+- A map is constructed that for a given `tensor` records the `memref` value to
+ use during the conversion. In this map the `tensor` value used in the
+ `hal.interface.store.tensor` is mapped to the `memref` value returned by the
+ created `iree.placeholder` operation.
+
+- The Dialect Conversion framework is used to implement a set of patterns that
+ convert from operations on `tensor`s to operation on `memref`s,
+
+ - A `hal.interface.load.tensor`, is replaced with an `iree.placeholder` to
+ get the `memref` view of the input to the dispatch region.
+ - All Linalg operation on `tensor`s (expected to be just `linalg.generic`
+ or `linalg.indexed_generic` operations) are converted to the
+ corresponding operation on `memref`s. Instead of returning a `tensor`
+ value the converted operation takes an additional `memref` operand as
+ argument. This `memref` is where the result of the operation is
+ populated. Current implementation looks for the `memref` to use from the
+ map constructed previously. If there is no `memref` associated with the
+ result `tensor` the conversion fails.
+ - At this stage, any `mhlo` operation not converted to a Linalg operation
+ are directly converted to a Linalg operation on buffers. This is done
+ for operations that when converted to Linalg have a *reduction* iterator
+ type. Some examples of ops converted this way are
+
+ - `mhlo.dot`
+ - `mhlo.reduce`
+ - `mhlo.conv`
+ - `mhlo.reduce_window`.
+
+ Since the specification of the Linalg operations require the output
+ `memref` to be initialized appropriately, a `linalg.fill` operation is
+ used to achieve this.
+
+__Roadmap Note__ : Right now the code-generation pipeline relies on fusion of
+operations on tensor level. In the near future, we want to be able to fuse
+operations like `linalg.matmul` and `linalg.conv` with consumers/producers that
+are element-wise operations using the
+[fusion of Linalg operation on `memref`s][LinalgFusionOnBuffers].
+
+At this stage of the compilation all operations must have been converted to
+Linalg operations on buffers. Shown below are the IR at the end of this stage
+for the two examples in Snippets 1 and 2.
+
+```mlir
+func @main_ex_dispatch() {
+ %0 = iree.placeholder for "interface buffer"
+ {binding = @legacy_io::@ret0} : memref<4x10xf32>
+ %c0 = constant 0 : index
+ %1 = iree.placeholder for "interface buffer"
+ {binding = @legacy_io::@arg0} : memref<4x5xf32>
+ %2 = iree.placeholder for "interface buffer"
+ {binding = @legacy_io::@arg1} : memref<5x10xf32>
+ %cst = constant 0.000000e+00 : f32
+ linalg.matmul(%1, %2, %0) :
+ memref<4x5xf32>, memref<5x10xf32>, memref<4x10xf32>
+ return
+}
+```
+
+<a name="snippet5"></a> Snippet 5 : Matrix-matrix multiply after conversion to
+Linalg operation on `memref`s.
+
+```mlir
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+func @main_ex_dispatch() {
+ %0 = iree.placeholder for "interface buffer"
+ {binding = @legacy_io::@ret0} : memref<10x5xf32>
+ %c0 = constant 0 : index
+ %1 = iree.placeholder for "interface buffer"
+ {binding = @legacy_io::@arg0} : memref<10x5xf32>
+ %2 = iree.placeholder for "interface buffer"
+ {binding = @legacy_io::@arg1} : memref<10x5xf32>
+ %3 = iree.placeholder for "interface buffer"
+ {binding = @legacy_io::@arg2} : memref<10x5xf32>
+ linalg.generic
+ {args_in = 3 : i64, args_out = 1 : i64,
+ indexing_maps = [#map0, #map0, #map0],
+ iterator_types = ["parallel", "parallel"]} %1, %2, %3, %0 {
+ ^bb0(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32): // no predecessors
+ %4 = addf %arg0, %arg1 : f32
+ %5 = mulf %4, %arg2 : f32
+ linalg.yield %5 : f32
+ }: memref<10x5xf32>, memref<10x5xf32>, memref<10x5xf32>, memref<10x5xf32>
+ return
+}
+```
+
+<a name="snippet6"></a> Snippet 6 : Elementwise operations after conversion to
+Linalg operation on `memref`s
+
+The rest of the code-generation differs on whether the compilation is for CPU
+(using LLVM) or for GPU (using SPIR-V).
+
+## Conversion from Linalg on buffers to SPIR-V dialect
+
+The following sections describe the progressive lowering of Linalg operation on
+buffers to SPIR-V dialect. Once lowered to the SPIR-V dialect, it can be
+serialized into a SPIR-V binary using the
+[serialization mechanism provided by the SPIR-V dialect][SpirvSerialization].
+The steps involved in the lowering are described below, with each of the arrows
+representing a pass.
+
+
+
+These passes are described below in more detail.
+
+### Tiling and fusion on buffer operations
+
+The GPU hardware typically provides multiple-levels of compute hierarchy, namely
+*workgroup* level, *subgroup* level and *workitem* level. These map to blocks,
+warps and threads, respectively, in CUDA terminology. Tiling is a way to map the
+computations to each level of the compute hierarchy. For example 3-D tiling a
+`linalg.matmul` operation decomposes the computation into several tiled
+matrix-matrix multiplies.
+[Tiling transformation in Linalg dialect][LinalgTiling] generates the
+outer-loops that iterate over tiled `linalg.matmul` operations. These outer
+loops can be mapped to different workgroups, if they are parallel. The tiled
+`linalg.matmul` operation can be further tiled to map to subgroups. Finally, the
+tiled operation can be lowered to loops with individual iterations mapped to
+workitems. The [LinalgTileAndFusePass][LinalgTileAndFuse] uses the Linalg Tiling
+patterns ([defined here][LinalgTilingPatterns]) to tile operations like
+`linalg.matmul`, `linalg.conv` and `linalg.*_pooling`. The result of tiling the
+code in Snippet 5 is shown below. As expected there are 2-parallel loops that
+iterate over tiles of the original iteration space (i.e. inter-tile loops) and
+can be distributed to workgroups.
+
+```mlir
+func @main_ex_dispatch_0()
+ attributes {
+ spv.entry_point_abi = {local_size = dense<[8, 8, 1]> : vector<3xi32>}} {
+ %cst = constant 0.000000e+00 : f32
+ %c0 = constant 0 : index
+ %c4 = constant 4 : index
+ %c10 = constant 10 : index
+ %0 = iree.placeholder for "interface buffer"
+ {binding = @legacy_io::@ret0} : memref<4x10xf32>
+ %1 = iree.placeholder for "interface buffer"
+ {binding = @legacy_io::@arg0} : memref<4x5xf32>
+ %2 = iree.placeholder for "interface buffer"
+ {binding = @legacy_io::@arg1} : memref<5x10xf32>
+ linalg.fill(%0, %cst) : memref<4x10xf32>, f32
+ scf.parallel (%arg0, %arg1) = (%c0, %c0) to (%c4, %c10) step (%c8, %c8) {
+ scf.for %arg2 = %c0 to %c5 step %c4 {
+ ...
+ %5 = subview %1[%arg0, %arg2]...
+ ...
+ %8 = subview %2[%arg2, %arg1]...
+ ...
+ %11 = subview %0[%arg0, %arg1]..
+ linalg.matmul {__internal_linalg_transform__ = "workgroup"} %5, %8, %11...
+ }
+ scf.yield
+ }
+ return
+}
+```
+
+<a name="snippet7"></a> Snippet 7 : `linalg.matmul` after tiling.
+
+#### Tile Size and Workgroup Size
+
+When operations that are to be tiled exist within the dispatch function (like
+`linalg.matmul` or `linalg.conv`), this pass also decides the 1. Tile size to be
+used for the tiling. 1. The workgroup size to be used.
+
+The tile size and workgroup size are closely linked since the code within the
+tiled loops are to be collectively executed by the entire workgroup. In other
+words, all workitems in the workgroup collaborate to execute the tiled
+`linalg.matmul`.
+
+__Roadmap Note__ : Currently the tile sizes used in this pass are hard-wired.
+Not much effort has been put into finding ideal tile size for each operation on
+different hardware. The value used is meant to be a baseline to test
+functionality, with performance considerations addressed over time.
+
+#### Markers
+
+Downstream passes have to handle tiled Linalg operations and untiled Linalg
+operation that might exist in the same function in different ways. For example,
+while the former are to be executed collectively by workitems within a
+workgroup, the latter have to be executed by all workitems across workgroups.
+One way to distinguish these two operations is to use the marker mechanism in
+Linalg ([LinalgMarker][LinalgTilingPatterns]). This is a `StrAttr` whose value
+can be used to encode the scope of the operation. For example, in Snippet 7
+above, the tiled `linalg.matmul` operation has a marker `workgroup` to indicate
+that this operation needs to be executed by a workgroup in a collective manner.
+At this time, the code-generation pipeline uses only the `workgroup` marker.
+
+__Roadmap Note__ : Markers are meant to be short-lived, ideally set and consumed
+within the same pass. In the current pipeline the lifetime spans passes to allow
+lowering to different hierarchies. The separate passes that implement the
+lowering from Linalg to SPIR-V can be combined into a single pass, relying A ->
+B -> C translation mechanism of the Dialect Conversion framework to implement
+the progressive lowering. In interest of separation of concerns and for better
+debuggability these passes are kept separate at the cost of having lifetimes of
+markers span passes.
+
+#### Promoting subviews to use workgroup local memory and use of synchronizations
+
+`Workgroup` memory (or `shared memory` in CUDA terminology) can be used to
+prefetch the inputs to the tiled operation. For example in the matrix-matrix
+multiply case, the same data row (column) of the LHS (RHS) matrix is read by
+multiple workitems. Prefetching the data into `Workgroup` memory can reduce the
+number of loads to `StorageClass` memory by an order of magnitude. This
+transformation can be achieved by using the
+[`Linalg Promotion`][LinalgPromotionPatterns] which modifies the `subview`s that
+are the operands to the tiled Linalg operation to use a new `memref` object. The
+size of this `memref` is computed from the size of the `subview`. This `memref`
+object is later lowered to use `Workgroup` memory Storage Class. The snippet
+below shows this transformation when applied to `linalg.matmul` (along with
+tiling). The newly created `memref` objects are annotated with the memory space
+`3` to indicate that they are to be lowered to use `Workgroup` memory. The copy
+of data from the original `memref` into the new `memref`, as well as the
+necessary synchronization constructs are generated as well. Note the memory
+space annotation used here is consistent with what
+[address space annotations used in NVVM][NVVMAddressSpace].
+
+```mlir
+func @matmul_tile()
+ attributes {
+ spv.entry_point_abi = {local_size = dense<[8, 8, 1]> : vector<3xi32>}} {
+ %c96 = constant 96 : index
+ %c4 = constant 4 : index
+ %c8 = constant 8 : index
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %0 = iree.placeholder for "interface buffer"
+ {binding = @legacy_io::@arg0} : memref<96x96xf32>
+ %1 = iree.placeholder for "interface buffer"
+ {binding = @legacy_io::@arg1} : memref<96x96xf32>
+ %2 = iree.placeholder for "interface buffer"
+ {binding = @legacy_io::@ret0} : memref<96x96xf32>
+ scf.parallel (%arg0, %arg1) = (%c0, %c0) to (%c96, %c96) step (%c8, %c8) {
+ scf.for %arg2 = %c0 to %c96 step %c4 {
+ ...
+ %5 = subview %0[%arg0, %arg2]...
+ ...
+ %8 = subview %1[%arg2, %arg1]...
+ ...
+ %11 = subview %2[%arg0, %arg1]...
+ %12 = alloc(%c8, %c4) : memref<?x?xf32, 3>
+ %13 = subview %12[%c0, %c0]...
+ %14 = alloc(%c4, %c8) : memref<?x?xf32, 3>
+ %15 = subview %14[%c0, %c0]...
+ linalg.copy(%5, %13) {__internal_linalg_transform__ = "workgroup"}
+ : memref<?x?xf32, #map2>, memref<?x?xf32, #map2, 3>
+ spv.ControlBarrier "Workgroup", "Workgroup", "AcquireRelease"
+ linalg.copy(%8, %15) {__internal_linalg_transform__ = "workgroup"}
+ : memref<?x?xf32, #map2>, memref<?x?xf32, #map2, 3>
+ spv.ControlBarrier "Workgroup", "Workgroup", "AcquireRelease"
+ linalg.matmul {__internal_linalg_transform__ = "workgroup"} %13, %15, %11...
+ spv.ControlBarrier "Workgroup", "Workgroup", "AcquireRelease"
+ dealloc %12 : memref<?x?xf32, 3>
+ dealloc %14 : memref<?x?xf32, 3>
+ }
+ scf.yield
+ }
+ return
+}
+```
+
+<a name="snippet8"></a> Snippet 8: `linalg.matmul` after tiling and promotion of
+operand subviews to use `Workgroup` memory.
+
+### Distributing to workgroups and workitems
+
+After tiling the operations within the dispatch functions are either
+`scf.parallel` operations or Linalg operations.
+
+- The outer `scf.parallel` operations represent parallel loops that are to be
+ distributed across workgroups. The distribution here assumes that the number
+ of workgroups along each dimension is equal to the number of iterations of
+ the `scf.parallel` operation.
+
+- Linalg operations that are not tiled, and are therefore __not within__ `scf`
+ operations, are lowered to loops. The resulting outer `scf.parallel`
+ operations are collapsed to have a single induction variable. This loop is
+ then distributed across workitems using their `GlobalInvocationId`, (which
+ is same as `blockIdx * blockDim + threadIdx` in CUDA terminology).
+
+- Linalg operations that are tiled, and are therefore __within__ `scf`
+ operations, are lowered to loops and the iterations of the `scf.parallel`
+ operations are mapped to workitems using their `LocalInvocationId` (which is
+ same as `threadIdx` in CUDA terminology). Note that these operations are
+ tagged with the `workgroup` marker which makes it easy to disambiguate from
+ the case where Linalg operations are outside of `scf` operations. Here too,
+ the distribution assumes that the workgroup size is greater than or equal to
+ the number of iterations of the partitioned loop.
+
+These transformations are applied by the [`ConvertToGPUPass`][ConvertToGPU].
+Below is the result of applying this pass to Snippet 7. The outer `scf.parallel`
+loop is distributed across workgroups. The tiled `linalg.matmul` operation is
+lowered to loops, and the outer `scf.parallel` operation generated during this
+lowering are distributed across workitems within the workgroup.
+
+```mlir
+func @main_ex_dispatch_0_dispatch_1()
+ attributes {
+ spv.entry_point_abi = {local_size = dense<[8, 8, 1]> : vector<3xi32>}} {
+ %c5 = constant 5 : index
+ %c8 = constant 8 : index
+ %c4 = constant 4 : index
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %0 = iree.placeholder for "interface buffer"
+ {binding = @legacy_io::@ret0} : memref<4x10xf32>
+ %1 = iree.placeholder for "interface buffer"
+ {binding = @legacy_io::@arg0} : memref<4x5xf32>
+ %2 = iree.placeholder for "interface buffer"
+ {binding = @legacy_io::@arg1} : memref<5x10xf32>
+ %3 = "gpu.block_id"() {dimension = "x"} : () -> index
+ %4 = muli %3, %c8 : index
+ scf.for %arg0 = %c0 to %c5 step %c4 {
+ ...
+ %9 = subview %1[0, %arg0]
+ ...
+ %14 = subview %2[%arg0, %4]
+ %15 = subview %0[0, %4]
+ %16 = "gpu.thread_id"() {dimension = "x"} : () -> index
+ %17 = "gpu.thread_id"() {dimension = "y"} : () -> index
+ %18 = cmpi "slt", %17, %c4 : index
+ %19 = cmpi "slt", %16, %13 : index
+ %20 = and %18, %19 : i1
+ scf.if %20 {
+ scf.for %arg1 = %c0 to %8 step %c1 {
+ %21 = load %9[%17, %arg1] : memref<4x?xf32, #map0>
+ %22 = load %14[%arg1, %16] : memref<?x?xf32, #map1>
+ %23 = load %15[%17, %16] : memref<4x?xf32, #map1>
+ %24 = mulf %21, %22 : f32
+ %25 = addf %23, %24 : f32
+ store %25, %15[%17, %16] : memref<4x?xf32, #map1>
+ }
+ }
+ }
+ return
+}
+```
+
+<a name="snippet9"></a> Snippet 9: `linalg.matmul` after distributing parallel
+inter-tile loops to workgroups and intra-tile loops to workitems.
+
+[Snippet 6](#snippet6) shows the fused element-wise operations represented using
+a `linalg.generic` operation. This operation is not tiled in the
+`LinalgTileAndFusePass`. So the `ConvertToGPUPass` lowers this operation to
+`scf.parallel` loops, which are collapsed into a `scf.parallel` operation with a
+single induction variable. This loop is then distributed across workitems using
+the `GlobalInvocationId`. The resulting IR is shown below.
+
+```mlir
+func @main_ex_dispatch_0()
+ attributes {
+ spv.entry_point_abi = {local_size = dense<[32, 1, 1]> : vector<3xi32>}} {
+ %c50 = constant 50 : index
+ %c5 = constant 5 : index
+ %0 = iree.placeholder for "interface buffer"
+ {binding = @legacy_io::@ret0} : memref<10x5xf32>
+ %1 = iree.placeholder for "interface buffer"
+ {binding = @legacy_io::@arg0} : memref<10x5xf32>
+ %2 = iree.placeholder for "interface buffer"
+ {binding = @legacy_io::@arg1} : memref<10x5xf32>
+ %3 = iree.placeholder for "interface buffer"
+ {binding = @legacy_io::@arg2} : memref<10x5xf32>
+ %4 = "gpu.block_id"() {dimension = "x"} : () -> index
+ %5 = "gpu.block_dim"() {dimension = "x"} : () -> index
+ %6 = "gpu.thread_id"() {dimension = "x"} : () -> index
+ %7 = muli %4, %5 : index
+ %8 = addi %7, %6 : index
+ %9 = cmpi "slt", %8, %c50 : index
+ scf.if %9 {
+ %10 = divi_signed %8, %c5 : index
+ %11 = remi_signed %8, %c5 : index
+ %12 = load %1[%10, %11] : memref<10x5xf32>
+ %13 = load %2[%10, %11] : memref<10x5xf32>
+ %14 = load %3[%10, %11] : memref<10x5xf32>
+ %15 = addf %12, %13 : f32
+ %16 = mulf %15, %14 : f32
+ store %16, %0[%10, %11] : memref<10x5xf32>
+ }
+ return
+}
+```
+
+<a name="snippet10"></a> Snippet 10: Distributing the iterations for pointwise
+operations for GPU execution.
+
+### Lowering to SPIR-V dialect
+
+The last step is to take the result of the previous pass and lowering it to
+SPIR-V dialect. Since SPIR-V dialect is *closed*, i.e. it has a separate type
+system, its best to lower all the operations to SPIR-V in one step. This is done
+by applying all the patterns that lower all the different IR constructs into
+SPIR-V within the [`ConvertToSPIRVPass`][ConvertToSPIRV]. These are
+
+- [GPU dialect to SPIR-V conversion][GPUToSPIRV].
+- [SCF dialect to SPIR-V conversion][SCFToSPIRV].
+- [Standard dialect to SPIR-V conversion][StandardToSPIRV].
+- Patterns that lower the `iree.placeholder` instruction into a SPIR-V.
+
+Once applied the resulting IR is in SPIR-V dialect that can be serialized to a
+SPIR-V binary.
+
+[ConvertToGPU]: https://github.com/google/iree/blob/main/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp
+[ConvertToSPIRV]: https://github.com/google/iree/blob/main/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp
+[DotAfterAll]: https://gist.github.com/MaheshRavishankar/9e2d406296f469515c4a79bf1e7eef44
+[GPUToSPIRV]: https://github.com/llvm/llvm-project/blob/master/mlir/include/mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.h
+[HLOToLinalgPass]: https://github.com/tensorflow/tensorflow/blob/75c40f6bff2faa3d90a375dfa4025b2e6e2d7a3d/tensorflow/compiler/mlir/xla/transforms/passes.h#L67
+[LinalgDialect]: https://mlir.llvm.org/docs/Dialects/Linalg/
+[LinalgFusionOnBuffers]: https://github.com/llvm/llvm-project/blob/ef868a848e6def288d2df7a1b3ebe09463afc8d0/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h#L86
+[LinalgFusionOfTensorOps]: https://github.com/llvm/llvm-project/blob/80cb25cbd555f9634836b766c86aead435b60eaa/mlir/include/mlir/Dialect/Linalg/Passes.td#L30
+[LinalgPromotionPatterns]: https://github.com/llvm/llvm-project/blob/303a7f7a26e2aae1cb85f49dccbc0b5d14e0b2e0/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h#L358
+[LinalgRationale]: https://mlir.llvm.org/docs/Rationale/RationaleLinalgDialect/
+[LinalgTileAndFuse]: https://github.com/google/iree/blob/main/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
+[LinalgTiling]: https://mlir.llvm.org/docs/Dialects/Linalg/#set-of-key-transformationsa-namekey_transformationsa
+[LinalgTilingPatterns]: https://github.com/llvm/llvm-project/blob/master/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+[NVVMAddressSpace]: https://docs.nvidia.com/cuda/nvvm-ir-spec/index.html#address-space
+[PwAfterAll]: https://gist.github.com/MaheshRavishankar/02cdd22f7c99e568f933244b5a679510
+[SCFToSPIRV]: https://github.com/llvm/llvm-project/blob/master/mlir/include/mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h
+[SpirvSerialization]: https://mlir.llvm.org/docs/Dialects/SPIR-V/#serialization-and-deserialization
+[StandardToSPIRV]: https://github.com/llvm/llvm-project/blob/master/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h
diff --git a/docs/design_docs/hlo_to_linalg.png b/docs/design_docs/hlo_to_linalg.png
new file mode 100755
index 0000000..469ed26
--- /dev/null
+++ b/docs/design_docs/hlo_to_linalg.png
Binary files differ
diff --git a/docs/design_docs/linalg_to_spirv.png b/docs/design_docs/linalg_to_spirv.png
new file mode 100755
index 0000000..fd6aee7
--- /dev/null
+++ b/docs/design_docs/linalg_to_spirv.png
Binary files differ
diff --git a/docs/get_started/cmake_options_and_variables.md b/docs/get_started/cmake_options_and_variables.md
index 8e121fc..3f3dbdd 100644
--- a/docs/get_started/cmake_options_and_variables.md
+++ b/docs/get_started/cmake_options_and_variables.md
@@ -63,17 +63,21 @@
#### `IREE_HAL_DRIVERS_TO_BUILD`:STRING
-*This does not have any effect at the moment, but will be supported in the
-future!* Semicolon-separated list of HAL drivers to build, or `all` for building
-all HAL drivers. Case-insensitive. Defaults to `all`. Example:
+*Righ now this only affects whether tests are enabled when compiling for
+Android; it will be fully supported in the future!*
+
+Semicolon-separated list of HAL drivers to build, or `all` for building all HAL
+drivers. Case-insensitive. Defaults to `all`. Example:
`-DIREE_HAL_DRIVERS_TO_BUILD="Vulkan;VMLA"`.
#### `IREE_TARGET_BACKENDS_TO_BUILD`:STRING
-*This does not have any effect at the moment, but will be supported in the
-future!* Semicolon-separated list of HAL drivers to build, or `all` for building
-all HAL drivers. Case-insensitive. Defaults to `all`. Example:
-`-DIREE_HAL_DRIVERS_TO_BUILD="Vulkan_SPIRV;VMLA"`.
+*Righ now this only affects whether tests are enabled when compiling for
+Android; it will be fully supported in the future!*
+
+Semicolon-separated list of HAL drivers to build, or `all` for building all
+compiler target backends. Case-insensitive. Defaults to `all`. Example:
+`-DIREE_HAL_DRIVERS_TO_BUILD="Vulkan-SPIRV;VMLA"`.
#### `IREE_ENABLE_LLD`:BOOL
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
index 8f5ec8d..46a3785 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
@@ -39,6 +39,16 @@
np.random.seed(seed)
+def backends_to_str(target_backends):
+ """Creates a flattened and normalized string representing target_backends."""
+ normalized_backends = []
+ for backend in target_backends:
+ # Remove unusual characters and ensure names don't end or start in "_".
+ backend = re.sub("[^0-9a-zA-Z_]+", "_", backend)
+ normalized_backends.append(backend.strip("_"))
+ return "__".join(normalized_backends)
+
+
def compile_tf_module(tf_module,
target_backends=(),
exported_names=(),
@@ -52,9 +62,9 @@
saved_model:
A TF SavedModel directory containing the files used translate the
tf.Module into an IREE module.
- tf_input__backends.mlir:
+ tf_input.mlir:
MLIR for the module in TF's input dialect.
- iree_input__backends.mlir:
+ iree_input.mlir:
The MLIR above translated to IREE via compiler.TF_IMPORT_PASS_PIPELINE.
compiled__backends.vmfb:
A VM FlatBuffer compiled to the target backends from the IREE MLIR above.
@@ -77,14 +87,6 @@
# We break up the compilation here so we can save intermediary artifacts.
compiler_context = compiler.Context()
- if artifacts_dir is not None:
- normalized_backends = []
- for backend in target_backends:
- # Remove unusual characters and ensure names don't end or start in "_".
- backend = re.sub("[^0-9a-zA-Z_]+", "_", backend)
- normalized_backends.append(backend.strip("_"))
- backends_string = "__".join(normalized_backends)
-
# Convert the tf_module into raw TF input MLIR.
compiler_module = compiler.tf_load_saved_model(
sm_path,
@@ -93,8 +95,7 @@
pass_pipeline=())
if artifacts_dir is not None:
- tf_mlir_path = os.path.join(artifacts_dir,
- f"tf_input__{backends_string}.mlir")
+ tf_mlir_path = os.path.join(artifacts_dir, "tf_input.mlir")
logging.info("Saving raw TF input MLIR to: %s", tf_mlir_path)
with open(tf_mlir_path, "w") as f:
f.write(compiler_module.to_asm())
@@ -103,16 +104,15 @@
compiler_module.run_pass_pipeline(compiler.TF_IMPORT_PASS_PIPELINE)
if artifacts_dir is not None:
- iree_mlir_path = os.path.join(artifacts_dir,
- f"iree_input__{backends_string}.mlir")
+ iree_mlir_path = os.path.join(artifacts_dir, "iree_input.mlir")
logging.info("Saving IREE input MLIR to: %s", iree_mlir_path)
with open(iree_mlir_path, "w") as f:
f.write(compiler_module.to_asm())
compiled_module = compiler_module.compile(target_backends=target_backends)
if artifacts_dir is not None:
- compiled_path = os.path.join(artifacts_dir,
- f"compiled__{backends_string}.vmfb")
+ compiled_name = f"compiled__{backends_to_str(target_backends)}.vmfb"
+ compiled_path = os.path.join(artifacts_dir, compiled_name)
logging.info("Saving compiled IREE module to: %s", compiled_path)
with open(compiled_path, "wb") as f:
f.write(compiled_module)
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
index 25645df..b1d9adb 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
@@ -17,6 +17,7 @@
import os
import tempfile
+from absl import logging
from absl.testing import parameterized
from pyiree.tf.support import tf_utils
import tensorflow as tf
@@ -52,7 +53,7 @@
},
{
'testcase_name': 'multiple_backends',
- 'target_backends': ['vmla', 'llvm'],
+ 'target_backends': ['vmla', 'llvm-ir'],
},
])
def test_artifact_saving(self, target_backends):
@@ -65,12 +66,14 @@
artifacts_to_check = [
'saved_model',
- f'tf_input__{"__".join(target_backends)}.mlir',
- f'iree_input__{"__".join(target_backends)}.mlir',
- f'compiled__{"__".join(target_backends)}.vmfb',
+ 'tf_input.mlir',
+ 'iree_input.mlir',
+ f'compiled__{tf_utils.backends_to_str(target_backends)}.vmfb',
]
for artifact in artifacts_to_check:
- self.assertTrue(os.path.exists(os.path.join(artifacts_dir, artifact)))
+ artifact_path = os.path.join(artifacts_dir, artifact)
+ logging.info('Checking path: %s', artifact_path)
+ self.assertTrue(os.path.exists(artifact_path))
@parameterized.named_parameters([
{
diff --git a/iree/compiler/Conversion/LinalgToLLVM/BUILD b/iree/compiler/Conversion/LinalgToLLVM/BUILD
index 41d9988..fcd9476 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/BUILD
+++ b/iree/compiler/Conversion/LinalgToLLVM/BUILD
@@ -22,7 +22,6 @@
name = "LinalgToLLVM",
srcs = [
"ConvertToLLVM.cpp",
- "HALInterfaceToMemrefArguments.cpp",
"Passes.cpp",
],
hdrs = [
diff --git a/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt b/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
index d21dc19..bc31e4e 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
+++ b/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
@@ -21,7 +21,6 @@
"Passes.h"
SRCS
"ConvertToLLVM.cpp"
- "HALInterfaceToMemrefArguments.cpp"
"Passes.cpp"
DEPS
MLIRAffineToStandard
diff --git a/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp b/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp
index 2134171..76fb2c5 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp
+++ b/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp
@@ -125,6 +125,180 @@
}
};
+/// Returns true if `aOp` has a desciptor (set, binding) pair smaller than
+/// `bOp`. Note that this ignores the offset.
+bool operator<(IREE::HAL::InterfaceBindingOp aOp,
+ IREE::HAL::InterfaceBindingOp bOp) {
+ if (aOp.set().getZExtValue() == bOp.set().getZExtValue())
+ return aOp.binding().getZExtValue() < bOp.binding().getZExtValue();
+ return aOp.set().getZExtValue() < bOp.set().getZExtValue();
+}
+
+// Change signature of entry function to func
+// entry_func(%packed_buffers_arg_ptr:
+// !<llvm.int8**>, %push_constant: !<llvm.int64*>) and lower IREE and HAL ops to
+// corresponding LLVMIR ops to construct memref descriptors and load
+// push_constant values.
+class ConvertFuncWithHALInterface : public ConvertToLLVMPattern {
+ public:
+ explicit ConvertFuncWithHALInterface(MLIRContext *context,
+ LLVMTypeConverter &typeconverter)
+ : ConvertToLLVMPattern(FuncOp::getOperationName(), context,
+ typeconverter) {}
+
+ LogicalResult matchAndRewrite(
+ Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ if (SymbolTable::getSymbolVisibility(op) != SymbolTable::Visibility::Public)
+ return failure();
+ auto funcOp = dyn_cast_or_null<FuncOp>(op);
+ FunctionType fnType = funcOp.getType();
+ if (fnType.getNumInputs() != 0) {
+ return rewriter.notifyMatchFailure(
+ funcOp, "entry function should not have inputs");
+ }
+
+ // Get interface buffers from all the blocks.
+ SmallVector<IREE::PlaceholderOp, 8> bufferOps;
+ SmallVector<IREE::HAL::InterfaceLoadConstantOp, 8> loadOps;
+ for (Block &block : funcOp.getBlocks()) {
+ for (Operation &op : block) {
+ if (auto phOp = dyn_cast<IREE::PlaceholderOp>(op))
+ bufferOps.push_back(phOp);
+ if (auto phOp = dyn_cast<IREE::HAL::InterfaceLoadConstantOp>(op)) {
+ loadOps.push_back(phOp);
+ }
+ }
+ }
+
+ if (bufferOps.empty()) return failure();
+
+ // A map from buffer ops to their corresponding interface binding ops.
+ llvm::DenseMap<Operation *, IREE::HAL::InterfaceBindingOp> bufferBindingMap;
+ for (auto bufferOp : bufferOps) {
+ auto symbol = SymbolTable::lookupNearestSymbolFrom(
+ bufferOp, bufferOp.getAttrOfType<SymbolRefAttr>("binding"));
+ bufferBindingMap[bufferOp] = cast<IREE::HAL::InterfaceBindingOp>(symbol);
+ }
+
+ // Sort buffers according to their descriptor (set, binding) pair.
+ llvm::sort(bufferOps, [&bufferBindingMap](IREE::PlaceholderOp aBuffer,
+ IREE::PlaceholderOp bBuffer) {
+ return bufferBindingMap[aBuffer] < bufferBindingMap[bBuffer];
+ });
+
+ // A map from buffer ops to their corresponding function argument indices.
+ llvm::DenseMap<Operation *, unsigned> bufferArgMap;
+ // A map from binding ops to their corresponding function argument indices.
+ llvm::DenseMap<Operation *, unsigned> bindingArgMap;
+ llvm::SmallVector<MemRefType, 4> inputMemRefTypes;
+ llvm::SmallVector<LLVM::LLVMType, 4> inputStructPtrs;
+ unsigned argIndex = 0;
+ for (auto bufferOp : bufferOps) {
+ auto binding = bufferBindingMap[bufferOp];
+ auto it = bindingArgMap.find(binding);
+ if (it != bindingArgMap.end()) {
+ bufferArgMap[bufferOp] = it->second;
+ } else {
+ bindingArgMap[binding] = argIndex;
+ bufferArgMap[bufferOp] = argIndex;
+ ++argIndex;
+ }
+
+ auto memrefType = bufferOp.getType().dyn_cast_or_null<MemRefType>();
+ inputMemRefTypes.push_back(memrefType);
+ auto elementType = typeConverter.convertType(memrefType.getElementType())
+ .dyn_cast<LLVM::LLVMType>();
+ if (!elementType) return failure();
+ inputStructPtrs.push_back(
+ elementType.getPointerTo(memrefType.getMemorySpace()));
+ }
+
+ TypeConverter::SignatureConversion signatureConverter(/*numOrigInputs=*/0);
+
+ // func foo(%packed_buffer_args: !llvm<i8**>, %push_constant: !llvm<i64*>)
+ auto packedBuffersArgsTy =
+ LLVM::LLVMType::getInt8PtrTy(typeConverter.getDialect()).getPointerTo();
+ auto pushConstantArgTy =
+ LLVM::LLVMType::getInt64Ty(typeConverter.getDialect()).getPointerTo();
+ signatureConverter.addInputs(packedBuffersArgsTy);
+ signatureConverter.addInputs(pushConstantArgTy);
+
+ // Create the new function's signature.
+ Location loc = funcOp.getLoc();
+ auto newFuncOp = rewriter.create<FuncOp>(
+ loc, funcOp.getName(),
+ rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
+ llvm::None),
+ ArrayRef<NamedAttribute>());
+
+ // Move all ops in the old function's region to the new function.
+ rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
+ newFuncOp.end());
+ rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter);
+
+ auto builder = OpBuilder::atBlockBegin(&(newFuncOp.getBlocks().front()));
+
+ // Cast and unpack input packed_buffer_arguments and construct memref
+ // descriptors.
+ Value packedBuffersArgsPtr = builder.create<LLVM::BitcastOp>(
+ loc,
+ LLVM::LLVMType::getStructTy(typeConverter.getDialect(), inputStructPtrs)
+ .getPointerTo(),
+ newFuncOp.getArgument(0));
+ Value packedBuffersArgs =
+ builder.create<LLVM::LoadOp>(loc, packedBuffersArgsPtr);
+ for (auto bufferOp : bufferOps) {
+ MemRefType memrefType = bufferOp.getType().dyn_cast_or_null<MemRefType>();
+ if (!memrefType) return failure();
+ const auto index = bufferArgMap[bufferOp];
+ Value bufferPtr = builder.create<LLVM::ExtractValueOp>(
+ loc, inputStructPtrs[index], packedBuffersArgs,
+ rewriter.getI64ArrayAttr(index));
+ if (memrefType.hasStaticShape()) {
+ auto desc = MemRefDescriptor::fromStaticShape(
+ builder, loc, typeConverter, memrefType, bufferPtr);
+ rewriter.replaceOp(bufferOp, {desc});
+ } else {
+ auto desc = MemRefDescriptor::undef(
+ builder, loc, typeConverter.convertType(memrefType));
+ desc.setAllocatedPtr(builder, loc, bufferPtr);
+ desc.setAlignedPtr(builder, loc, bufferPtr);
+ rewriter.replaceOp(bufferOp, {desc});
+ }
+ }
+
+ // Lower hal.interface.load.constant ops into llvm.getelementptr, llvm.load
+ for (auto loadOp : loadOps) {
+ Value offset = builder.create<LLVM::ConstantOp>(
+ loc, LLVM::LLVMType::getInt64Ty(typeConverter.getDialect()),
+ builder.getI64IntegerAttr(loadOp.offset().getZExtValue()));
+ Value constPtr = builder.create<LLVM::GEPOp>(loc, pushConstantArgTy,
+ newFuncOp.getArgument(1),
+ ArrayRef<Value>({offset}));
+ Value dimConstant = builder.create<LLVM::LoadOp>(loc, constPtr);
+ rewriter.replaceOp(loadOp, dimConstant);
+ }
+
+ rewriter.eraseOp(funcOp);
+ return success();
+ }
+};
+
+class RemoveInterfaceOpPattern : public ConvertToLLVMPattern {
+ public:
+ explicit RemoveInterfaceOpPattern(MLIRContext *context,
+ LLVMTypeConverter &typeconverter)
+ : ConvertToLLVMPattern(IREE::HAL::InterfaceOp::getOperationName(),
+ context, typeconverter) {}
+ LogicalResult matchAndRewrite(
+ Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
namespace {
struct ConvertToLLVMPass
: public PassWrapper<ConvertToLLVMPass, OperationPass<ModuleOp>> {
@@ -151,11 +325,12 @@
populateVectorToLLVMConversionPatterns(converter, patterns);
populateLinalgToLLVMConversionPatterns(converter, patterns, &getContext());
// The following patterns resolves dynamic shapes by substituting tie_shape
- // ops with an updated memref descriptors and replacing RankDimOp with actual
- // index loaded from memref<?xi32> that holds all dynamic shapes
- // push constants.
- patterns.insert<ConvertRankedDimPattern, ConvertTieShapePattern,
- RemoveMakeRankedShape>(&getContext(), converter);
+ // ops with an updated memref descriptors and replacing RankDimOp with
+ // actual index loaded from memref<?xi32> that holds all dynamic shapes push
+ // constants.
+ patterns.insert<ConvertFuncWithHALInterface, ConvertRankedDimPattern,
+ ConvertTieShapePattern, RemoveMakeRankedShape,
+ RemoveInterfaceOpPattern>(&getContext(), converter);
LLVMConversionTarget target(getContext());
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
if (failed(applyPartialConversion(module, target, patterns)))
@@ -168,7 +343,8 @@
static PassRegistration<ConvertToLLVMPass> pass(
"iree-codegen-convert-to-llvm",
- "Perform final conversion from Linalg/HAL/Shape/Vector/Standard to LLVMIR "
+ "Perform final conversion from Linalg/HAL/Shape/Vector/Standard to "
+ "LLVMIR "
"dialect",
[] { return std::make_unique<ConvertToLLVMPass>(); });
diff --git a/iree/compiler/Conversion/LinalgToLLVM/HALInterfaceToMemrefArguments.cpp b/iree/compiler/Conversion/LinalgToLLVM/HALInterfaceToMemrefArguments.cpp
deleted file mode 100644
index ac968e4..0000000
--- a/iree/compiler/Conversion/LinalgToLLVM/HALInterfaceToMemrefArguments.cpp
+++ /dev/null
@@ -1,231 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include <memory>
-
-#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
-#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
-#include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Pass/PassRegistry.h"
-#include "mlir/Transforms/DialectConversion.h"
-
-namespace mlir {
-namespace iree_compiler {
-namespace {
-
-/// Returns true if the given function contains interface related operations
-/// that are used by other ops.
-bool containsUsedInterfaceOp(FuncOp funcOp) {
- for (Block& block : funcOp.getBlocks()) {
- for (Operation& op : block) {
- if (!op.getUses().empty() &&
- (isa<IREE::PlaceholderOp>(op) ||
- isa<IREE::HAL::InterfaceLoadConstantOp>(op))) {
- return true;
- }
- }
- }
- return false;
-}
-
-/// Returns true if `aOp` has a desciptor (set, binding) pair smaller than
-/// `bOp`. Note that this ignores the offset.
-bool operator<(IREE::HAL::InterfaceBindingOp aOp,
- IREE::HAL::InterfaceBindingOp bOp) {
- if (aOp.set().getZExtValue() == bOp.set().getZExtValue())
- return aOp.binding().getZExtValue() < bOp.binding().getZExtValue();
- return aOp.set().getZExtValue() < bOp.set().getZExtValue();
-}
-
-/// A pattern to process function interface. It replaces interface related ops
-/// with function arguments to match LLVM's CodeGen's ABI contract.
-///
-/// IREE scheduler passes interface ABI information via hal.interface.* ops to
-/// all backends. We create iree.placeholder ops to represent buffers behind
-/// those hal.interface.* ops. However the LLVM CodeGen uses function parameters
-/// and memref descriptors for ABI. So we need to bridge the gap somewhere.
-///
-/// This pass finds all interface buffers used in the function, sort them
-/// according to the descriptor (set, binding) pair, and put unique ones as
-/// function parameters in order.
-/// Note: This should be kept consistent with LLVM's HAL backend.
-struct ProcessFuncInterfacePattern : public OpConversionPattern<FuncOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult matchAndRewrite(
- FuncOp funcOp, ArrayRef<Value> Operands,
- ConversionPatternRewriter& rewriter) const override {
- // Only process entry functions.
- if (SymbolTable::getSymbolVisibility(funcOp) !=
- SymbolTable::Visibility::Public)
- return failure();
-
- FunctionType fnType = funcOp.getType();
- if (fnType.getNumInputs() != 0)
- return rewriter.notifyMatchFailure(
- funcOp, "entry function should not have inputs");
-
- // Get interface buffers from all the blocks.
- SmallVector<IREE::PlaceholderOp, 8> bufferOps;
- SmallVector<IREE::HAL::InterfaceLoadConstantOp, 8> loadOps;
- for (Block& block : funcOp.getBlocks()) {
- for (Operation& op : block) {
- if (auto phOp = dyn_cast<IREE::PlaceholderOp>(op))
- bufferOps.push_back(phOp);
- if (auto phOp = dyn_cast<IREE::HAL::InterfaceLoadConstantOp>(op)) {
- loadOps.push_back(phOp);
- }
- }
- }
-
- if (bufferOps.empty()) return failure();
-
- // A map from buffer ops to their corresponding interface binding ops.
- llvm::DenseMap<Operation*, IREE::HAL::InterfaceBindingOp> bufferBindingMap;
- for (auto bufferOp : bufferOps) {
- auto symbol = SymbolTable::lookupNearestSymbolFrom(
- bufferOp, bufferOp.getAttrOfType<SymbolRefAttr>("binding"));
- bufferBindingMap[bufferOp] = cast<IREE::HAL::InterfaceBindingOp>(symbol);
- }
-
- // Sort buffers according to their descriptor (set, binding) pair.
- llvm::sort(bufferOps, [&bufferBindingMap](IREE::PlaceholderOp aBuffer,
- IREE::PlaceholderOp bBuffer) {
- return bufferBindingMap[aBuffer] < bufferBindingMap[bBuffer];
- });
-
- // Create a function argument for each of the unique binding pointed by the
- // buffer ops.
- TypeConverter::SignatureConversion signatureConverter(/*numOrigInputs=*/0);
- // A map from buffer ops to their corresponding function argument indices.
- llvm::DenseMap<Operation*, unsigned> bufferArgMap;
- // A map from binding ops to their corresponding function argument indices.
- llvm::DenseMap<Operation*, unsigned> bindingArgMap;
- unsigned argIndex = 0;
- for (auto bufferOp : bufferOps) {
- auto binding = bufferBindingMap[bufferOp];
- auto it = bindingArgMap.find(binding);
- if (it != bindingArgMap.end()) {
- bufferArgMap[bufferOp] = it->second;
- } else {
- bindingArgMap[binding] = argIndex;
- bufferArgMap[bufferOp] = argIndex;
- signatureConverter.addInputs(bufferOp.getType());
- ++argIndex;
- }
- }
- Type dynamicDimsBufferType =
- MemRefType::get(ShapedType::kDynamicSize, rewriter.getIntegerType(32));
- signatureConverter.addInputs(dynamicDimsBufferType);
-
- // Create the new function's signature.
- Location loc = funcOp.getLoc();
- auto newFuncOp = rewriter.create<FuncOp>(
- loc, funcOp.getName(),
- rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
- llvm::None),
- ArrayRef<NamedAttribute>());
- newFuncOp.setAttr("llvm.emit_c_interface",
- mlir::UnitAttr::get(funcOp.getContext()));
-
- // Move all ops in the old function's region to the new function.
- rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
- newFuncOp.end());
- rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter);
-
- // Replace all buffer ops' uses with the newly created function arguments
- // and erase them.
- for (auto bufferOp : bufferOps) {
- bufferOp.replaceAllUsesWith(
- newFuncOp.getArgument(bufferArgMap[bufferOp]));
-
- rewriter.eraseOp(bufferOp);
- }
-
- // Lower all hal.interface.load.constant ops into std.load
- // from the last buffer holding all dynamic dimensions with the proper
- // offset.
- Type indexType = rewriter.getIndexType();
- auto builder = OpBuilder::atBlockBegin(&(newFuncOp.getBlocks().front()));
- auto newLoc = newFuncOp.front().front().getLoc();
- for (auto loadOp : loadOps) {
- SmallVector<Value, 1> indices;
- Value constantOffset = builder.create<ConstantOp>(
- newLoc, indexType,
- rewriter.getIntegerAttr(indexType, loadOp.offset().getZExtValue()));
- indices.push_back(constantOffset);
- Value loadDim = builder.create<LoadOp>(
- newLoc, newFuncOp.getArgument(newFuncOp.getNumArguments() - 1),
- indices);
- Value loadDimIndex =
- builder.create<IndexCastOp>(newLoc, loadDim, indexType);
- loadOp.replaceAllUsesWith(loadDimIndex);
- rewriter.eraseOp(loadOp);
- }
- rewriter.eraseOp(funcOp);
- return success();
- }
-};
-
-struct RemoveInterfaceOpPattern
- : public OpRewritePattern<IREE::HAL::InterfaceOp> {
- using OpRewritePattern::OpRewritePattern;
- LogicalResult matchAndRewrite(IREE::HAL::InterfaceOp interfaceOp,
- PatternRewriter& rewriter) const override {
- rewriter.eraseOp(interfaceOp);
- return success();
- }
-};
-
-/// Converting from Linalg to LLVM needs to run on a module and since it
-/// applies a full conversion, make a module with jst the impl function.
-struct HALInterfaceToMemrefArgumentsPass
- : PassWrapper<HALInterfaceToMemrefArgumentsPass, OperationPass<ModuleOp>> {
- void runOnOperation() override {
- MLIRContext& context = getContext();
-
- OwningRewritePatternList patterns;
- patterns.insert<ProcessFuncInterfacePattern>(&context);
- patterns.insert<RemoveInterfaceOpPattern>(&context);
-
- ConversionTarget target(context);
- // Convert the interface related ops away.
- target.addDynamicallyLegalOp<FuncOp>(
- [](FuncOp funcOp) { return !containsUsedInterfaceOp(funcOp); });
- target.addIllegalOp<IREE::PlaceholderOp>();
- target.addIllegalDialect<IREE::HAL::HALDialect>();
- // Allow the rest.
- target.markUnknownOpDynamicallyLegal([](Operation*) { return true; });
-
- if (failed(applyFullConversion(getOperation(), target, patterns)))
- return signalPassFailure();
- }
-};
-
-} // namespace
-
-std::unique_ptr<OperationPass<ModuleOp>>
-createHALInterfaceToMemrefArgumentsPass() {
- return std::make_unique<HALInterfaceToMemrefArgumentsPass>();
-}
-
-static PassRegistration<HALInterfaceToMemrefArgumentsPass> pass(
- "iree-codegen-hal-interface-to-memref-arguments-pass",
- "Convert a function with HAL bindings interface to memref arguments",
- [] { return std::make_unique<HALInterfaceToMemrefArgumentsPass>(); });
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp b/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp
index e8c6d9c..8c8eb21 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp
+++ b/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp
@@ -35,10 +35,7 @@
passManager.addPass(createCanonicalizerPass());
passManager.addPass(createCSEPass());
- // Convert ExecuableOp entry function to use memref arguments.
- passManager.addPass(createHALInterfaceToMemrefArgumentsPass());
-
- // (Linalg, STD) -> LLVM
+ // (HAL, IREE, Linalg, STD) -> LLVM
// OpPassManager& llvmPassManager = passManager.nest<ModuleOp>();
passManager.addPass(createConvertToLLVMPass());
passManager.addPass(createCanonicalizerPass());
diff --git a/iree/compiler/Conversion/LinalgToLLVM/Passes.h b/iree/compiler/Conversion/LinalgToLLVM/Passes.h
index 5bfb893..fdad0e6 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/Passes.h
+++ b/iree/compiler/Conversion/LinalgToLLVM/Passes.h
@@ -20,11 +20,6 @@
namespace mlir {
namespace iree_compiler {
-/// Converts function signture type from hal interface op annotation to memref
-/// argument.
-std::unique_ptr<OperationPass<ModuleOp>>
-createHALInterfaceToMemrefArgumentsPass();
-
/// Pass to perform final conversion to LLVM dialect.
std::unique_ptr<OperationPass<ModuleOp>> createConvertToLLVMPass();
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp
index 98e91fa..81514ec 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp
@@ -563,7 +563,7 @@
ConversionPatternRewriter &rewriter) const override {
// Check for marker that specifies that the linalg op is to be partitioned
// across threads within a workgroup.
- if (!hasWorkItemMarker(linalgOp)) return failure();
+ if (!hasWorkGroupMarker(linalgOp)) return failure();
Optional<linalg::LinalgLoops> loops =
linalg::linalgLowerOpToLoops<scf::ParallelOp>(rewriter, linalgOp);
if (!loops) return failure();
@@ -587,7 +587,7 @@
LogicalResult matchAndRewrite(
LinalgOpTy linalgOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- if (!hasWorkItemMarker(linalgOp)) return failure();
+ if (!hasWorkGroupMarker(linalgOp)) return failure();
Optional<linalg::LinalgLoops> loops =
linalg::linalgLowerOpToLoops<scf::ParallelOp>(rewriter, linalgOp);
if (!loops) return failure();
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
index e9dddd6..934e5ae 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
@@ -314,7 +314,7 @@
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
- if (!hasWorkItemMarker(op)) return failure();
+ if (!hasWorkGroupMarker(op)) return failure();
return linalg::LinalgPromotionPattern<linalg::MatmulOp>::matchAndRewrite(
op, rewriter);
}
@@ -365,7 +365,7 @@
.setLoopType(linalg::LinalgTilingLoopType::ParallelLoops),
tileSizeCalculator.getWorkGroupSize(),
linalg::LinalgMarker(ArrayRef<Identifier>(),
- Identifier::get(getWorkItemMarker(), context)));
+ Identifier::get(getWorkGroupMarker(), context)));
applyPatternsAndFoldGreedily(getOperation(), tilingPatterns);
if (useWorkgroupMemory) {
@@ -385,7 +385,7 @@
[&](OpBuilder &b, Value src, Value dst) -> LogicalResult {
return copyToFromWorkgroupMemory(b, src, dst);
}),
- linalg::LinalgMarker(Identifier::get(getWorkItemMarker(), context),
+ linalg::LinalgMarker(Identifier::get(getWorkGroupMarker(), context),
Identifier::get(PromotionMarker, context)));
applyPatternsAndFoldGreedily(getOperation(), promotionPatterns);
}
@@ -394,7 +394,7 @@
OpBuilder builder(context);
funcOp.walk([&builder](linalg::LinalgOp linalgOp) {
if (hasMarker(linalgOp, PromotionMarker)) {
- setWorkItemMarker(linalgOp);
+ setWorkGroupMarker(linalgOp);
insertBarrierAfter(builder, linalgOp.getLoc(), linalgOp);
}
});
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.cpp b/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.cpp
index c874234..47747de 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.cpp
@@ -36,8 +36,6 @@
StringRef getWorkGroupMarker() { return "workgroup"; }
-StringRef getWorkItemMarker() { return "workitem"; }
-
bool hasMarker(Operation *op, StringRef marker) {
return checkMarkerValue(op, marker);
}
@@ -46,10 +44,6 @@
return checkMarkerValue(op, getWorkGroupMarker());
}
-bool hasWorkItemMarker(Operation *op) {
- return checkMarkerValue(op, getWorkItemMarker());
-}
-
void setMarker(Operation *op, StringRef marker) {
op->setAttr(linalg::LinalgTransforms::kLinalgTransformMarker,
StringAttr::get(marker, op->getContext()));
@@ -57,6 +51,5 @@
void setWorkGroupMarker(Operation *op) { setMarker(op, getWorkGroupMarker()); }
-void setWorkItemMarker(Operation *op) { setMarker(op, getWorkItemMarker()); }
} // namespace iree_compiler
} // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h b/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h
index 36dccca..e512ead 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h
+++ b/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h
@@ -31,7 +31,7 @@
namespace iree_compiler {
/// Marker to denote that a linalg operation is to be partitioned to workitems.
-StringRef getWorkItemMarker();
+StringRef getWorkGroupMarker();
/// Returns true if an operation has the specified `marker`. When `marker` is
/// empty, returns true if the operation has any marker.
@@ -39,14 +39,14 @@
/// Returns true if an operation has marker to denote that it is to be
/// partitioned to workitems.
-bool hasWorkItemMarker(Operation *);
+bool hasWorkGroupMarker(Operation *);
/// Sets a given marker on an operation.
void setMarker(Operation *, StringRef);
/// Sets marker to denote that a linalg operation is to be partitioned to
/// workitems.
-void setWorkItemMarker(Operation *);
+void setWorkGroupMarker(Operation *);
} // namespace iree_compiler
} // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir
index 679f523..64621f3 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir
@@ -162,7 +162,7 @@
%12 = dim %arg2, %c1 : memref<?x?xf32>
%13 = affine.min #map0(%arg4)[%12]
%14 = subview %arg2[%arg3, %arg4] [%11, %13] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map2>
- linalg.matmul %5, %9, %14 {__internal_linalg_transform__ = "workitem"} : (memref<?x?xf32, #map2>, memref<?x?xf32, #map2>, memref<?x?xf32, #map2>)
+ linalg.matmul %5, %9, %14 {__internal_linalg_transform__ = "workgroup"} : (memref<?x?xf32, #map2>, memref<?x?xf32, #map2>, memref<?x?xf32, #map2>)
}
scf.yield
}
@@ -235,7 +235,7 @@
%13 = affine.min #map5(%arg5)[%4]
%14 = dim %arg2, %c3 : memref<?x?x?x?xf32>
%15 = subview %arg2[%arg3, %arg4, %arg5, 0] [%11, %12, %13, %14] [1, 1, 1, 1] : memref<?x?x?x?xf32> to memref<?x?x?x?xf32, #map3>
- linalg.conv(%arg0, %9, %15) {__internal_linalg_transform__ = "workitem", dilations = [1, 1], strides = [1, 1]} : memref<?x?x?x?xf32>, memref<?x?x?x?xf32, #map3>, memref<?x?x?x?xf32, #map3>
+ linalg.conv(%arg0, %9, %15) {__internal_linalg_transform__ = "workgroup", dilations = [1, 1], strides = [1, 1]} : memref<?x?x?x?xf32>, memref<?x?x?x?xf32, #map3>, memref<?x?x?x?xf32, #map3>
scf.yield
}
return
@@ -364,7 +364,7 @@
%9 = affine.min #map3(%arg3)[%2]
%10 = affine.min #map4(%arg4)[%3]
%11 = subview %arg2[%arg3, %arg4] [%9, %10] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map2>
- linalg.pooling_max(%8, %arg1, %11) {__internal_linalg_transform__ = "workitem", dilations = [1, 1], strides = [1, 1]} : memref<?x?xf32, #map2>, memref<?x?xf32>, memref<?x?xf32, #map2>
+ linalg.pooling_max(%8, %arg1, %11) {__internal_linalg_transform__ = "workgroup", dilations = [1, 1], strides = [1, 1]} : memref<?x?xf32, #map2>, memref<?x?xf32>, memref<?x?xf32, #map2>
scf.yield
}
return
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu_option.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu_option.mlir
index 1701535..63f8aa5 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu_option.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu_option.mlir
@@ -32,7 +32,7 @@
%13 = affine.min #map5(%arg5)[%4]
%14 = dim %arg2, %c3 : memref<?x?x?x?xf32>
%15 = subview %arg2[%arg3, %arg4, %arg5, 0] [%11, %12, %13, %14] [1, 1, 1, 1] : memref<?x?x?x?xf32> to memref<?x?x?x?xf32, #map3>
- linalg.conv(%arg0, %9, %15) {__internal_linalg_transform__ = "workitem", dilations = [1, 1], strides = [1, 1]} : memref<?x?x?x?xf32>, memref<?x?x?x?xf32, #map3>, memref<?x?x?x?xf32, #map3>
+ linalg.conv(%arg0, %9, %15) {__internal_linalg_transform__ = "workgroup", dilations = [1, 1], strides = [1, 1]} : memref<?x?x?x?xf32>, memref<?x?x?x?xf32, #map3>, memref<?x?x?x?xf32, #map3>
scf.yield
}
return
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/cyclic_to_workgroup.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/cyclic_to_workgroup.mlir
index 110ac24..cac18ab 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/cyclic_to_workgroup.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/cyclic_to_workgroup.mlir
@@ -27,7 +27,7 @@
%12 = dim %arg2, %c1 : memref<?x?xf32>
%13 = affine.min #map0(%arg4)[%12]
%14 = subview %arg2[%arg3, %arg4] [%11, %13] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map2>
- linalg.matmul %5, %9, %14 {__internal_linalg_transform__ = "workitem"} : (memref<?x?xf32, #map2>, memref<?x?xf32, #map2>, memref<?x?xf32, #map2>)
+ linalg.matmul %5, %9, %14 {__internal_linalg_transform__ = "workgroup"} : (memref<?x?xf32, #map2>, memref<?x?xf32, #map2>, memref<?x?xf32, #map2>)
}
scf.yield
}
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir
index 0e2fe6d..1728d35 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir
@@ -51,7 +51,7 @@
// CHECK: %[[VIEW2:.+]] = subview %[[ARG2]]
// CHECK: linalg.conv
// CHECK-SAME: %[[ARG0]], %[[VIEW1]], %[[VIEW2]]
-// CHECK-SAME: "workitem"
+// CHECK-SAME: "workgroup"
// -----
@@ -81,7 +81,7 @@
// CHECK: %[[VIEW1:.+]] = subview %[[ARG1]]
// CHECK: %[[VIEW2:.+]] = subview %[[ARG2]]
// CHECK: linalg.matmul
-// CHECK-SAME: "workitem"
+// CHECK-SAME: "workgroup"
// CHECK-SAME: %[[VIEW0]], %[[VIEW1]], %[[VIEW2]]
// -----
@@ -111,4 +111,4 @@
// CHECK: %[[VIEW2:.+]] = subview %[[ARG2]]
// CHECK: linalg.pooling_max
// CHECK-SAME: %[[VIEW0]], %[[ARG1]], %[[VIEW2]]
-// CHECK-SAME: "workitem"
+// CHECK-SAME: "workgroup"
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir
index 76cfcb8..a24c77b 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir
@@ -36,12 +36,12 @@
// CHECK: %[[ALLOC2:.+]] = alloc(%[[C4]], %[[C8]]) : memref<?x?xf32, 3>
// CHECK: %[[SUBVIEW2:.+]] = subview %[[ALLOC2]]
// CHECK: linalg.copy(%[[ARG0SV]], %[[SUBVIEW1]])
-// CHECK-SAME: "workitem"
+// CHECK-SAME: "workgroup"
// CHECK: spv.ControlBarrier "Workgroup", "Workgroup", "AcquireRelease"
// CHECK: linalg.copy(%[[ARG1SV]], %[[SUBVIEW2]])
-// CHECK-SAME: "workitem"
+// CHECK-SAME: "workgroup"
// CHECK: spv.ControlBarrier "Workgroup", "Workgroup", "AcquireRelease"
-// CHECK: linalg.matmul {{.*}}"workitem"{{.*}} %[[SUBVIEW1]], %[[SUBVIEW2]], %[[RET0SV]]
+// CHECK: linalg.matmul {{.*}}"workgroup"{{.*}} %[[SUBVIEW1]], %[[SUBVIEW2]], %[[RET0SV]]
// CHECK: spv.ControlBarrier "Workgroup", "Workgroup", "AcquireRelease"
// CHECK-DAG: dealloc %[[ALLOC1]] : memref<?x?xf32, 3>
// CHECK-DAG: dealloc %[[ALLOC2]] : memref<?x?xf32, 3>
diff --git a/iree/compiler/Conversion/init_conversions.h b/iree/compiler/Conversion/init_conversions.h
index 7a190e7..259e3d5 100644
--- a/iree/compiler/Conversion/init_conversions.h
+++ b/iree/compiler/Conversion/init_conversions.h
@@ -47,7 +47,6 @@
inline void registerLinalgToLLVMPasses() {
static bool init_once = []() {
// LinalgToLLVM
- createHALInterfaceToMemrefArgumentsPass();
return true;
}();
(void)init_once;
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp
index 7269089..8cb47b5 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp
@@ -65,13 +65,10 @@
auto executableOp = cast<ExecutableOp>(targetOp.getParentOp());
auto entryPointOps =
executableOp.getBlock().getOps<ExecutableEntryPointOp>();
- const bool addCInterface = true;
+
for (auto entryPointOp : entryPointOps) {
- std::string funcName =
- addCInterface ? "_mlir_ciface_" + std::string(entryPointOp.sym_name())
- : std::string(entryPointOp.sym_name());
- dyLibExecutableDef.entry_points.push_back("invoke_" + funcName);
- createLLVMInvocationFunc(funcName, llvmModule.get());
+ dyLibExecutableDef.entry_points.push_back(
+ std::string(entryPointOp.sym_name()));
}
// LLVMIR opt passes.
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMIRPasses.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMIRPasses.cpp
index 24d5877..e91441d 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMIRPasses.cpp
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMIRPasses.cpp
@@ -44,44 +44,6 @@
return machine;
}
-void createLLVMInvocationFunc(const std::string& name, llvm::Module* module) {
- // TODO(ataei): This is written as a stub in LLVM IR. It would be easier to
- // have this using MLIR and lower it to LLVM like the dispatch function
- // implementation is.
-
- auto& ctx = module->getContext();
- llvm::IRBuilder<> builder(ctx);
- auto var_func = module->getFunction(name);
-
- auto new_type = llvm::FunctionType::get(
- builder.getVoidTy(), builder.getInt8PtrTy()->getPointerTo(),
- /*isVarArg=*/false);
-
- auto new_name = "invoke_" + name;
- auto func_cst = module->getOrInsertFunction(new_name, new_type);
- llvm::Function* interface_func =
- llvm::cast<llvm::Function>(func_cst.getCallee());
-
- auto bb = llvm::BasicBlock::Create(ctx);
- bb->insertInto(interface_func);
- builder.SetInsertPoint(bb);
- llvm::Value* argList = interface_func->arg_begin();
- llvm::SmallVector<llvm::Value*, 8> args;
- args.reserve(llvm::size(var_func->args()));
- for (auto& indexedArg : llvm::enumerate(var_func->args())) {
- llvm::Value* arg_index = llvm::Constant::getIntegerValue(
- builder.getInt64Ty(), llvm::APInt(64, indexedArg.index()));
- llvm::Value* arg_ptr_ptr = builder.CreateGEP(argList, arg_index);
- llvm::Value* arg_ptr = builder.CreateLoad(arg_ptr_ptr);
- arg_ptr = builder.CreateBitCast(
- arg_ptr, indexedArg.value().getType()->getPointerTo());
- llvm::Value* arg = builder.CreateLoad(arg_ptr);
- args.push_back(arg);
- }
- builder.CreateCall(var_func, args);
- builder.CreateRetVoid();
-}
-
LogicalResult runLLVMIRPasses(const LLVMTargetOptions& options,
llvm::TargetMachine* machine,
llvm::Module* module) {
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMIRPasses.h b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMIRPasses.h
index 199e36f..37ee1ba 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMIRPasses.h
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMIRPasses.h
@@ -31,9 +31,6 @@
std::unique_ptr<llvm::TargetMachine> createTargetMachine(
const LLVMTargetOptions& options);
-// Creates an invocation function in a module for the given function name.
-void createLLVMInvocationFunc(const std::string& name, llvm::Module* module);
-
// Creates and runs LLVMIR optimization passes defined in LLVMTargetOptions.
LogicalResult runLLVMIRPasses(const LLVMTargetOptions& options,
llvm::TargetMachine* machine,
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMIRTarget.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMIRTarget.cpp
index 98c0bf4..96bb5ac 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMIRTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMIRTarget.cpp
@@ -58,13 +58,9 @@
auto executableOp = cast<IREE::HAL::ExecutableOp>(targetOp.getParentOp());
auto entryPointOps =
executableOp.getBlock().getOps<IREE::HAL::ExecutableEntryPointOp>();
- const bool addCInterface = true;
for (auto entryPointOp : entryPointOps) {
- std::string funcName =
- addCInterface ? "_mlir_ciface_" + std::string(entryPointOp.sym_name())
- : std::string(entryPointOp.sym_name());
- llvmIrExecutableDef.entry_points.push_back(funcName);
- createLLVMInvocationFunc(funcName, llvmModule.get());
+ llvmIrExecutableDef.entry_points.push_back(
+ std::string(entryPointOp.sym_name()));
}
// LLVMIR opt passes.
@@ -74,8 +70,9 @@
options_.targetTriple);
return failure();
}
- if (failed(
- runLLVMIRPasses(options_, targetMachine.get(), llvmModule.get()))) {
+ LogicalResult translationResult =
+ runLLVMIRPasses(options_, targetMachine.get(), llvmModule.get());
+ if (failed(translationResult)) {
return targetOp.emitError(
"Can't build LLVMIR opt passes for ExecutableOp module");
}
diff --git a/iree/hal/dylib/BUILD b/iree/hal/dylib/BUILD
index 25c08ea..fc3ccb9 100644
--- a/iree/hal/dylib/BUILD
+++ b/iree/hal/dylib/BUILD
@@ -60,7 +60,6 @@
srcs = ["dylib_executable.cc"],
hdrs = ["dylib_executable.h"],
deps = [
- ":memref_runtime",
"//iree/base:dynamic_library",
"//iree/base:file_io",
"//iree/base:status",
@@ -89,10 +88,3 @@
"//iree/hal:executable_format",
],
)
-
-cc_library(
- name = "memref_runtime",
- hdrs = [
- "memref_runtime.h",
- ],
-)
diff --git a/iree/hal/dylib/CMakeLists.txt b/iree/hal/dylib/CMakeLists.txt
index 7644d92..d720435 100644
--- a/iree/hal/dylib/CMakeLists.txt
+++ b/iree/hal/dylib/CMakeLists.txt
@@ -65,7 +65,6 @@
SRCS
"dylib_executable.cc"
DEPS
- ::memref_runtime
absl::inlined_vector
absl::span
flatbuffers
@@ -97,11 +96,3 @@
iree::hal::executable_format
PUBLIC
)
-
-iree_cc_library(
- NAME
- memref_runtime
- HDRS
- "memref_runtime.h"
- PUBLIC
-)
diff --git a/iree/hal/dylib/dylib_executable.cc b/iree/hal/dylib/dylib_executable.cc
index e06bb19..e58a003 100644
--- a/iree/hal/dylib/dylib_executable.cc
+++ b/iree/hal/dylib/dylib_executable.cc
@@ -17,7 +17,6 @@
#include "flatbuffers/flatbuffers.h"
#include "iree/base/file_io.h"
#include "iree/base/tracing.h"
-#include "iree/hal/dylib/memref_runtime.h"
#include "iree/schemas/dylib_executable_def_generated.h"
namespace iree {
@@ -96,15 +95,9 @@
struct DyLibDispatchState : public HostExecutable::DispatchState {
DyLibDispatchState() = default;
- ~DyLibDispatchState() override {
- for (int i = 0; i < descriptors.size(); ++i) {
- freeUnrankedDescriptor(descriptors[i]);
- }
- }
-
void* entry_function = nullptr;
- absl::InlinedVector<UnrankedMemRefType<uint32_t>*, 4> descriptors;
absl::InlinedVector<void*, 4> args;
+ absl::InlinedVector<int64_t, 4> push_constant;
};
StatusOr<ref_ptr<HostExecutable::DispatchState>>
@@ -127,17 +120,14 @@
MemoryAccessBitfield::kWrite,
io_binding.offset, io_binding.length));
auto data = memory.mutable_data();
- auto descriptor = allocUnrankedDescriptor<uint32_t>(data);
- dispatch_state->descriptors.push_back(descriptor);
- dispatch_state->args.push_back(&descriptor->descriptor);
+
+ dispatch_state->args.push_back(data);
}
}
-
- auto push_constants_descriptor = allocUnrankedDescriptor<uint32_t>(
- const_cast<uint32_t*>(params.push_constants->values.data()),
- {static_cast<int64_t>(params.push_constants->values.size())});
- dispatch_state->descriptors.push_back(push_constants_descriptor);
- dispatch_state->args.push_back(&push_constants_descriptor->descriptor);
+ // TODO(ataei): Consider moving this casting to codegen side ?!
+ for (int i = 0; i < params.push_constants->values.size(); ++i) {
+ dispatch_state->push_constant.push_back(params.push_constants->values[i]);
+ }
return std::move(dispatch_state);
}
@@ -147,8 +137,10 @@
IREE_TRACE_SCOPE0("DyLibExecutable::DispatchTile");
auto* dispatch_state = static_cast<DyLibDispatchState*>(state);
- auto entry_function = (void (*)(void**))dispatch_state->entry_function;
- entry_function(dispatch_state->args.data());
+ auto entry_function =
+ (void (*)(void**, int64_t*))dispatch_state->entry_function;
+ entry_function(dispatch_state->args.data(),
+ dispatch_state->push_constant.data());
return OkStatus();
}
diff --git a/iree/hal/dylib/memref_runtime.h b/iree/hal/dylib/memref_runtime.h
deleted file mode 100644
index 50d3987..0000000
--- a/iree/hal/dylib/memref_runtime.h
+++ /dev/null
@@ -1,177 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-//
-
-#ifndef IREE_HAL_DYLIB_MEMREF_RUNTIME_H_
-#define IREE_HAL_DYLIB_MEMREF_RUNTIME_H_
-
-#include <assert.h>
-
-#include <cstdint>
-#include <vector>
-
-namespace iree {
-namespace hal {
-namespace dylib {
-
-template <int N>
-void dropFront(int64_t arr[N], int64_t *res) {
- for (unsigned i = 1; i < N; ++i) *(res + i - 1) = arr[i];
-}
-
-/// StridedMemRef descriptor type with static rank.
-template <typename T, int N>
-struct StridedMemRefType {
- T *basePtr;
- T *data;
- int64_t offset;
- int64_t sizes[N];
- int64_t strides[N];
- // This operator[] is extremely slow and only for sugaring purposes.
- StridedMemRefType<T, N - 1> operator[](int64_t idx) {
- StridedMemRefType<T, N - 1> res;
- res.basePtr = basePtr;
- res.data = data;
- res.offset = offset + idx * strides[0];
- dropFront<N>(sizes, res.sizes);
- dropFront<N>(strides, res.strides);
- return res;
- }
-};
-
-/// StridedMemRef descriptor type specialized for rank 1.
-template <typename T>
-struct StridedMemRefType<T, 1> {
- T *basePtr;
- T *data;
- int64_t offset;
- int64_t sizes[1];
- int64_t strides[1];
- T &operator[](int64_t idx) { return *(data + offset + idx * strides[0]); }
-};
-
-/// StridedMemRef descriptor type specialized for rank 0.
-template <typename T>
-struct StridedMemRefType<T, 0> {
- T *basePtr;
- T *data;
- int64_t offset;
-};
-
-// Unranked MemRef
-template <typename T>
-struct UnrankedMemRefType {
- int64_t rank;
- void *descriptor;
-};
-
-// Given a shape with sizes greater than 0 along all dimensions,
-// returns the distance, in number of elements, between a slice in a dimension
-// and the next slice in the same dimension.
-// e.g. shape[3, 4, 5] -> strides[20, 5, 1]
-inline std::vector<int64_t> makeStrides(const std::vector<int64_t> &shape) {
- std::vector<int64_t> tmp;
- if (shape.empty()) return tmp;
- tmp.reserve(shape.size());
- int64_t running = 1;
- for (auto rit = shape.rbegin(), reit = shape.rend(); rit != reit; ++rit) {
- assert(*rit > 0 &&
- "size must be greater than 0 along all dimensions of shape");
- tmp.push_back(running);
- running *= *rit;
- }
- return std::vector<int64_t>(tmp.rbegin(), tmp.rend());
-}
-
-// Mallocs a StridedMemRefDescriptor<T, N>* that matches the MLIR ABI.
-// This is an implementation detail that is kept in sync with MLIR codegen
-// conventions.
-template <typename T, int N>
-StridedMemRefType<T, N> *makeStridedMemRefDescriptor(
- void *ptr, const std::vector<int64_t> &shape) {
- StridedMemRefType<T, N> *descriptor = static_cast<StridedMemRefType<T, N> *>(
- malloc(sizeof(StridedMemRefType<T, N>)));
- descriptor->basePtr = static_cast<T *>(ptr);
- descriptor->data = static_cast<T *>(ptr);
- descriptor->offset = 0;
- std::copy(shape.begin(), shape.end(), descriptor->sizes);
- auto strides = makeStrides(shape);
- std::copy(strides.begin(), strides.end(), descriptor->strides);
- return descriptor;
-}
-
-// Mallocs a StridedMemRefDescriptor<T, 0>* (i.e. a pointer to scalar) that
-// matches the MLIR ABI. This is an implementation detail that is kept in sync
-// with MLIR codegen conventions.
-template <typename T>
-StridedMemRefType<T, 0> *makeStridedMemRefDescriptor(
- void *ptr, const std::vector<int64_t> &shape) {
- StridedMemRefType<T, 0> *descriptor = static_cast<StridedMemRefType<T, 0> *>(
- malloc(sizeof(StridedMemRefType<T, 0>)));
- descriptor->basePtr = static_cast<T *>(ptr);
- descriptor->data = static_cast<T *>(ptr);
- descriptor->offset = 0;
- return descriptor;
-}
-
-// Mallocs an UnrankedMemRefType<T>* that contains a ranked
-// StridedMemRefDescriptor<T, Rank>* and matches the MLIR ABI. This is an
-// implementation detail that is kept in sync with MLIR codegen conventions.
-template <typename T>
-UnrankedMemRefType<T> *allocUnrankedDescriptor(
- void *data, const std::vector<int64_t> &shape) {
- UnrankedMemRefType<T> *res = static_cast<UnrankedMemRefType<T> *>(
- malloc(sizeof(UnrankedMemRefType<T>)));
- res->rank = shape.size();
- if (res->rank == 0)
- res->descriptor = makeStridedMemRefDescriptor<T>(data, shape);
- else if (res->rank == 1)
- res->descriptor = makeStridedMemRefDescriptor<T, 1>(data, shape);
- else if (res->rank == 2)
- res->descriptor = makeStridedMemRefDescriptor<T, 2>(data, shape);
- else if (res->rank == 3)
- res->descriptor = makeStridedMemRefDescriptor<T, 3>(data, shape);
- else if (res->rank == 4)
- res->descriptor = makeStridedMemRefDescriptor<T, 4>(data, shape);
- else if (res->rank == 5)
- res->descriptor = makeStridedMemRefDescriptor<T, 5>(data, shape);
- else if (res->rank == 6)
- res->descriptor = makeStridedMemRefDescriptor<T, 6>(data, shape);
- else
- assert(false && "Unsupported 6+D memref descriptor");
- return res;
-}
-
-// Shape and strides aren't used in the generated code (yet).
-// TODO(ataei): Delete this version once we can pass shapes.
-template <typename T>
-UnrankedMemRefType<T> *allocUnrankedDescriptor(void *data) {
- UnrankedMemRefType<T> *res = static_cast<UnrankedMemRefType<T> *>(
- malloc(sizeof(UnrankedMemRefType<T>)));
- res->descriptor = makeStridedMemRefDescriptor<T>(data, {});
- return res;
-}
-
-// Frees an UnrankedMemRefType<T>*
-template <typename T>
-void freeUnrankedDescriptor(UnrankedMemRefType<T> *desc) {
- free(desc->descriptor);
- free(desc);
-}
-
-} // namespace dylib
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_DYLIB_MEMREF_RUNTIME_H_
diff --git a/iree/hal/llvmjit/BUILD b/iree/hal/llvmjit/BUILD
index 088bb8b..3ebd609 100644
--- a/iree/hal/llvmjit/BUILD
+++ b/iree/hal/llvmjit/BUILD
@@ -64,7 +64,6 @@
srcs = ["llvmjit_executable.cc"],
hdrs = ["llvmjit_executable.h"],
deps = [
- ":memref_runtime",
"//iree/base:status",
"//iree/base:tracing",
"//iree/hal:buffer",
@@ -95,10 +94,3 @@
"//iree/hal:executable_format",
],
)
-
-cc_library(
- name = "memref_runtime",
- hdrs = [
- "memref_runtime.h",
- ],
-)
diff --git a/iree/hal/llvmjit/CMakeLists.txt b/iree/hal/llvmjit/CMakeLists.txt
index 8418745..ca40941 100644
--- a/iree/hal/llvmjit/CMakeLists.txt
+++ b/iree/hal/llvmjit/CMakeLists.txt
@@ -68,7 +68,6 @@
SRCS
"llvmjit_executable.cc"
DEPS
- ::memref_runtime
LLVMAsmParser
LLVMCore
LLVMOrcJIT
@@ -102,11 +101,3 @@
iree::hal::executable_format
PUBLIC
)
-
-iree_cc_library(
- NAME
- memref_runtime
- HDRS
- "memref_runtime.h"
- PUBLIC
-)
diff --git a/iree/hal/llvmjit/llvmjit_executable.cc b/iree/hal/llvmjit/llvmjit_executable.cc
index 1596b9e..7d26ccd 100644
--- a/iree/hal/llvmjit/llvmjit_executable.cc
+++ b/iree/hal/llvmjit/llvmjit_executable.cc
@@ -21,7 +21,6 @@
#include "iree/base/tracing.h"
#include "iree/hal/buffer.h"
#include "iree/hal/executable.h"
-#include "iree/hal/llvmjit/memref_runtime.h"
#include "iree/schemas/llvmir_executable_def_generated.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/StringRef.h"
@@ -82,13 +81,11 @@
make_ref<LLVMJITExecutable>(spec, std::move(ll_jit), allow_aliasing_data);
for (const auto func_name : *entry_points) {
- auto func_symbol =
- executable->ll_jit_->lookup("invoke_" + func_name->str());
+ auto func_symbol = executable->ll_jit_->lookup(func_name->str());
if (!func_symbol) {
return NotFoundErrorBuilder(IREE_LOC)
<< "Can't JIT compile function : " << func_name;
}
- // Map function to its invoke_ symbol.
executable->symbols_.push_back(func_symbol.get());
}
@@ -111,15 +108,10 @@
struct LLVMJITDispatchState : public HostExecutable::DispatchState {
LLVMJITDispatchState() = default;
- ~LLVMJITDispatchState() override {
- for (int i = 0; i < descriptors.size(); ++i) {
- freeUnrankedDescriptor(descriptors[i]);
- }
- }
llvm::JITEvaluatedSymbol symbol;
- llvm::SmallVector<UnrankedMemRefType<uint32_t>*, 4> descriptors;
llvm::SmallVector<void*, 4> args;
+ llvm::SmallVector<int64_t, 4> push_constant;
};
StatusOr<ref_ptr<HostExecutable::DispatchState>>
@@ -142,17 +134,13 @@
MemoryAccessBitfield::kWrite,
io_binding.offset, io_binding.length));
auto data = memory.mutable_data();
- auto descriptor = allocUnrankedDescriptor<uint32_t>(data);
- dispatch_state->descriptors.push_back(descriptor);
- dispatch_state->args.push_back(&descriptor->descriptor);
+ dispatch_state->args.push_back(data);
}
}
-
- auto push_constants_descriptor = allocUnrankedDescriptor<uint32_t>(
- const_cast<uint32_t*>(params.push_constants->values.data()),
- {static_cast<int64_t>(params.push_constants->values.size())});
- dispatch_state->descriptors.push_back(push_constants_descriptor);
- dispatch_state->args.push_back(&push_constants_descriptor->descriptor);
+ // TODO(ataei): Consider moving this casting to codegen side ?!
+ for (int i = 0; i < params.push_constants->values.size(); ++i) {
+ dispatch_state->push_constant.push_back(params.push_constants->values[i]);
+ }
return std::move(dispatch_state);
}
@@ -162,8 +150,9 @@
IREE_TRACE_SCOPE0("LLVMJITExecutable::DispatchTile");
auto* dispatch_state = static_cast<LLVMJITDispatchState*>(state);
- auto func_ptr = (void (*)(void**))dispatch_state->symbol.getAddress();
- func_ptr(dispatch_state->args.data());
+ auto func_ptr =
+ (void (*)(void**, int64_t*))dispatch_state->symbol.getAddress();
+ func_ptr(dispatch_state->args.data(), dispatch_state->push_constant.data());
return OkStatus();
}
diff --git a/iree/hal/llvmjit/memref_runtime.h b/iree/hal/llvmjit/memref_runtime.h
deleted file mode 100644
index 6b94410..0000000
--- a/iree/hal/llvmjit/memref_runtime.h
+++ /dev/null
@@ -1,177 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-//
-
-#ifndef IREE_HAL_LLVMJIT_LLVMJIT_MEMREF_RUNTIME_H_
-#define IREE_HAL_LLVMJIT_LLVMJIT_MEMREF_RUNTIME_H_
-
-#include <assert.h>
-
-#include <cstdint>
-#include <vector>
-
-namespace iree {
-namespace hal {
-namespace llvmjit {
-
-template <int N>
-void dropFront(int64_t arr[N], int64_t *res) {
- for (unsigned i = 1; i < N; ++i) *(res + i - 1) = arr[i];
-}
-
-/// StridedMemRef descriptor type with static rank.
-template <typename T, int N>
-struct StridedMemRefType {
- T *basePtr;
- T *data;
- int64_t offset;
- int64_t sizes[N];
- int64_t strides[N];
- // This operator[] is extremely slow and only for sugaring purposes.
- StridedMemRefType<T, N - 1> operator[](int64_t idx) {
- StridedMemRefType<T, N - 1> res;
- res.basePtr = basePtr;
- res.data = data;
- res.offset = offset + idx * strides[0];
- dropFront<N>(sizes, res.sizes);
- dropFront<N>(strides, res.strides);
- return res;
- }
-};
-
-/// StridedMemRef descriptor type specialized for rank 1.
-template <typename T>
-struct StridedMemRefType<T, 1> {
- T *basePtr;
- T *data;
- int64_t offset;
- int64_t sizes[1];
- int64_t strides[1];
- T &operator[](int64_t idx) { return *(data + offset + idx * strides[0]); }
-};
-
-/// StridedMemRef descriptor type specialized for rank 0.
-template <typename T>
-struct StridedMemRefType<T, 0> {
- T *basePtr;
- T *data;
- int64_t offset;
-};
-
-// Unranked MemRef
-template <typename T>
-struct UnrankedMemRefType {
- int64_t rank;
- void *descriptor;
-};
-
-// Given a shape with sizes greater than 0 along all dimensions,
-// returns the distance, in number of elements, between a slice in a dimension
-// and the next slice in the same dimension.
-// e.g. shape[3, 4, 5] -> strides[20, 5, 1]
-inline std::vector<int64_t> makeStrides(const std::vector<int64_t> &shape) {
- std::vector<int64_t> tmp;
- if (shape.empty()) return tmp;
- tmp.reserve(shape.size());
- int64_t running = 1;
- for (auto rit = shape.rbegin(), reit = shape.rend(); rit != reit; ++rit) {
- assert(*rit > 0 &&
- "size must be greater than 0 along all dimensions of shape");
- tmp.push_back(running);
- running *= *rit;
- }
- return std::vector<int64_t>(tmp.rbegin(), tmp.rend());
-}
-
-// Mallocs a StridedMemRefDescriptor<T, N>* that matches the MLIR ABI.
-// This is an implementation detail that is kept in sync with MLIR codegen
-// conventions.
-template <typename T, int N>
-StridedMemRefType<T, N> *makeStridedMemRefDescriptor(
- void *ptr, const std::vector<int64_t> &shape) {
- StridedMemRefType<T, N> *descriptor = static_cast<StridedMemRefType<T, N> *>(
- malloc(sizeof(StridedMemRefType<T, N>)));
- descriptor->basePtr = static_cast<T *>(ptr);
- descriptor->data = static_cast<T *>(ptr);
- descriptor->offset = 0;
- std::copy(shape.begin(), shape.end(), descriptor->sizes);
- auto strides = makeStrides(shape);
- std::copy(strides.begin(), strides.end(), descriptor->strides);
- return descriptor;
-}
-
-// Mallocs a StridedMemRefDescriptor<T, 0>* (i.e. a pointer to scalar) that
-// matches the MLIR ABI. This is an implementation detail that is kept in sync
-// with MLIR codegen conventions.
-template <typename T>
-StridedMemRefType<T, 0> *makeStridedMemRefDescriptor(
- void *ptr, const std::vector<int64_t> &shape) {
- StridedMemRefType<T, 0> *descriptor = static_cast<StridedMemRefType<T, 0> *>(
- malloc(sizeof(StridedMemRefType<T, 0>)));
- descriptor->basePtr = static_cast<T *>(ptr);
- descriptor->data = static_cast<T *>(ptr);
- descriptor->offset = 0;
- return descriptor;
-}
-
-// Mallocs an UnrankedMemRefType<T>* that contains a ranked
-// StridedMemRefDescriptor<T, Rank>* and matches the MLIR ABI. This is an
-// implementation detail that is kept in sync with MLIR codegen conventions.
-template <typename T>
-UnrankedMemRefType<T> *allocUnrankedDescriptor(
- void *data, const std::vector<int64_t> &shape) {
- UnrankedMemRefType<T> *res = static_cast<UnrankedMemRefType<T> *>(
- malloc(sizeof(UnrankedMemRefType<T>)));
- res->rank = shape.size();
- if (res->rank == 0)
- res->descriptor = makeStridedMemRefDescriptor<T>(data, shape);
- else if (res->rank == 1)
- res->descriptor = makeStridedMemRefDescriptor<T, 1>(data, shape);
- else if (res->rank == 2)
- res->descriptor = makeStridedMemRefDescriptor<T, 2>(data, shape);
- else if (res->rank == 3)
- res->descriptor = makeStridedMemRefDescriptor<T, 3>(data, shape);
- else if (res->rank == 4)
- res->descriptor = makeStridedMemRefDescriptor<T, 4>(data, shape);
- else if (res->rank == 5)
- res->descriptor = makeStridedMemRefDescriptor<T, 5>(data, shape);
- else if (res->rank == 6)
- res->descriptor = makeStridedMemRefDescriptor<T, 6>(data, shape);
- else
- assert(false && "Unsupported 6+D memref descriptor");
- return res;
-}
-
-// Shape and strides aren't used in the generated code (yet).
-// TODO(ataei): Delete this version once we can pass shapes.
-template <typename T>
-UnrankedMemRefType<T> *allocUnrankedDescriptor(void *data) {
- UnrankedMemRefType<T> *res = static_cast<UnrankedMemRefType<T> *>(
- malloc(sizeof(UnrankedMemRefType<T>)));
- res->descriptor = makeStridedMemRefDescriptor<T>(data, {});
- return res;
-}
-
-// Frees an UnrankedMemRefType<T>*
-template <typename T>
-void freeUnrankedDescriptor(UnrankedMemRefType<T> *desc) {
- free(desc->descriptor);
- free(desc);
-}
-
-} // namespace llvmjit
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_LLVMJIT_LLVMJIT_MEMREF_RUNTIME_H_
diff --git a/iree/test/e2e/xla_ops/BUILD b/iree/test/e2e/xla_ops/BUILD
index 9af6bd1..b492e10 100644
--- a/iree/test/e2e/xla_ops/BUILD
+++ b/iree/test/e2e/xla_ops/BUILD
@@ -82,9 +82,7 @@
"tanh.mlir",
"torch_index_select.mlir",
"transpose.mlir",
-
- # TODO(#2022): fails on real devices.
- # "while.mlir",
+ "while.mlir",
],
driver = "vulkan",
target_backend = "vulkan-spirv",
diff --git a/iree/test/e2e/xla_ops/CMakeLists.txt b/iree/test/e2e/xla_ops/CMakeLists.txt
index 7f65b06..e2ab883 100644
--- a/iree/test/e2e/xla_ops/CMakeLists.txt
+++ b/iree/test/e2e/xla_ops/CMakeLists.txt
@@ -66,6 +66,7 @@
"tanh.mlir"
"torch_index_select.mlir"
"transpose.mlir"
+ "while.mlir"
TARGET_BACKEND
vulkan-spirv
DRIVER