Merge google -> main (#6167)
* Add DataType and tweak InterpreterTest for TFLite bindings.
* [MLIR][HLO] Fix `chlo.broadcast_select` lowering in case all operands are scal..
* LLVM integrates
diff --git a/bindings/python/iree/runtime/vm.cc b/bindings/python/iree/runtime/vm.cc
index 6e885b6..0cb03d8 100644
--- a/bindings/python/iree/runtime/vm.cc
+++ b/bindings/python/iree/runtime/vm.cc
@@ -409,6 +409,12 @@
case IREE_HAL_ELEMENT_TYPE_FLOAT_64:
dtype_code = "d";
break;
+ case IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_INTEGER_SIGNED, 1):
+ // Due to layering issues it is not uncommon to get i1 buffer views
+ // and we just silently promote them to i8 since that is what they are.
+ // Really i1 should not exist at this boundary.
+ dtype_code = "b";
+ break;
default:
throw RaiseValueError("Unsupported VM Buffer -> numpy dtype mapping");
}
diff --git a/bindings/tflite/model.c b/bindings/tflite/model.c
index df2c0b9..9e1051c 100644
--- a/bindings/tflite/model.c
+++ b/bindings/tflite/model.c
@@ -159,8 +159,13 @@
iree_atomic_ref_count_init(&model->ref_count);
model->allocator = allocator;
model->owned_model_data = (uint8_t*)model + file_size;
- (void)fread(model->owned_model_data, 1, file_size, file);
+ int ret = fread(model->owned_model_data, 1, file_size, file);
fclose(file);
+ if (ret != file_size) {
+ IREE_TRACE_MESSAGE(ERROR, "failed model+data read");
+ IREE_TRACE_ZONE_END(z0);
+ return NULL;
+ }
status = _TfLiteModelInitializeModule(model->owned_model_data, file_size,
allocator, model);
diff --git a/experimental/rocm/rocm_buffer.c b/experimental/rocm/rocm_buffer.c
index e11d069..5269b4f 100644
--- a/experimental/rocm/rocm_buffer.c
+++ b/experimental/rocm/rocm_buffer.c
@@ -118,7 +118,8 @@
return iree_ok_status();
}
-void **iree_hal_rocm_buffer_device_pointer(iree_hal_buffer_t *base_buffer) {
+hipDeviceptr_t iree_hal_rocm_buffer_device_pointer(
+ iree_hal_buffer_t *base_buffer) {
iree_hal_rocm_buffer_t *buffer = iree_hal_rocm_buffer_cast(base_buffer);
return buffer->device_ptr;
}
diff --git a/experimental/rocm/rocm_buffer.h b/experimental/rocm/rocm_buffer.h
index 4c097b1..7e24bc7 100644
--- a/experimental/rocm/rocm_buffer.h
+++ b/experimental/rocm/rocm_buffer.h
@@ -25,7 +25,7 @@
// Returns the rocm base pointer for the given |buffer|.
// This is the entire allocated_buffer and must be offset by the buffer
// byte_offset and byte_length when used.
-void **iree_hal_rocm_buffer_device_pointer(iree_hal_buffer_t *buffer);
+hipDeviceptr_t iree_hal_rocm_buffer_device_pointer(iree_hal_buffer_t *buffer);
#ifdef __cplusplus
} // extern "C"
diff --git a/experimental/rocm/rocm_headers.h b/experimental/rocm/rocm_headers.h
index e4aafd2..056e8c0 100644
--- a/experimental/rocm/rocm_headers.h
+++ b/experimental/rocm/rocm_headers.h
@@ -7,6 +7,10 @@
#ifndef IREE_HAL_ROCM_ROCM_HEADERS_H_
#define IREE_HAL_ROCM_ROCM_HEADERS_H_
+#if defined(IREE_PTR_SIZE_32)
+#error 32-bit not supported on ROCm
+#endif // defined(IREE_PTR_SIZE_32)
+
#include "hip/hip_runtime.h"
#endif // IREE_HAL_ROCM_ROCM_HEADERS_H_
diff --git a/iree/base/internal/BUILD b/iree/base/internal/BUILD
index 7636882..fba2c6a 100644
--- a/iree/base/internal/BUILD
+++ b/iree/base/internal/BUILD
@@ -68,6 +68,19 @@
#===------------------------------------------------------------------------===#
cc_library(
+ name = "arena",
+ srcs = ["arena.c"],
+ hdrs = ["arena.h"],
+ deps = [
+ ":atomic_slist",
+ ":synchronization",
+ "//iree/base",
+ "//iree/base:core_headers",
+ "//iree/base:tracing",
+ ],
+)
+
+cc_library(
name = "atomic_slist",
srcs = ["atomic_slist.c"],
hdrs = ["atomic_slist.h"],
diff --git a/iree/base/internal/CMakeLists.txt b/iree/base/internal/CMakeLists.txt
index 1eba07d..1300382 100644
--- a/iree/base/internal/CMakeLists.txt
+++ b/iree/base/internal/CMakeLists.txt
@@ -52,6 +52,22 @@
iree_cc_library(
NAME
+ arena
+ HDRS
+ "arena.h"
+ SRCS
+ "arena.c"
+ DEPS
+ ::atomic_slist
+ ::synchronization
+ iree::base
+ iree::base::core_headers
+ iree::base::tracing
+ PUBLIC
+)
+
+iree_cc_library(
+ NAME
atomic_slist
HDRS
"atomic_slist.h"
diff --git a/iree/hal/local/arena.c b/iree/base/internal/arena.c
similarity index 99%
rename from iree/hal/local/arena.c
rename to iree/base/internal/arena.c
index 0aadee4..9b48c77 100644
--- a/iree/hal/local/arena.c
+++ b/iree/base/internal/arena.c
@@ -4,7 +4,7 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#include "iree/hal/local/arena.h"
+#include "iree/base/internal/arena.h"
#include <stdint.h>
#include <string.h>
diff --git a/iree/hal/local/arena.h b/iree/base/internal/arena.h
similarity index 98%
rename from iree/hal/local/arena.h
rename to iree/base/internal/arena.h
index f050da8..1d0afae 100644
--- a/iree/hal/local/arena.h
+++ b/iree/base/internal/arena.h
@@ -4,8 +4,8 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#ifndef IREE_HAL_LOCAL_ARENA_H_
-#define IREE_HAL_LOCAL_ARENA_H_
+#ifndef IREE_BASE_INTERNAL_ARENA_H_
+#define IREE_BASE_INTERNAL_ARENA_H_
#include <stddef.h>
@@ -150,4 +150,4 @@
} // extern "C"
#endif // __cplusplus
-#endif // IREE_HAL_LOCAL_ARENA_H_
+#endif // IREE_BASE_INTERNAL_ARENA_H_
diff --git a/iree/base/internal/call_once.h b/iree/base/internal/call_once.h
index 2e656a5..da411dd 100644
--- a/iree/base/internal/call_once.h
+++ b/iree/base/internal/call_once.h
@@ -72,6 +72,13 @@
InitOnceExecuteOnce(flag, iree_call_once_callback_impl, (PVOID)¶m, NULL);
}
+#elif IREE_SYNCHRONIZATION_DISABLE_UNSAFE
+
+// No-op when the thread control is disabled.
+#define IREE_ONCE_FLAG_INIT 1
+#define iree_once_flag uint32_t
+static inline void iree_call_once(iree_once_flag* flag, void (*func)(void)) {}
+
#else
// Fallback using pthread_once:
diff --git a/iree/base/internal/synchronization.c b/iree/base/internal/synchronization.c
index d8a822f..53be754 100644
--- a/iree/base/internal/synchronization.c
+++ b/iree/base/internal/synchronization.c
@@ -154,7 +154,7 @@
#define iree_mutex_impl_initialize(mutex)
#define iree_mutex_impl_deinitialize(mutex)
#define iree_mutex_impl_lock(mutex)
-#define iree_mutex_impl_try_lock(mutex)
+#define iree_mutex_impl_try_lock(mutex) true
#define iree_mutex_impl_unlock(mutex)
#elif defined(IREE_PLATFORM_WINDOWS) && defined(IREE_MUTEX_USE_WIN32_SRW)
@@ -309,7 +309,9 @@
IREE_DISABLE_THREAD_SAFETY_ANALYSIS {}
bool iree_slim_mutex_try_lock(iree_slim_mutex_t* mutex)
- IREE_DISABLE_THREAD_SAFETY_ANALYSIS {}
+ IREE_DISABLE_THREAD_SAFETY_ANALYSIS {
+ return iree_mutex_try_lock((iree_mutex_t*)&mutex->reserved);
+}
void iree_slim_mutex_unlock(iree_slim_mutex_t* mutex)
IREE_DISABLE_THREAD_SAFETY_ANALYSIS {}
diff --git a/iree/compiler/Conversion/CodegenUtils/FunctionUtils.cpp b/iree/compiler/Conversion/CodegenUtils/FunctionUtils.cpp
index dbb95d7..fa45076 100644
--- a/iree/compiler/Conversion/CodegenUtils/FunctionUtils.cpp
+++ b/iree/compiler/Conversion/CodegenUtils/FunctionUtils.cpp
@@ -16,17 +16,6 @@
bool isEntryPoint(FuncOp func) { return func.isPublic(); }
-FailureOr<FuncOp> getSingleEntryPointFunction(ModuleOp module) {
- auto entryPointFns = llvm::to_vector<1>(llvm::make_filter_range(
- module.getOps<FuncOp>(), [&](FuncOp op) { return isEntryPoint(op); }));
- if (!llvm::hasSingleElement(entryPointFns)) {
- module.emitError(
- "cannot handle modules with multiple entry point functions.");
- return {};
- }
- return entryPointFns[0];
-}
-
unsigned getNumOuterParallelLoops(linalg::LinalgOp op) {
return op.iterator_types()
.getValue()
@@ -47,6 +36,17 @@
return nullptr;
}
+llvm::StringMap<IREE::HAL::ExecutableEntryPointOp> getAllEntryPoints(
+ ModuleOp module) {
+ auto targetOp =
+ module.getOperation()->getParentOfType<IREE::HAL::ExecutableTargetOp>();
+ llvm::StringMap<IREE::HAL::ExecutableEntryPointOp> entryPointOps;
+ for (auto op : targetOp.getOps<IREE::HAL::ExecutableEntryPointOp>()) {
+ entryPointOps[op.sym_name()] = op;
+ }
+ return entryPointOps;
+}
+
/// Walk up the defs of the view, to get the untiled value. Either walks up
/// `ViewOpInterface` op-chains or the `subtensor` op-chains.
static Value getViewSource(Value view) {
diff --git a/iree/compiler/Conversion/CodegenUtils/FunctionUtils.h b/iree/compiler/Conversion/CodegenUtils/FunctionUtils.h
index 8cf1899..dda4b4f 100644
--- a/iree/compiler/Conversion/CodegenUtils/FunctionUtils.h
+++ b/iree/compiler/Conversion/CodegenUtils/FunctionUtils.h
@@ -8,6 +8,7 @@
#define IREE_COMPILER_CONVERSION_CODEGENUTILS_FUNCTIONUTILS_H_
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "llvm/ADT/StringMap.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/IR/BuiltinOps.h"
@@ -17,15 +18,16 @@
/// Returns true if the given `func` is a kernel dispatch entry point.
bool isEntryPoint(FuncOp func);
-/// Given a module returns the entrypoint function within the module.
-FailureOr<FuncOp> getSingleEntryPointFunction(ModuleOp module);
-
-/// Returns the number of outer parallel loops of a linalgOp.
-unsigned getNumOuterParallelLoops(linalg::LinalgOp op);
+/// Returns a map from function symbol name to corresponding entry point op.
+llvm::StringMap<IREE::HAL::ExecutableEntryPointOp> getAllEntryPoints(
+ ModuleOp module);
/// Returns the entry point op for the `funcOp`. Returns `nullptr` on failure.
IREE::HAL::ExecutableEntryPointOp getEntryPoint(FuncOp funcOp);
+/// Returns the number of outer parallel loops of a linalgOp.
+unsigned getNumOuterParallelLoops(linalg::LinalgOp op);
+
/// Returns the untiled type of a tiled view for both tensor and memref
/// types. Either walks the `ViewOpInterface` chain (for memrefs) or the
/// `subtensor` op chain (for tensors).
diff --git a/iree/compiler/Conversion/Common/LinalgBufferizePass.cpp b/iree/compiler/Conversion/Common/LinalgBufferizePass.cpp
index 087bfea..f385a22 100644
--- a/iree/compiler/Conversion/Common/LinalgBufferizePass.cpp
+++ b/iree/compiler/Conversion/Common/LinalgBufferizePass.cpp
@@ -623,6 +623,16 @@
} else {
return nullptr;
}
+
+ // If its a static allocation hoist it all the way up at begining of the
+ // function.
+ if (dynamicDims.empty()) {
+ auto funcOp = op->getParentOfType<FuncOp>();
+ OpBuilder::InsertionGuard g(b);
+ b.setInsertionPointToStart(&funcOp.front());
+ return allocationFn(b, loc, resultType.getShape(),
+ resultType.getElementType(), dynamicDims);
+ }
return allocationFn(b, loc, resultType.getShape(),
resultType.getElementType(), dynamicDims);
}
diff --git a/iree/compiler/Conversion/Common/SetNumWorkgroupsPass.cpp b/iree/compiler/Conversion/Common/SetNumWorkgroupsPass.cpp
index edd6415..3afe1dc 100644
--- a/iree/compiler/Conversion/Common/SetNumWorkgroupsPass.cpp
+++ b/iree/compiler/Conversion/Common/SetNumWorkgroupsPass.cpp
@@ -30,14 +30,14 @@
}
SetNumWorkgroupsPass(ArrayRef<int64_t> ws = {})
- : workgroupSize(ws.begin(), ws.end()) {}
+ : workloadPerWorkgroup(ws.begin(), ws.end()) {}
SetNumWorkgroupsPass(const SetNumWorkgroupsPass &pass)
- : workgroupSize(pass.workgroupSize) {}
+ : workloadPerWorkgroup(pass.workloadPerWorkgroup) {}
void runOnOperation() override;
private:
- SmallVector<int64_t> workgroupSize;
+ SmallVector<int64_t> workloadPerWorkgroup;
};
} // namespace
@@ -46,33 +46,50 @@
IREE::HAL::ExecutableTargetOp targetOp = getOperation();
ModuleOp module = targetOp.getInnerModule();
- if (workgroupSize.empty()) {
- // If no workgroup size is specified, leave the workgroup size as is, just
- // set the number of workgroups to be 1, 1, 1 to have a single invocation.
- WorkgroupCountRegionBuilder regionBuilder =
- [](OpBuilder &b, Location loc,
- std::array<Value, 3> workload) -> std::array<Value, 3> {
- Value one = b.create<ConstantIndexOp>(loc, 1);
- return {one, one, one};
- };
- OpBuilder builder(context);
- for (auto funcOp : module.getOps<FuncOp>()) {
- if (failed(defineWorkgroupCountRegion(builder, funcOp, regionBuilder))) {
+ llvm::StringMap<IREE::HAL::ExecutableEntryPointOp> entryPoints =
+ getAllEntryPoints(module);
+ for (auto funcOp : module.getOps<FuncOp>()) {
+ auto entryPointOp = entryPoints.lookup(funcOp.getName());
+ if (!entryPointOp) continue;
+ SmallVector<int64_t, 4> currWorkloadPerWorkgroup;
+
+ // First check if there is a workload provided.
+ if (!workloadPerWorkgroup.empty()) {
+ currWorkloadPerWorkgroup.assign(workloadPerWorkgroup.begin(),
+ workloadPerWorkgroup.end());
+ } else if (IREE::HAL::TranslationInfo translationInfo =
+ getTranslationInfo(entryPointOp)) {
+ if (ArrayAttr workloadPerWorkgroupAttr =
+ translationInfo.workloadPerWorkgroup()) {
+ currWorkloadPerWorkgroup = llvm::to_vector<4>(llvm::map_range(
+ workloadPerWorkgroupAttr,
+ [](Attribute attr) { return attr.cast<IntegerAttr>().getInt(); }));
+ }
+ }
+
+ if (currWorkloadPerWorkgroup.empty()) {
+ // If no workgroup size is specified, leave the workgroup size as is, just
+ // set the number of workgroups to be 1, 1, 1 to have a single invocation.
+ WorkgroupCountRegionBuilder regionBuilder =
+ [](OpBuilder &b, Location loc,
+ std::array<Value, 3> workload) -> std::array<Value, 3> {
+ Value one = b.create<ConstantIndexOp>(loc, 1);
+ return {one, one, one};
+ };
+ OpBuilder builder(context);
+ for (auto funcOp : module.getOps<FuncOp>()) {
+ if (failed(
+ defineWorkgroupCountRegion(builder, funcOp, regionBuilder))) {
+ return signalPassFailure();
+ }
+ }
+ } else {
+ if (failed(materializeStaticLaunchInformation(
+ funcOp, currWorkloadPerWorkgroup))) {
+ funcOp.emitError("failed to materialize constant workgroup size");
return signalPassFailure();
}
}
- return;
- }
-
- auto entryPointFn = getSingleEntryPointFunction(module);
- if (failed(entryPointFn)) {
- return signalPassFailure();
- }
- auto funcOp = entryPointFn.getValue();
-
- if (failed(materializeStaticLaunchInformation(funcOp, workgroupSize))) {
- funcOp.emitError("failed to materialize constant workgroup size");
- return signalPassFailure();
}
// Apply post distribution canonicalization passes.
diff --git a/iree/compiler/Conversion/Common/test/linalg_bufferize.mlir b/iree/compiler/Conversion/Common/test/linalg_bufferize.mlir
index e8bdb0a..9a78512 100644
--- a/iree/compiler/Conversion/Common/test/linalg_bufferize.mlir
+++ b/iree/compiler/Conversion/Common/test/linalg_bufferize.mlir
@@ -177,6 +177,7 @@
hal.interface.binding @ret0, set=0, binding=3, type="StorageBuffer", access="Write|Discard"
}
// CHECK-LABEL: func @tile_from_pointwise_lhs()
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<1x3xf32>
// CHECK-DAG: %[[TENSOR_LHS:.+]] = hal.interface.binding.subspan @io::@TENSOR_LHS
// CHECK-DAG: %[[TENSOR_RHS:.+]] = hal.interface.binding.subspan @io::@TENSOR_RHS
// CHECK-DAG: %[[TENSOR_INIT:.+]] = hal.interface.binding.subspan @io::@TENSOR_INIT
@@ -185,7 +186,6 @@
// CHECK: scf.for %[[IV1:.+]] = {{.+}} {
// CHECK-DAG: %[[LHS:.+]] = memref.subview %[[TENSOR_LHS]][%[[IV0]], 0] [1, 3] [1, 1]
// CHECK-DAG: %[[RHS:.+]] = memref.subview %[[TENSOR_RHS]][0, %[[IV1]]] [3, 1] [1, 1]
-// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<1x3xf32>
// CHECK: linalg.generic
// CHECK-SAME: ins(%[[LHS]] :
// CHECK-SAME: outs(%[[ALLOC]]
@@ -234,6 +234,7 @@
hal.interface.binding @TENSOR_INIT, set=0, binding=2, type="StorageBuffer", access="Read|Write"
}
// CHECK-LABEL: func @tile_from_pointwise_lhs_inplace()
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<1x3xf32>
// CHECK-DAG: %[[TENSOR_LHS:.+]] = hal.interface.binding.subspan @io::@TENSOR_LHS
// CHECK-DAG: %[[TENSOR_RHS:.+]] = hal.interface.binding.subspan @io::@TENSOR_RHS
// CHECK-DAG: %[[RETURN:.+]] = hal.interface.binding.subspan @io::@TENSOR_INIT
@@ -241,7 +242,6 @@
// CHECK: scf.for %[[IV1:.+]] = {{.+}} {
// CHECK-DAG: %[[LHS:.+]] = memref.subview %[[TENSOR_LHS]][%[[IV0]], 0] [1, 3] [1, 1]
// CHECK-DAG: %[[RHS:.+]] = memref.subview %[[TENSOR_RHS]][0, %[[IV1]]] [3, 1] [1, 1]
-// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<1x3xf32>
// CHECK: linalg.generic
// CHECK-SAME: ins(%[[LHS]] :
// CHECK-SAME: outs(%[[ALLOC]]
@@ -1172,12 +1172,12 @@
hal.interface.binding @wo2, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
}
// CHECK-LABEL: func @pooling_nhwc_sum
+// CHECK: %[[WINDOW:.+]] = memref.alloc() : memref<2x3xf32>
// CHECK-DAG: %[[INPUT:.+]] = hal.interface.binding.subspan @io::@ro1[%c0] : memref<1x4x6x1xf32>
// CHECK-DAG: %[[INIT:.+]] = hal.interface.binding.subspan @io::@ro0[%c0] : memref<f32>
// CHECK-DAG: %[[RET0:.+]] = hal.interface.binding.subspan @io::@wo2[%c0] : memref<1x2x2x1xf32>
// CHECK: %[[INIT_VAL:.+]] = memref.load %[[INIT]][] : memref<f32>
// CHECK: linalg.fill(%[[RET0]], %[[INIT_VAL]]) : memref<1x2x2x1xf32>, f32
-// CHECK: %[[WINDOW:.+]] = memref.alloc() : memref<2x3xf32>
// CHECK: linalg.pooling_nhwc_sum
// CHECK-SAME: dilations = dense<1> : vector<2xi64>
// CHECK-SAME: strides = dense<[2, 3]> : vector<2xi64>
@@ -1381,10 +1381,9 @@
}
// CHECK-LABEL: func @dont_use_buffer_for_operand_when_output_tensor_used()
-
-// CHECK: %[[OUTPUT:.+]] = hal.interface.binding.subspan @interface_io::@wo3
// CHECK: %[[ALLOC:.+]] = memref.alloc
-// CHECK-NEXT: linalg.fill(%[[ALLOC]], %{{.+}})
+// CHECK: %[[OUTPUT:.+]] = hal.interface.binding.subspan @interface_io::@wo3
+// CHECK: linalg.fill(%[[ALLOC]], %{{.+}})
// CHECK-NEXT: linalg.conv_2d_input_nhwc_filter_hwcf
// CHECK-SAME: outs(%[[ALLOC]] : memref<1x112x112x32xf32>)
// CHECK-NEXT: linalg.fill(%[[OUTPUT]], %{{.+}})
@@ -1745,6 +1744,8 @@
}
// CHECK-LABEL: func @padded_matmul()
+// CHECK-DAG: %[[LHS_PADDED:.+]] = memref.alloc() : memref<64x32xf32>
+// CHECK-DAG: %[[RHS_PADDED:.+]] = memref.alloc() : memref<32x16xf32>
// CHECK-DAG: %[[C0:.+]] = constant 0.000000e+00 : f32
// CHECK-DAG: %[[LHS:.+]] = hal.interface.binding.subspan @io::@s0b0_ro_external[%c0] : memref<12544x27xf32>
// CHECK-DAG: %[[RHS:.+]] = hal.interface.binding.subspan @io::@s0b1_ro_external[%c0] : memref<27x16xf32>
@@ -1753,11 +1754,9 @@
// CHECK-DAG: %[[RHS_V:.+]] = memref.subview %[[RHS]][0, %{{.*}}] [27, 16] [1, 1]
// CHECK-DAG: %[[DST_V:.+]] = memref.subview %[[DST]][%{{.*}}, %{{.*}}] [64, 16] [1, 1]
// CHECK: linalg.fill(%[[DST_V]], %[[C0]])
-// CHECK: %[[LHS_PADDED:.+]] = memref.alloc() : memref<64x32xf32>
// CHECK: linalg.fill(%[[LHS_PADDED]], %[[C0]]) : memref<64x32xf32>, f32
// CHECK: %[[LHS_PADDED_INTER:.+]] = memref.subview %[[LHS_PADDED]][0, 0] [64, 27] [1, 1]
// CHECK: linalg.copy(%[[LHS_V]], %[[LHS_PADDED_INTER]])
-// CHECK: %[[RHS_PADDED:.+]] = memref.alloc() : memref<32x16xf32>
// CHECK: linalg.fill(%[[RHS_PADDED]], %[[C0]]) : memref<32x16xf32>, f32
// CHECK: %[[RHS_PADDED_INTER:.+]] = memref.subview %[[RHS_PADDED]][0, 0] [27, 16] [1, 1]
// CHECK: linalg.copy(%[[RHS_V]], %[[RHS_PADDED_INTER]])
@@ -1815,9 +1814,12 @@
}
// CHECK: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (4, -d0 + s0)>
// CHECK: func @dot_general_padded
-// CHECK-DAG: %[[ARG0:.+]] = hal.interface.binding.subspan @io::@arg0
-// CHECK-DAG: %[[ARG1:.+]] = hal.interface.binding.subspan @io::@arg1
-// CHECK-DAG: %[[RET0:.+]] = hal.interface.binding.subspan @io::@ret0
+// CHECK-DAG: %[[ALLOC_RET0:.+]] = memref.alloc
+// CHECK-DAG: %[[ALLOC_ARG1:.+]] = memref.alloc
+// CHECK-DAG: %[[ALLOC_ARG0:.+]] = memref.alloc
+// CHECK-DAG: %[[ARG0:.+]] = hal.interface.binding.subspan @io::@arg0[{{.*}}] : memref<?x?xf32>
+// CHECK-DAG: %[[ARG1:.+]] = hal.interface.binding.subspan @io::@arg1[{{.*}}] : memref<?x?xf32>
+// CHECK-DAG: %[[RET0:.+]] = hal.interface.binding.subspan @io::@ret0[{{.*}}] : memref<?x?xf32>
// CHECK-DAG: %[[M:.+]] = hal.interface.load.constant offset = 0
// CHECK-DAG: %[[N:.+]] = hal.interface.load.constant offset = 1
// CHECK: scf.for %[[IV0:.+]] = %{{.+}} to %[[M]]
@@ -1826,15 +1828,11 @@
// CHECK-DAG: %[[TILE_N:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[N]]]
// CHECK-DAG: %[[ARG0_SV:.+]] = memref.subview %[[ARG0]]
// CHECK-DAG: %[[ARG1_SV:.+]] = memref.subview %[[ARG1]]
-// CHECK: %[[ALLOC_ARG0:.+]] = memref.alloc
-// CHECK: linalg.fill(%[[ALLOC_ARG0]]
+// CHECK: linalg.fill(%[[ALLOC_ARG0]]
// CHECK: %[[ALLOC_ARG0_SV:.+]] = memref.subview %[[ALLOC_ARG0]]
-// CHECK: linalg.copy(%[[ARG0_SV]]
-// CHECK: %[[ALLOC_ARG1:.+]] = memref.alloc
+// CHECK: linalg.copy(%[[ARG0_SV]], %[[ALLOC_ARG0_SV]])
// CHECK: linalg.fill(%[[ALLOC_ARG1]]
-// CHECK: %[[ALLOC_ARG1_SV:.+]] = memref.subview %[[ALLOC_ARG1]]
// CHECK: linalg.copy(%[[ARG1_SV]]
-// CHECK: %[[ALLOC_RET0:.+]] = memref.alloc
// CHECK: linalg.fill(%[[ALLOC_RET0]]
// CHECK: linalg.matmul
// CHECK-SAME: ins(%[[ALLOC_ARG0]], %[[ALLOC_ARG1]]
@@ -1904,16 +1902,16 @@
// CHECK-DAG: %[[ARG0:.+]] = hal.interface.binding.subspan @io::@arg0
// CHECK-DAG: %[[ARG1:.+]] = hal.interface.binding.subspan @io::@arg1
// CHECK-DAG: %[[RET0:.+]] = hal.interface.binding.subspan @io::@ret0
+// CHECK-DAG: %[[ALLOC_ARG0:.+]] = memref.alloc() : memref<1x16x16x3x3x8xf32>
+// CHECK-DAG: %[[ALLOC_ARG1:.+]] = memref.alloc() : memref<3x3x8x4xf32>
+// CHECK-DAG: %[[ALLOC_RET0:.+]] = memref.alloc() : memref<1x16x16x4xf32>
// CHECK: scf.for
// CHECK: scf.for
// CHECK: scf.for
// CHECK-DAG: %[[ARG0_SV:.+]] = memref.subview %[[ARG0]]
// CHECK-DAG: %[[ARG1_SV:.+]] = memref.subview %[[ARG1]]
-// CHECK-DAG: %[[ALLOC_ARG1:.+]] = memref.alloc()
// CHECK-DAG: linalg.copy(%[[ARG1_SV]], %[[ALLOC_ARG1]])
-// CHECK-DAG: %[[ALLOC_RET0:.+]] = memref.alloc()
// CHECK-DAG: linalg.fill(%[[ALLOC_RET0]]
-// CHECK-DAG: %[[ALLOC_ARG0:.+]] = memref.alloc()
// CHECK: linalg.generic
// CHECK-SAME: ins(%[[ARG0_SV]]
// CHECK-SAME: outs(%[[ALLOC_ARG0]]
diff --git a/iree/compiler/Conversion/LinalgToLLVM/BUILD b/iree/compiler/Conversion/LinalgToLLVM/BUILD
index 1fe6525..d6c96f2 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/BUILD
+++ b/iree/compiler/Conversion/LinalgToLLVM/BUILD
@@ -21,6 +21,7 @@
"PadLinalgWorkgroupTiles.cpp",
"Passes.cpp",
"PlanConvLoopOrder.cpp",
+ "TilePadAndVectorizeWorkgroups.cpp",
"UnfuseFMAOps.cpp",
],
hdrs = [
diff --git a/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt b/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
index 5a8318e..9a2ed21 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
+++ b/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
@@ -25,6 +25,7 @@
"PadLinalgWorkgroupTiles.cpp"
"Passes.cpp"
"PlanConvLoopOrder.cpp"
+ "TilePadAndVectorizeWorkgroups.cpp"
"UnfuseFMAOps.cpp"
DEPS
LLVMSupport
diff --git a/iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.cpp b/iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.cpp
index f2361eb..0b74cf7 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.cpp
+++ b/iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.cpp
@@ -10,6 +10,7 @@
#include "iree/compiler/Conversion/CodegenUtils/MarkerUtils.h"
#include "iree/compiler/Conversion/Common/Transforms.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "iree/compiler/Dialect/HAL/IR/LoweringConfig.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/CommandLine.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
@@ -69,7 +70,7 @@
/// Usually the tile sizes for the first level of tiling decides the workgroup
/// size for the dispatch on the CPU backend. This is a general helper that
/// converts tile sizes of the first level into workgroup sizes.
-static SmallVector<int64_t, 3> getWorkgroupSizeFromTileSizes(
+static SmallVector<int64_t, 3> getWorkloadPerWorkgroup(
ArrayRef<int64_t> distributedTileSizes) {
if (distributedTileSizes.size() > kNumMaxParallelDims) {
distributedTileSizes = distributedTileSizes.take_back(kNumMaxParallelDims);
@@ -77,12 +78,24 @@
return llvm::to_vector<3>(llvm::reverse(distributedTileSizes));
}
+/// Sets the translation info on the `hal.executable.entry_point` op
+/// corresponding to the `entryPointFn`. Returns failure if a translation info
+/// is already set on the entry point op and is incompatible with what is being
+/// set.
+static LogicalResult setTranslationInfo(
+ FuncOp entryPointFn, IREE::HAL::DispatchLoweringPassPipeline passPipeline,
+ ArrayRef<int64_t> workloadPerWorkgroup) {
+ auto entryPointOp = getEntryPoint(entryPointFn);
+ auto translationInfo = buildTranslationInfo(
+ passPipeline, workloadPerWorkgroup, entryPointFn.getContext());
+ return setTranslationInfo(entryPointOp, translationInfo);
+}
+
/// Sets the lowering configuration for dispatch region with root op that
/// implements the contraction operation interface.
-static Optional<IREE::HAL::TranslateExecutableInfo> setRootConfig(
- linalg::ContractionOpInterface contractionOp) {
- assert(!hasLoweringConfig(contractionOp) &&
- "illegal to update configuration of root");
+static LogicalResult setRootConfig(
+ FuncOp entryPointFn, linalg::ContractionOpInterface contractionOp) {
+ if (hasLoweringConfig(entryPointFn)) return success();
if (contractionOp.isRowMajorMatmul()) {
int mWorkgroupSize = matmulWorkgroupTileSize;
int nWorkgroupSize = matmulWorkgroupTileSize;
@@ -115,12 +128,12 @@
{matmulVectorSize, matmulVectorSize, matmulVectorSize}};
SmallVector<int64_t, 4> nativeVectorSize = {
matmulVectorSize, matmulVectorSize, matmulVectorSize};
- IREE::HAL::LoweringConfig config =
- getConfigAttr(tileSizes, nativeVectorSize, contractionOp->getContext());
+ IREE::HAL::LoweringConfig config = buildConfigAttr(
+ tileSizes, nativeVectorSize, contractionOp->getContext());
setLoweringConfig(contractionOp, config);
- return IREE::HAL::TranslateExecutableInfo{
- IREE::HAL::DispatchLoweringPassPipeline::CPUVectorization,
- getWorkgroupSizeFromTileSizes(tileSizes[0])};
+ return setTranslationInfo(
+ entryPointFn, IREE::HAL::DispatchLoweringPassPipeline::CPUVectorization,
+ getWorkloadPerWorkgroup(tileSizes[0]));
}
if (contractionOp.isRowMajorBatchMatmul()) {
// TODO(ataei, ravishankarm): This should just use the configuration for
@@ -133,14 +146,14 @@
batchMatmulL2TileSize}};
SmallVector<int64_t, 4> nativeVectorSize = {
1, batchMatmulL2TileSize, batchMatmulL2TileSize, batchMatmulL2TileSize};
- IREE::HAL::LoweringConfig config =
- getConfigAttr(tileSizes, nativeVectorSize, contractionOp->getContext());
+ IREE::HAL::LoweringConfig config = buildConfigAttr(
+ tileSizes, nativeVectorSize, contractionOp->getContext());
setLoweringConfig(contractionOp, config);
- return IREE::HAL::TranslateExecutableInfo{
- IREE::HAL::DispatchLoweringPassPipeline::CPUVectorization,
- getWorkgroupSizeFromTileSizes(tileSizes[0])};
+ return setTranslationInfo(
+ entryPointFn, IREE::HAL::DispatchLoweringPassPipeline::CPUVectorization,
+ getWorkloadPerWorkgroup(tileSizes[0]));
}
- return llvm::None;
+ return success();
}
/// Legalized the tile sizes for the first-level of tiling
@@ -159,8 +172,9 @@
/// Sets the lowering configuration for dispatch region with root op being a
/// generic op.
-static Optional<IREE::HAL::TranslateExecutableInfo> setRootConfig(
- linalg::GenericOp genericOp) {
+static LogicalResult setRootConfig(FuncOp entryPointFn,
+ linalg::GenericOp genericOp) {
+ if (hasLoweringConfig(genericOp)) return success();
int64_t numOuterParallelLoops = getNumOuterParallelLoops(genericOp);
SmallVector<int64_t, 4> workgroupTileSizes(numOuterParallelLoops,
genericOpsWorkgroupTileSize);
@@ -168,177 +182,88 @@
numOuterParallelLoops, workgroupTileSizes);
TileSizesListType tileSizes = {workgroupTileSizes};
IREE::HAL::LoweringConfig config =
- getConfigAttr(tileSizes, ArrayRef<int64_t>{}, genericOp->getContext());
+ buildConfigAttr(tileSizes, ArrayRef<int64_t>{}, genericOp->getContext());
setLoweringConfig(genericOp, config);
- return IREE::HAL::TranslateExecutableInfo{
- IREE::HAL::DispatchLoweringPassPipeline::CPUVectorization,
- getWorkgroupSizeFromTileSizes(tileSizes[0])};
-}
-
-/// Sets the configuration for a linalg op that is not the root of the
-/// dispatch. The configuration should use the tile sizes of the first level of
-/// tiling passed in through `firstLevelTileSizes` for correctness.
-// TODO(ravishankarm): This method should be deleted. The root configuration
-// must be enough. The pass pipeline must use the root configuration as the
-// driver of transformations. Leave it as is for now.
-LogicalResult setNonRootConfig(linalg::LinalgOp linalgOp,
- ArrayRef<int64_t> parallelLoopTileSizes) {
- int64_t numOuterParallelLoops = getNumOuterParallelLoops(linalgOp);
- if (parallelLoopTileSizes.size() != numOuterParallelLoops) {
- return linalgOp.emitError(
- "expected non root ops to have same number of outer-parallel loops as "
- "root op");
- }
- // TODO(ravishankarm): For now just set the first level of tile-size, but need
- // to extend this to make op-specific decision.
- auto vec = llvm::to_vector<4>(parallelLoopTileSizes);
- TileSizesListType tileSizes = {vec};
- IREE::HAL::LoweringConfig config =
- getConfigAttr(tileSizes, ArrayRef<int64_t>{}, linalgOp->getContext());
- setLoweringConfig(linalgOp, config);
- return success();
+ return setTranslationInfo(
+ entryPointFn, IREE::HAL::DispatchLoweringPassPipeline::CPUVectorization,
+ getWorkloadPerWorkgroup(tileSizes[0]));
}
/// Finds the root operation in the given list of linalg operations and sets its
/// configuration. Returns the root operation.
-static FailureOr<IREE::HAL::TranslateExecutableInfo> setRootConfig(
- ArrayRef<linalg::LinalgOp> linalgOps,
- SmallVectorImpl<int64_t> ¶llelLoopTileSizes) {
- // First iterate over all operations to find the root operations and set its
- // lowering configuration (that are not linalg.generic).
+static LogicalResult setRootConfig(FuncOp entryPointFn,
+ ArrayRef<linalg::LinalgOp> linalgOps) {
linalg::LinalgOp rootOp = nullptr;
-
- Optional<IREE::HAL::TranslateExecutableInfo> passPipeline;
- auto checkOrUpdatePassPipeline =
- [&](linalg::LinalgOp linalgOp,
- Optional<IREE::HAL::TranslateExecutableInfo> opPassPipeline)
- -> LogicalResult {
- if (!opPassPipeline) return success();
- if (passPipeline && passPipeline.getValue() != opPassPipeline.getValue()) {
- return linalgOp.emitError(
- "mismatch in translation configuration for ops in dispatch region");
- }
- if (!passPipeline) {
- passPipeline = opPassPipeline.getValue();
- rootOp = linalgOp;
- }
- return success();
- };
-
for (auto linalgOp : linalgOps) {
if (!hasMarker(linalgOp, getWorkgroupMarker())) continue;
- auto opPassPipeline =
- TypeSwitch<Operation *, Optional<IREE::HAL::TranslateExecutableInfo>>(
- linalgOp.getOperation())
+ auto status =
+ TypeSwitch<Operation *, LogicalResult>(linalgOp.getOperation())
.Case<linalg::ContractionOpInterface>(
- [&](auto op) { return setRootConfig(op); })
- .Default([](Operation *)
- -> Optional<IREE::HAL::TranslateExecutableInfo> {
- return llvm::None;
- });
- auto status = checkOrUpdatePassPipeline(linalgOp, opPassPipeline);
+ [&](auto op) { return setRootConfig(entryPointFn, op); })
+ .Default([](Operation *) { return success(); });
if (failed(status)) {
return status;
}
+ if (hasLoweringConfig(linalgOp)) {
+ if (rootOp) {
+ return linalgOp.emitError(
+ "unhandled multiple roots in dispatch region");
+ }
+ rootOp = linalgOp;
+ continue;
+ }
}
// If no root operation found, check if the dispatch region contains a single
// generic op and chose pipeline based on that.
- if (!passPipeline) {
+ if (!rootOp) {
for (auto linalgOp : linalgOps) {
if (!hasMarker(linalgOp, getWorkgroupMarker())) continue;
auto genericOp = dyn_cast<linalg::GenericOp>(linalgOp.getOperation());
if (!genericOp) continue;
- auto opPassPipeline = setRootConfig(genericOp);
- auto status = checkOrUpdatePassPipeline(linalgOp, opPassPipeline);
- if (failed(status)) {
- return status;
+ if (failed(setRootConfig(entryPointFn, genericOp))) {
+ return failure();
+ }
+ if (hasLoweringConfig(genericOp)) {
+ if (rootOp) {
+ return genericOp.emitError(
+ "unhandled multiple roots in dispatch region");
+ }
+ rootOp = genericOp;
+ continue;
}
}
}
-
- // If still no root operation, use default.
- if (!passPipeline) return failure();
-
- parallelLoopTileSizes =
- getTileSizes(rootOp, static_cast<unsigned>(TilingLevel::WorkGroupTiles));
-
- // Some consistency checks.
- int64_t numOuterParallelLoops = getNumOuterParallelLoops(rootOp);
- if (parallelLoopTileSizes.size() != numOuterParallelLoops) {
- LogicalResult status = rootOp.emitError(
- "expected as many tiles sizes as the parallel loops of the operation");
- return status;
- }
- auto distributedStart =
- std::max<int64_t>(0, numOuterParallelLoops - kNumMaxParallelDims);
- ArrayRef<int64_t> parallelLoopTileSizesRef(parallelLoopTileSizes);
- // THe outer non-distributed paralle loops must be zero.
- if (distributedStart &&
- llvm::any_of(parallelLoopTileSizesRef.take_front(distributedStart),
- [](int64_t v) -> bool { return v; })) {
- LogicalResult status = rootOp.emitError(
- "expected non-distributed parallel loop tile size to be 0");
- return status;
- }
- if (llvm::any_of(parallelLoopTileSizesRef.take_back(numOuterParallelLoops -
- distributedStart),
- [](int64_t v) -> bool { return !v; })) {
- LogicalResult status = rootOp.emitError(
- "expected distributed parallel loop tile size to be non-zero");
- return status;
- }
- return passPipeline.getValue();
+ return success();
}
-IREE::HAL::TranslateExecutableInfo initCPULaunchConfig(ModuleOp moduleOp) {
- // By default, the CPU backend will just lower the ops in the dispatch region
- // as is with no distribution.
- IREE::HAL::TranslateExecutableInfo pipelineOnFailure = {
- IREE::HAL::DispatchLoweringPassPipeline::CPUDefault, {}};
- // The current linalg based lowering only tested for a single function case.
- auto entryPointFn = getSingleEntryPointFunction(moduleOp);
- if (failed(entryPointFn)) {
- return pipelineOnFailure;
- }
- auto funcOps = moduleOp.getOps<FuncOp>();
- if (!llvm::hasSingleElement(funcOps)) {
- return pipelineOnFailure;
- }
- FuncOp funcOp = *funcOps.begin();
- SmallVector<linalg::LinalgOp, 4> linalgOps;
- SmallVector<Operation *, 4> tiledLoops;
- // If there are no linalg ops, not using Linalg based lowering.
- if (failed(getLinalgOps(funcOp, linalgOps, tiledLoops)) ||
- linalgOps.empty()) {
- return pipelineOnFailure;
- }
+LogicalResult initCPULaunchConfig(ModuleOp moduleOp) {
+ llvm::StringMap<IREE::HAL::ExecutableEntryPointOp> entryPointOps =
+ getAllEntryPoints(moduleOp);
+ for (auto funcOp : moduleOp.getOps<FuncOp>()) {
+ auto entryPointOp = entryPointOps.lookup(funcOp.getName());
+ if (!entryPointOp) continue;
+ SmallVector<linalg::LinalgOp, 4> linalgOps;
+ SmallVector<Operation *, 4> tiledLoops;
+ // If there are no linalg ops, not using Linalg based lowering.
+ if (succeeded(getLinalgOps(funcOp, linalgOps, tiledLoops)) &&
+ !linalgOps.empty()) {
+ if (failed(setRootConfig(funcOp, linalgOps))) {
+ return failure();
+ }
+ }
- SmallVector<int64_t> parallelLoopTileSizes;
- auto passPipeline = setRootConfig(linalgOps, parallelLoopTileSizes);
- if (failed(passPipeline)) {
- return pipelineOnFailure;
- }
-
- // Set the configuration of all other linalg operations that are not the root
- // operation.
- // TODO(ravishankarm): The configuration of the root must drive the lowering
- // completely. This step should be removed.
- LogicalResult status = success();
- for (auto linalgOp : linalgOps) {
- if (hasLoweringConfig(linalgOp)) continue;
- status = setNonRootConfig(linalgOp, parallelLoopTileSizes);
- if (failed(status)) {
- break;
+ // If the function entry point already doesnt have a lowering info attribute
+ // on it, just add the default.
+ if (!getTranslationInfo(entryPointOp)) {
+ if (failed(setTranslationInfo(
+ funcOp, IREE::HAL::DispatchLoweringPassPipeline::CPUDefault,
+ {}))) {
+ return failure();
+ }
}
}
- if (failed(status)) {
- for (auto linalgOp : linalgOps) {
- eraseLoweringConfig(linalgOp);
- }
- return pipelineOnFailure;
- }
- return passPipeline.getValue();
+ return success();
}
} // namespace iree_compiler
diff --git a/iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.h b/iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.h
index 54ac9ec..a9a1da8 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.h
+++ b/iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.h
@@ -20,7 +20,7 @@
NumTileLevels = 3
};
-IREE::HAL::TranslateExecutableInfo initCPULaunchConfig(ModuleOp moduleOp);
+LogicalResult initCPULaunchConfig(ModuleOp moduleOp);
} // namespace iree_compiler
} // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToLLVM/LowerExecutableTargetPass.cpp b/iree/compiler/Conversion/LinalgToLLVM/LowerExecutableTargetPass.cpp
index ee91566..78528e8 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/LowerExecutableTargetPass.cpp
+++ b/iree/compiler/Conversion/LinalgToLLVM/LowerExecutableTargetPass.cpp
@@ -57,11 +57,11 @@
"expected to work on the std.module op within the "
"hal.executable.target operation")};
- ListOption<int> workgroupSizes{
- *this, "workgroup-size", llvm::cl::MiscFlags::CommaSeparated,
+ ListOption<int> workloadPerWorkgroup{
+ *this, "workload-per-workgroup", llvm::cl::MiscFlags::CommaSeparated,
llvm::cl::desc(
- "Specifies the workgroup size to use in x, y, z order. Is expected "
- "for use only with use-lowering-pipeline option")};
+ "Specifies the workload per workgroup to use in x, y, z order. Is "
+ "expected for use only with use-lowering-pipeline option")};
/// TODO(ravishankarm): Option to not generate any `vector.` instructions. The
/// VMVX backend uses the same lowering as the CPU pass but there is no
@@ -96,10 +96,11 @@
if (!useLoweringPipeline.empty()) {
// Use the pass pipeline specified in the command line.
- SmallVector<int64_t, 4> dispatchWorkgroupSize;
- dispatchWorkgroupSize.assign(workgroupSizes.begin(), workgroupSizes.end());
+ SmallVector<int64_t, 4> workloadPerWorkgroupVec;
+ workloadPerWorkgroupVec.assign(workloadPerWorkgroup.begin(),
+ workloadPerWorkgroup.end());
executableLoweringPipeline.addPass(
- createSetNumWorkgroupsPass(dispatchWorkgroupSize));
+ createSetNumWorkgroupsPass(workloadPerWorkgroupVec));
OpPassManager &nestedModulePM = executableLoweringPipeline.nest<ModuleOp>();
if (failed(parsePassPipeline(sanitizePipelineString(useLoweringPipeline),
nestedModulePM))) {
@@ -107,23 +108,49 @@
}
} else {
// Use default heuristics.
- FailureOr<IREE::HAL::TranslateExecutableInfo> translationInfo =
- initCPULaunchConfig(moduleOp);
- if (failed(translationInfo)) {
+ if (failed(initCPULaunchConfig(moduleOp))) {
return signalPassFailure();
}
- executableLoweringPipeline.addPass(
- createSetNumWorkgroupsPass(translationInfo->workgroupSize));
+ // There might be multiple entry points in the module. Currently, all of
+ // them need to have the same pipeline.
+ // TODO(ravishankarm): This is strange that this is not enforced
+ // structurally, but something to address later on. For now this restriction
+ // is fine.
+ llvm::StringMap<IREE::HAL::ExecutableEntryPointOp> entryPoints =
+ getAllEntryPoints(moduleOp);
+ Optional<IREE::HAL::DispatchLoweringPassPipeline> passPipeline;
+ for (auto &it : entryPoints) {
+ auto entryPointOp = it.second;
+ if (IREE::HAL::TranslationInfo translationInfo =
+ getTranslationInfo(entryPointOp)) {
+ IREE::HAL::DispatchLoweringPassPipeline currPipeline =
+ translationInfo.passPipeline().getValue();
+ if (passPipeline) {
+ if (currPipeline != passPipeline.getValue()) {
+ moduleOp.emitError(
+ "unhandled compilation of entry point function with different "
+ "pass pipelines within a module");
+ return signalPassFailure();
+ }
+ continue;
+ }
+ passPipeline = currPipeline;
+ }
+ }
+
+ executableLoweringPipeline.addPass(createSetNumWorkgroupsPass());
OpPassManager &nestedModulePM = executableLoweringPipeline.nest<ModuleOp>();
if (!testLoweringConfiguration) {
- switch (translationInfo->passPipeline) {
+ switch (passPipeline.getValue()) {
case IREE::HAL::DispatchLoweringPassPipeline::CPUDefault:
addCPUDefaultPassPipeline(nestedModulePM);
break;
case IREE::HAL::DispatchLoweringPassPipeline::CPUVectorization:
addCPUVectorizationPassPipeline(nestedModulePM, lowerToVectors);
break;
+ default:
+ llvm_unreachable("Unsupported pipeline on CPU target.");
}
}
}
diff --git a/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp b/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp
index 3264e6d..6483c10 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp
+++ b/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp
@@ -17,6 +17,11 @@
namespace mlir {
namespace iree_compiler {
+static llvm::cl::opt<bool> clUseTensorPadTileAndVectorize(
+ "iree-codegen-linalg-to-llvm-use-tensor-to-vectors",
+ llvm::cl::desc("If enabled will use tensor -> vector transformation pass"),
+ llvm::cl::init(false));
+
static Value cpuAllocationFunction(OpBuilder &builder, Location loc,
ArrayRef<int64_t> staticShape,
Type elementType,
@@ -33,14 +38,27 @@
// re-enable.
// passManager.addNestedPass<FuncOp>(createPadLinalgWorkgroupTilesPass());
+ if (clUseTensorPadTileAndVectorize) {
+ // Tile and vectorize linalg ops on tensors.
+ passManager.addNestedPass<FuncOp>(
+ createTilePadAndVectorizeWorkgroupsPass());
+ passManager.addNestedPass<FuncOp>(createCSEPass());
+ passManager.addNestedPass<FuncOp>(createCanonicalizerPass());
+ }
+
// Use stack allocation on CPU side.
addLinalgBufferizePasses(passManager, cpuAllocationFunction);
+ passManager.addNestedPass<FuncOp>(createCSEPass());
+ passManager.addNestedPass<FuncOp>(createCanonicalizerPass());
- // Tile and vectorize linalg ops.
- passManager.addNestedPass<FuncOp>(createCanonicalizerPass());
- passManager.addNestedPass<FuncOp>(
- createLinalgTileAndVectorizeWorkgroupsPass(lowerToVectors));
- passManager.addNestedPass<FuncOp>(createCanonicalizerPass());
+ if (!clUseTensorPadTileAndVectorize) {
+ // Tile and vectorize linalg ops on buffers.
+ passManager.addNestedPass<FuncOp>(
+ createLinalgTileAndVectorizeWorkgroupsPass(lowerToVectors));
+ passManager.addNestedPass<FuncOp>(createCSEPass());
+ passManager.addNestedPass<FuncOp>(createCanonicalizerPass());
+ }
+
passManager.addNestedPass<FuncOp>(createForOpCanonicalizationPass());
passManager.addNestedPass<FuncOp>(createPlanConvLoopOrderPass());
diff --git a/iree/compiler/Conversion/LinalgToLLVM/Passes.h b/iree/compiler/Conversion/LinalgToLLVM/Passes.h
index 252c2b0..19a2e0e 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/Passes.h
+++ b/iree/compiler/Conversion/LinalgToLLVM/Passes.h
@@ -23,6 +23,9 @@
/// vector size.
std::unique_ptr<OperationPass<FuncOp>> createPadLinalgWorkgroupTilesPass();
+/// Multi-level tiling, padding and vectorization of linalg ops on tensors.
+std::unique_ptr<FunctionPass> createTilePadAndVectorizeWorkgroupsPass();
+
/// Vectorizes linalg ops executed in the same hal.interface.workgroup.
std::unique_ptr<FunctionPass> createLinalgTileAndVectorizeWorkgroupsPass(
bool lowerToVectors = true);
diff --git a/iree/compiler/Conversion/LinalgToLLVM/TilePadAndVectorizeWorkgroups.cpp b/iree/compiler/Conversion/LinalgToLLVM/TilePadAndVectorizeWorkgroups.cpp
new file mode 100644
index 0000000..d971231
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToLLVM/TilePadAndVectorizeWorkgroups.cpp
@@ -0,0 +1,210 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Conversion/CodegenUtils/MarkerUtils.h"
+#include "iree/compiler/Conversion/CodegenUtils/TransformUtils.h"
+#include "iree/compiler/Conversion/Common/Transforms.h"
+#include "iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "iree-linalg-to-llvm-tile-and-pad-workgroups"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+// Could just be linalg::TilingPattern with a ContractionOpInterface filter, but
+// that is always templated on an op.
+struct TileWorkgroups : public linalg::LinalgBaseTilingPattern {
+ using Base = linalg::LinalgBaseTilingPattern;
+ TileWorkgroups(MLIRContext *context, linalg::LinalgTilingOptions options,
+ linalg::LinalgTransformationFilter marker)
+ : LinalgBaseTilingPattern(context, options, marker) {}
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ auto contractionOp = dyn_cast<linalg::ContractionOpInterface>(op);
+ if (!contractionOp) return failure();
+
+ linalg::TiledLinalgOp tiledLinalgOp;
+ if (failed(Base::matchAndRewriteBase(op, rewriter, tiledLinalgOp))) {
+ return failure();
+ }
+ rewriter.replaceOp(op, tiledLinalgOp.tensorResults);
+ return success();
+ }
+};
+
+} // namespace
+
+namespace {
+struct TilePadAndVectorizeWorkgroupsPass
+ : public PassWrapper<TilePadAndVectorizeWorkgroupsPass, FunctionPass> {
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<linalg::LinalgDialect, memref::MemRefDialect,
+ vector::VectorDialect>();
+ }
+ void runOnFunction() override;
+};
+} // namespace
+
+void TilePadAndVectorizeWorkgroupsPass::runOnFunction() {
+ MLIRContext *context = &getContext();
+ auto funcOp = getOperation();
+
+ // First level of tiling patterns {
+ {
+ OwningRewritePatternList l1patterns(&getContext());
+ l1patterns.insert<TileWorkgroups>(
+ context,
+ linalg::LinalgTilingOptions()
+ .setTileSizeComputationFunction(
+ [](OpBuilder &builder,
+ Operation *operation) -> SmallVector<Value, 4> {
+ return getTileSizes(
+ builder, operation,
+ static_cast<unsigned>(TilingLevel::Level1Tiles));
+ })
+ .setPaddingValueComputationFunction(
+ [](OpBuilder &b, OpOperand &op) -> Value {
+ auto t = getElementTypeOrSelf(op.get().getType());
+ return b.create<ConstantOp>(op.getOwner()->getLoc(), t,
+ b.getZeroAttr(t));
+ }),
+ linalg::LinalgTransformationFilter(
+ Identifier::get(getWorkgroupMarker(), context),
+ Identifier::get(getWorkgroupL1TileMarker(), context)));
+
+ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(l1patterns)))) {
+ return signalPassFailure();
+ }
+ }
+
+ // Apply canoncalization
+ {
+ OwningRewritePatternList canonicalizationPatterns(&getContext());
+ linalg::populateLinalgTilingCanonicalizationPatterns(
+ canonicalizationPatterns);
+ memref::DimOp::getCanonicalizationPatterns(canonicalizationPatterns,
+ context);
+ canonicalizationPatterns.add<linalg::AffineMinSCFCanonicalizationPattern>(
+ &getContext());
+ if (failed(applyPatternsAndFoldGreedily(
+ funcOp, std::move(canonicalizationPatterns)))) {
+ return signalPassFailure();
+ }
+ }
+ // Second level of tiling patterns{
+ {
+ OwningRewritePatternList l1patterns(&getContext());
+ l1patterns.insert<TileWorkgroups>(
+ context,
+ linalg::LinalgTilingOptions().setTileSizeComputationFunction(
+ [](OpBuilder &builder,
+ Operation *operation) -> SmallVector<Value, 4> {
+ return getTileSizes(
+ builder, operation,
+ static_cast<unsigned>(TilingLevel::Level2Tiles));
+ }),
+ linalg::LinalgTransformationFilter(
+ Identifier::get(getWorkgroupL1TileMarker(), context),
+ Identifier::get(getVectorizeMarker(), context)));
+
+ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(l1patterns)))) {
+ return signalPassFailure();
+ }
+ }
+ // Apply canoncalization
+ {
+ OwningRewritePatternList canonicalizationPatterns(&getContext());
+ linalg::populateLinalgTilingCanonicalizationPatterns(
+ canonicalizationPatterns);
+ memref::DimOp::getCanonicalizationPatterns(canonicalizationPatterns,
+ context);
+ canonicalizationPatterns.add<linalg::AffineMinSCFCanonicalizationPattern>(
+ &getContext());
+ if (failed(applyPatternsAndFoldGreedily(
+ funcOp, std::move(canonicalizationPatterns)))) {
+ return signalPassFailure();
+ }
+ }
+ // Apply vectorization patterns.
+ {
+ OwningRewritePatternList vectorizationPatterns(&getContext());
+ linalg::insertVectorizationPatterns<linalg::ContractionOpInterface,
+ linalg::CopyOp, linalg::FillOp>(
+ vectorizationPatterns, linalg::LinalgVectorizationOptions(),
+ linalg::LinalgTransformationFilter(
+ Identifier::get(getVectorizeMarker(), context)));
+ if (failed(applyPatternsAndFoldGreedily(
+ funcOp, std::move(vectorizationPatterns)))) {
+ return signalPassFailure();
+ }
+ }
+
+ // TODO: This should be a folding of Add into Contract in core but while
+ // they live in different dialects, it is not possible without unnatural
+ // dependencies.
+ funcOp.walk([&](Operation *op) {
+ if (auto contract = canonicalizeContractionAdd(op))
+ op->replaceAllUsesWith(contract);
+ });
+ // Apply vector specific operation lowering.
+ {
+ vector::VectorTransformsOptions vectorTransformsOptions =
+ vector::VectorTransformsOptions().setVectorTransformsOptions(
+ vector::VectorContractLowering::OuterProduct);
+ OwningRewritePatternList vectorContractLoweringPatterns(&getContext());
+ vectorContractLoweringPatterns
+ .insert<ContractionOpToOuterProductOpLowering,
+ ContractionOpToMatmulOpLowering, ContractionOpLowering>(
+ vectorTransformsOptions, context);
+ vector::populateVectorTransferLoweringPatterns(
+ vectorContractLoweringPatterns);
+ if (failed(applyPatternsAndFoldGreedily(
+ funcOp, std::move(vectorContractLoweringPatterns)))) {
+ return signalPassFailure();
+ }
+ }
+ //
+ // Hosit hierarchical tiling indexing and other loop invariant transfer
+ // ops computation.
+ //
+ // Programmatic controlled lowering of vector.transfer only.
+ {
+ VectorTransferToSCFOptions vectorToSCFOptions =
+ VectorTransferToSCFOptions().setUnroll(true);
+ OwningRewritePatternList vectorToLoopsPatterns(&getContext());
+ populateVectorToSCFConversionPatterns(vectorToLoopsPatterns,
+ vectorToSCFOptions);
+ // Hosit hierarchical tiling indexing and other loop invariant transfer
+ // ops computation.
+ linalg::hoistRedundantVectorTransfers(funcOp);
+
+ memref::populateFoldSubViewOpPatterns(vectorToLoopsPatterns);
+ if (failed(applyPatternsAndFoldGreedily(
+ funcOp, std::move(vectorToLoopsPatterns)))) {
+ return signalPassFailure();
+ }
+ }
+}
+
+std::unique_ptr<FunctionPass> createTilePadAndVectorizeWorkgroupsPass() {
+ return std::make_unique<TilePadAndVectorizeWorkgroupsPass>();
+}
+
+static PassRegistration<TilePadAndVectorizeWorkgroupsPass> pass(
+ "iree-codegen-linalg-to-llvm-tile-pad-and-vectorize-workgroups",
+ "Tile and pad workgroups tiles",
+ [] { return std::make_unique<TilePadAndVectorizeWorkgroupsPass>(); });
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToLLVM/test/BUILD b/iree/compiler/Conversion/LinalgToLLVM/test/BUILD
index 6aefad2..a9fceec 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/test/BUILD
+++ b/iree/compiler/Conversion/LinalgToLLVM/test/BUILD
@@ -27,6 +27,7 @@
"matmul_vectorization.mlir",
"pad_linalg_workgroup_tiles.mlir",
"plan_conv_loop_order.mlir",
+ "tile_pad_and_vectorize_workgroups.mlir",
"unfused_fma.mlir",
],
include = ["*.mlir"],
diff --git a/iree/compiler/Conversion/LinalgToLLVM/test/CMakeLists.txt b/iree/compiler/Conversion/LinalgToLLVM/test/CMakeLists.txt
index 7b0821f..f833305 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/test/CMakeLists.txt
+++ b/iree/compiler/Conversion/LinalgToLLVM/test/CMakeLists.txt
@@ -22,6 +22,7 @@
"matmul_vectorization.mlir"
"pad_linalg_workgroup_tiles.mlir"
"plan_conv_loop_order.mlir"
+ "tile_pad_and_vectorize_workgroups.mlir"
"unfused_fma.mlir"
DATA
iree::tools::IreeFileCheck
diff --git a/iree/compiler/Conversion/LinalgToLLVM/test/materialize_launch_configuration.mlir b/iree/compiler/Conversion/LinalgToLLVM/test/materialize_launch_configuration.mlir
index eaebb09..0247bdd 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/test/materialize_launch_configuration.mlir
+++ b/iree/compiler/Conversion/LinalgToLLVM/test/materialize_launch_configuration.mlir
@@ -52,8 +52,7 @@
}
}
-// CHECK-DAG: #[[CONFIG0:.+]] = {tileSizes = {{\[}}[64, 64]{{\]}}}
-// CHECK-DAG: #[[CONFIG1:.+]] = {nativeVectorSize = [4, 4, 4], tileSizes = {{\[}}[64, 64], [32, 32, 32], [4, 4, 4]{{\]}}}
+// CHECK-DAG: #[[CONFIG:.+]] = {nativeVectorSize = [4, 4, 4], tileSizes = {{\[}}[64, 64], [32, 32, 32], [4, 4, 4]{{\]}}}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)>
// CHECK: hal.executable.entry_point @matmul_tensors
// CHECK-NEXT: (%[[ARG0:[a-zA-Z0-9_]+]]: index
@@ -63,12 +62,8 @@
// CHECK-DAG: %[[D0:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
// CHECK-DAG: %[[D1:.+]] = affine.apply #[[MAP0]]()[%[[ARG1]]]
// CHECK: hal.return %[[D0]], %[[D1]], %[[C1]] : index, index, index
-// CHECK: linalg.copy
-// CHECK-SAME: lowering.config = #[[CONFIG0]]
// CHECK: linalg.matmul
-// CHECK-SAME: lowering.config = #[[CONFIG1]]
-// CHECK: linalg.copy
-// CHECK-SAME: lowering.config = #[[CONFIG0]]
+// CHECK-SAME: lowering.config = #[[CONFIG]]
// -----
@@ -91,7 +86,7 @@
%0 = hal.interface.binding.subspan @io::@arg0[%c0] : memref<?x?xf32>
%1 = hal.interface.binding.subspan @io::@arg1[%c0] : memref<?xf32>
%2 = hal.interface.binding.subspan @io::@ret0[%c0] : memref<?x?xf32>
- linalg.generic {__internal_linalg_transform__ = "workgroup"} {
+ linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d1)>,
affine_map<(d0, d1) -> (d0, d1)>],
@@ -346,8 +341,7 @@
}
}
}
-// CHECK-DAG: #[[CONFIG0:.+]] = {tileSizes = {{\[}}[1, 32, 32]{{\]}}
-// CHECK-DAG: #[[CONFIG1:.+]] = {nativeVectorSize = [1, 4, 4, 4], tileSizes = {{\[}}[1, 32, 32], [1, 16, 16, 16], [1, 4, 4, 4]{{\]}}
+// CHECK-DAG: #[[CONFIG:.+]] = {nativeVectorSize = [1, 4, 4, 4], tileSizes = {{\[}}[1, 32, 32], [1, 16, 16, 16], [1, 4, 4, 4]{{\]}}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 32)>
// CHECK: hal.executable.entry_point @batch_matmul_tensors
// CHECK-NEXT: (%[[ARG0:[a-zA-Z0-9]+]]: index
@@ -356,9 +350,5 @@
// CHECK-DAG: %[[D0:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
// CHECK-DAG: %[[D1:.+]] = affine.apply #[[MAP0]]()[%[[ARG1]]]
// CHECK: hal.return %[[D0]], %[[D1]], %[[ARG2]]
-// CHECK: linalg.copy
-// CHECK-SAME: lowering.config = #[[CONFIG0]]
// CHECK: linalg.batch_matmul
-// CHECK-SAME: lowering.config = #[[CONFIG1]]
-// CHECK: linalg.copy
-// CHECK-SAME: lowering.config = #[[CONFIG0]]
+// CHECK-SAME: lowering.config = #[[CONFIG]]
diff --git a/iree/compiler/Conversion/LinalgToLLVM/test/tile_pad_and_vectorize_workgroups.mlir b/iree/compiler/Conversion/LinalgToLLVM/test/tile_pad_and_vectorize_workgroups.mlir
new file mode 100644
index 0000000..1cb8997
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToLLVM/test/tile_pad_and_vectorize_workgroups.mlir
@@ -0,0 +1,80 @@
+// RUN: iree-opt %s -cse -iree-codegen-linalg-to-llvm-tile-pad-and-vectorize-workgroups -cse -canonicalize -split-input-file | IreeFileCheck %s
+
+#config0 = {tileSizes = [[64, 64]]}
+#config1 = {nativeVectorSize = [4, 4, 4], tileSizes = [[64, 64], [32, 32, 32], [4, 4, 4]]}
+#map0 = affine_map<()[s0] -> (s0 * 64)>
+#map1 = affine_map<(d0) -> (64, -d0 + 383)>
+#map2 = affine_map<(d0) -> (64, -d0 + 513)>
+#map3 = affine_map<(d0) -> (-d0 + 383, 64)>
+#map4 = affine_map<(d0) -> (-d0 + 513, 64)>
+module {
+ func @dot_383x383x513_dispatch_0() {
+ %c0 = constant 0 : index
+ %c513 = constant 513 : index
+ %c383 = constant 383 : index
+ %cst = constant 0.000000e+00 : f32
+ %0 = hal.interface.binding.subspan @io::@s0b0_ro_external[%c0] : !flow.dispatch.tensor<readonly:383x383xf32>
+ %1 = hal.interface.binding.subspan @io::@s0b1_ro_external[%c0] : !flow.dispatch.tensor<readonly:383x513xf32>
+ %2 = hal.interface.binding.subspan @io::@s0b2_xw_external[%c0] : !flow.dispatch.tensor<writeonly:383x513xf32>
+ %workgroup_id_x = hal.interface.workgroup.id[0] : index
+ %workgroup_count_x = hal.interface.workgroup.count[0] : index
+ %workgroup_id_y = hal.interface.workgroup.id[1] : index
+ %workgroup_count_y = hal.interface.workgroup.count[1] : index
+ %3 = affine.apply #map0()[%workgroup_id_y]
+ %4 = affine.apply #map0()[%workgroup_count_y]
+ scf.for %arg0 = %3 to %c383 step %4 {
+ %5 = affine.apply #map0()[%workgroup_id_x]
+ %6 = affine.apply #map0()[%workgroup_count_x]
+ scf.for %arg1 = %5 to %c513 step %6 {
+ %7 = affine.min #map1(%arg0)
+ %8 = flow.dispatch.tensor.load %0, offsets = [%arg0, 0], sizes = [%7, 383], strides = [1, 1] : !flow.dispatch.tensor<readonly:383x383xf32> -> tensor<?x383xf32>
+ %9 = affine.min #map2(%arg1)
+ %10 = flow.dispatch.tensor.load %1, offsets = [0, %arg1], sizes = [383, %9], strides = [1, 1] : !flow.dispatch.tensor<readonly:383x513xf32> -> tensor<383x?xf32>
+ %11 = affine.min #map3(%arg0)
+ %12 = affine.min #map4(%arg1)
+ %13 = linalg.init_tensor [%11, %12] : tensor<?x?xf32>
+ %14 = linalg.fill(%13, %cst) {__internal_linalg_transform__ = "workgroup", lowering.config = #config0} : tensor<?x?xf32>, f32 -> tensor<?x?xf32>
+ %15 = linalg.matmul {__internal_linalg_transform__ = "workgroup", lowering.config = #config1} ins(%8, %10 : tensor<?x383xf32>, tensor<383x?xf32>) outs(%14 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ flow.dispatch.tensor.store %15, %2, offsets = [%arg0, %arg1], sizes = [%7, %9], strides = [1, 1] : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:383x513xf32>
+ }
+ }
+ return
+ }
+ hal.interface @io attributes {sym_visibility = "private"} {
+ hal.interface.binding @s0b0_ro_external, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @s0b1_ro_external, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @s0b2_xw_external, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+ }
+}
+// CHECK-LABEL: @dot_383x383x513_dispatch_0
+// CHECK-DAG: %[[CST:.+]] = constant 0.000000e+00 : f32
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C4:.+]] = constant 4 : index
+// CHECK-DAG: %[[C383:.+]] = constant 383 : index
+// CHECK-DAG: %[[C513:.+]] = constant 513 : index
+// CHECK-DAG: %[[C32:.+]] = constant 32 : index
+// CHECK: %[[LHS:.+]] = hal.interface.binding.subspan @io::@s0b0_ro_external[%c0] : !flow.dispatch.tensor<readonly:383x383xf32>
+// CHECK: %[[RHS:.+]] = hal.interface.binding.subspan @io::@s0b1_ro_external[%c0] : !flow.dispatch.tensor<readonly:383x513xf32>
+// CHECK: %[[DST:.+]] = hal.interface.binding.subspan @io::@s0b2_xw_external[%c0] : !flow.dispatch.tensor<writeonly:383x513xf32>
+// CHECK: %[[LHS_WG_TILE:.+]] = flow.dispatch.tensor.load %[[LHS]]
+// CHECK: %[[RHS_WG_TILE:.+]] = flow.dispatch.tensor.load %[[RHS]]
+// CHECK: %[[DST_WG_TILE_INIT:.+]] = linalg.init_tensor
+// CHECK: %[[DST_WG_TILE_INIT_C0:.+]] = linalg.fill(%[[DST_WG_TILE_INIT]], %[[CST]])
+// CHECK: {{.*}} = scf.for {{.*}} = %[[C0]] to %[[C383]] step %[[C32]] iter_args(%[[DST_WG_TILE_0:.+]] = %[[DST_WG_TILE_INIT_C0]])
+// CHECK: {{.*}} = scf.for {{.*}} = %[[C0]] to %[[C513]] step %[[C32]] iter_args(%[[DST_WG_TILE_1:.+]] = %[[DST_WG_TILE_0]])
+// CHECK: {{.*}} = scf.for {{.*}} = %[[C0]] to %[[C383]] step %[[C32]] iter_args(%[[DST_WG_TILE_2:.+]] = %[[DST_WG_TILE_1]])
+// CHECK: %[[LHS_L1_TILE:.+]] = subtensor %[[LHS_WG_TILE]]
+// CHECK: %[[RHS_L1_TILE:.+]] = subtensor %[[RHS_WG_TILE]]
+// CHECK: %[[DST_L1_TILE:.+]] = subtensor %[[DST_WG_TILE_2]]
+// CHECK: %[[LHS_L1_TILE_PADDED:.+]] = linalg.pad_tensor %[[LHS_L1_TILE]]
+// CHECK: %[[RHS_L1_TILE_PADDED:.+]] = linalg.pad_tensor %[[RHS_L1_TILE]]
+// CHECK: %[[DST_L1_TILE_PADDED:.+]] = linalg.pad_tensor %[[DST_L1_TILE]]
+// CHECK: {{.*}} = scf.for {{.*}} = %[[C0]] to %[[C32]] step %[[C4]] iter_args(%[[DST_VEC_TILE_0:.+]] = %[[DST_L1_TILE_PADDED]])
+// CHECK: {{.*}} = scf.for {{.*}} = %[[C0]] to %[[C32]] step %[[C4]] iter_args(%[[DST_VEC_TILE_1:.+]] = %[[DST_VEC_TILE_0]])
+// CHECK: {{.*}} = scf.for {{.*}} = %[[C0]] to %[[C32]] step %[[C4]] iter_args(%[[DST_VEC_TILE_2:.+]] = %[[DST_VEC_TILE_1]])
+// CHECK: %[[LHS_VEC_TILE:.+]] = subtensor %[[LHS_L1_TILE_PADDED]]
+// CHECK: %[[RHS_VEC_TILE:.+]] = subtensor %[[RHS_L1_TILE_PADDED]]
+// CHECK: %[[DST_VEC_TILE:.+]] = subtensor %[[DST_VEC_TILE_2]]
+// CHECK: %[[LHS_VEC:.+]] = vector.transfer_read %[[LHS_VEC_TILE]]
+// CHECK: %[[RHS_VEC:.+]] = vector.transfer_read %[[RHS_VEC_TILE]]
+// CHECK: %[[DST_VEC:.+]] = vector.transfer_read %[[DST_VEC_TILE]]
diff --git a/iree/compiler/Conversion/LinalgToLLVMGPU/BUILD b/iree/compiler/Conversion/LinalgToLLVMGPU/BUILD
index 88236b7..42e637d 100644
--- a/iree/compiler/Conversion/LinalgToLLVMGPU/BUILD
+++ b/iree/compiler/Conversion/LinalgToLLVMGPU/BUILD
@@ -17,6 +17,7 @@
"ConvertToNVVM.cpp",
"ConvertToROCDL.cpp",
"KernelConfig.cpp",
+ "LowerExecutableTargetPass.cpp",
"Passes.cpp",
"RemoveTrivialLoops.cpp",
"TileAndDistribute.cpp",
@@ -35,6 +36,7 @@
"//iree/compiler/Dialect/Shape/Transforms",
"@llvm-project//mlir:Affine",
"@llvm-project//mlir:AffineToStandard",
+ "@llvm-project//mlir:GPUDialect",
"@llvm-project//mlir:GPUToNVVMTransforms",
"@llvm-project//mlir:GPUToROCDLTransforms",
"@llvm-project//mlir:GPUTransforms",
diff --git a/iree/compiler/Conversion/LinalgToLLVMGPU/CMakeLists.txt b/iree/compiler/Conversion/LinalgToLLVMGPU/CMakeLists.txt
index 71a1a72..5e9459f 100644
--- a/iree/compiler/Conversion/LinalgToLLVMGPU/CMakeLists.txt
+++ b/iree/compiler/Conversion/LinalgToLLVMGPU/CMakeLists.txt
@@ -22,6 +22,7 @@
"ConvertToNVVM.cpp"
"ConvertToROCDL.cpp"
"KernelConfig.cpp"
+ "LowerExecutableTargetPass.cpp"
"Passes.cpp"
"RemoveTrivialLoops.cpp"
"TileAndDistribute.cpp"
diff --git a/iree/compiler/Conversion/LinalgToLLVMGPU/KernelConfig.cpp b/iree/compiler/Conversion/LinalgToLLVMGPU/KernelConfig.cpp
index 7ec713f..224398e 100644
--- a/iree/compiler/Conversion/LinalgToLLVMGPU/KernelConfig.cpp
+++ b/iree/compiler/Conversion/LinalgToLLVMGPU/KernelConfig.cpp
@@ -16,18 +16,40 @@
static constexpr unsigned cudaWarpSize = 32;
-static LaunchConfig getOpLaunchConfig(linalg::GenericOp op) {
- LaunchConfig config;
+static void setConfig(TileSizesListType tileSizes, Operation* op) {
+ IREE::HAL::LoweringConfig config =
+ buildConfigAttr(tileSizes, ArrayRef<int64_t>{}, op->getContext());
+ setLoweringConfig(op, config);
+}
+
+/// Sets the translation info on the `hal.executable.entry_point` op
+/// corresponding to the `entryPointFn`. Returns failure if a translation info
+/// is already set on the entry point op and is incompatible with what is being
+/// set.
+static LogicalResult setTranslationInfo(
+ FuncOp entryPointFn, IREE::HAL::DispatchLoweringPassPipeline passPipeline,
+ ArrayRef<int64_t> workloadPerWorkgroup) {
+ auto entryPointOp = getEntryPoint(entryPointFn);
+ auto translationInfo = buildTranslationInfo(
+ passPipeline, workloadPerWorkgroup, entryPointFn.getContext());
+ return setTranslationInfo(entryPointOp, translationInfo);
+}
+
+static LogicalResult setRootConfig(FuncOp entryPoint, linalg::GenericOp op) {
+ IREE::HAL::DispatchLoweringPassPipeline passPipeline =
+ IREE::HAL::DispatchLoweringPassPipeline::LLVMGPUVectorize;
+ std::array<int64_t, 3> workgroupSize = {1, 1, 1};
+ TileSizesListType tileSizes;
size_t numLoops = getNumOuterParallelLoops(op);
if (numLoops == 0) {
// Pure reduction, we serialize the operation on a single thread.
// TODO: Use atomic to allow distributing reduction loops.
- config.setWorkgroupSize({1, 1, 1});
- config.setTileSizes(op, {}, 0);
- return config;
+ tileSizes.push_back({});
+ setConfig(tileSizes, op);
+ return setTranslationInfo(entryPoint, passPipeline, workgroupSize);
}
- config.setWorkgroupSize({cudaWarpSize, 1, 1});
+ workgroupSize = {cudaWarpSize, 1, 1};
// Pick a fixed tile size independent of the original shape.
// TODO(thomasraoux): Currently the original shape information is lost during
// tiling at the flow level. We need way to access it to be able to make a
@@ -36,82 +58,101 @@
SmallVector<int64_t, 4> ts;
ts.resize(numLoops, 1);
ts.back() = lowerTs;
- config.setTileSizes(op, ts, 0); // Workgroup level.
- config.setTileSizes(op, {}, 1); // Subgroup level.
+ tileSizes.push_back(ts); // Workgroup level.
+ tileSizes.push_back({}); // Subgroup level.
ts.back() = lowerTs / cudaWarpSize;
- config.setTileSizes(op, ts, 2); // Thread level.
- return config;
+ tileSizes.push_back(ts); // Thread level.
+ setConfig(tileSizes, op);
+ return setTranslationInfo(entryPoint, passPipeline, workgroupSize);
}
-static LaunchConfig getOpLaunchConfig(linalg::MatmulOp op) {
- LaunchConfig config;
+static LogicalResult setRootConfig(FuncOp entryPoint, linalg::MatmulOp op) {
+ IREE::HAL::DispatchLoweringPassPipeline passPipeline =
+ IREE::HAL::DispatchLoweringPassPipeline::LLVMGPUVectorize;
+ TileSizesListType tileSizes;
const int64_t numWarp = 2;
- std::array<int64_t, 3> workgroupSize = {numWarp * cudaWarpSize, 1, 1};
- config.setWorkgroupSize(workgroupSize);
+ SmallVector<int64_t, 3> workgroupSize = {numWarp * cudaWarpSize, 1, 1};
// Currently just a basic tile size to enable tiling and vectorization.
// TODO: pick a more efficient tile size and tile at subgroup level.
SmallVector<int64_t, 4> ts = {2, 256, 4};
- config.setTileSizes(op, ts, 0); // Workgroup level.
- config.setTileSizes(op, {}, 1); // Subgroup level.
+ tileSizes.push_back(ts); // Workgroup level.
+ tileSizes.push_back({}); // Subgroup level.
SmallVector<int64_t, 4> invocationLevelTs = {ts[0] / workgroupSize[1],
ts[1] / workgroupSize[0]};
- config.setTileSizes(op, invocationLevelTs, 2); // Thread level.
- return config;
+ tileSizes.push_back(invocationLevelTs); // Thread level.
+ setConfig(tileSizes, op);
+ return setTranslationInfo(entryPoint, passPipeline, workgroupSize);
}
-static LaunchConfig getOpLaunchConfig(linalg::BatchMatmulOp op) {
- LaunchConfig config;
+static LogicalResult setRootConfig(FuncOp entryPoint,
+ linalg::BatchMatmulOp op) {
+ IREE::HAL::DispatchLoweringPassPipeline passPipeline =
+ IREE::HAL::DispatchLoweringPassPipeline::LLVMGPUVectorize;
+ TileSizesListType tileSizes;
const int64_t numWarp = 2;
- std::array<int64_t, 3> workgroupSize = {numWarp * cudaWarpSize, 1, 1};
- config.setWorkgroupSize(workgroupSize);
+ SmallVector<int64_t, 3> workgroupSize = {numWarp * cudaWarpSize, 1, 1};
SmallVector<int64_t, 4> ts = {1, 2, 256, 4};
- config.setTileSizes(op, ts, 0); // Workgroup level.
- config.setTileSizes(op, {}, 1); // Subgroup level.
+ tileSizes.push_back(ts); // Workgroup level.
+ tileSizes.push_back({}); // Subgroup level.
SmallVector<int64_t, 4> invocationLevelTs = {ts[0], ts[1] / workgroupSize[1],
ts[2] / workgroupSize[0]};
- config.setTileSizes(op, invocationLevelTs, 2); // Thread level.
- return config;
+ tileSizes.push_back(invocationLevelTs); // Thread level.
+ setConfig(tileSizes, op);
+ return setTranslationInfo(entryPoint, passPipeline, workgroupSize);
}
// Basic default properties for linalg ops that haven't been tuned.
-static LaunchConfig getDefaultOpLaunchConfig(linalg::LinalgOp op) {
- LaunchConfig config;
+static LogicalResult setRootDefaultConfig(FuncOp entryPoint,
+ linalg::LinalgOp op) {
+ IREE::HAL::DispatchLoweringPassPipeline passPipeline =
+ IREE::HAL::DispatchLoweringPassPipeline::LLVMGPUDistribute;
+ std::array<int64_t, 3> workgroupSize = {1, 1, 1};
+ TileSizesListType tileSizes;
size_t numLoops = getNumOuterParallelLoops(op);
- if (numLoops == 0) return config;
+ if (numLoops == 0) {
+ return setTranslationInfo(entryPoint, passPipeline, workgroupSize);
+ }
- config.setWorkgroupSize({cudaWarpSize, 1, 1});
+ workgroupSize = {cudaWarpSize, 1, 1};
int64_t lowerTs = 4 * cudaWarpSize;
SmallVector<int64_t, 4> ts;
ts.resize(numLoops, 1);
ts.back() = lowerTs;
- config.setTileSizes(op, ts, 0); // Workgroup level.
- config.setTileSizes(op, {}, 1); // Subgroup level.
+ tileSizes.push_back(ts);
+ tileSizes.push_back({}); // Subgroup level.
ts.back() = lowerTs / cudaWarpSize;
- config.setTileSizes(op, ts, 2); // Thread level.
- return config;
+ tileSizes.push_back(ts); // Thread level.
+ setConfig(tileSizes, op);
+ return setTranslationInfo(entryPoint, passPipeline, workgroupSize);
}
-static LaunchConfig getOpLaunchConfig(linalg::LinalgOp linalgOp) {
+static LogicalResult setRootConfig(FuncOp entryPointFn,
+ linalg::LinalgOp linalgOp) {
if (auto genericOp = dyn_cast<linalg::GenericOp>(linalgOp.getOperation()))
- return getOpLaunchConfig(genericOp);
+ return setRootConfig(entryPointFn, genericOp);
if (auto matmul = dyn_cast<linalg::MatmulOp>(linalgOp.getOperation()))
- return getOpLaunchConfig(matmul);
+ return setRootConfig(entryPointFn, matmul);
if (auto batchMatmul =
dyn_cast<linalg::BatchMatmulOp>(linalgOp.getOperation()))
- return getOpLaunchConfig(batchMatmul);
- return getDefaultOpLaunchConfig(linalgOp);
+ return setRootConfig(entryPointFn, batchMatmul);
+ return setRootDefaultConfig(entryPointFn, linalgOp);
}
namespace mlir {
namespace iree_compiler {
-Optional<LaunchConfig> getLLVMGPULaunchConfig(
- MLIRContext *context, const linalg::LinalgDependenceGraph &dependenceGraph,
- ArrayRef<linalg::LinalgOp> linalgOps) {
- LaunchConfig launchConfig;
-
+LogicalResult initGPULaunchConfig(ModuleOp moduleOp) {
linalg::LinalgOp rootOperation;
- if (linalgOps.empty()) return llvm::None;
+ auto funcOps = moduleOp.getOps<FuncOp>();
+ assert(llvm::hasSingleElement(funcOps));
+ FuncOp funcOp = *funcOps.begin();
+ SmallVector<linalg::LinalgOp, 4> linalgOps;
+ funcOp.walk([&](linalg::LinalgOp op) { linalgOps.push_back(op); });
+ if (linalgOps.empty()) {
+ return ::setTranslationInfo(
+ funcOp, IREE::HAL::DispatchLoweringPassPipeline::LLVMGPUDistribute,
+ {1, 1, 1});
+ }
if (linalgOps.size() == 1) rootOperation = *linalgOps.begin();
// if there is more than one linalg op, look for the root one.
for (linalg::LinalgOp op : linalgOps) {
@@ -142,14 +183,13 @@
}
}
}
- launchConfig = getOpLaunchConfig(rootOperation);
- launchConfig.setRootOperation(rootOperation.getOperation());
- if (!launchConfig.getTileSizes(rootOperation, 0).empty()) {
- if (failed(propogateRootOperationLaunchConfig(launchConfig, rootOperation,
- dependenceGraph)))
- return llvm::None;
+ if (failed(setRootConfig(funcOp, rootOperation))) return failure();
+ IREE::HAL::LoweringConfig config = getLoweringConfig(rootOperation);
+ for (linalg::LinalgOp op : linalgOps) {
+ if (op == rootOperation) continue;
+ setLoweringConfig(op, config);
}
- return launchConfig;
+ return success();
}
} // namespace iree_compiler
diff --git a/iree/compiler/Conversion/LinalgToLLVMGPU/KernelConfig.h b/iree/compiler/Conversion/LinalgToLLVMGPU/KernelConfig.h
index 0711fd9..f10af2f 100644
--- a/iree/compiler/Conversion/LinalgToLLVMGPU/KernelConfig.h
+++ b/iree/compiler/Conversion/LinalgToLLVMGPU/KernelConfig.h
@@ -7,16 +7,13 @@
#ifndef IREE_COMPILER_CONVERSION_LINALGTOLLVMGPU_KERNELCONFIG_H_
#define IREE_COMPILER_CONVERSION_LINALGTOLLVMGPU_KERNELCONFIG_H_
-#include "iree/compiler/Conversion/Common/LaunchConfig.h"
-#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
-#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "iree/compiler/Dialect/HAL/IR/LoweringConfig.h"
+#include "mlir/IR/BuiltinOps.h"
namespace mlir {
namespace iree_compiler {
-Optional<LaunchConfig> getLLVMGPULaunchConfig(
- MLIRContext *context, const linalg::LinalgDependenceGraph &dependenceGraph,
- ArrayRef<linalg::LinalgOp> linalgOps);
+LogicalResult initGPULaunchConfig(ModuleOp moduleOp);
} // namespace iree_compiler
} // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToLLVMGPU/LowerExecutableTargetPass.cpp b/iree/compiler/Conversion/LinalgToLLVMGPU/LowerExecutableTargetPass.cpp
new file mode 100644
index 0000000..45ebece
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToLLVMGPU/LowerExecutableTargetPass.cpp
@@ -0,0 +1,125 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Conversion/CodegenUtils/FunctionUtils.h"
+#include "iree/compiler/Conversion/Common/Passes.h"
+#include "iree/compiler/Conversion/LinalgToLLVMGPU/KernelConfig.h"
+#include "iree/compiler/Conversion/LinalgToLLVMGPU/Passes.h"
+#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Pass/PassRegistry.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+/// Lowers an hal.executable.target operation to scalar/native-vector
+/// code. Invokes different compilation pipeline to
+/// - first lower to scalar/native-vector code
+/// - then convert to NVVM/ROCDL dialect.
+/// This should be merged with the equivalent pass in LinalgToLLVM. Fo
+/// simplicity it is currently a separate pass.
+class LowerExecutableTargetPass
+ : public PassWrapper<LowerExecutableTargetPass,
+ OperationPass<IREE::HAL::ExecutableTargetOp>> {
+ public:
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<IREE::HAL::HALDialect, linalg::LinalgDialect,
+ vector::VectorDialect, gpu::GPUDialect>();
+ }
+
+ LowerExecutableTargetPass() = default;
+ LowerExecutableTargetPass(const LowerExecutableTargetPass &pass) = default;
+
+ void runOnOperation() override;
+};
+} // namespace
+
+void LowerExecutableTargetPass::runOnOperation() {
+ IREE::HAL::ExecutableTargetOp targetOp = getOperation();
+ ModuleOp moduleOp = targetOp.getInnerModule();
+
+ OpPassManager executableLoweringPipeline(
+ IREE::HAL::ExecutableTargetOp::getOperationName());
+
+ if (failed(initGPULaunchConfig(moduleOp))) {
+ return signalPassFailure();
+ }
+ // There might be multiple entry points in the module. Currently, all of
+ // them need to have the same pipeline.
+ // TODO(ravishankarm): This is strange that this is not enforced
+ // structurally, but something to address later on. For now this restriction
+ // is fine.
+ llvm::StringMap<IREE::HAL::ExecutableEntryPointOp> entryPoints =
+ getAllEntryPoints(moduleOp);
+ Optional<IREE::HAL::DispatchLoweringPassPipeline> passPipeline;
+ SmallVector<int64_t, 4> workloadPerWorkgroup;
+ for (auto &it : entryPoints) {
+ auto entryPointOp = it.second;
+ if (IREE::HAL::TranslationInfo translationInfo =
+ getTranslationInfo(entryPointOp)) {
+ IREE::HAL::DispatchLoweringPassPipeline currPipeline =
+ translationInfo.passPipeline().getValue();
+ if (ArrayAttr workloadPerWorkgroupAttr =
+ translationInfo.workloadPerWorkgroup()) {
+ workloadPerWorkgroup = llvm::to_vector<4>(llvm::map_range(
+ workloadPerWorkgroupAttr,
+ [](Attribute attr) { return attr.cast<IntegerAttr>().getInt(); }));
+ }
+ if (passPipeline) {
+ if (currPipeline != passPipeline.getValue()) {
+ moduleOp.emitError(
+ "unhandled compilation of entry point function with different "
+ "pass pipelines within a module");
+ return signalPassFailure();
+ }
+ continue;
+ }
+ passPipeline = currPipeline;
+ }
+ }
+ auto funcOps = moduleOp.getOps<FuncOp>();
+ FuncOp funcOp = *funcOps.begin();
+ // Attach the workgroup size as an attribute. This will be used when
+ // creating the flatbuffer.
+ funcOp->setAttr(
+ "llvmgpu_workgroup_size",
+ DenseElementsAttr::get<int64_t>(
+ VectorType::get(3, IntegerType::get(moduleOp.getContext(), 64)),
+ workloadPerWorkgroup));
+
+ switch (*passPipeline) {
+ case IREE::HAL::DispatchLoweringPassPipeline::LLVMGPUDistribute:
+ addGPUSimpleDistributePassPipeline(executableLoweringPipeline);
+ break;
+ case IREE::HAL::DispatchLoweringPassPipeline::LLVMGPUVectorize:
+ addGPUVectorizationPassPipeline(executableLoweringPipeline);
+ break;
+ default:
+ llvm_unreachable("Unsupported pipeline on GPU target.");
+ }
+
+ if (failed(runPipeline(executableLoweringPipeline, targetOp))) {
+ return signalPassFailure();
+ }
+}
+
+std::unique_ptr<OperationPass<IREE::HAL::ExecutableTargetOp>>
+createLowerExecutableTargetGPUPass() {
+ return std::make_unique<LowerExecutableTargetPass>();
+}
+
+static PassRegistration<LowerExecutableTargetPass> pass(
+ "iree-lower-executable-target-gpu-pass",
+ "Perform lowering of executable target using one of the "
+ "IREE::HAL::DispatchLoweringPassPipeline",
+ [] { return std::make_unique<LowerExecutableTargetPass>(); });
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToLLVMGPU/Passes.cpp b/iree/compiler/Conversion/LinalgToLLVMGPU/Passes.cpp
index 86ef6b5..348526c 100644
--- a/iree/compiler/Conversion/LinalgToLLVMGPU/Passes.cpp
+++ b/iree/compiler/Conversion/LinalgToLLVMGPU/Passes.cpp
@@ -21,7 +21,17 @@
namespace mlir {
namespace iree_compiler {
-static void addLinalgToLLVMGPUPasses(OpPassManager &pm, bool useROCM) {
+static Value gpuAllocationFunction(OpBuilder &builder, Location loc,
+ ArrayRef<int64_t> staticShape,
+ Type elementType,
+ ArrayRef<Value> dynamicSizes) {
+ MemRefType allocType = MemRefType::get(staticShape, elementType, {}, 3);
+ return builder.create<memref::AllocOp>(loc, allocType, dynamicSizes);
+}
+
+void addGPUVectorizationPassPipeline(OpPassManager &pm) {
+ // Convert tensor to buffers.
+ addLinalgBufferizePasses(pm.nest<ModuleOp>(), gpuAllocationFunction);
//===--------------------------------------------------------------------===//
// Initial clean up.
//===--------------------------------------------------------------------===//
@@ -40,13 +50,32 @@
pm.nest<ModuleOp>().addNestedPass<FuncOp>(createVectorizationPass());
pm.nest<ModuleOp>().addNestedPass<FuncOp>(createCanonicalizerPass());
pm.nest<ModuleOp>().addNestedPass<FuncOp>(createCSEPass());
+}
+void addGPUSimpleDistributePassPipeline(OpPassManager &pm) {
+ // Convert tensor to buffers.
+ addLinalgBufferizePasses(pm.nest<ModuleOp>(), gpuAllocationFunction);
+
+ //===--------------------------------------------------------------------===//
+ // Initial clean up.
+ //===--------------------------------------------------------------------===//
+ pm.addNestedPass<ModuleOp>(createCanonicalizerPass());
+ pm.addNestedPass<ModuleOp>(createCSEPass());
+
+ // Distribute linalg onto threads within the workgroup.
+ pm.addPass(createTileAndDistributeToThreads());
+ pm.addNestedPass<ModuleOp>(createCanonicalizerPass());
+ pm.addNestedPass<ModuleOp>(createCSEPass());
+
+ pm.nest<ModuleOp>().addNestedPass<FuncOp>(
+ createRemoveSingleIterationLoopPass());
+}
+
+static void addLowerToLLVMGPUPasses(OpPassManager &pm, bool useROCM) {
pm.addNestedPass<ModuleOp>(createLowerAffinePass());
pm.addNestedPass<ModuleOp>(createCanonicalizerPass());
pm.addNestedPass<ModuleOp>(createCSEPass());
- // TODO: This currently maps to a single thread. We should share Tile and
- // distribute with other GPU backends.
// Linalg -> SCF
pm.nest<ModuleOp>().addNestedPass<FuncOp>(createConvertLinalgToLoopsPass());
pm.nest<ModuleOp>().addNestedPass<FuncOp>(createCanonicalizerPass());
@@ -76,17 +105,7 @@
}
void buildLLVMGPUTransformPassPipeline(OpPassManager &pm, bool useROCM) {
- OpPassManager &nestedModulePM = pm.nest<ModuleOp>();
- nestedModulePM.addPass(createInlinerPass());
-
- WorkgroupMemoryAllocationFn allocationFn =
- [](OpBuilder &builder, Location loc, ArrayRef<int64_t> staticShape,
- Type elementType, ArrayRef<Value> dynamicSizes) {
- MemRefType allocType = MemRefType::get(staticShape, elementType, {}, 3);
- return builder.create<memref::AllocOp>(loc, allocType, dynamicSizes);
- };
- addLinalgBufferizePasses(nestedModulePM, allocationFn);
-
+ pm.addPass(createLowerExecutableTargetGPUPass());
//===--------------------------------------------------------------------===//
// Convert Linalg ops to LLVM+NVVM/ROCDL ops.
//
@@ -94,35 +113,19 @@
// - All Linalg/Loops/GPU/Affine/Standard ops are converted away.
// - The module contains the final llvm.module ready to be serialized.
//===--------------------------------------------------------------------===//
- addLinalgToLLVMGPUPasses(pm, useROCM);
+ addLowerToLLVMGPUPasses(pm, useROCM);
}
-static PassPipelineRegistration<> linalgToNVVMPipeline(
+static PassPipelineRegistration<> LinalgNVVMPipeline(
"iree-codegen-linalg-to-nvvm-pipeline",
"Runs the progressive lowering pipeline from Linalg to NVVM",
[](OpPassManager &passManager) {
- addLinalgToLLVMGPUPasses(passManager, false);
- });
-
-static PassPipelineRegistration<> linalgToROCDLPipeline(
- "iree-codegen-linalg-to-rocdl-pipeline",
- "Runs the progressive lowering pipeline from Linalg to ROCDL",
- [](OpPassManager &passManager) {
- addLinalgToLLVMGPUPasses(passManager, true);
- });
-
-static PassPipelineRegistration<> hloToLinalgNVVMPipeline(
- "iree-codegen-hlo-to-nvvm-pipeline",
- "Runs the progressive lowering pipeline from XLA HLO to Linalg to "
- "NVVM",
- [](OpPassManager &passManager) {
buildLLVMGPUTransformPassPipeline(passManager, false);
});
-static PassPipelineRegistration<> hloToLinalgROCDLPipeline(
- "iree-codegen-hlo-to-rocdl-pipeline",
- "Runs the progressive lowering pipeline from XLA HLO to Linalg to "
- "ROCDL",
+static PassPipelineRegistration<> LinalgROCDLPipeline(
+ "iree-codegen-linalg-to-rocdl-pipeline",
+ "Runs the progressive lowering pipeline from Linalg to ROCDL",
[](OpPassManager &passManager) {
buildLLVMGPUTransformPassPipeline(passManager, true);
});
diff --git a/iree/compiler/Conversion/LinalgToLLVMGPU/Passes.h b/iree/compiler/Conversion/LinalgToLLVMGPU/Passes.h
index 39ed813..4c5aa4d 100644
--- a/iree/compiler/Conversion/LinalgToLLVMGPU/Passes.h
+++ b/iree/compiler/Conversion/LinalgToLLVMGPU/Passes.h
@@ -28,6 +28,17 @@
std::unique_ptr<OperationPass<FuncOp>> createRemoveSingleIterationLoopPass();
+/// Create pass calling the dynamic pipeline for LLVMGPU.
+std::unique_ptr<OperationPass<IREE::HAL::ExecutableTargetOp>>
+createLowerExecutableTargetGPUPass();
+
+/// Lowering calling vectorization patterns.
+void addGPUVectorizationPassPipeline(OpPassManager &passManager);
+
+/// Simple lowering only distributute linalg ops on blocks and threads. This
+/// will result in scalar operations.
+void addGPUSimpleDistributePassPipeline(OpPassManager &passManager);
+
/// Populates passes needed to lower a XLA HLO op to NVVM/ROCDL dialect via the
/// structured ops path. The pass manager `pm` in here should operate on the
/// module within the IREE::HAL::ExecutableOp.
diff --git a/iree/compiler/Conversion/LinalgToLLVMGPU/TileAndDistribute.cpp b/iree/compiler/Conversion/LinalgToLLVMGPU/TileAndDistribute.cpp
index 3ebf53d..a0cf908 100644
--- a/iree/compiler/Conversion/LinalgToLLVMGPU/TileAndDistribute.cpp
+++ b/iree/compiler/Conversion/LinalgToLLVMGPU/TileAndDistribute.cpp
@@ -9,6 +9,7 @@
#include "iree/compiler/Conversion/Common/Transforms.h"
#include "iree/compiler/Conversion/LinalgToLLVMGPU/KernelConfig.h"
#include "iree/compiler/Conversion/LinalgToLLVMGPU/Passes.h"
+#include "iree/compiler/Dialect/HAL/IR/LoweringConfig.h"
#include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
@@ -62,11 +63,11 @@
/// Patterns for thread level tiling.
static void populateTilingToInvocationPatterns(
MLIRContext *context, OwningRewritePatternList &patterns,
- const LaunchConfig &launchConfig) {
+ ArrayRef<int64_t> workgroupSize) {
linalg::TileSizeComputationFunction getInnerTileSizeFn =
- [launchConfig](OpBuilder &builder, Operation *operation) {
+ [](OpBuilder &builder, Operation *operation) {
SmallVector<Value, 4> tileSizesVal;
- ArrayRef<int64_t> tileSizes = launchConfig.getTileSizes(operation, 2);
+ SmallVector<int64_t, 4> tileSizes = getTileSizes(operation, 2);
if (tileSizes.empty()) return SmallVector<Value, 4>();
tileSizesVal.reserve(tileSizes.size());
for (auto val : llvm::enumerate(tileSizes)) {
@@ -80,11 +81,11 @@
return tileSizesVal;
};
- auto getThreadProcInfoFn = [launchConfig](
+ auto getThreadProcInfoFn = [workgroupSize](
OpBuilder &builder, Location loc,
ArrayRef<Range> parallelLoopRanges) {
return getGPUThreadIdsAndCounts(builder, loc, parallelLoopRanges.size(),
- launchConfig.getWorkgroupSize());
+ workgroupSize);
};
linalg::LinalgLoopDistributionOptions invocationDistributionOptions;
invocationDistributionOptions.procInfo = getThreadProcInfoFn;
@@ -122,10 +123,10 @@
/// picked by the root op.
static void populateTilingCopyToWorkgroupMemPatterns(
MLIRContext *context, OwningRewritePatternList &patterns,
- const LaunchConfig &launchConfig) {
+ ArrayRef<int64_t> workgroupSize) {
// Tile and distribute copy to workgroup memory.
linalg::TileSizeComputationFunction wgCopyTileSizeFn =
- [launchConfig](OpBuilder &builder, Operation *operation) {
+ [](OpBuilder &builder, Operation *operation) {
const int64_t copyTileSize = 4;
// We tile to 4 as we want each thread to load 4 element in a cyclic
// distribution.
@@ -141,7 +142,7 @@
builder.create<ConstantIndexOp>(operation->getLoc(), copyTileSize));
return tileSizesVal;
};
- auto getCopyThreadProcInfoFn = [launchConfig](
+ auto getCopyThreadProcInfoFn = [workgroupSize](
OpBuilder &builder, Location loc,
ArrayRef<Range> parallelLoopRanges) {
SmallVector<std::array<int64_t, 3>, 2> staticRanges;
@@ -158,15 +159,14 @@
staticRanges.push_back(
{cstOffset.getValue(), cstSize.getValue(), cstStride.getValue()});
}
- ArrayRef<int64_t> wokgroupSize = launchConfig.getWorkgroupSize();
// Only support static dimension with 1D workgroups for now. Fall back to
// the naive distribution for other cases.
- if (hasDynamicRange || wokgroupSize[1] != 1 || wokgroupSize[2] != 1)
+ if (hasDynamicRange || workgroupSize[1] != 1 || workgroupSize[2] != 1)
return getGPUThreadIdsAndCounts(builder, loc, parallelLoopRanges.size(),
- launchConfig.getWorkgroupSize());
+ workgroupSize);
Value serializedId =
builder.create<gpu::ThreadIdOp>(loc, builder.getIndexType(), "x");
- int64_t numIds = wokgroupSize[0];
+ int64_t numIds = workgroupSize[0];
int numDims = parallelLoopRanges.size();
SmallVector<linalg::ProcInfo, 2> procInfo(numDims);
assert(numDims <= kNumGPUDims);
@@ -337,32 +337,25 @@
if (failed(getLinalgOps(funcOp, linalgOps, tiledLoops))) {
return signalPassFailure();
}
- linalg::Aliases aliases;
- linalg::LinalgDependenceGraph dependenceGraph(aliases, linalgOps);
- auto config = getLLVMGPULaunchConfig(context, dependenceGraph, linalgOps);
- if (!config) return signalPassFailure();
-
- // Attach the workgroup size as an attribute. This will be used when
- // creating the flatbuffer.
- funcOp->setAttr("llvmgpu_workgroup_size",
- DenseElementsAttr::get<int64_t>(
- VectorType::get(3, IntegerType::get(context, 64)),
- config->getWorkgroupSize()));
-
- Operation *rootOp = config->getRootOperation(llvm::to_vector<4>(
- llvm::map_range(linalgOps, [](linalg::LinalgOp op) {
- return op.getOperation();
- })));
- SmallVector<int64_t, 4> wgTileSize =
- llvm::to_vector<4>(config->getTileSizes(rootOp, 0));
+ if (linalgOps.empty() || !hasLoweringConfig(*linalgOps.begin())) return;
+ std::array<int64_t, 3> workgroupSize;
+ for (auto it : llvm::enumerate(funcOp->getAttr("llvmgpu_workgroup_size")
+ .cast<DenseIntElementsAttr>()
+ .getIntValues())) {
+ workgroupSize[it.index()] = it.value().getZExtValue();
+ }
+ IREE::HAL::LoweringConfig config = getLoweringConfig(*linalgOps.begin());
+ SmallVector<int64_t, 4> wgTileSize = getTileSizes(config, 0);
// If there is no tile size, skip tiling.
if (wgTileSize.empty()) return;
- unsigned numOuterParallelLoops =
- getNumOuterParallelLoops(cast<linalg::LinalgOp>(rootOp));
- size_t numContractionLoops =
- wgTileSize.size() > numOuterParallelLoops
- ? wgTileSize.size() - numOuterParallelLoops
- : 0;
+ unsigned numOuterParallelLoops = 0;
+ unsigned numContractionLoops = 0;
+ for (linalg::LinalgOp linalgOp : linalgOps) {
+ numOuterParallelLoops =
+ std::max(getNumOuterParallelLoops(linalgOp), numOuterParallelLoops);
+ numContractionLoops =
+ std::max(linalgOp.getNumReductionLoops(), numContractionLoops);
+ }
size_t numTilableDims =
std::min(kWorkgroupDimCount, numOuterParallelLoops);
wgTileSize.resize(numTilableDims);
@@ -386,7 +379,7 @@
// same size.
OwningRewritePatternList wgTilingPatterns(context);
populateTilingReductionPatterns(context, wgTilingPatterns,
- config->getTileSizes(rootOp, 0));
+ getTileSizes(config, 0));
(void)applyPatternsAndFoldGreedily(funcOp, std::move(wgTilingPatterns));
applyCanonicalizationPatternsForTiling(context, funcOp);
}
@@ -411,9 +404,9 @@
// Apply last level of tiling and distribute to threads.
OwningRewritePatternList threadLevelTilingPatterns(context);
populateTilingToInvocationPatterns(context, threadLevelTilingPatterns,
- *config);
+ workgroupSize);
populateTilingCopyToWorkgroupMemPatterns(
- context, threadLevelTilingPatterns, *config);
+ context, threadLevelTilingPatterns, workgroupSize);
(void)applyPatternsAndFoldGreedily(
funcOp, std::move(threadLevelTilingPatterns));
applyCanonicalizationPatternsForTiling(context, funcOp);
diff --git a/iree/compiler/Conversion/LinalgToLLVMGPU/test/distribute_to_thread.mlir b/iree/compiler/Conversion/LinalgToLLVMGPU/test/distribute_to_thread.mlir
index edb214c..b92cf7d 100644
--- a/iree/compiler/Conversion/LinalgToLLVMGPU/test/distribute_to_thread.mlir
+++ b/iree/compiler/Conversion/LinalgToLLVMGPU/test/distribute_to_thread.mlir
@@ -4,7 +4,7 @@
hal.executable.target @cuda, filter="cuda" {
hal.executable.entry_point @add_dispatch_0 attributes {interface = @io, ordinal = 0 : index}
module {
- func @add_dispatch_0() {
+ func @add_dispatch_0() attributes {llvmgpu_workgroup_size = dense<[32, 1, 1]> : vector<3xi64>} {
%c0 = constant 0 : index
%c1024 = constant 1024 : index
%0 = hal.interface.binding.subspan @io::@ro0[%c0] : memref<1024xf32>
@@ -20,7 +20,7 @@
%6 = memref.subview %0[%arg0] [%5] [1] : memref<1024xf32> to memref<?xf32, affine_map<(d0)[s0] -> (d0 + s0)>>
%7 = memref.subview %1[%arg0] [%5] [1] : memref<1024xf32> to memref<?xf32, affine_map<(d0)[s0] -> (d0 + s0)>>
%8 = memref.subview %2[%arg0] [%5] [1] : memref<1024xf32> to memref<?xf32, affine_map<(d0)[s0] -> (d0 + s0)>>
- linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%6, %7 : memref<?xf32, affine_map<(d0)[s0] -> (d0 + s0)>>, memref<?xf32, affine_map<(d0)[s0] -> (d0 + s0)>>) outs(%8 : memref<?xf32, affine_map<(d0)[s0] -> (d0 + s0)>>) attrs = {__internal_linalg_transform__ = "workgroup"} {
+ linalg.generic {lowering.config = {tileSizes = [[128], [], [4]]}, indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%6, %7 : memref<?xf32, affine_map<(d0)[s0] -> (d0 + s0)>>, memref<?xf32, affine_map<(d0)[s0] -> (d0 + s0)>>) outs(%8 : memref<?xf32, affine_map<(d0)[s0] -> (d0 + s0)>>) attrs = {__internal_linalg_transform__ = "workgroup"} {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32): // no predecessors
%9 = addf %arg1, %arg2 : f32
linalg.yield %9 : f32
@@ -44,7 +44,7 @@
hal.executable.target @cuda, filter="cuda" {
hal.executable.entry_point @dot_dispatch_0 attributes {interface = @legacy_io, ordinal = 0 : index}
module {
- func @dot_dispatch_0() {
+ func @dot_dispatch_0() attributes {llvmgpu_workgroup_size = dense<[64, 1, 1]> : vector<3xi64>} {
%cst = constant 0.000000e+00 : f32
%c0 = constant 0 : index
%c1024 = constant 1024 : index
@@ -68,8 +68,8 @@
%9 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 1024)>(%arg1)[%workgroup_size_x]
%10 = memref.subview %1[0, %arg1] [1024, %9] [1, 1] : memref<1024x1024xf32> to memref<1024x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>
%11 = memref.subview %2[%arg0, %arg1] [%7, %9] [1, 1] : memref<1024x1024xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>
- linalg.fill(%11, %cst) {__internal_linalg_transform__ = "workgroup"} : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>, f32
- linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%8, %10 : memref<?x1024xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>, memref<1024x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>) outs(%11 : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>)
+ linalg.fill(%11, %cst) {__internal_linalg_transform__ = "workgroup", lowering.config = {tileSizes = [[2, 256, 4], [], [2, 4]]}} : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>, f32
+ linalg.matmul {__internal_linalg_transform__ = "workgroup", lowering.config = {tileSizes = [[2, 256, 4], [], [2, 4]]}} ins(%8, %10 : memref<?x1024xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>, memref<1024x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>) outs(%11 : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>)
}
}
return
@@ -115,7 +115,7 @@
// CHECK-DAG: %[[A:.+]] = memref.subview %17[%[[IND0]], 0] [2, 4] [1, 1] : memref<2x4xf32, #{{.*}}, 3> to memref<2x4xf32, #{{.*}}, 3>
// CHECK-DAG: %[[B:.+]] = memref.subview %18[0, %[[IND1]]] [4, 4] [1, 1] : memref<4x256xf32, #{{.*}}, 3> to memref<4x4xf32, #{{.*}}, 3>
// CHECK-DAG: %[[C:.+]] = memref.subview %11[%[[IND0]], %[[IND1]]] [2, 4] [1, 1] : memref<2x256xf32, #{{.*}}> to memref<2x4xf32, #{{.*}}>
-// CHECK: linalg.matmul {__internal_linalg_transform__ = "vectorize", is_root_op, launch_info_key = "__op_num_0__"} ins(%[[A]], %[[B]] : memref<2x4xf32, #{{.*}}, 3>, memref<4x4xf32, #{{.*}}, 3>) outs(%[[C]] : memref<2x4xf32, #{{.*}}>)
+// CHECK: linalg.matmul {__internal_linalg_transform__ = "vectorize", {{.*}}} ins(%[[A]], %[[B]] : memref<2x4xf32, #{{.*}}, 3>, memref<4x4xf32, #{{.*}}, 3>) outs(%[[C]] : memref<2x4xf32, #{{.*}}>)
// CHECK: }
// CHECK: }
@@ -126,7 +126,7 @@
hal.executable.target @cuda, filter="cuda" {
hal.executable.entry_point @predict_dispatch_153 attributes {interface = @io, ordinal = 0 : index}
module {
- func @predict_dispatch_153() {
+ func @predict_dispatch_153() attributes {llvmgpu_workgroup_size = dense<[32, 1, 1]> : vector<3xi64>} {
%c0 = constant 0 : index
%cst = constant 0x7FC00000 : f32
%cst_0 = constant 0xFF800000 : f32
diff --git a/iree/compiler/Conversion/LinalgToLLVMGPU/test/nvvm_pipeline_test.mlir b/iree/compiler/Conversion/LinalgToLLVMGPU/test/nvvm_pipeline_test.mlir
index d55ddbb..65d1f54 100644
--- a/iree/compiler/Conversion/LinalgToLLVMGPU/test/nvvm_pipeline_test.mlir
+++ b/iree/compiler/Conversion/LinalgToLLVMGPU/test/nvvm_pipeline_test.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-hlo-to-nvvm-pipeline))" %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-linalg-to-nvvm-pipeline))" %s | IreeFileCheck %s
// Verify that a simple element wise op gets lowered succefully all the way to
// nvvm/llvm dialect.
diff --git a/iree/compiler/Conversion/LinalgToLLVMGPU/test/rocdl_pipeline_test.mlir b/iree/compiler/Conversion/LinalgToLLVMGPU/test/rocdl_pipeline_test.mlir
index 83e42b9..f56b022 100644
--- a/iree/compiler/Conversion/LinalgToLLVMGPU/test/rocdl_pipeline_test.mlir
+++ b/iree/compiler/Conversion/LinalgToLLVMGPU/test/rocdl_pipeline_test.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-hlo-to-rocdl-pipeline))" %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-linalg-to-rocdl-pipeline))" %s | IreeFileCheck %s
// Verify that a simple element wise op gets lowered succefully all the way to
// nvvm/llvm dialect.
diff --git a/iree/compiler/Conversion/init_conversions.h b/iree/compiler/Conversion/init_conversions.h
index 7a74a17..d23c3d8 100644
--- a/iree/compiler/Conversion/init_conversions.h
+++ b/iree/compiler/Conversion/init_conversions.h
@@ -63,6 +63,7 @@
createLinalgTileAndVectorizeWorkgroupsPass();
createUnfusedFMAOpsPass();
createPadLinalgWorkgroupTilesPass();
+ createTilePadAndVectorizeWorkgroupsPass();
return true;
}();
(void)init_once;
diff --git a/iree/compiler/Dialect/Flow/Transforms/BUILD b/iree/compiler/Dialect/Flow/Transforms/BUILD
index 3359517..a78165f 100644
--- a/iree/compiler/Dialect/Flow/Transforms/BUILD
+++ b/iree/compiler/Dialect/Flow/Transforms/BUILD
@@ -81,6 +81,7 @@
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:SCFToStandard",
"@llvm-project//mlir:Shape",
+ "@llvm-project//mlir:ShapeToStandard",
"@llvm-project//mlir:ShapeTransforms",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
diff --git a/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt b/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
index 729a12e..39fc71b 100644
--- a/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
+++ b/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
@@ -57,6 +57,7 @@
MLIRSCFToStandard
MLIRShape
MLIRShapeOpsTransforms
+ MLIRShapeToStandard
MLIRStandard
MLIRSupport
MLIRTensor
diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
index 8130008..833ac10b 100644
--- a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
@@ -14,6 +14,7 @@
#include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
+#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
#include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
#include "mlir/Conversion/TosaToSCF/TosaToSCF.h"
#include "mlir/Conversion/TosaToStandard/TosaToStandard.h"
@@ -68,6 +69,14 @@
// TODO: Currently recurses into SCF in Linalg generic - with hilarity.
passManager.addNestedPass<FuncOp>(mlir::createLowerToCFGPass());
+ // Various shape functions may have been materialized in the `shape.shape_of`
+ // style of treating shapes as tensors. We prefer to legalize these to
+ // scalar ops as early as possible to avoid having them persist as tensor
+ // computations.
+ passManager.addNestedPass<FuncOp>(createShapeToShapeLowering());
+ passManager.addPass(createConvertShapeToStandardPass());
+ passManager.addNestedPass<FuncOp>(mlir::createCanonicalizerPass());
+
// Now that control flow has been lowered, promote and extract_element
// to tensor loads. This will be done again later once everything that can
// be is lowered to device.
diff --git a/iree/compiler/Dialect/Flow/Transforms/VerifyCompilerInputLegality.cpp b/iree/compiler/Dialect/Flow/Transforms/VerifyCompilerInputLegality.cpp
index ca21ea7..323624d 100644
--- a/iree/compiler/Dialect/Flow/Transforms/VerifyCompilerInputLegality.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/VerifyCompilerInputLegality.cpp
@@ -8,6 +8,7 @@
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
+#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
@@ -31,6 +32,7 @@
conversionTarget.addIllegalDialect<tosa::TosaDialect>();
conversionTarget.addIllegalDialect<mhlo::MhloDialect>();
conversionTarget.addIllegalDialect<chlo::HloClientDialect>();
+ conversionTarget.addIllegalDialect<mlir::shape::ShapeDialect>();
// Exception: ApplyScaleOp is actually a lowered op on par with standard
// dialect.
diff --git a/iree/compiler/Dialect/HAL/IR/LoweringConfig.cpp b/iree/compiler/Dialect/HAL/IR/LoweringConfig.cpp
index f1af230..52f5757 100644
--- a/iree/compiler/Dialect/HAL/IR/LoweringConfig.cpp
+++ b/iree/compiler/Dialect/HAL/IR/LoweringConfig.cpp
@@ -10,6 +10,7 @@
#include "mlir/Dialect/StandardOps/IR/Ops.h"
static const char kConfigAttrName[] = "lowering.config";
+static const char kTranslationInfoAttrName[] = "translation.info";
#include "iree/compiler/Dialect/HAL/IR/LoweringConfig.cpp.inc"
#include "iree/compiler/Dialect/HAL/IR/LoweringConfigEnums.cpp.inc"
@@ -18,6 +19,47 @@
namespace iree_compiler {
//===----------------------------------------------------------------------===//
+// Helpers for getting/setting information needed to lower an executable. These
+// are information that are stored as attributes on the
+// `hal.executable.entry_point`
+//===----------------------------------------------------------------------===//
+
+IREE::HAL::TranslationInfo buildTranslationInfo(
+ IREE::HAL::DispatchLoweringPassPipeline passPipeline,
+ ArrayRef<int64_t> workloadPerWorkgroup, MLIRContext *context) {
+ OpBuilder builder(context);
+ auto pipelineAttr =
+ IREE::HAL::DispatchLoweringPassPipelineAttr::get(context, passPipeline);
+ ArrayAttr workloadPerWorkgroupAttr =
+ builder.getI64ArrayAttr(workloadPerWorkgroup);
+ return IREE::HAL::TranslationInfo::get(pipelineAttr, workloadPerWorkgroupAttr,
+ context);
+}
+
+IREE::HAL::TranslationInfo getTranslationInfo(
+ IREE::HAL::ExecutableEntryPointOp entryPointOp) {
+ return entryPointOp->getAttrOfType<IREE::HAL::TranslationInfo>(
+ kTranslationInfoAttrName);
+}
+
+LogicalResult setTranslationInfo(IREE::HAL::ExecutableEntryPointOp entryPointOp,
+ IREE::HAL::TranslationInfo translationInfo) {
+ auto currTranslationAttr =
+ entryPointOp->getAttrOfType<IREE::HAL::TranslationInfo>(
+ kTranslationInfoAttrName);
+ if (currTranslationAttr) {
+ if (currTranslationAttr != translationInfo) {
+ return entryPointOp.emitError(
+ "illegal to override set translation information");
+ }
+ }
+ if (!currTranslationAttr) {
+ entryPointOp->setAttr(kTranslationInfoAttrName, translationInfo);
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// Helpers for getting/setting the `hal.lowering.*` attributes that drive the
// linalg-based lowering.
// ===----------------------------------------------------------------------===//
@@ -42,9 +84,9 @@
// Helpers for accessing values from the LoweringConfig attribute.
//===----------------------------------------------------------------------===//
-IREE::HAL::LoweringConfig getConfigAttr(TileSizesListTypeRef tileSizes,
- ArrayRef<int64_t> nativeVectorSize,
- MLIRContext *context) {
+IREE::HAL::LoweringConfig buildConfigAttr(TileSizesListTypeRef tileSizes,
+ ArrayRef<int64_t> nativeVectorSize,
+ MLIRContext *context) {
OpBuilder builder(context);
ArrayAttr tileSizesAttr = nullptr;
if (!tileSizes.empty()) {
diff --git a/iree/compiler/Dialect/HAL/IR/LoweringConfig.h b/iree/compiler/Dialect/HAL/IR/LoweringConfig.h
index 54e7e47..315521f 100644
--- a/iree/compiler/Dialect/HAL/IR/LoweringConfig.h
+++ b/iree/compiler/Dialect/HAL/IR/LoweringConfig.h
@@ -15,13 +15,14 @@
#ifndef IREE_COMPILER_CONVERSION_COMMON_LOWERINGCONFIG_H_
#define IREE_COMPILER_CONVERSION_COMMON_LOWERINGCONFIG_H_
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
// clang-format off
-#include "iree/compiler/Dialect/HAL/IR/LoweringConfig.h.inc"
#include "iree/compiler/Dialect/HAL/IR/LoweringConfigEnums.h.inc"
+#include "iree/compiler/Dialect/HAL/IR/LoweringConfig.h.inc"
// clang-format on
namespace mlir {
@@ -29,22 +30,13 @@
namespace IREE {
namespace HAL {
-/// Struct that for a given hal.target.executable defines how it is translated.
-// TODO(ravishankarm): This could also be converted to an attribute on the
-// hal.executable.target
-struct TranslateExecutableInfo {
- DispatchLoweringPassPipeline passPipeline;
- SmallVector<int64_t, 3> workgroupSize;
-};
-inline bool operator==(const TranslateExecutableInfo &lhs,
- const TranslateExecutableInfo &rhs) {
- return lhs.passPipeline == rhs.passPipeline &&
- lhs.workgroupSize == rhs.workgroupSize;
+inline bool operator==(const TranslationInfo &lhs, const TranslationInfo &rhs) {
+ return lhs.passPipeline() == rhs.passPipeline() &&
+ lhs.workloadPerWorkgroup() == rhs.workloadPerWorkgroup();
}
-inline bool operator!=(const TranslateExecutableInfo &lhs,
- const TranslateExecutableInfo &rhs) {
+inline bool operator!=(const TranslationInfo &lhs, const TranslationInfo &rhs) {
return !(lhs == rhs);
}
@@ -52,6 +44,31 @@
} // namespace IREE
//===----------------------------------------------------------------------===//
+// Helpers for getting/setting information needed to lower an executable. These
+// are information that are stored as attributes on the
+// `hal.executable.entry_point`
+//===----------------------------------------------------------------------===//
+
+/// Builder method for IREE::HAL::TranslationInfoAttr.
+IREE::HAL::TranslationInfo buildTranslationInfo(
+ IREE::HAL::DispatchLoweringPassPipeline passPipeline,
+ ArrayRef<int64_t> workloadPerWorkgroup, MLIRContext *context);
+
+/// Gets the translate executable info attribute value associated with
+/// `entryPointOp`.
+IREE::HAL::TranslationInfo getTranslationInfo(
+ IREE::HAL::ExecutableEntryPointOp entryPointOp);
+
+/// Set the translate executable info with the entry point op. Returns a failure
+/// if these have already been set for the `entryPointOp` and are incompatible
+/// with what is being set.
+// TODO(ravishankarm, benvanik): Eventually all the information needed for the
+// lowering will be consolidated into a single attribute with richer
+// information.
+LogicalResult setTranslationInfo(IREE::HAL::ExecutableEntryPointOp entryPointOp,
+ IREE::HAL::TranslationInfo translationInfo);
+
+//===----------------------------------------------------------------------===//
// Helpers for getting/setting the `hal.lowering.*` attributes that drive the
// linalg-based lowering.
// ===----------------------------------------------------------------------===//
@@ -86,9 +103,9 @@
using TileSizesListTypeRef = ArrayRef<SmallVector<int64_t, 4>>;
/// Construct a lowering configuration.
-IREE::HAL::LoweringConfig getConfigAttr(TileSizesListTypeRef tileSizes,
- ArrayRef<int64_t> nativeVectorSize,
- MLIRContext *context);
+IREE::HAL::LoweringConfig buildConfigAttr(TileSizesListTypeRef tileSizes,
+ ArrayRef<int64_t> nativeVectorSize,
+ MLIRContext *context);
/// Get the tile sizes for all levels.
TileSizesListType getTileSizes(IREE::HAL::LoweringConfig config);
diff --git a/iree/compiler/Dialect/HAL/IR/LoweringConfig.td b/iree/compiler/Dialect/HAL/IR/LoweringConfig.td
index cd7ecb1..1bf8744 100644
--- a/iree/compiler/Dialect/HAL/IR/LoweringConfig.td
+++ b/iree/compiler/Dialect/HAL/IR/LoweringConfig.td
@@ -9,17 +9,23 @@
// Putting this in HAL dialect for now.
include "iree/compiler/Dialect/HAL/IR/HALDialect.td"
+// List of pre-existing pipelines for translating executables.
def CPU_Default
: I32EnumAttrCase<"CPUDefault", 0>;
def CPU_Vectorization
: I32EnumAttrCase<"CPUVectorization", 1>;
+def LLVMGPU_SimpleDistribute
+ : I32EnumAttrCase<"LLVMGPUDistribute", 2>;
+def LLVMGPU_Vectorize
+ : I32EnumAttrCase<"LLVMGPUVectorize", 3>;
// EnumAttrCase for all known lowerings for ops within dispatch region
// to scalar/native-vector code.
def DispatchLoweringPassPipelineEnum : I32EnumAttr<
"DispatchLoweringPassPipeline",
"identifier for pass pipeline use to lower dispatch region",
- [CPU_Default, CPU_Vectorization]> {
+ [CPU_Default, CPU_Vectorization, LLVMGPU_SimpleDistribute,
+ LLVMGPU_Vectorize]> {
let cppNamespace = "::mlir::iree_compiler::IREE::HAL";
}
@@ -27,6 +33,14 @@
TypedArrayAttrBase<I64ArrayAttr,
"list of tile sizes for all levels"> { }
+// Attribute that captures information needed for translating the executables.
+def TranslationInfoAttr :
+ StructAttr<"TranslationInfo", HAL_Dialect, [
+ StructFieldAttr<"passPipeline", DispatchLoweringPassPipelineEnum>,
+ StructFieldAttr<"workloadPerWorkgroup",
+ DefaultValuedAttr<I64ArrayAttr, "{}">>,
+ ]>;
+
// Attribute that carries information needed to perform
// tiling/vectorization, etc.
def HAL_LoweringConfigAttr :
diff --git a/iree/compiler/Dialect/HAL/Target/ROCM/ROCMTarget.cpp b/iree/compiler/Dialect/HAL/Target/ROCM/ROCMTarget.cpp
index e900630..5a86b3b 100644
--- a/iree/compiler/Dialect/HAL/Target/ROCM/ROCMTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/ROCM/ROCMTarget.cpp
@@ -150,7 +150,8 @@
iree_ROCMExecutableDef_start_as_root(builder);
// Link module to Device Library
- if (options_.ROCMLinkBC) LinkROCDLIfNecessary(llvmModule.get());
+ if (options_.ROCMLinkBC)
+ LinkROCDLIfNecessary(llvmModule.get(), options_.ROCMTargetChip);
// Serialize hsaco kernel into the binary that we will embed in the
// final flatbuffer.
diff --git a/iree/compiler/Dialect/HAL/Target/ROCM/ROCMTarget.h b/iree/compiler/Dialect/HAL/Target/ROCM/ROCMTarget.h
index 962b803..5a28c06 100644
--- a/iree/compiler/Dialect/HAL/Target/ROCM/ROCMTarget.h
+++ b/iree/compiler/Dialect/HAL/Target/ROCM/ROCMTarget.h
@@ -29,7 +29,7 @@
std::function<ROCMTargetOptions()> queryOptions);
// Links LLVM module to ROC Device Library Bit Code
-void LinkROCDLIfNecessary(llvm::Module *module);
+void LinkROCDLIfNecessary(llvm::Module *module, std::string targetChip);
// Compiles ISAToHsaco Code
std::string createHsaco(const std::string isa, StringRef name);
diff --git a/iree/compiler/Dialect/HAL/Target/ROCM/ROCMTargetUtils.cpp b/iree/compiler/Dialect/HAL/Target/ROCM/ROCMTargetUtils.cpp
index c1849cc..d4b798b 100644
--- a/iree/compiler/Dialect/HAL/Target/ROCM/ROCMTargetUtils.cpp
+++ b/iree/compiler/Dialect/HAL/Target/ROCM/ROCMTargetUtils.cpp
@@ -90,10 +90,15 @@
return success();
}
-static std::vector<std::string> GetROCDLPaths() {
- // AMDGPU version-neutral bitcodes.
+static std::vector<std::string> GetROCDLPaths(std::string targetChip) {
+ // AMDGPU bitcodes.
+ int lenOfChipPrefix = 3;
+ std::string chipId = targetChip.substr(lenOfChipPrefix);
+ std::string chip_isa_bc = "oclc_isa_version_" + chipId + ".bc";
static std::vector<std::string> *rocdl_filenames =
- new std::vector<std::string>({"ocml.bc", "ockl.bc"});
+ new std::vector<std::string>({"ocml.bc", "ockl.bc",
+ "oclc_unsafe_math_off.bc",
+ "oclc_daz_opt_off.bc", chip_isa_bc});
// Construct full path to ROCDL bitcode libraries.
std::string rocdl_dir_path = "/opt/rocm/amdgcn/bitcode";
@@ -106,11 +111,12 @@
}
// Links ROCm-Device-Libs into the given module if the module needs it.
-void LinkROCDLIfNecessary(llvm::Module *module) {
+void LinkROCDLIfNecessary(llvm::Module *module, std::string targetChip) {
if (!HAL::CouldNeedDeviceBitcode(*module)) {
return;
}
- if (!succeeded(HAL::LinkWithBitcodeVector(module, GetROCDLPaths()))) {
+ if (!succeeded(
+ HAL::LinkWithBitcodeVector(module, GetROCDLPaths(targetChip)))) {
llvm::WithColor::error(llvm::errs()) << "Fail to Link ROCDL.\n";
};
}
diff --git a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/CMakeLists.txt b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/CMakeLists.txt
index 9c66971..956f58f 100644
--- a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/CMakeLists.txt
+++ b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/CMakeLists.txt
@@ -12,6 +12,7 @@
VMToEmitC
HDRS
"ConvertVMToEmitC.h"
+ "EmitCTypeConverter.h"
SRCS
"ConvertVMToEmitC.cpp"
DEPS
@@ -19,6 +20,8 @@
MLIRPass
MLIREmitC
MLIRTransforms
+ iree::compiler::Dialect::IREE::Conversion::PreserveCompilerHints
+ iree::compiler::Dialect::IREE::IR
iree::compiler::Dialect::VM::Analysis
iree::compiler::Dialect::VM::IR
INCLUDES
diff --git a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp
index 16cc258..d2dba6c 100644
--- a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp
+++ b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp
@@ -7,12 +7,14 @@
#include "iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.h"
#include "emitc/Dialect/EmitC/IR/EmitC.h"
+#include "iree/compiler/Dialect/IREE/Conversion/PreserveCompilerHints.h"
#include "iree/compiler/Dialect/IREE/IR/IREEDialect.h"
-#include "iree/compiler/Dialect/VM/Analysis/RegisterAllocation.h"
+#include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
#include "iree/compiler/Dialect/VM/IR/VMOps.h"
#include "llvm/ADT/TypeSwitch.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Matchers.h"
+#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
@@ -20,6 +22,61 @@
namespace {
+template <typename SrcOpTy>
+Optional<emitc::ApplyOp> createVmTypeDefPtr(ConversionPatternRewriter &rewriter,
+ SrcOpTy srcOp, Type elementType) {
+ auto ctx = srcOp.getContext();
+ auto loc = srcOp.getLoc();
+
+ // TODO(simon-camp): Cleanup this up
+ StringRef elementTypeConstructor;
+ std::string elementTypeConstructorArg;
+ if (auto intType = elementType.template dyn_cast<IntegerType>()) {
+ unsigned int bitWidth = intType.getIntOrFloatBitWidth();
+ elementTypeConstructor = "iree_vm_type_def_make_value_type";
+ elementTypeConstructorArg =
+ std::string("IREE_VM_VALUE_TYPE_I") + std::to_string(bitWidth);
+ } else if (auto refType =
+ elementType.template dyn_cast<IREE::VM::RefType>()) {
+ auto objType = refType.getObjectType();
+
+ elementTypeConstructor = "iree_vm_type_def_make_ref_type";
+ if (objType.template isa<IREE::VM::BufferType>()) {
+ elementTypeConstructorArg = "iree_vm_buffer_type_id()";
+ } else if (objType.template isa<IREE::VM::ListType>()) {
+ elementTypeConstructorArg = "iree_vm_list_type_id()";
+ } else {
+ srcOp.emitError() << "Unhandled ref object type " << objType;
+ return None;
+ }
+ } else if (auto opaqueType =
+ elementType.template dyn_cast<IREE::VM::OpaqueType>()) {
+ elementTypeConstructor = "iree_vm_type_def_make_variant_type";
+ elementTypeConstructorArg = "";
+ } else {
+ srcOp.emitError() << "Unhandled element type " << elementType;
+ return None;
+ }
+
+ auto elementTypeOp = rewriter.template create<emitc::CallOp>(
+ /*location=*/loc,
+ /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_type_def_t"),
+ /*callee=*/StringAttr::get(ctx, elementTypeConstructor),
+ /*args=*/
+ ArrayAttr::get(ctx,
+ {emitc::OpaqueAttr::get(ctx, elementTypeConstructorArg)}),
+ /*templateArgs=*/ArrayAttr{},
+ /*operands=*/ArrayRef<Value>{});
+
+ auto elementTypePtrOp = rewriter.template create<emitc::ApplyOp>(
+ /*location=*/loc,
+ /*result=*/emitc::OpaqueType::get(ctx, "iree_vm_type_def_t*"),
+ /*applicableOperator=*/StringAttr::get(ctx, "&"),
+ /*operand=*/elementTypeOp.getResult(0));
+
+ return elementTypePtrOp;
+}
+
/// Generate two calls which resemble the IREE_RETURN_IF_ERROR macro. We need
/// to split it here becasue we cannot produce a macro invocation with a
/// function call as argument in emitc.
@@ -55,15 +112,15 @@
}));
}
-template <typename AccessOpTy, typename GlobalOpTy>
-GlobalOpTy lookupGlobalOp(AccessOpTy accessOp) {
+template <typename AccessOpTy, typename ResultOpTy>
+ResultOpTy lookupSymbolRef(AccessOpTy accessOp, StringRef attrName) {
FlatSymbolRefAttr globalAttr =
accessOp.getOperation()->template getAttrOfType<FlatSymbolRefAttr>(
- "global");
- GlobalOpTy globalOp =
+ attrName);
+ ResultOpTy globalOp =
accessOp.getOperation()
->template getParentOfType<IREE::VM::ModuleOp>()
- .template lookupSymbol<GlobalOpTy>(globalAttr.getValue());
+ .template lookupSymbol<ResultOpTy>(globalAttr.getValue());
return globalOp;
}
@@ -81,8 +138,10 @@
LogicalResult matchAndRewrite(
SrcOpTy op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
+ auto ctx = op.getContext();
+
auto type = op.getOperation()->getResultTypes();
- StringAttr callee = rewriter.getStringAttr(funcName);
+ StringAttr callee = StringAttr::get(ctx, funcName);
// Default to an empty args attribute, which results in the operands being
// printed as the arguments to the function call.
@@ -112,6 +171,124 @@
StringRef funcName;
};
+template <typename CmpOpTy>
+class CompareRefOpConversion : public OpConversionPattern<CmpOpTy> {
+ public:
+ using OpConversionPattern<CmpOpTy>::OpConversionPattern;
+
+ CompareRefOpConversion(MLIRContext *context, StringRef funcName,
+ VMAnalysisCache &vmAnalysisCache)
+ : OpConversionPattern<CmpOpTy>(context),
+ funcName(funcName),
+ vmAnalysisCache(vmAnalysisCache) {}
+
+ private:
+ LogicalResult matchAndRewrite(
+ CmpOpTy cmpOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto ctx = cmpOp.getContext();
+ auto loc = cmpOp.getLoc();
+
+ auto funcOp =
+ cmpOp.getOperation()->template getParentOfType<IREE::VM::FuncOp>();
+ auto ptr = vmAnalysisCache.find(funcOp.getOperation());
+ if (ptr == vmAnalysisCache.end()) {
+ return cmpOp.emitError() << "parent func op not found in cache.";
+ }
+ ValueLiveness &valueLiveness = ptr->second.valueLiveness;
+
+ bool moveLhs =
+ valueLiveness.isLastValueUse(cmpOp.lhs(), cmpOp.getOperation());
+ bool moveRhs =
+ valueLiveness.isLastValueUse(cmpOp.rhs(), cmpOp.getOperation());
+
+ rewriter.replaceOpWithNewOp<emitc::CallOp>(
+ /*op=*/cmpOp,
+ /*type=*/cmpOp.getType(),
+ /*callee=*/StringAttr::get(ctx, funcName),
+ /*args=*/ArrayAttr{},
+ /*templateArgs=*/ArrayAttr{},
+ /*operands=*/operands);
+
+ if (moveLhs) {
+ rewriter.create<emitc::CallOp>(
+ /*loc=*/loc,
+ /*type=*/TypeRange{},
+ /*callee=*/StringAttr::get(ctx, "iree_vm_ref_release"),
+ /*args=*/ArrayAttr{},
+ /*templateArgs=*/ArrayAttr{},
+ /*operands=*/ArrayRef<Value>{operands[0]});
+ }
+
+ // NOTE: If lhs and rhs alias we call release twice on the same argument.
+ if (moveRhs) {
+ rewriter.create<emitc::CallOp>(
+ /*loc=*/loc,
+ /*type=*/TypeRange{},
+ /*callee=*/StringAttr::get(ctx, "iree_vm_ref_release"),
+ /*args=*/ArrayAttr{},
+ /*templateArgs=*/ArrayAttr{},
+ /*operands=*/ArrayRef<Value>{operands[1]});
+ }
+
+ return success();
+ }
+
+ StringRef funcName;
+ VMAnalysisCache &vmAnalysisCache;
+};
+
+class CompareRefNotZeroOpConversion
+ : public OpConversionPattern<IREE::VM::CmpNZRefOp> {
+ using OpConversionPattern<IREE::VM::CmpNZRefOp>::OpConversionPattern;
+
+ public:
+ CompareRefNotZeroOpConversion(MLIRContext *context,
+ VMAnalysisCache &vmAnalysisCache)
+ : OpConversionPattern<IREE::VM::CmpNZRefOp>(context),
+ vmAnalysisCache(vmAnalysisCache) {}
+
+ private:
+ LogicalResult matchAndRewrite(
+ IREE::VM::CmpNZRefOp cmpOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto ctx = cmpOp.getContext();
+ auto loc = cmpOp.getLoc();
+
+ auto funcOp = cmpOp.getOperation()->getParentOfType<IREE::VM::FuncOp>();
+ auto ptr = vmAnalysisCache.find(funcOp.getOperation());
+ if (ptr == vmAnalysisCache.end()) {
+ return cmpOp.emitError() << "parent func op not found in cache.";
+ }
+ ValueLiveness &valueLiveness = ptr->second.valueLiveness;
+
+ bool move =
+ valueLiveness.isLastValueUse(cmpOp.operand(), cmpOp.getOperation());
+
+ rewriter.replaceOpWithNewOp<emitc::CallOp>(
+ /*op=*/cmpOp,
+ /*type=*/cmpOp.getType(),
+ /*callee=*/StringAttr::get(ctx, "vm_cmp_nz_ref"),
+ /*args=*/ArrayAttr{},
+ /*templateArgs=*/ArrayAttr{},
+ /*operands=*/operands);
+
+ if (move) {
+ rewriter.create<emitc::CallOp>(
+ /*loc=*/loc,
+ /*type=*/TypeRange{},
+ /*callee=*/StringAttr::get(ctx, "iree_vm_ref_release"),
+ /*args=*/ArrayAttr{},
+ /*templateArgs=*/ArrayAttr{},
+ /*operands=*/ArrayRef<Value>{operands[0]});
+ }
+
+ return success();
+ }
+
+ VMAnalysisCache &vmAnalysisCache;
+};
+
template <typename ConstOpTy>
class ConstOpConversion : public OpRewritePattern<ConstOpTy> {
public:
@@ -153,19 +330,137 @@
public:
using OpRewritePattern<IREE::VM::ConstRefZeroOp>::OpRewritePattern;
+ ConstRefZeroOpConversion(MLIRContext *context,
+ VMAnalysisCache &vmAnalysisCache)
+ : OpRewritePattern<IREE::VM::ConstRefZeroOp>(context),
+ vmAnalysisCache(vmAnalysisCache) {}
+
LogicalResult matchAndRewrite(IREE::VM::ConstRefZeroOp constRefZeroOp,
PatternRewriter &rewriter) const final {
auto ctx = constRefZeroOp.getContext();
+ auto loc = constRefZeroOp.getLoc();
- StringRef typeString = "iree_vm_ref_t";
- auto type = emitc::OpaqueType::get(constRefZeroOp.getContext(), typeString);
+ auto funcOp =
+ constRefZeroOp.getOperation()->getParentOfType<IREE::VM::FuncOp>();
- StringRef valueString = "{0}";
- emitc::OpaqueAttr value = emitc::OpaqueAttr::get(ctx, valueString);
+ auto ptr = vmAnalysisCache.find(funcOp.getOperation());
+ if (ptr == vmAnalysisCache.end()) {
+ return constRefZeroOp.emitError() << "parent func op not found in cache.";
+ }
+ RegisterAllocation ®isterAllocation = ptr->second.registerAllocation;
- rewriter.replaceOpWithNewOp<emitc::ConstOp>(constRefZeroOp, type, value);
+ int32_t ordinal =
+ registerAllocation.mapToRegister(constRefZeroOp.getResult()).ordinal();
+
+ auto refPtrOp = rewriter.replaceOpWithNewOp<emitc::CallOp>(
+ /*op=*/constRefZeroOp,
+ /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_ref_t*"),
+ /*callee=*/StringAttr::get(ctx, "VM_ARRAY_ELEMENT_ADDRESS"),
+ /*args=*/
+ ArrayAttr::get(ctx, {emitc::OpaqueAttr::get(ctx, "local_refs"),
+ rewriter.getI32IntegerAttr(ordinal)}),
+ /*templateArgs=*/ArrayAttr{},
+ /*operands=*/ArrayRef<Value>{});
+
+ rewriter.create<emitc::CallOp>(
+ /*loc=*/loc,
+ /*type=*/TypeRange{},
+ /*callee=*/StringAttr::get(ctx, "iree_vm_ref_release"),
+ /*args=*/ArrayAttr{},
+ /*templateArgs=*/ArrayAttr{},
+ /*operands=*/ArrayRef<Value>{constRefZeroOp.result()});
return success();
}
+
+ VMAnalysisCache &vmAnalysisCache;
+};
+
+class ConstRefRodataOpConversion
+ : public OpConversionPattern<IREE::VM::ConstRefRodataOp> {
+ public:
+ using OpConversionPattern<IREE::VM::ConstRefRodataOp>::OpConversionPattern;
+
+ ConstRefRodataOpConversion(MLIRContext *context,
+ VMAnalysisCache &vmAnalysisCache)
+ : OpConversionPattern<IREE::VM::ConstRefRodataOp>(context),
+ vmAnalysisCache(vmAnalysisCache) {}
+
+ LogicalResult matchAndRewrite(
+ IREE::VM::ConstRefRodataOp constRefRodataOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ auto ctx = constRefRodataOp.getContext();
+ auto loc = constRefRodataOp.getLoc();
+
+ auto rodataOp =
+ lookupSymbolRef<IREE::VM::ConstRefRodataOp, IREE::VM::RodataOp>(
+ constRefRodataOp, "rodata");
+ if (!rodataOp) {
+ return constRefRodataOp.emitError() << "Unable to find RodataOp";
+ }
+
+ // TODO(simon-camp): We can't represent structs in emitc (yet maybe), so
+ // the buffer where rodatas live after code generation as well as the
+ // state struct argument name are hardcoded here.
+ auto byteBufferPtrOp = rewriter.create<emitc::CallOp>(
+ /*location=*/loc,
+ /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_buffer_t*"),
+ /*callee=*/StringAttr::get(ctx, "VM_ARRAY_ELEMENT_ADDRESS"),
+ /*args=*/
+ ArrayAttr::get(ctx,
+ {emitc::OpaqueAttr::get(ctx, "state->rodata_buffers"),
+ rewriter.getUI32IntegerAttr(static_cast<uint32_t>(
+ rodataOp.ordinal().getValue().getZExtValue()))}),
+ /*templateArgs=*/ArrayAttr{},
+ /*operands=*/ArrayRef<Value>{});
+
+ auto typeIdOp = rewriter.create<emitc::CallOp>(
+ /*location=*/loc,
+ /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_ref_type_t"),
+ /*callee=*/StringAttr::get(ctx, "iree_vm_buffer_type_id"),
+ /*args=*/ArrayAttr{},
+ /*templateArgs=*/ArrayAttr{},
+ /*operands=*/ArrayRef<Value>{});
+
+ auto funcOp =
+ constRefRodataOp.getOperation()->getParentOfType<IREE::VM::FuncOp>();
+
+ auto ptr = vmAnalysisCache.find(funcOp.getOperation());
+ if (ptr == vmAnalysisCache.end()) {
+ return constRefRodataOp.emitError()
+ << "parent func op not found in cache.";
+ }
+ RegisterAllocation ®isterAllocation = ptr->second.registerAllocation;
+
+ int32_t ordinal =
+ registerAllocation.mapToRegister(constRefRodataOp.getResult())
+ .ordinal();
+
+ auto refPtrOp = rewriter.replaceOpWithNewOp<emitc::CallOp>(
+ /*op=*/constRefRodataOp,
+ /*type=*/
+ emitc::OpaqueType::get(ctx, "iree_vm_ref_t*"),
+ // /*type=*/typeConverter->convertType(constRefRodataOp.getResult().getType()),
+ /*callee=*/StringAttr::get(ctx, "VM_ARRAY_ELEMENT_ADDRESS"),
+ /*args=*/
+ ArrayAttr::get(ctx, {emitc::OpaqueAttr::get(ctx, "local_refs"),
+ rewriter.getI32IntegerAttr(ordinal)}),
+ /*templateArgs=*/ArrayAttr{},
+ /*operands=*/ArrayRef<Value>{});
+
+ failableCall(
+ /*rewriter=*/rewriter,
+ /*loc=*/loc,
+ /*callee=*/StringAttr::get(ctx, "iree_vm_ref_wrap_retain"),
+ /*args=*/ArrayAttr{},
+ /*templateArgs=*/ArrayAttr{},
+ /*operands=*/
+ ArrayRef<Value>{byteBufferPtrOp.getResult(0), typeIdOp.getResult(0),
+ refPtrOp.getResult(0)});
+
+ return success();
+ }
+
+ VMAnalysisCache &vmAnalysisCache;
};
template <typename LoadOpTy, typename GlobalOpTy>
@@ -182,13 +477,14 @@
ConversionPatternRewriter &rewriter) const override {
auto ctx = loadOp.getContext();
- GlobalOpTy globalOp = lookupGlobalOp<LoadOpTy, GlobalOpTy>(loadOp);
+ GlobalOpTy globalOp =
+ lookupSymbolRef<LoadOpTy, GlobalOpTy>(loadOp, "global");
if (!globalOp) {
return loadOp.emitError() << "Unable to find GlobalOp";
}
auto type = loadOp.getOperation()->getResultTypes();
- StringAttr callee = rewriter.getStringAttr(funcName);
+ StringAttr callee = StringAttr::get(ctx, funcName);
// TODO(simon-camp): We can't represent structs in emitc (yet maybe), so
// the buffer where globals live after code generation as well as the
@@ -222,13 +518,14 @@
ConversionPatternRewriter &rewriter) const override {
auto ctx = storeOp.getContext();
- GlobalOpTy globalOp = lookupGlobalOp<StoreOpTy, GlobalOpTy>(storeOp);
+ GlobalOpTy globalOp =
+ lookupSymbolRef<StoreOpTy, GlobalOpTy>(storeOp, "global");
if (!globalOp) {
return storeOp.emitError() << "Unable to find GlobalOp";
}
auto type = storeOp.getOperation()->getResultTypes();
- StringAttr callee = rewriter.getStringAttr(funcName);
+ StringAttr callee = StringAttr::get(ctx, funcName);
// TODO(simon-camp): We can't represent structs in emitc (yet maybe), so
// the buffer where globals live after code generation as well as the
@@ -275,19 +572,19 @@
return op.emitError() << " index for list argument out of range";
}
- Value listOperand = op.getOperation()->getOperand(listArgumentIndex);
+ Value listOperand = operands[listArgumentIndex];
// deref
auto refOp = rewriter.create<emitc::ApplyOp>(
/*location=*/loc,
/*type=*/emitc::OpaqueType::get(ctx, "iree_vm_ref_t"),
- /*applicableOperator=*/rewriter.getStringAttr("*"),
+ /*applicableOperator=*/StringAttr::get(ctx, "*"),
/*operand=*/listOperand);
auto listDerefOp = rewriter.create<emitc::CallOp>(
/*location=*/loc,
/*type=*/emitc::OpaqueType::get(ctx, "iree_vm_list_t*"),
- /*callee=*/rewriter.getStringAttr("iree_vm_list_deref"),
+ /*callee=*/StringAttr::get(ctx, "iree_vm_list_deref"),
/*args=*/ArrayAttr{},
/*templateArgs=*/ArrayAttr{},
/*operands=*/ArrayRef<Value>{refOp.getResult()});
@@ -295,7 +592,7 @@
rewriter.create<emitc::CallOp>(
/*location=*/loc,
/*type=*/TypeRange{},
- /*callee=*/rewriter.getStringAttr("VM_RETURN_IF_LIST_NULL"),
+ /*callee=*/StringAttr::get(ctx, "VM_RETURN_IF_LIST_NULL"),
/*args=*/
ArrayAttr::get(ctx, {rewriter.getIndexAttr(0),
emitc::OpaqueAttr::get(ctx, "local_refs")}),
@@ -317,7 +614,7 @@
auto callOp = failableCall(
/*rewriter=*/rewriter,
/*loc=*/loc,
- /*callee=*/rewriter.getStringAttr(funcName),
+ /*callee=*/StringAttr::get(ctx, funcName),
/*args=*/ArrayAttr{},
/*templateArgs=*/ArrayAttr{},
/*operands=*/ArrayRef<Value>(updatedOperands));
@@ -327,7 +624,7 @@
rewriter.replaceOpWithNewOp<emitc::CallOp>(
/*op=*/op,
/*type=*/op.getOperation()->getResultTypes(),
- /*callee=*/rewriter.getStringAttr(funcName),
+ /*callee=*/StringAttr::get(ctx, funcName),
/*args=*/ArrayAttr{},
/*templateArgs=*/ArrayAttr{},
/*operands=*/ArrayRef<Value>(updatedOperands));
@@ -347,8 +644,14 @@
class ListAllocOpConversion
: public OpConversionPattern<IREE::VM::ListAllocOp> {
+ public:
using OpConversionPattern<IREE::VM::ListAllocOp>::OpConversionPattern;
+ ListAllocOpConversion(TypeConverter &typeConverter, MLIRContext *context,
+ VMAnalysisCache &vmAnalysisCache)
+ : OpConversionPattern<IREE::VM::ListAllocOp>(typeConverter, context),
+ vmAnalysisCache(vmAnalysisCache) {}
+
private:
LogicalResult matchAndRewrite(
IREE::VM::ListAllocOp allocOp, ArrayRef<Value> operands,
@@ -370,36 +673,24 @@
auto ctx = allocOp.getContext();
auto loc = allocOp.getLoc();
- auto listType = allocOp.getType()
- .cast<IREE::VM::RefType>()
- .getObjectType()
- .cast<IREE::VM::ListType>();
- auto elementType = listType.getElementType();
- std::string elementTypeStr;
- StringRef elementTypeConstructor;
- if (elementType.isa<IntegerType>()) {
- unsigned int bitWidth = elementType.getIntOrFloatBitWidth();
- elementTypeStr =
- std::string("IREE_VM_VALUE_TYPE_I") + std::to_string(bitWidth);
- elementTypeConstructor = "iree_vm_type_def_make_value_type";
- } else {
- return allocOp.emitError() << "Unhandeled element type " << elementType;
+ Type convertedType = typeConverter->convertType(allocOp.getType());
+
+ if (!convertedType) {
+ return allocOp.emitOpError() << "type conversion failed";
}
- auto elementTypeOp = rewriter.create<emitc::CallOp>(
- /*location=*/loc,
- /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_type_def_t"),
- /*callee=*/rewriter.getStringAttr(elementTypeConstructor),
- /*args=*/
- ArrayAttr::get(ctx, {emitc::OpaqueAttr::get(ctx, elementTypeStr)}),
- /*templateArgs=*/ArrayAttr{},
- /*operands=*/ArrayRef<Value>{});
+ auto elementType = allocOp.getType()
+ .cast<IREE::VM::RefType>()
+ .getObjectType()
+ .cast<IREE::VM::ListType>()
+ .getElementType();
- auto elementTypePtrOp = rewriter.create<emitc::ApplyOp>(
- /*location=*/loc,
- /*result=*/emitc::OpaqueType::get(ctx, "iree_vm_type_def_t*"),
- /*applicableOperator=*/rewriter.getStringAttr("&"),
- /*operand=*/elementTypeOp.getResult(0));
+ Optional<emitc::ApplyOp> elementTypePtrOp =
+ createVmTypeDefPtr(rewriter, allocOp, elementType);
+
+ if (!elementTypePtrOp.hasValue()) {
+ return failure();
+ }
auto listOp = rewriter.create<emitc::ConstOp>(
/*location=*/loc,
@@ -409,40 +700,38 @@
auto listPtrOp = rewriter.create<emitc::ApplyOp>(
/*location=*/loc,
/*result=*/emitc::OpaqueType::get(ctx, "iree_vm_list_t**"),
- /*applicableOperator=*/rewriter.getStringAttr("&"),
+ /*applicableOperator=*/StringAttr::get(ctx, "&"),
/*operand=*/listOp.getResult());
failableCall(
/*rewriter=*/rewriter,
/*location=*/loc,
- /*callee=*/rewriter.getStringAttr("iree_vm_list_create"),
+ /*callee=*/StringAttr::get(ctx, "iree_vm_list_create"),
/*args=*/
ArrayAttr::get(ctx, {rewriter.getIndexAttr(0), rewriter.getIndexAttr(1),
emitc::OpaqueAttr::get(ctx, "state->allocator"),
rewriter.getIndexAttr(2)}),
/*templateArgs=*/ArrayAttr{},
/*operands=*/
- ArrayRef<Value>{elementTypePtrOp.getResult(), operands[0],
+ ArrayRef<Value>{elementTypePtrOp.getValue().getResult(), operands[0],
listPtrOp.getResult()});
- // TODO(simon-camp): This is expensive as we recalculate the
- // RegisterAllocation for every alloc in a function. We could make it
- // compatible with the analysis framework in MLIR which would cache it
- // automatically IIUC. See here for reference
- // https://mlir.llvm.org/docs/PassManagement/#analysis-management
auto funcOp = allocOp.getOperation()->getParentOfType<IREE::VM::FuncOp>();
- RegisterAllocation registerAllocation;
- if (failed(registerAllocation.recalculate(funcOp))) {
- return allocOp.emitOpError() << "unable to perform register allocation";
+
+ auto ptr = vmAnalysisCache.find(funcOp.getOperation());
+ if (ptr == vmAnalysisCache.end()) {
+ return allocOp.emitError() << "parent func op not found in cache.";
}
+ RegisterAllocation ®isterAllocation = ptr->second.registerAllocation;
int32_t ordinal =
registerAllocation.mapToRegister(allocOp.getResult()).ordinal();
auto refPtrOp = rewriter.replaceOpWithNewOp<emitc::CallOp>(
/*op=*/allocOp,
- /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_ref_t*"),
- /*callee=*/rewriter.getStringAttr("VM_ARRAY_ELEMENT_ADDRESS"),
+ // /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_ref_t*"),
+ /*type=*/convertedType,
+ /*callee=*/StringAttr::get(ctx, "VM_ARRAY_ELEMENT_ADDRESS"),
/*args=*/
ArrayAttr::get(ctx, {emitc::OpaqueAttr::get(ctx, "local_refs"),
rewriter.getI32IntegerAttr(ordinal)}),
@@ -452,7 +741,7 @@
auto refTypeOp = rewriter.create<emitc::CallOp>(
/*location=*/loc,
/*type=*/emitc::OpaqueType::get(ctx, "iree_vm_ref_type_t"),
- /*callee=*/rewriter.getStringAttr("iree_vm_list_type_id"),
+ /*callee=*/StringAttr::get(ctx, "iree_vm_list_type_id"),
/*args=*/ArrayAttr{},
/*templateArgs=*/ArrayAttr{},
/*operands=*/ArrayRef<Value>{});
@@ -460,7 +749,7 @@
failableCall(
/*rewriter=*/rewriter,
/*location=*/loc,
- /*callee=*/rewriter.getStringAttr("iree_vm_ref_wrap_assign"),
+ /*callee=*/StringAttr::get(ctx, "iree_vm_ref_wrap_assign"),
/*args=*/ArrayAttr{},
/*templateArgs=*/ArrayAttr{},
/*operands=*/
@@ -469,6 +758,8 @@
return success();
}
+
+ VMAnalysisCache &vmAnalysisCache;
};
template <typename GetOpTy>
@@ -511,19 +802,19 @@
auto valuePtrOp = rewriter.create<emitc::ApplyOp>(
/*location=*/loc,
/*result=*/emitc::OpaqueType::get(ctx, "iree_vm_value_t*"),
- /*applicableOperator=*/rewriter.getStringAttr("&"),
+ /*applicableOperator=*/StringAttr::get(ctx, "&"),
/*operand=*/valueOp.getResult());
auto refOp = rewriter.create<emitc::ApplyOp>(
/*location=*/loc,
/*type=*/emitc::OpaqueType::get(ctx, "iree_vm_ref_t"),
- /*applicableOperator=*/rewriter.getStringAttr("*"),
- /*operand=*/getOp.list());
+ /*applicableOperator=*/StringAttr::get(ctx, "*"),
+ /*operand=*/operands[0]);
auto listDerefOp = rewriter.create<emitc::CallOp>(
/*location=*/loc,
/*type=*/emitc::OpaqueType::get(ctx, "iree_vm_list_t*"),
- /*callee=*/rewriter.getStringAttr("iree_vm_list_deref"),
+ /*callee=*/StringAttr::get(ctx, "iree_vm_list_deref"),
/*args=*/ArrayAttr{},
/*templateArgs=*/ArrayAttr{},
/*operands=*/ArrayRef<Value>{refOp.getResult()});
@@ -531,7 +822,7 @@
rewriter.create<emitc::CallOp>(
/*location=*/loc,
/*type=*/TypeRange{},
- /*callee=*/rewriter.getStringAttr("VM_RETURN_IF_LIST_NULL"),
+ /*callee=*/StringAttr::get(ctx, "VM_RETURN_IF_LIST_NULL"),
/*args=*/
ArrayAttr::get(ctx, {rewriter.getIndexAttr(0),
emitc::OpaqueAttr::get(ctx, "local_refs")}),
@@ -541,7 +832,7 @@
auto getValueOp = failableCall(
/*rewriter=*/rewriter,
/*location=*/loc,
- /*callee=*/rewriter.getStringAttr("iree_vm_list_get_value_as"),
+ /*callee=*/StringAttr::get(ctx, "iree_vm_list_get_value_as"),
/*args=*/
ArrayAttr::get(ctx,
{rewriter.getIndexAttr(0), rewriter.getIndexAttr(1),
@@ -555,7 +846,7 @@
rewriter.replaceOpWithNewOp<emitc::CallOp>(
/*op=*/getOp,
/*type=*/getOp.getType(),
- /*callee=*/rewriter.getStringAttr(valueExtractor.getValue()),
+ /*callee=*/StringAttr::get(ctx, valueExtractor.getValue()),
/*args=*/ArrayAttr{},
/*templateArgs=*/ArrayAttr{},
/*operands=*/ArrayRef<Value>{valuePtrOp.getResult()});
@@ -564,6 +855,102 @@
}
};
+class ListGetRefOpConversion
+ : public OpConversionPattern<IREE::VM::ListGetRefOp> {
+ public:
+ using OpConversionPattern<IREE::VM::ListGetRefOp>::OpConversionPattern;
+
+ ListGetRefOpConversion(MLIRContext *context, VMAnalysisCache &vmAnalysisCache)
+ : OpConversionPattern<IREE::VM::ListGetRefOp>(context),
+ vmAnalysisCache(vmAnalysisCache) {}
+
+ private:
+ LogicalResult matchAndRewrite(
+ IREE::VM::ListGetRefOp getOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto ctx = getOp.getContext();
+ auto loc = getOp.getLoc();
+
+ auto refOp = rewriter.create<emitc::ApplyOp>(
+ /*location=*/loc,
+ /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_ref_t"),
+ /*applicableOperator=*/StringAttr::get(ctx, "*"),
+ /*operand=*/operands[0]);
+
+ auto listDerefOp = rewriter.create<emitc::CallOp>(
+ /*location=*/loc,
+ /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_list_t*"),
+ /*callee=*/StringAttr::get(ctx, "iree_vm_list_deref"),
+ /*args=*/ArrayAttr{},
+ /*templateArgs=*/ArrayAttr{},
+ /*operands=*/ArrayRef<Value>{refOp.getResult()});
+
+ rewriter.create<emitc::CallOp>(
+ /*location=*/loc,
+ /*type=*/TypeRange{},
+ /*callee=*/StringAttr::get(ctx, "VM_RETURN_IF_LIST_NULL"),
+ /*args=*/
+ ArrayAttr::get(ctx, {rewriter.getIndexAttr(0),
+ emitc::OpaqueAttr::get(ctx, "local_refs")}),
+ /*templateArgs=*/ArrayAttr{},
+ /*operands=*/ArrayRef<Value>{listDerefOp.getResult(0)});
+
+ auto funcOp = getOp.getOperation()->getParentOfType<IREE::VM::FuncOp>();
+
+ auto ptr = vmAnalysisCache.find(funcOp.getOperation());
+ if (ptr == vmAnalysisCache.end()) {
+ return getOp.emitError() << "parent func op not found in cache.";
+ }
+ RegisterAllocation ®isterAllocation = ptr->second.registerAllocation;
+
+ int32_t ordinal =
+ registerAllocation.mapToRegister(getOp.getResult()).ordinal();
+
+ auto refPtrOp = rewriter.replaceOpWithNewOp<emitc::CallOp>(
+ /*op=*/getOp,
+ /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_ref_t*"),
+ /*callee=*/StringAttr::get(ctx, "VM_ARRAY_ELEMENT_ADDRESS"),
+ /*args=*/
+ ArrayAttr::get(ctx, {emitc::OpaqueAttr::get(ctx, "local_refs"),
+ rewriter.getI32IntegerAttr(ordinal)}),
+ /*templateArgs=*/ArrayAttr{},
+ /*operands=*/ArrayRef<Value>{});
+
+ failableCall(
+ /*rewriter=*/rewriter,
+ /*loc=*/loc,
+ /*callee=*/StringAttr::get(ctx, "iree_vm_list_get_ref_retain"),
+ /*args=*/ArrayAttr{},
+ /*templateArgs=*/ArrayAttr{},
+ /*operands=*/
+ ArrayRef<Value>{listDerefOp.getResult(0), getOp.index(),
+ refPtrOp.getResult(0)});
+
+ Type elementType = getOp.getResult().getType();
+
+ auto elementTypePtrOp = createVmTypeDefPtr(rewriter, getOp, elementType);
+
+ if (!elementTypePtrOp.hasValue()) {
+ return failure();
+ }
+
+ rewriter.create<emitc::CallOp>(
+ /*location=*/loc,
+ /*type=*/TypeRange{},
+ /*callee=*/StringAttr::get(ctx, "VM_REF_RELEASE_IF_TYPE_MISMATCH"),
+ /*args=*/
+ ArrayAttr{},
+ /*templateArgs=*/ArrayAttr{},
+ /*operands=*/
+ ArrayRef<Value>{refPtrOp.getResult(0),
+ elementTypePtrOp.getValue().getResult()});
+
+ return success();
+ }
+
+ VMAnalysisCache &vmAnalysisCache;
+};
+
template <typename SetOpTy>
class ListSetOpConversion : public OpConversionPattern<SetOpTy> {
using OpConversionPattern<SetOpTy>::OpConversionPattern;
@@ -584,13 +971,13 @@
.Default([](Operation *) { return None; });
if (!valueConstructor.hasValue()) {
- return setOp.emitOpError() << " not handeled";
+ return setOp.emitOpError() << " not handled";
}
auto valueOp = rewriter.create<emitc::CallOp>(
/*location=*/loc,
/*type=*/emitc::OpaqueType::get(ctx, "iree_vm_value_t"),
- /*callee=*/rewriter.getStringAttr(valueConstructor.getValue()),
+ /*callee=*/StringAttr::get(ctx, valueConstructor.getValue()),
/*args=*/ArrayAttr{},
/*templateArgs=*/ArrayAttr{},
/*operands=*/ArrayRef<Value>{setOp.value()});
@@ -598,19 +985,19 @@
auto valuePtrOp = rewriter.create<emitc::ApplyOp>(
/*location=*/loc,
/*result=*/emitc::OpaqueType::get(ctx, "iree_vm_value_t*"),
- /*applicableOperator=*/rewriter.getStringAttr("&"),
+ /*applicableOperator=*/StringAttr::get(ctx, "&"),
/*operand=*/valueOp.getResult(0));
auto refOp = rewriter.create<emitc::ApplyOp>(
/*location=*/loc,
/*type=*/emitc::OpaqueType::get(ctx, "iree_vm_ref_t"),
- /*applicableOperator=*/rewriter.getStringAttr("*"),
- /*operand=*/setOp.list());
+ /*applicableOperator=*/StringAttr::get(ctx, "*"),
+ /*operand=*/operands[0]);
auto listDerefOp = rewriter.create<emitc::CallOp>(
/*location=*/loc,
/*type=*/emitc::OpaqueType::get(ctx, "iree_vm_list_t*"),
- /*callee=*/rewriter.getStringAttr("iree_vm_list_deref"),
+ /*callee=*/StringAttr::get(ctx, "iree_vm_list_deref"),
/*args=*/ArrayAttr{},
/*templateArgs=*/ArrayAttr{},
/*operands=*/ArrayRef<Value>{refOp.getResult()});
@@ -618,7 +1005,7 @@
rewriter.create<emitc::CallOp>(
/*location=*/loc,
/*type=*/TypeRange{},
- /*callee=*/rewriter.getStringAttr("VM_RETURN_IF_LIST_NULL"),
+ /*callee=*/StringAttr::get(ctx, "VM_RETURN_IF_LIST_NULL"),
/*args=*/
ArrayAttr::get(ctx, {rewriter.getIndexAttr(0),
emitc::OpaqueAttr::get(ctx, "local_refs")}),
@@ -628,22 +1015,93 @@
auto callOp = failableCall(
/*rewriter=*/rewriter,
/*loc=*/loc,
- /*callee=*/rewriter.getStringAttr("iree_vm_list_set_value"),
+ /*callee=*/StringAttr::get(ctx, "iree_vm_list_set_value"),
/*args=*/ArrayAttr{},
/*templateArgs=*/ArrayAttr{},
/*operands=*/
ArrayRef<Value>{listDerefOp.getResult(0), setOp.index(),
valuePtrOp.getResult()});
- rewriter.replaceOp(setOp, ArrayRef<Value>{});
+ rewriter.eraseOp(setOp);
return success();
}
};
+
+class ListSetRefOpConversion
+ : public OpConversionPattern<IREE::VM::ListSetRefOp> {
+ public:
+ using OpConversionPattern<IREE::VM::ListSetRefOp>::OpConversionPattern;
+
+ ListSetRefOpConversion(MLIRContext *context, VMAnalysisCache &vmAnalysisCache)
+ : OpConversionPattern<IREE::VM::ListSetRefOp>(context),
+ vmAnalysisCache(vmAnalysisCache) {}
+
+ private:
+ LogicalResult matchAndRewrite(
+ IREE::VM::ListSetRefOp setOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto ctx = setOp.getContext();
+ auto loc = setOp.getLoc();
+
+ auto refOp = rewriter.create<emitc::ApplyOp>(
+ /*location=*/loc,
+ /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_ref_t"),
+ /*applicableOperator=*/StringAttr::get(ctx, "*"),
+ /*operand=*/operands[0]);
+
+ auto listDerefOp = rewriter.create<emitc::CallOp>(
+ /*location=*/loc,
+ /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_list_t*"),
+ /*callee=*/StringAttr::get(ctx, "iree_vm_list_deref"),
+ /*args=*/ArrayAttr{},
+ /*templateArgs=*/ArrayAttr{},
+ /*operands=*/ArrayRef<Value>{refOp.getResult()});
+
+ rewriter.create<emitc::CallOp>(
+ /*location=*/loc,
+ /*type=*/TypeRange{},
+ /*callee=*/StringAttr::get(ctx, "VM_RETURN_IF_LIST_NULL"),
+ /*args=*/
+ ArrayAttr::get(ctx, {rewriter.getIndexAttr(0),
+ emitc::OpaqueAttr::get(ctx, "local_refs")}),
+ /*templateArgs=*/ArrayAttr{},
+ /*operands=*/ArrayRef<Value>{listDerefOp.getResult(0)});
+
+ auto funcOp = setOp.getOperation()->getParentOfType<IREE::VM::FuncOp>();
+ auto ptr = vmAnalysisCache.find(funcOp.getOperation());
+ if (ptr == vmAnalysisCache.end()) {
+ return setOp.emitError() << "parent func op not found in cache.";
+ }
+ ValueLiveness &valueLiveness = ptr->second.valueLiveness;
+
+ bool move =
+ valueLiveness.isLastValueUse(setOp.value(), setOp.getOperation());
+
+ auto callOp = failableCall(
+ /*rewriter=*/rewriter,
+ /*loc=*/loc,
+ /*callee=*/StringAttr::get(ctx, "iree_vm_list_set_ref_retain"),
+ /*args=*/ArrayAttr{},
+ /*templateArgs=*/ArrayAttr{},
+ /*operands=*/
+ ArrayRef<Value>{listDerefOp.getResult(0), setOp.index(), operands[2]});
+
+ rewriter.eraseOp(setOp);
+
+ return success();
+ }
+
+ VMAnalysisCache &vmAnalysisCache;
+};
} // namespace
-void populateVMToCPatterns(MLIRContext *context,
- OwningRewritePatternList &patterns) {
+void populateVMToEmitCPatterns(MLIRContext *context,
+ IREE::VM::EmitCTypeConverter &typeConverter,
+ OwningRewritePatternList &patterns,
+ VMAnalysisCache &vmAnalysisCache) {
+ populatePreserveCompilerHintsPatterns(context, patterns);
+
// Globals
patterns.insert<
GlobalLoadOpConversion<IREE::VM::GlobalLoadI32Op, IREE::VM::GlobalI32Op>>(
@@ -655,10 +1113,12 @@
// Constants
patterns.insert<ConstOpConversion<IREE::VM::ConstI32Op>>(context);
patterns.insert<ConstZeroOpConversion<IREE::VM::ConstI32ZeroOp>>(context);
- patterns.insert<ConstRefZeroOpConversion>(context);
+ patterns.insert<ConstRefZeroOpConversion>(context, vmAnalysisCache);
+ patterns.insert<ConstRefRodataOpConversion>(context, vmAnalysisCache);
// List ops
- patterns.insert<ListAllocOpConversion>(context);
+ patterns.insert<ListAllocOpConversion>(typeConverter, context,
+ vmAnalysisCache);
patterns.insert<ListOpConversion<IREE::VM::ListReserveOp>>(
context, "iree_vm_list_reserve", 0, true);
patterns.insert<ListOpConversion<IREE::VM::ListResizeOp>>(
@@ -666,7 +1126,9 @@
patterns.insert<ListOpConversion<IREE::VM::ListSizeOp>>(
context, "iree_vm_list_size", 0, false);
patterns.insert<ListGetOpConversion<IREE::VM::ListGetI32Op>>(context);
+ patterns.insert<ListGetRefOpConversion>(context, vmAnalysisCache);
patterns.insert<ListSetOpConversion<IREE::VM::ListSetI32Op>>(context);
+ patterns.insert<ListSetRefOpConversion>(context, vmAnalysisCache);
// Conditional assignment ops
patterns.insert<CallOpConversion<IREE::VM::SelectI32Op>>(context,
@@ -722,6 +1184,11 @@
"vm_cmp_lt_i32u");
patterns.insert<CallOpConversion<IREE::VM::CmpNZI32Op>>(context,
"vm_cmp_nz_i32");
+ patterns.insert<CompareRefOpConversion<IREE::VM::CmpEQRefOp>>(
+ context, "vm_cmp_eq_ref", vmAnalysisCache);
+ patterns.insert<CompareRefOpConversion<IREE::VM::CmpNERefOp>>(
+ context, "vm_cmp_ne_ref", vmAnalysisCache);
+ patterns.insert<CompareRefNotZeroOpConversion>(context, vmAnalysisCache);
// ExtF32: Native floating-point constants
patterns.insert<ConstOpConversion<IREE::VM::ConstF32Op>>(context);
@@ -859,21 +1326,37 @@
void runOnOperation() override {
ConversionTarget target(getContext());
+ EmitCTypeConverter typeConverter;
+
+ VMAnalysisCache vmAnalysisCache;
+
+ for (auto funcOp : getOperation().getOps<IREE::VM::FuncOp>()) {
+ Operation *op = funcOp.getOperation();
+ vmAnalysisCache.insert(std::make_pair(
+ op, VMAnalysis{RegisterAllocation(op), ValueLiveness(op)}));
+ }
OwningRewritePatternList patterns(&getContext());
- populateVMToCPatterns(&getContext(), patterns);
+ populateVMToEmitCPatterns(&getContext(), typeConverter, patterns,
+ vmAnalysisCache);
target.addLegalDialect<mlir::emitc::EmitCDialect>();
- target.addLegalDialect<iree_compiler::IREEDialect>();
- target.addIllegalDialect<IREE::VM::VMDialect>();
+
+ target.addDynamicallyLegalOp<IREE::DoNotOptimizeOp>(
+ [&](IREE::DoNotOptimizeOp op) {
+ return typeConverter.isLegal(op.getResultTypes());
+ });
// Structural ops
target.addLegalOp<IREE::VM::ModuleOp>();
target.addLegalOp<IREE::VM::ModuleTerminatorOp>();
target.addLegalOp<IREE::VM::FuncOp>();
- target.addLegalOp<IREE::VM::GlobalI32Op>();
target.addLegalOp<IREE::VM::ExportOp>();
+ // Global ops
+ target.addLegalOp<IREE::VM::GlobalI32Op>();
+ target.addLegalOp<IREE::VM::RodataOp>();
+
// Control flow ops
target.addLegalOp<IREE::VM::BranchOp>();
target.addLegalOp<IREE::VM::CallOp>();
diff --git a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.h b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.h
index 765a6ca..01630cd 100644
--- a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.h
+++ b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.h
@@ -7,14 +7,26 @@
#ifndef IREE_COMPILER_DIALECT_VM_CONVERSION_VMTOEMITC_CONVERTVMTOEMITC_H_
#define IREE_COMPILER_DIALECT_VM_CONVERSION_VMTOEMITC_CONVERTVMTOEMITC_H_
+#include "iree/compiler/Dialect/VM/Analysis/RegisterAllocation.h"
+#include "iree/compiler/Dialect/VM/Analysis/ValueLiveness.h"
+#include "iree/compiler/Dialect/VM/Conversion/VMToEmitC/EmitCTypeConverter.h"
#include "iree/compiler/Dialect/VM/IR/VMOps.h"
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace iree_compiler {
-void populateVMToCPatterns(MLIRContext *context,
- OwningRewritePatternList &patterns);
+struct VMAnalysis {
+ RegisterAllocation registerAllocation;
+ ValueLiveness valueLiveness;
+};
+
+using VMAnalysisCache = DenseMap<Operation *, VMAnalysis>;
+
+void populateVMToEmitCPatterns(MLIRContext *context,
+ IREE::VM::EmitCTypeConverter &typeConverter,
+ OwningRewritePatternList &patterns,
+ VMAnalysisCache &vmAnalysisCache);
namespace IREE {
namespace VM {
diff --git a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/EmitCTypeConverter.h b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/EmitCTypeConverter.h
new file mode 100644
index 0000000..ac9b057
--- /dev/null
+++ b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/EmitCTypeConverter.h
@@ -0,0 +1,38 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_DIALECT_VM_CONVERSION_VMTOEMITC_TYPECONVERTER_H_
+#define IREE_COMPILER_DIALECT_VM_CONVERSION_VMTOEMITC_TYPECONVERTER_H_
+
+#include "emitc/Dialect/EmitC/IR/EmitC.h"
+#include "iree/compiler/Dialect/VM/IR/VMTypes.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace VM {
+
+class EmitCTypeConverter : public mlir::TypeConverter {
+ public:
+ EmitCTypeConverter() {
+ // Return the incoming type in the default case.
+ addConversion([](Type type) { return type; });
+
+ addConversion([](emitc::OpaqueType type) { return type; });
+
+ addConversion([](IREE::VM::RefType type) {
+ return emitc::OpaqueType::get(type.getContext(), "iree_vm_ref_t*");
+ });
+ }
+};
+
+} // namespace VM
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_DIALECT_VM_CONVERSION_VMTOEMITC_TYPECONVERTER_H_
diff --git a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/const_ops.mlir b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/const_ops.mlir
index 61a01c3..3746b24 100644
--- a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/const_ops.mlir
+++ b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/const_ops.mlir
@@ -29,8 +29,9 @@
vm.module @my_module {
// CHECK-LABEL: vm.func @const_ref_zero
- vm.func @const_ref_zero() {
- /// CHECK: %[[NULL:.+]] = "emitc.const"() {value = #emitc.opaque<"{0}">} : () -> !emitc.opaque<"iree_vm_ref_t">
+ vm.func @const_ref_zero() -> !vm.ref<?> {
+ // CHECK: %[[REF:.+]] = emitc.call "VM_ARRAY_ELEMENT_ADDRESS"() {args = [#emitc.opaque<"local_refs">, 0 : i32]} : () -> !emitc.opaque<"iree_vm_ref_t*">
+ // CHECK: emitc.call "iree_vm_ref_release"(%[[REF]]) : (!emitc.opaque<"iree_vm_ref_t*">) -> ()
%null = vm.const.ref.zero : !vm.ref<?>
vm.return
}
diff --git a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/type_conversion.mlir b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/type_conversion.mlir
new file mode 100644
index 0000000..31f172f
--- /dev/null
+++ b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/type_conversion.mlir
@@ -0,0 +1,38 @@
+// RUN: iree-opt -split-input-file -pass-pipeline='vm.module(iree-vm-ordinal-allocation),vm.module(iree-convert-vm-to-emitc)' %s | IreeFileCheck %s
+
+vm.module @my_module {
+ // CHECK-LABEL: @list_alloc
+ vm.func @list_alloc(%arg0: i32) {
+ %list = vm.list.alloc %arg0 : (i32) -> !vm.list<i32>
+ // CHECK: %[[LIST:.+]] = emitc.call "VM_ARRAY_ELEMENT_ADDRESS"() {args = [#emitc.opaque<"local_refs">, 0 : i32]} : () -> !emitc.opaque<"iree_vm_ref_t*">
+ %list_dno = iree.do_not_optimize(%list) : !vm.list<i32>
+ // CHECK: iree.do_not_optimize(%[[LIST]]) : !emitc.opaque<"iree_vm_ref_t*">
+ vm.return
+ }
+
+ // CHECK-LABEL: @list_size
+ vm.func @list_size(%arg0: i32) {
+ %list = vm.list.alloc %arg0 : (i32) -> !vm.list<i32>
+ // CHECK: %[[LIST:.+]] = emitc.call "VM_ARRAY_ELEMENT_ADDRESS"() {args = [#emitc.opaque<"local_refs">, 0 : i32]} : () -> !emitc.opaque<"iree_vm_ref_t*">
+ %size = vm.list.size %list : (!vm.list<i32>) -> i32
+ // CHECK: %[[SIZE:.+]] = emitc.call "iree_vm_list_size"(%{{.+}})
+ %size_dno = iree.do_not_optimize(%size) : i32
+ // CHECK: iree.do_not_optimize(%[[SIZE]]) : i32
+ vm.return
+ }
+}
+
+// -----
+
+vm.module @my_module {
+ vm.rodata private @byte_buffer dense<[1, 2, 3]> : tensor<3xi32>
+ // CHECK-LABEL: @ref
+ vm.export @ref
+ vm.func @ref(%arg0: i32) {
+ %buffer = vm.const.ref.rodata @byte_buffer : !vm.buffer
+ // CHECK: %[[BUFFER:.+]] = emitc.call "VM_ARRAY_ELEMENT_ADDRESS"() {args = [#emitc.opaque<"local_refs">, 0 : i32]} : () -> !emitc.opaque<"iree_vm_ref_t*">
+ %buffer_dno = iree.do_not_optimize(%buffer) : !vm.buffer
+ // CHECK: iree.do_not_optimize(%[[BUFFER]]) : !emitc.opaque<"iree_vm_ref_t*">
+ vm.return
+ }
+}
diff --git a/iree/compiler/Dialect/VM/IR/VMOps.cpp b/iree/compiler/Dialect/VM/IR/VMOps.cpp
index 922b0db..caad188 100644
--- a/iree/compiler/Dialect/VM/IR/VMOps.cpp
+++ b/iree/compiler/Dialect/VM/IR/VMOps.cpp
@@ -732,6 +732,9 @@
StringRef visibility;
if (parser.parseOptionalKeyword(&visibility,
{"public", "private", "nested"})) {
+ parser.emitError(
+ parser.getCurrentLocation(),
+ "expected valid visibility specifier (public, private or nested)");
return failure();
}
StringAttr visibilityAttr = parser.getBuilder().getStringAttr(visibility);
diff --git a/iree/compiler/Dialect/VM/Target/BUILD b/iree/compiler/Dialect/VM/Target/BUILD
index a0a0707..2f2cd29 100644
--- a/iree/compiler/Dialect/VM/Target/BUILD
+++ b/iree/compiler/Dialect/VM/Target/BUILD
@@ -23,6 +23,16 @@
)
cc_library(
+ name = "ConstantEncodingUtils",
+ srcs = ["ConstantEncodingUtils.cpp"],
+ hdrs = ["ConstantEncodingUtils.h"],
+ deps = [
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:Support",
+ ],
+)
+
+cc_library(
name = "init_targets",
hdrs = ["init_targets.h"],
deps = [
diff --git a/iree/compiler/Dialect/VM/Target/Bytecode/BUILD b/iree/compiler/Dialect/VM/Target/Bytecode/BUILD
index 410601d..5a43410 100644
--- a/iree/compiler/Dialect/VM/Target/Bytecode/BUILD
+++ b/iree/compiler/Dialect/VM/Target/Bytecode/BUILD
@@ -25,6 +25,7 @@
"//iree/compiler/Dialect/VM/Analysis",
"//iree/compiler/Dialect/VM/IR",
"//iree/compiler/Dialect/VM/Target:CallingConventionUtils",
+ "//iree/compiler/Dialect/VM/Target:ConstantEncodingUtils",
"//iree/compiler/Dialect/VM/Transforms",
"//iree/compiler/Utils",
"//iree/schemas:bytecode_module_def_c_fbs",
diff --git a/iree/compiler/Dialect/VM/Target/Bytecode/CMakeLists.txt b/iree/compiler/Dialect/VM/Target/Bytecode/CMakeLists.txt
index 416d443..9fe5439 100644
--- a/iree/compiler/Dialect/VM/Target/Bytecode/CMakeLists.txt
+++ b/iree/compiler/Dialect/VM/Target/Bytecode/CMakeLists.txt
@@ -36,6 +36,7 @@
iree::compiler::Dialect::VM::Analysis
iree::compiler::Dialect::VM::IR
iree::compiler::Dialect::VM::Target::CallingConventionUtils
+ iree::compiler::Dialect::VM::Target::ConstantEncodingUtils
iree::compiler::Dialect::VM::Transforms
iree::compiler::Utils
iree::schemas::bytecode_module_def_c_fbs
diff --git a/iree/compiler/Dialect/VM/Target/Bytecode/ConstantEncoder.cpp b/iree/compiler/Dialect/VM/Target/Bytecode/ConstantEncoder.cpp
index 0bd5618..7e4031d 100644
--- a/iree/compiler/Dialect/VM/Target/Bytecode/ConstantEncoder.cpp
+++ b/iree/compiler/Dialect/VM/Target/Bytecode/ConstantEncoder.cpp
@@ -6,146 +6,25 @@
#include "iree/compiler/Dialect/VM/Target/Bytecode/ConstantEncoder.h"
+#include "iree/compiler/Dialect/VM/Target/ConstantEncodingUtils.h"
#include "llvm/Support/CRC.h"
-#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/Diagnostics.h"
namespace mlir {
namespace iree_compiler {
namespace IREE {
namespace VM {
-// TODO(benvanik): switch to LLVM's BinaryStreamWriter to handle endianness.
-
-static void serializeConstantI8Array(DenseIntElementsAttr attr,
- size_t alignment, FlatbufferBuilder &fbb) {
- // vm.rodata and other very large constants end up as this; since i8 is i8
- // everywhere (endianness doesn't matter when you have one byte :) we can
- // directly access the data and memcpy.
- int64_t totalSize = attr.getNumElements() * sizeof(int8_t);
- uint8_t *bytePtr = flatbuffers_uint8_vec_extend(fbb, totalSize);
- if (attr.isSplat()) {
- // NOTE: this is a slow path and we should have eliminated it earlier on
- // during constant op conversion.
- for (const APInt &value : attr.getIntValues()) {
- *(bytePtr++) = value.extractBitsAsZExtValue(8, 0) & UINT8_MAX;
- }
- } else {
- auto rawData = attr.getRawData();
- std::memcpy(bytePtr, rawData.data(), rawData.size());
- }
-}
-
-static void serializeConstantI16Array(DenseIntElementsAttr attr,
- size_t alignment,
- FlatbufferBuilder &fbb) {
- int64_t totalSize = attr.getNumElements() * sizeof(int16_t);
- uint8_t *bytePtr = flatbuffers_uint8_vec_extend(fbb, totalSize);
- uint16_t *nativePtr = reinterpret_cast<uint16_t *>(bytePtr);
- for (const APInt &value : attr.getIntValues()) {
- *(nativePtr++) = value.extractBitsAsZExtValue(16, 0) & UINT16_MAX;
- }
-}
-
-static void serializeConstantI32Array(DenseIntElementsAttr attr,
- size_t alignment,
- FlatbufferBuilder &fbb) {
- int64_t totalSize = attr.getNumElements() * sizeof(int32_t);
- uint8_t *bytePtr = flatbuffers_uint8_vec_extend(fbb, totalSize);
- uint32_t *nativePtr = reinterpret_cast<uint32_t *>(bytePtr);
- for (const APInt &value : attr.getIntValues()) {
- *(nativePtr++) = value.extractBitsAsZExtValue(32, 0) & UINT32_MAX;
- }
-}
-
-static void serializeConstantI64Array(DenseIntElementsAttr attr,
- size_t alignment,
- FlatbufferBuilder &fbb) {
- int64_t totalSize = attr.getNumElements() * sizeof(int64_t);
- uint8_t *bytePtr = flatbuffers_uint8_vec_extend(fbb, totalSize);
- uint64_t *nativePtr = reinterpret_cast<uint64_t *>(bytePtr);
- for (const APInt &value : attr.getIntValues()) {
- *(nativePtr++) = value.extractBitsAsZExtValue(64, 0) & UINT64_MAX;
- }
-}
-
-static void serializeConstantF16Array(DenseFPElementsAttr attr,
- size_t alignment,
- FlatbufferBuilder &fbb) {
- int64_t totalSize = attr.getNumElements() * sizeof(uint16_t);
- uint8_t *bytePtr = flatbuffers_uint8_vec_extend(fbb, totalSize);
- uint16_t *nativePtr = reinterpret_cast<uint16_t *>(bytePtr);
- for (const APFloat &value : attr.getFloatValues()) {
- *(nativePtr++) =
- value.bitcastToAPInt().extractBitsAsZExtValue(16, 0) & UINT16_MAX;
- }
-}
-
-static void serializeConstantF32Array(DenseFPElementsAttr attr,
- size_t alignment,
- FlatbufferBuilder &fbb) {
- int64_t totalSize = attr.getNumElements() * sizeof(float);
- uint8_t *bytePtr = flatbuffers_uint8_vec_extend(fbb, totalSize);
- float *nativePtr = reinterpret_cast<float *>(bytePtr);
- for (const APFloat &value : attr.getFloatValues()) {
- *(nativePtr++) = value.convertToFloat();
- }
-}
-
-static void serializeConstantF64Array(DenseFPElementsAttr attr,
- size_t alignment,
- FlatbufferBuilder &fbb) {
- int64_t totalSize = attr.getNumElements() * sizeof(double);
- uint8_t *bytePtr = flatbuffers_uint8_vec_extend(fbb, totalSize);
- double *nativePtr = reinterpret_cast<double *>(bytePtr);
- for (const APFloat &value : attr.getFloatValues()) {
- *(nativePtr++) = value.convertToDouble();
- }
-}
-
SerializedConstantRef serializeConstant(Location loc, ElementsAttr elementsAttr,
size_t alignment, bool calculateCRC32,
FlatbufferBuilder &fbb) {
flatcc_builder_start_vector(fbb, 1, alignment, FLATBUFFERS_COUNT_MAX(1));
- if (auto attr = elementsAttr.dyn_cast<DenseIntElementsAttr>()) {
- switch (attr.getType().getElementTypeBitWidth()) {
- case 8:
- serializeConstantI8Array(attr, alignment, fbb);
- break;
- case 16:
- serializeConstantI16Array(attr, alignment, fbb);
- break;
- case 32:
- serializeConstantI32Array(attr, alignment, fbb);
- break;
- case 64:
- serializeConstantI64Array(attr, alignment, fbb);
- break;
- default:
- emitError(loc) << "unhandled element bitwidth "
- << attr.getType().getElementTypeBitWidth();
- return {};
- }
- } else if (auto attr = elementsAttr.dyn_cast<DenseFPElementsAttr>()) {
- switch (attr.getType().getElementTypeBitWidth()) {
- case 16:
- serializeConstantF16Array(attr, alignment, fbb);
- break;
- case 32:
- serializeConstantF32Array(attr, alignment, fbb);
- break;
- case 64:
- serializeConstantF64Array(attr, alignment, fbb);
- break;
- default:
- emitError(loc) << "unhandled element bitwidth "
- << attr.getType().getElementTypeBitWidth();
- return {};
- }
- } else {
- emitError(loc) << "unimplemented attribute encoding: "
- << elementsAttr.getType();
+
+ int32_t bitwidth = elementsAttr.getType().getElementTypeBitWidth();
+ int64_t size = elementsAttr.getNumElements() * (bitwidth / 8);
+ uint8_t *bytePtr = flatbuffers_uint8_vec_extend(fbb, size);
+
+ if (failed(serializeConstantArray(loc, elementsAttr, alignment, bytePtr))) {
return {};
}
diff --git a/iree/compiler/Dialect/VM/Target/C/CMakeLists.txt b/iree/compiler/Dialect/VM/Target/C/CMakeLists.txt
index b9cabcd..6d560f2 100644
--- a/iree/compiler/Dialect/VM/Target/C/CMakeLists.txt
+++ b/iree/compiler/Dialect/VM/Target/C/CMakeLists.txt
@@ -28,6 +28,7 @@
iree::compiler::Dialect::VM::IR
iree::compiler::Dialect::VM::Conversion::VMToEmitC
iree::compiler::Dialect::VM::Target::CallingConventionUtils
+ iree::compiler::Dialect::VM::Target::ConstantEncodingUtils
INCLUDES
"${PROJECT_SOURCE_DIR}/third_party/mlir-emitc/include"
PUBLIC
diff --git a/iree/compiler/Dialect/VM/Target/C/CModuleTarget.cpp b/iree/compiler/Dialect/VM/Target/C/CModuleTarget.cpp
index fa3552d..490ce00 100644
--- a/iree/compiler/Dialect/VM/Target/C/CModuleTarget.cpp
+++ b/iree/compiler/Dialect/VM/Target/C/CModuleTarget.cpp
@@ -12,6 +12,7 @@
#include "iree/compiler/Dialect/VM/Analysis/RegisterAllocation.h"
#include "iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.h"
#include "iree/compiler/Dialect/VM/Target/CallingConventionUtils.h"
+#include "iree/compiler/Dialect/VM/Target/ConstantEncodingUtils.h"
#include "iree/compiler/Dialect/VM/Transforms/Passes.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
@@ -48,7 +49,46 @@
"//"
<< std::string(77, '=') << "\n";
}
+static LogicalResult printRodataBuffers(IREE::VM::ModuleOp &moduleOp,
+ mlir::emitc::CppEmitter &emitter) {
+ llvm::raw_ostream &output = emitter.ostream();
+ std::string moduleName = moduleOp.getName().str();
+ for (auto rodataOp : moduleOp.getOps<IREE::VM::RodataOp>()) {
+ ElementsAttr value = rodataOp.value();
+ auto bitwidth = value.getType().getElementTypeBitWidth();
+ size_t size = value.getNumElements() * (bitwidth / 8);
+ SmallVector<uint8_t, 32> byteBuffer;
+ byteBuffer.resize(size);
+
+ constexpr size_t kDefaultRodataAlignment = 16;
+
+ size_t alignment =
+ rodataOp.alignment()
+ ? static_cast<size_t>(rodataOp.alignment().getValue())
+ : 0;
+ if (alignment == 0) alignment = kDefaultRodataAlignment;
+
+ if (failed(serializeConstantArray(rodataOp.getLoc(), value, alignment,
+ byteBuffer.data()))) {
+ return rodataOp.emitError() << "error during serialization";
+ }
+
+ std::string buffer_name =
+ moduleOp.getName().str() + "_" + rodataOp.getName().str();
+
+ output << "iree_alignas(" << alignment << ") static const uint8_t "
+ << buffer_name << "[] = {";
+ llvm::interleaveComma(byteBuffer, output, [&](uint8_t value) {
+ output << static_cast<unsigned int>(value);
+ });
+ output << "};\n";
+ }
+
+ output << "\n";
+
+ return success();
+}
static LogicalResult printStructDefinitions(IREE::VM::ModuleOp &moduleOp,
mlir::emitc::CppEmitter &emitter) {
llvm::raw_ostream &output = emitter.ostream();
@@ -62,6 +102,8 @@
<< moduleOp.ordinal_counts().getValue().global_bytes() << "];\n";
output << "iree_vm_ref_t refs["
<< moduleOp.ordinal_counts().getValue().global_refs() << "];\n";
+ output << "iree_vm_buffer_t rodata_buffers["
+ << moduleOp.ordinal_counts().getValue().rodatas() << "];\n";
output << "};\n";
output << "typedef struct " << moduleName << "_t " << moduleName << "_t;\n";
@@ -113,8 +155,8 @@
});
}
-static LogicalResult initializeGlobals(IREE::VM::ModuleOp moduleOp,
- mlir::emitc::CppEmitter &emitter) {
+static LogicalResult initializeState(IREE::VM::ModuleOp moduleOp,
+ mlir::emitc::CppEmitter &emitter) {
llvm::raw_ostream &output = emitter.ostream();
for (auto globalOp : moduleOp.getOps<IREE::VM::GlobalI32Op>()) {
@@ -135,8 +177,19 @@
<< "Initializers for globals not supported yet";
}
}
+ // TODO(simon-camp): Support globals with different element type
- // TODO(simon-camp): Support vm.global.i64 and vm.global.ref
+ for (auto rodataOp : moduleOp.getOps<IREE::VM::RodataOp>()) {
+ std::string buffer_name =
+ moduleOp.getName().str() + "_" + rodataOp.getName().str();
+ output << "iree_vm_buffer_initialize("
+ << "IREE_VM_BUFFER_ACCESS_ORIGIN_MODULE, "
+ << "iree_make_byte_span("
+ << "(void*)" << buffer_name << ", sizeof(" << buffer_name << ")), "
+ << "iree_allocator_null(), "
+ << "&state->rodata_buffers[" << rodataOp.ordinal() << "]"
+ << ");\n";
+ }
return success();
}
@@ -560,8 +613,8 @@
<< "state->allocator = allocator;\n";
// initialize globals
- if (failed(initializeGlobals(moduleOp, emitter))) {
- return moduleOp.emitError() << "Failed to emit global initialization";
+ if (failed(initializeState(moduleOp, emitter))) {
+ return moduleOp.emitError() << "Failed to emit state members";
}
output << "*out_module_state = (iree_vm_module_state_t*)state;\n"
@@ -701,6 +754,10 @@
/*forwardDeclareVariables=*/true);
mlir::emitc::CppEmitter::Scope scope(emitter);
+ if (failed(printRodataBuffers(moduleOp, emitter))) {
+ return failure();
+ }
+
// build struct definitions
if (failed(printStructDefinitions(moduleOp, emitter))) {
return failure();
diff --git a/iree/compiler/Dialect/VM/Target/C/test/constant_ops.mlir b/iree/compiler/Dialect/VM/Target/C/test/constant_ops.mlir
new file mode 100644
index 0000000..b6e6f66
--- /dev/null
+++ b/iree/compiler/Dialect/VM/Target/C/test/constant_ops.mlir
@@ -0,0 +1,24 @@
+// RUN: iree-translate -iree-vm-ir-to-c-module -iree-vm-c-module-optimize=false %s | IreeFileCheck %s
+
+vm.module @constant_ops {
+ // Check the generated arrays
+
+ // CHECK: iree_alignas(16) static const uint8_t constant_ops_buffer_1[] = {1, 2, 3};
+ // CHECK-NEXT: iree_alignas(16) static const uint8_t constant_ops_buffer_2[] = {1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0};
+
+ // Check the generated state struct
+ // CHECK-LABEL: struct constant_ops_state_t {
+ // CHECK-NEXT: iree_allocator_t allocator;
+ // CHECK-NEXT: uint8_t rwdata[0];
+ // CHECK-NEXT: iree_vm_ref_t refs[0];
+ // CHECK-NEXT: iree_vm_buffer_t rodata_buffers[2];
+ // CHECK-NEXT: };
+
+ vm.rodata private @buffer_1 dense<[1, 2, 3]> : tensor<3xi8>
+ vm.rodata private @buffer_2 dense<[1, 2, 3]> : tensor<3xi32>
+
+ // check state initialization inside the alloc_state function
+ // CHECK-LABEL: static iree_status_t constant_ops_alloc_state(
+ // CHECK: iree_vm_buffer_initialize(IREE_VM_BUFFER_ACCESS_ORIGIN_MODULE, iree_make_byte_span((void*)constant_ops_buffer_1, sizeof(constant_ops_buffer_1)), iree_allocator_null(), &state->rodata_buffers[0]);
+ // CHECK-NEXT: iree_vm_buffer_initialize(IREE_VM_BUFFER_ACCESS_ORIGIN_MODULE, iree_make_byte_span((void*)constant_ops_buffer_2, sizeof(constant_ops_buffer_2)), iree_allocator_null(), &state->rodata_buffers[1]);
+}
diff --git a/iree/compiler/Dialect/VM/Target/C/test/global_ops.mlir b/iree/compiler/Dialect/VM/Target/C/test/global_ops.mlir
index d568b12..029624c 100644
--- a/iree/compiler/Dialect/VM/Target/C/test/global_ops.mlir
+++ b/iree/compiler/Dialect/VM/Target/C/test/global_ops.mlir
@@ -6,6 +6,7 @@
// CHECK-NEXT: iree_allocator_t allocator;
// CHECK-NEXT: uint8_t rwdata[8];
// CHECK-NEXT: iree_vm_ref_t refs[0];
+ // CHECK-NEXT: iree_vm_buffer_t rodata_buffers[0];
// CHECK-NEXT: };
vm.global.i32 @c42 42 : i32
diff --git a/iree/compiler/Dialect/VM/Target/CMakeLists.txt b/iree/compiler/Dialect/VM/Target/CMakeLists.txt
index 3f4b4c2..7f5fa2e 100644
--- a/iree/compiler/Dialect/VM/Target/CMakeLists.txt
+++ b/iree/compiler/Dialect/VM/Target/CMakeLists.txt
@@ -29,6 +29,19 @@
iree_cc_library(
NAME
+ ConstantEncodingUtils
+ HDRS
+ "ConstantEncodingUtils.h"
+ SRCS
+ "ConstantEncodingUtils.cpp"
+ DEPS
+ MLIRIR
+ MLIRSupport
+ PUBLIC
+)
+
+iree_cc_library(
+ NAME
init_targets
HDRS
"init_targets.h"
diff --git a/iree/compiler/Dialect/VM/Target/ConstantEncodingUtils.cpp b/iree/compiler/Dialect/VM/Target/ConstantEncodingUtils.cpp
new file mode 100644
index 0000000..cbea13d
--- /dev/null
+++ b/iree/compiler/Dialect/VM/Target/ConstantEncodingUtils.cpp
@@ -0,0 +1,132 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/VM/Target/ConstantEncodingUtils.h"
+
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Location.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace VM {
+
+// TODO(benvanik): switch to LLVM's BinaryStreamWriter to handle endianness.
+
+static void serializeConstantI8Array(DenseIntElementsAttr attr,
+ size_t alignment, uint8_t *bytePtr) {
+ // vm.rodata and other very large constants end up as this; since i8 is i8
+ // everywhere (endianness doesn't matter when you have one byte :) we can
+ // directly access the data and memcpy.
+ if (attr.isSplat()) {
+ // NOTE: this is a slow path and we should have eliminated it earlier on
+ // during constant op conversion.
+ for (const APInt &value : attr.getIntValues()) {
+ *(bytePtr++) = value.extractBitsAsZExtValue(8, 0) & UINT8_MAX;
+ }
+ } else {
+ auto rawData = attr.getRawData();
+ std::memcpy(bytePtr, rawData.data(), rawData.size());
+ }
+}
+
+static void serializeConstantI16Array(DenseIntElementsAttr attr,
+ size_t alignment, uint8_t *bytePtr) {
+ uint16_t *nativePtr = reinterpret_cast<uint16_t *>(bytePtr);
+ for (const APInt &value : attr.getIntValues()) {
+ *(nativePtr++) = value.extractBitsAsZExtValue(16, 0) & UINT16_MAX;
+ }
+}
+
+static void serializeConstantI32Array(DenseIntElementsAttr attr,
+ size_t alignment, uint8_t *bytePtr) {
+ uint32_t *nativePtr = reinterpret_cast<uint32_t *>(bytePtr);
+ for (const APInt &value : attr.getIntValues()) {
+ *(nativePtr++) = value.extractBitsAsZExtValue(32, 0) & UINT32_MAX;
+ }
+}
+
+static void serializeConstantI64Array(DenseIntElementsAttr attr,
+ size_t alignment, uint8_t *bytePtr) {
+ uint64_t *nativePtr = reinterpret_cast<uint64_t *>(bytePtr);
+ for (const APInt &value : attr.getIntValues()) {
+ *(nativePtr++) = value.extractBitsAsZExtValue(64, 0) & UINT64_MAX;
+ }
+}
+
+static void serializeConstantF16Array(DenseFPElementsAttr attr,
+ size_t alignment, uint8_t *bytePtr) {
+ uint16_t *nativePtr = reinterpret_cast<uint16_t *>(bytePtr);
+ for (const APFloat &value : attr.getFloatValues()) {
+ *(nativePtr++) =
+ value.bitcastToAPInt().extractBitsAsZExtValue(16, 0) & UINT16_MAX;
+ }
+}
+
+static void serializeConstantF32Array(DenseFPElementsAttr attr,
+ size_t alignment, uint8_t *bytePtr) {
+ float *nativePtr = reinterpret_cast<float *>(bytePtr);
+ for (const APFloat &value : attr.getFloatValues()) {
+ *(nativePtr++) = value.convertToFloat();
+ }
+}
+
+static void serializeConstantF64Array(DenseFPElementsAttr attr,
+ size_t alignment, uint8_t *bytePtr) {
+ double *nativePtr = reinterpret_cast<double *>(bytePtr);
+ for (const APFloat &value : attr.getFloatValues()) {
+ *(nativePtr++) = value.convertToDouble();
+ }
+}
+
+LogicalResult serializeConstantArray(Location loc, ElementsAttr elementsAttr,
+ size_t alignment, uint8_t *dst) {
+ auto bitwidth = elementsAttr.getType().getElementTypeBitWidth();
+
+ if (auto attr = elementsAttr.dyn_cast<DenseIntElementsAttr>()) {
+ switch (bitwidth) {
+ case 8:
+ serializeConstantI8Array(attr, alignment, dst);
+ break;
+ case 16:
+ serializeConstantI16Array(attr, alignment, dst);
+ break;
+ case 32:
+ serializeConstantI32Array(attr, alignment, dst);
+ break;
+ case 64:
+ serializeConstantI64Array(attr, alignment, dst);
+ break;
+ default:
+ return emitError(loc) << "unhandled element bitwidth " << bitwidth;
+ }
+ } else if (auto attr = elementsAttr.dyn_cast<DenseFPElementsAttr>()) {
+ switch (bitwidth) {
+ case 16:
+ serializeConstantF16Array(attr, alignment, dst);
+ break;
+ case 32:
+ serializeConstantF32Array(attr, alignment, dst);
+ break;
+ case 64:
+ serializeConstantF64Array(attr, alignment, dst);
+ break;
+ default:
+ return emitError(loc) << "unhandled element bitwidth " << bitwidth;
+ }
+ } else {
+ return emitError(loc) << "unimplemented attribute encoding: "
+ << elementsAttr.getType();
+ }
+
+ return success();
+}
+
+} // namespace VM
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Dialect/VM/Target/ConstantEncodingUtils.h b/iree/compiler/Dialect/VM/Target/ConstantEncodingUtils.h
new file mode 100644
index 0000000..f6c8a69
--- /dev/null
+++ b/iree/compiler/Dialect/VM/Target/ConstantEncodingUtils.h
@@ -0,0 +1,27 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_DIALECT_VM_TARGET_CONSTANTENCODINGUTILS_H_
+#define IREE_COMPILER_DIALECT_VM_TARGET_CONSTANTENCODINGUTILS_H_
+
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Location.h"
+#include "mlir/Support/LogicalResult.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace VM {
+
+LogicalResult serializeConstantArray(Location loc, ElementsAttr elementsAttr,
+ size_t alignment, uint8_t *dst);
+
+} // namespace VM
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_DIALECT_VM_TARGET_CONSTANTENCODINGUTILS_H_
diff --git a/iree/compiler/Translation/IREEVM.cpp b/iree/compiler/Translation/IREEVM.cpp
index 4289fff..ef86d73 100644
--- a/iree/compiler/Translation/IREEVM.cpp
+++ b/iree/compiler/Translation/IREEVM.cpp
@@ -263,10 +263,11 @@
// Exposed via the --iree-mlir-to-vm-c-module translation.
static LogicalResult translateFromMLIRToVMCModule(
ModuleOp moduleOp, BindingOptions bindingOptions,
+ InputDialectOptions inputOptions,
IREE::HAL::TargetOptions executableOptions,
IREE::VM::TargetOptions targetOptions,
IREE::VM::CTargetOptions cTargetOptions, llvm::raw_ostream &output) {
- auto result = translateFromMLIRToVM(moduleOp, bindingOptions,
+ auto result = translateFromMLIRToVM(moduleOp, bindingOptions, inputOptions,
executableOptions, targetOptions);
if (failed(result)) {
return result;
diff --git a/iree/hal/local/BUILD b/iree/hal/local/BUILD
index b587223..5ebd6ac 100644
--- a/iree/hal/local/BUILD
+++ b/iree/hal/local/BUILD
@@ -13,25 +13,6 @@
licenses = ["notice"], # Apache 2.0
)
-# TODO(benvanik): replace iree/base/arena.h with this one. We still want the
-# old-style arena for pure stack use; we may be able to do that with a change
-# to block pool that allows for on-stack initialization (iree_stack_arena_t
-# that has storage for one block inside itself and then dynamically allocates
-# new ones if needed). That way we have only one arena implementation and can
-# easily use the iree_allocator_t interface without worry.
-cc_library(
- name = "arena",
- srcs = ["arena.c"],
- hdrs = ["arena.h"],
- deps = [
- "//iree/base",
- "//iree/base:core_headers",
- "//iree/base:tracing",
- "//iree/base/internal:atomic_slist",
- "//iree/base/internal:synchronization",
- ],
-)
-
# TODO(benvanik): move into base/? may be useful for other backends or for other
# parts of the system (like modules handling IO/RPC).
cc_library(
@@ -127,12 +108,12 @@
"sync_semaphore.h",
],
deps = [
- ":arena",
":local",
"//iree/base",
"//iree/base:core_headers",
"//iree/base:tracing",
"//iree/base/internal",
+ "//iree/base/internal:arena",
"//iree/base/internal:synchronization",
"//iree/hal",
],
@@ -159,7 +140,6 @@
"task_semaphore.h",
],
deps = [
- ":arena",
":event_pool",
":executable_library",
":local",
@@ -167,6 +147,7 @@
"//iree/base:core_headers",
"//iree/base:tracing",
"//iree/base/internal",
+ "//iree/base/internal:arena",
"//iree/base/internal:synchronization",
"//iree/base/internal:wait_handle",
"//iree/hal",
diff --git a/iree/hal/local/CMakeLists.txt b/iree/hal/local/CMakeLists.txt
index 98313c1..a73aab3 100644
--- a/iree/hal/local/CMakeLists.txt
+++ b/iree/hal/local/CMakeLists.txt
@@ -12,22 +12,6 @@
iree_cc_library(
NAME
- arena
- HDRS
- "arena.h"
- SRCS
- "arena.c"
- DEPS
- iree::base
- iree::base::core_headers
- iree::base::internal::atomic_slist
- iree::base::internal::synchronization
- iree::base::tracing
- PUBLIC
-)
-
-iree_cc_library(
- NAME
event_pool
HDRS
"event_pool.h"
@@ -123,11 +107,11 @@
"sync_event.c"
"sync_semaphore.c"
DEPS
- ::arena
::local
iree::base
iree::base::core_headers
iree::base::internal
+ iree::base::internal::arena
iree::base::internal::synchronization
iree::base::tracing
iree::hal
@@ -154,13 +138,13 @@
"task_queue_state.c"
"task_semaphore.c"
DEPS
- ::arena
::event_pool
::executable_library
::local
iree::base
iree::base::core_headers
iree::base::internal
+ iree::base::internal::arena
iree::base::internal::synchronization
iree::base::internal::wait_handle
iree::base::tracing
diff --git a/iree/hal/local/task_command_buffer.c b/iree/hal/local/task_command_buffer.c
index 8d38c33..b6312c1 100644
--- a/iree/hal/local/task_command_buffer.c
+++ b/iree/hal/local/task_command_buffer.c
@@ -36,8 +36,8 @@
// and manager of the lifetime of the tasks.
typedef struct iree_hal_task_command_buffer_t {
iree_hal_resource_t resource;
+ iree_allocator_t host_allocator;
- iree_hal_device_t* device;
iree_task_scope_t* scope;
iree_hal_command_buffer_mode_t mode;
iree_hal_command_category_t allowed_categories;
@@ -105,13 +105,11 @@
}
iree_status_t iree_hal_task_command_buffer_create(
- iree_hal_device_t* device, iree_task_scope_t* scope,
- iree_hal_command_buffer_mode_t mode,
+ iree_task_scope_t* scope, iree_hal_command_buffer_mode_t mode,
iree_hal_command_category_t command_categories,
iree_hal_queue_affinity_t queue_affinity,
- iree_arena_block_pool_t* block_pool,
+ iree_arena_block_pool_t* block_pool, iree_allocator_t host_allocator,
iree_hal_command_buffer_t** out_command_buffer) {
- IREE_ASSERT_ARGUMENT(device);
IREE_ASSERT_ARGUMENT(out_command_buffer);
*out_command_buffer = NULL;
if (!iree_all_bits_set(mode, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT)) {
@@ -131,13 +129,12 @@
IREE_TRACE_ZONE_BEGIN(z0);
iree_hal_task_command_buffer_t* command_buffer = NULL;
- iree_status_t status =
- iree_allocator_malloc(iree_hal_device_host_allocator(device),
- sizeof(*command_buffer), (void**)&command_buffer);
+ iree_status_t status = iree_allocator_malloc(
+ host_allocator, sizeof(*command_buffer), (void**)&command_buffer);
if (iree_status_is_ok(status)) {
iree_hal_resource_initialize(&iree_hal_task_command_buffer_vtable,
&command_buffer->resource);
- command_buffer->device = device;
+ command_buffer->host_allocator = host_allocator;
command_buffer->scope = scope;
command_buffer->mode = mode;
command_buffer->allowed_categories = command_categories;
@@ -165,8 +162,7 @@
iree_hal_command_buffer_t* base_command_buffer) {
iree_hal_task_command_buffer_t* command_buffer =
iree_hal_task_command_buffer_cast(base_command_buffer);
- iree_allocator_t host_allocator =
- iree_hal_device_host_allocator(command_buffer->device);
+ iree_allocator_t host_allocator = command_buffer->host_allocator;
IREE_TRACE_ZONE_BEGIN(z0);
iree_hal_task_command_buffer_reset(command_buffer);
diff --git a/iree/hal/local/task_command_buffer.h b/iree/hal/local/task_command_buffer.h
index 835ba3a..d3b1a4a 100644
--- a/iree/hal/local/task_command_buffer.h
+++ b/iree/hal/local/task_command_buffer.h
@@ -8,8 +8,8 @@
#define IREE_HAL_LOCAL_TASK_COMMAND_BUFFER_H_
#include "iree/base/api.h"
+#include "iree/base/internal/arena.h"
#include "iree/hal/api.h"
-#include "iree/hal/local/arena.h"
#include "iree/hal/local/task_queue_state.h"
#include "iree/task/scope.h"
#include "iree/task/task.h"
@@ -19,11 +19,10 @@
#endif // __cplusplus
iree_status_t iree_hal_task_command_buffer_create(
- iree_hal_device_t* device, iree_task_scope_t* scope,
- iree_hal_command_buffer_mode_t mode,
+ iree_task_scope_t* scope, iree_hal_command_buffer_mode_t mode,
iree_hal_command_category_t command_categories,
iree_hal_queue_affinity_t queue_affinity,
- iree_arena_block_pool_t* block_pool,
+ iree_arena_block_pool_t* block_pool, iree_allocator_t host_allocator,
iree_hal_command_buffer_t** out_command_buffer);
// Issues a recorded command buffer using the serial |queue_state|.
diff --git a/iree/hal/local/task_device.c b/iree/hal/local/task_device.c
index eb1026b..e11da39 100644
--- a/iree/hal/local/task_device.c
+++ b/iree/hal/local/task_device.c
@@ -10,8 +10,8 @@
#include <stdint.h>
#include <string.h>
+#include "iree/base/internal/arena.h"
#include "iree/base/tracing.h"
-#include "iree/hal/local/arena.h"
#include "iree/hal/local/event_pool.h"
#include "iree/hal/local/local_descriptor_set.h"
#include "iree/hal/local/local_descriptor_set_layout.h"
@@ -225,8 +225,9 @@
iree_host_size_t queue_index = iree_hal_task_device_select_queue(
device, command_categories, queue_affinity);
return iree_hal_task_command_buffer_create(
- base_device, &device->queues[queue_index].scope, mode, command_categories,
- queue_affinity, &device->large_block_pool, out_command_buffer);
+ &device->queues[queue_index].scope, mode, command_categories,
+ queue_affinity, &device->large_block_pool, device->host_allocator,
+ out_command_buffer);
}
static iree_status_t iree_hal_task_device_create_descriptor_set(
diff --git a/iree/hal/local/task_queue.h b/iree/hal/local/task_queue.h
index 77d7f69..7a60191 100644
--- a/iree/hal/local/task_queue.h
+++ b/iree/hal/local/task_queue.h
@@ -10,9 +10,9 @@
#include <stdint.h>
#include "iree/base/api.h"
+#include "iree/base/internal/arena.h"
#include "iree/base/internal/synchronization.h"
#include "iree/hal/api.h"
-#include "iree/hal/local/arena.h"
#include "iree/hal/local/task_queue_state.h"
#include "iree/task/executor.h"
#include "iree/task/scope.h"
diff --git a/iree/hal/local/task_semaphore.h b/iree/hal/local/task_semaphore.h
index f6daab2..452e5f0 100644
--- a/iree/hal/local/task_semaphore.h
+++ b/iree/hal/local/task_semaphore.h
@@ -10,8 +10,8 @@
#include <stdint.h>
#include "iree/base/api.h"
+#include "iree/base/internal/arena.h"
#include "iree/hal/api.h"
-#include "iree/hal/local/arena.h"
#include "iree/hal/local/event_pool.h"
#include "iree/task/submission.h"
#include "iree/task/task.h"
diff --git a/iree/hal/utils/BUILD b/iree/hal/utils/BUILD
new file mode 100644
index 0000000..234d6d3
--- /dev/null
+++ b/iree/hal/utils/BUILD
@@ -0,0 +1,24 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "deferred_command_buffer",
+ srcs = ["deferred_command_buffer.c"],
+ hdrs = ["deferred_command_buffer.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//iree/base",
+ "//iree/base:tracing",
+ "//iree/base/internal:arena",
+ "//iree/hal",
+ ],
+)
diff --git a/iree/hal/utils/CMakeLists.txt b/iree/hal/utils/CMakeLists.txt
new file mode 100644
index 0000000..6709717
--- /dev/null
+++ b/iree/hal/utils/CMakeLists.txt
@@ -0,0 +1,28 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# iree/hal/utils/BUILD #
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_cc_library(
+ NAME
+ deferred_command_buffer
+ HDRS
+ "deferred_command_buffer.h"
+ SRCS
+ "deferred_command_buffer.c"
+ DEPS
+ iree::base
+ iree::base::internal::arena
+ iree::base::tracing
+ iree::hal
+ PUBLIC
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/iree/hal/utils/deferred_command_buffer.c b/iree/hal/utils/deferred_command_buffer.c
new file mode 100644
index 0000000..d568c1b
--- /dev/null
+++ b/iree/hal/utils/deferred_command_buffer.c
@@ -0,0 +1,827 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/utils/deferred_command_buffer.h"
+
+#include "iree/base/internal/arena.h"
+#include "iree/base/tracing.h"
+
+//===----------------------------------------------------------------------===//
+// Command recording structures
+//===----------------------------------------------------------------------===//
+
+typedef enum iree_hal_command_type_e {
+ IREE_HAL_CMD_EXECUTION_BARRIER = 0,
+ IREE_HAL_CMD_SIGNAL_EVENT,
+ IREE_HAL_CMD_RESET_EVENT,
+ IREE_HAL_CMD_WAIT_EVENTS,
+ IREE_HAL_CMD_DISCARD_BUFFER,
+ IREE_HAL_CMD_FILL_BUFFER,
+ IREE_HAL_CMD_UPDATE_BUFFER,
+ IREE_HAL_CMD_COPY_BUFFER,
+ IREE_HAL_CMD_PUSH_CONSTANTS,
+ IREE_HAL_CMD_PUSH_DESCRIPTOR_SET,
+ IREE_HAL_CMD_BIND_DESCRIPTOR_SET,
+ IREE_HAL_CMD_DISPATCH,
+ IREE_HAL_CMD_DISPATCH_INDIRECT,
+} iree_hal_cmd_type_t;
+
+// Header prefixed to all commands, forming a linked-list.
+//
+// Each command is allocated from the arena and does *not* retain any resources.
+// We could elide some of these commands by keeping local state however that
+// requires knowing more about the target device (executable layouts, etc) and
+// prevents using this as a way to debug or benchmark command buffers. The
+// intent is that each command captures the exact information passed during the
+// call such that the target command buffer cannot tell they were deferred.
+//
+// As each command is variable sized we store pointers to the following command
+// to allow us to walk the list during replay. Storing just a size would be
+// insufficient as commands may be spread across many arena blocks from the
+// block pool.
+typedef struct iree_hal_cmd_header_t {
+ // Next command in the list or NULL if the end.
+ struct iree_hal_cmd_header_t* next;
+ // Type of the command that follows.
+ iree_hal_cmd_type_t type;
+} iree_hal_cmd_header_t;
+
+typedef iree_status_t (*iree_hal_cmd_apply_fn_t)(
+ iree_hal_command_buffer_t* target_command_buffer,
+ iree_hal_cmd_header_t* cmd_header);
+
+//===----------------------------------------------------------------------===//
+// Command list allocation and storage
+//===----------------------------------------------------------------------===//
+
+// A singly-linked list of commands allocated from an arena.
+typedef struct iree_hal_cmd_list_t {
+ // Arena used to hold the recorded commands using block_pool for storage.
+ // Will be reset as the command buffer is re-recorded.
+ iree_arena_allocator_t arena;
+
+ // Head of the command list.
+ iree_hal_cmd_header_t* head;
+ // Tail of the command list (may be head).
+ iree_hal_cmd_header_t* tail;
+} iree_hal_cmd_list_t;
+
+// Initializes a new command list that allocates from the given |block_pool|.
+// Upon return the command list is ready for recording.
+static void iree_hal_cmd_list_initialize(iree_arena_block_pool_t* block_pool,
+ iree_hal_cmd_list_t* out_cmd_list) {
+ iree_arena_initialize(block_pool, &out_cmd_list->arena);
+ out_cmd_list->head = NULL;
+ out_cmd_list->tail = NULL;
+}
+
+// Resets the command list and returns all arena blocks back to the block pool.
+// Upon return the command list is ready for recording.
+static void iree_hal_cmd_list_reset(iree_hal_cmd_list_t* cmd_list) {
+ // We could make reset retain a single block so as we know that we'll be
+ // adding more commands on this path and it would remove a round-trip through
+ // the pool.
+ iree_arena_reset(&cmd_list->arena);
+ cmd_list->head = NULL;
+ cmd_list->tail = NULL;
+}
+
+// Deinitializes the command list, preparing for destruction.
+static void iree_hal_cmd_list_deinitialize(iree_hal_cmd_list_t* cmd_list) {
+ iree_hal_cmd_list_reset(cmd_list);
+}
+
+// Appends a new command to the command list and returns the base pointer to its
+// storage. Callers must cast to the appropriate type and populate all fields.
+static iree_status_t iree_hal_cmd_list_append_command(
+ iree_hal_cmd_list_t* cmd_list, iree_hal_cmd_type_t command_type,
+ iree_host_size_t command_size, void** out_cmd) {
+ iree_hal_cmd_header_t* header = NULL;
+ IREE_RETURN_IF_ERROR(
+ iree_arena_allocate(&cmd_list->arena, command_size, (void**)&header));
+ header->next = NULL;
+ header->type = command_type;
+ if (!cmd_list->head) {
+ cmd_list->head = header;
+ } else if (cmd_list->tail) {
+ cmd_list->tail->next = header;
+ }
+ cmd_list->tail = header;
+ *out_cmd = header;
+ return iree_ok_status();
+}
+
+// Clones a source buffer and returns the pointer into the arena.
+static iree_status_t iree_hal_cmd_list_clone_data(iree_hal_cmd_list_t* cmd_list,
+ const void* source_data,
+ iree_host_size_t data_length,
+ void** out_target_data) {
+ void* target_data = NULL;
+ IREE_RETURN_IF_ERROR(
+ iree_arena_allocate(&cmd_list->arena, data_length, &target_data));
+ memcpy(target_data, source_data, data_length);
+ *out_target_data = target_data;
+ return iree_ok_status();
+}
+
+//===----------------------------------------------------------------------===//
+// iree_hal_deferred_command_buffer_t implementation
+//===----------------------------------------------------------------------===//
+
+typedef struct iree_hal_deferred_command_buffer_t {
+ iree_hal_resource_t resource;
+ iree_allocator_t host_allocator;
+ iree_hal_command_buffer_mode_t mode;
+ iree_hal_command_category_t allowed_categories;
+ iree_hal_cmd_list_t cmd_list;
+} iree_hal_deferred_command_buffer_t;
+
+static const iree_hal_command_buffer_vtable_t
+ iree_hal_deferred_command_buffer_vtable;
+
+static iree_hal_deferred_command_buffer_t*
+iree_hal_deferred_command_buffer_cast(iree_hal_command_buffer_t* base_value) {
+ IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_deferred_command_buffer_vtable);
+ return (iree_hal_deferred_command_buffer_t*)base_value;
+}
+
+IREE_API_EXPORT iree_status_t iree_hal_deferred_command_buffer_create(
+ iree_hal_command_buffer_mode_t mode,
+ iree_hal_command_category_t command_categories,
+ iree_arena_block_pool_t* block_pool, iree_allocator_t host_allocator,
+ iree_hal_command_buffer_t** out_command_buffer) {
+ IREE_ASSERT_ARGUMENT(block_pool);
+ IREE_ASSERT_ARGUMENT(out_command_buffer);
+ *out_command_buffer = NULL;
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ iree_hal_deferred_command_buffer_t* command_buffer = NULL;
+ iree_status_t status = iree_allocator_malloc(
+ host_allocator, sizeof(*command_buffer), (void**)&command_buffer);
+ if (iree_status_is_ok(status)) {
+ iree_hal_resource_initialize(&iree_hal_deferred_command_buffer_vtable,
+ &command_buffer->resource);
+ command_buffer->host_allocator = host_allocator;
+ command_buffer->mode = mode;
+ command_buffer->allowed_categories = command_categories;
+ iree_hal_cmd_list_initialize(block_pool, &command_buffer->cmd_list);
+ }
+
+ *out_command_buffer = (iree_hal_command_buffer_t*)command_buffer;
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+static void iree_hal_deferred_command_buffer_destroy(
+ iree_hal_command_buffer_t* base_command_buffer) {
+ iree_hal_deferred_command_buffer_t* command_buffer =
+ iree_hal_deferred_command_buffer_cast(base_command_buffer);
+ iree_allocator_t host_allocator = command_buffer->host_allocator;
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ iree_hal_cmd_list_deinitialize(&command_buffer->cmd_list);
+ iree_allocator_free(host_allocator, command_buffer);
+
+ IREE_TRACE_ZONE_END(z0);
+}
+
+static iree_hal_command_category_t
+iree_hal_deferred_command_buffer_allowed_categories(
+ const iree_hal_command_buffer_t* base_command_buffer) {
+ iree_hal_deferred_command_buffer_t* command_buffer =
+ iree_hal_deferred_command_buffer_cast(
+ (iree_hal_command_buffer_t*)base_command_buffer);
+ return command_buffer->allowed_categories;
+}
+
+static iree_status_t iree_hal_deferred_command_buffer_begin(
+ iree_hal_command_buffer_t* base_command_buffer) {
+ iree_hal_deferred_command_buffer_t* command_buffer =
+ iree_hal_deferred_command_buffer_cast(base_command_buffer);
+ iree_hal_cmd_list_reset(&command_buffer->cmd_list);
+ return iree_ok_status();
+}
+
+static iree_status_t iree_hal_deferred_command_buffer_end(
+ iree_hal_command_buffer_t* base_command_buffer) {
+ return iree_ok_status();
+}
+
+//===----------------------------------------------------------------------===//
+// IREE_HAL_CMD_EXECUTION_BARRIER
+//===----------------------------------------------------------------------===//
+
+typedef struct iree_hal_cmd_execution_barrier_t {
+ iree_hal_cmd_header_t header;
+ iree_hal_execution_stage_t source_stage_mask;
+ iree_hal_execution_stage_t target_stage_mask;
+ iree_hal_execution_barrier_flags_t flags;
+ iree_host_size_t memory_barrier_count;
+ const iree_hal_memory_barrier_t* memory_barriers;
+ iree_host_size_t buffer_barrier_count;
+ const iree_hal_buffer_barrier_t* buffer_barriers;
+} iree_hal_cmd_execution_barrier_t;
+
+static iree_status_t iree_hal_deferred_command_buffer_execution_barrier(
+ iree_hal_command_buffer_t* base_command_buffer,
+ iree_hal_execution_stage_t source_stage_mask,
+ iree_hal_execution_stage_t target_stage_mask,
+ iree_hal_execution_barrier_flags_t flags,
+ iree_host_size_t memory_barrier_count,
+ const iree_hal_memory_barrier_t* memory_barriers,
+ iree_host_size_t buffer_barrier_count,
+ const iree_hal_buffer_barrier_t* buffer_barriers) {
+ iree_hal_cmd_list_t* cmd_list =
+ &iree_hal_deferred_command_buffer_cast(base_command_buffer)->cmd_list;
+ iree_hal_cmd_execution_barrier_t* cmd = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_cmd_list_append_command(
+ cmd_list, IREE_HAL_CMD_EXECUTION_BARRIER, sizeof(*cmd), (void**)&cmd));
+ cmd->source_stage_mask = source_stage_mask;
+ cmd->target_stage_mask = target_stage_mask;
+ cmd->flags = flags;
+ cmd->memory_barrier_count = memory_barrier_count;
+ cmd->memory_barriers = NULL;
+ cmd->buffer_barrier_count = buffer_barrier_count;
+ cmd->buffer_barriers = NULL;
+ if (memory_barrier_count > 0) {
+ IREE_RETURN_IF_ERROR(iree_hal_cmd_list_clone_data(
+ cmd_list, memory_barriers,
+ sizeof(memory_barriers[0]) * memory_barrier_count,
+ (void**)&cmd->memory_barriers));
+ }
+ if (buffer_barrier_count > 0) {
+ IREE_RETURN_IF_ERROR(iree_hal_cmd_list_clone_data(
+ cmd_list, buffer_barriers,
+ sizeof(buffer_barriers[0]) * buffer_barrier_count,
+ (void**)&cmd->buffer_barriers));
+ }
+ return iree_ok_status();
+}
+
+static iree_status_t iree_hal_deferred_command_buffer_apply_execution_barrier(
+ iree_hal_command_buffer_t* target_command_buffer,
+ const iree_hal_cmd_execution_barrier_t* cmd) {
+ return iree_hal_command_buffer_execution_barrier(
+ target_command_buffer, cmd->source_stage_mask, cmd->target_stage_mask,
+ cmd->flags, cmd->memory_barrier_count, cmd->memory_barriers,
+ cmd->buffer_barrier_count, cmd->buffer_barriers);
+}
+
+//===----------------------------------------------------------------------===//
+// IREE_HAL_CMD_SIGNAL_EVENT
+//===----------------------------------------------------------------------===//
+
+typedef struct iree_hal_cmd_signal_event_t {
+ iree_hal_cmd_header_t header;
+ iree_hal_event_t* event;
+ iree_hal_execution_stage_t source_stage_mask;
+} iree_hal_cmd_signal_event_t;
+
+static iree_status_t iree_hal_deferred_command_buffer_signal_event(
+ iree_hal_command_buffer_t* base_command_buffer, iree_hal_event_t* event,
+ iree_hal_execution_stage_t source_stage_mask) {
+ iree_hal_cmd_list_t* cmd_list =
+ &iree_hal_deferred_command_buffer_cast(base_command_buffer)->cmd_list;
+ iree_hal_cmd_signal_event_t* cmd = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_cmd_list_append_command(
+ cmd_list, IREE_HAL_CMD_SIGNAL_EVENT, sizeof(*cmd), (void**)&cmd));
+ cmd->event = event;
+ cmd->source_stage_mask = source_stage_mask;
+ return iree_ok_status();
+}
+
+static iree_status_t iree_hal_deferred_command_buffer_apply_signal_event(
+ iree_hal_command_buffer_t* target_command_buffer,
+ const iree_hal_cmd_signal_event_t* cmd) {
+ return iree_hal_command_buffer_signal_event(target_command_buffer, cmd->event,
+ cmd->source_stage_mask);
+}
+
+//===----------------------------------------------------------------------===//
+// IREE_HAL_CMD_RESET_EVENT
+//===----------------------------------------------------------------------===//
+
+typedef struct iree_hal_cmd_reset_event_t {
+ iree_hal_cmd_header_t header;
+ iree_hal_event_t* event;
+ iree_hal_execution_stage_t source_stage_mask;
+} iree_hal_cmd_reset_event_t;
+
+static iree_status_t iree_hal_deferred_command_buffer_reset_event(
+ iree_hal_command_buffer_t* base_command_buffer, iree_hal_event_t* event,
+ iree_hal_execution_stage_t source_stage_mask) {
+ iree_hal_cmd_list_t* cmd_list =
+ &iree_hal_deferred_command_buffer_cast(base_command_buffer)->cmd_list;
+ iree_hal_cmd_reset_event_t* cmd = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_cmd_list_append_command(
+ cmd_list, IREE_HAL_CMD_RESET_EVENT, sizeof(*cmd), (void**)&cmd));
+ cmd->event = event;
+ cmd->source_stage_mask = source_stage_mask;
+ return iree_ok_status();
+}
+
+static iree_status_t iree_hal_deferred_command_buffer_apply_reset_event(
+ iree_hal_command_buffer_t* target_command_buffer,
+ const iree_hal_cmd_reset_event_t* cmd) {
+ return iree_hal_command_buffer_reset_event(target_command_buffer, cmd->event,
+ cmd->source_stage_mask);
+}
+
+//===----------------------------------------------------------------------===//
+// IREE_HAL_CMD_WAIT_EVENTS
+//===----------------------------------------------------------------------===//
+
+typedef struct iree_hal_cmd_wait_events_t {
+ iree_hal_cmd_header_t header;
+ iree_host_size_t event_count;
+ iree_hal_execution_stage_t source_stage_mask;
+ iree_hal_execution_stage_t target_stage_mask;
+ iree_host_size_t memory_barrier_count;
+ const iree_hal_memory_barrier_t* memory_barriers;
+ iree_host_size_t buffer_barrier_count;
+ const iree_hal_buffer_barrier_t* buffer_barriers;
+ iree_hal_event_t* events[];
+} iree_hal_cmd_wait_events_t;
+
+static iree_status_t iree_hal_deferred_command_buffer_wait_events(
+ iree_hal_command_buffer_t* base_command_buffer,
+ iree_host_size_t event_count, const iree_hal_event_t** events,
+ iree_hal_execution_stage_t source_stage_mask,
+ iree_hal_execution_stage_t target_stage_mask,
+ iree_host_size_t memory_barrier_count,
+ const iree_hal_memory_barrier_t* memory_barriers,
+ iree_host_size_t buffer_barrier_count,
+ const iree_hal_buffer_barrier_t* buffer_barriers) {
+ iree_hal_cmd_list_t* cmd_list =
+ &iree_hal_deferred_command_buffer_cast(base_command_buffer)->cmd_list;
+ iree_hal_cmd_wait_events_t* cmd = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_cmd_list_append_command(
+ cmd_list, IREE_HAL_CMD_WAIT_EVENTS,
+ sizeof(*cmd) + sizeof(cmd->events[0]) * event_count, (void**)&cmd));
+ cmd->event_count = event_count;
+ cmd->source_stage_mask = source_stage_mask;
+ cmd->target_stage_mask = target_stage_mask;
+ cmd->memory_barrier_count = memory_barrier_count;
+ cmd->memory_barriers = NULL;
+ cmd->buffer_barrier_count = buffer_barrier_count;
+ cmd->buffer_barriers = NULL;
+ memcpy(cmd->events, events, sizeof(cmd->events[0]) * event_count);
+ if (memory_barrier_count > 0) {
+ IREE_RETURN_IF_ERROR(iree_hal_cmd_list_clone_data(
+ cmd_list, memory_barriers,
+ sizeof(memory_barriers[0]) * memory_barrier_count,
+ (void**)&cmd->memory_barriers));
+ }
+ if (buffer_barrier_count > 0) {
+ IREE_RETURN_IF_ERROR(iree_hal_cmd_list_clone_data(
+ cmd_list, buffer_barriers,
+ sizeof(buffer_barriers[0]) * buffer_barrier_count,
+ (void**)&cmd->buffer_barriers));
+ }
+ return iree_ok_status();
+}
+
+static iree_status_t iree_hal_deferred_command_buffer_apply_wait_events(
+ iree_hal_command_buffer_t* target_command_buffer,
+ const iree_hal_cmd_wait_events_t* cmd) {
+ return iree_hal_command_buffer_wait_events(
+ target_command_buffer, cmd->event_count,
+ (const iree_hal_event_t**)cmd->events, cmd->source_stage_mask,
+ cmd->target_stage_mask, cmd->memory_barrier_count, cmd->memory_barriers,
+ cmd->buffer_barrier_count, cmd->buffer_barriers);
+}
+
+//===----------------------------------------------------------------------===//
+// IREE_HAL_CMD_DISCARD_BUFFER
+//===----------------------------------------------------------------------===//
+
+typedef struct iree_hal_cmd_discard_buffer_t {
+ iree_hal_cmd_header_t header;
+ iree_hal_buffer_t* buffer;
+} iree_hal_cmd_discard_buffer_t;
+
+static iree_status_t iree_hal_deferred_command_buffer_discard_buffer(
+ iree_hal_command_buffer_t* base_command_buffer, iree_hal_buffer_t* buffer) {
+ iree_hal_cmd_list_t* cmd_list =
+ &iree_hal_deferred_command_buffer_cast(base_command_buffer)->cmd_list;
+ iree_hal_cmd_discard_buffer_t* cmd = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_cmd_list_append_command(
+ cmd_list, IREE_HAL_CMD_DISCARD_BUFFER, sizeof(*cmd), (void**)&cmd));
+ cmd->buffer = buffer;
+ return iree_ok_status();
+}
+
+static iree_status_t iree_hal_deferred_command_buffer_apply_discard_buffer(
+ iree_hal_command_buffer_t* target_command_buffer,
+ const iree_hal_cmd_discard_buffer_t* cmd) {
+ return iree_hal_command_buffer_discard_buffer(target_command_buffer,
+ cmd->buffer);
+}
+
+//===----------------------------------------------------------------------===//
+// IREE_HAL_CMD_FILL_BUFFER
+//===----------------------------------------------------------------------===//
+
+typedef struct iree_hal_cmd_fill_buffer_t {
+ iree_hal_cmd_header_t header;
+ iree_hal_buffer_t* target_buffer;
+ iree_device_size_t target_offset;
+ iree_device_size_t length;
+ uint64_t pattern;
+ iree_host_size_t pattern_length;
+} iree_hal_cmd_fill_buffer_t;
+
+static iree_status_t iree_hal_deferred_command_buffer_fill_buffer(
+ iree_hal_command_buffer_t* base_command_buffer,
+ iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset,
+ iree_device_size_t length, const void* pattern,
+ iree_host_size_t pattern_length) {
+ iree_hal_cmd_list_t* cmd_list =
+ &iree_hal_deferred_command_buffer_cast(base_command_buffer)->cmd_list;
+ iree_hal_cmd_fill_buffer_t* cmd = NULL;
+ if (pattern_length > sizeof(cmd->pattern)) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "fill patterns must be < 8 bytes");
+ }
+ IREE_RETURN_IF_ERROR(iree_hal_cmd_list_append_command(
+ cmd_list, IREE_HAL_CMD_FILL_BUFFER, sizeof(*cmd), (void**)&cmd));
+ cmd->target_buffer = target_buffer;
+ cmd->target_offset = target_offset;
+ cmd->length = length;
+ memcpy(&cmd->pattern, pattern, pattern_length);
+ cmd->pattern_length = pattern_length;
+ return iree_ok_status();
+}
+
+static iree_status_t iree_hal_deferred_command_buffer_apply_fill_buffer(
+ iree_hal_command_buffer_t* target_command_buffer,
+ const iree_hal_cmd_fill_buffer_t* cmd) {
+ return iree_hal_command_buffer_fill_buffer(
+ target_command_buffer, cmd->target_buffer, cmd->target_offset,
+ cmd->length, (void**)&cmd->pattern, cmd->pattern_length);
+}
+
+//===----------------------------------------------------------------------===//
+// IREE_HAL_CMD_UPDATE_BUFFER
+//===----------------------------------------------------------------------===//
+
+typedef struct iree_hal_cmd_update_buffer_t {
+ iree_hal_cmd_header_t header;
+ iree_hal_buffer_t* target_buffer;
+ iree_device_size_t target_offset;
+ iree_device_size_t length;
+ uint8_t source_buffer[];
+} iree_hal_cmd_update_buffer_t;
+
+static iree_status_t iree_hal_deferred_command_buffer_update_buffer(
+ iree_hal_command_buffer_t* base_command_buffer, const void* source_buffer,
+ iree_host_size_t source_offset, iree_hal_buffer_t* target_buffer,
+ iree_device_size_t target_offset, iree_device_size_t length) {
+ iree_hal_cmd_list_t* cmd_list =
+ &iree_hal_deferred_command_buffer_cast(base_command_buffer)->cmd_list;
+ iree_hal_cmd_update_buffer_t* cmd = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_cmd_list_append_command(
+ cmd_list, IREE_HAL_CMD_UPDATE_BUFFER,
+ sizeof(*cmd) + sizeof(cmd->source_buffer[0]) * length, (void**)&cmd));
+ cmd->target_buffer = target_buffer;
+ cmd->target_offset = target_offset;
+ cmd->length = length;
+ memcpy(cmd->source_buffer, (const uint8_t*)source_buffer + source_offset,
+ sizeof(cmd->source_buffer[0]) * length);
+ return iree_ok_status();
+}
+
+static iree_status_t iree_hal_deferred_command_buffer_apply_update_buffer(
+ iree_hal_command_buffer_t* target_command_buffer,
+ const iree_hal_cmd_update_buffer_t* cmd) {
+ return iree_hal_command_buffer_update_buffer(
+ target_command_buffer, cmd->source_buffer, 0, cmd->target_buffer,
+ cmd->target_offset, cmd->length);
+}
+
+//===----------------------------------------------------------------------===//
+// IREE_HAL_CMD_COPY_BUFFER
+//===----------------------------------------------------------------------===//
+
+typedef struct iree_hal_cmd_copy_buffer_t {
+ iree_hal_cmd_header_t header;
+ iree_hal_buffer_t* source_buffer;
+ iree_device_size_t source_offset;
+ iree_hal_buffer_t* target_buffer;
+ iree_device_size_t target_offset;
+ iree_device_size_t length;
+} iree_hal_cmd_copy_buffer_t;
+
+static iree_status_t iree_hal_deferred_command_buffer_copy_buffer(
+ iree_hal_command_buffer_t* base_command_buffer,
+ iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset,
+ iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset,
+ iree_device_size_t length) {
+ iree_hal_cmd_list_t* cmd_list =
+ &iree_hal_deferred_command_buffer_cast(base_command_buffer)->cmd_list;
+ iree_hal_cmd_copy_buffer_t* cmd = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_cmd_list_append_command(
+ cmd_list, IREE_HAL_CMD_COPY_BUFFER, sizeof(*cmd), (void**)&cmd));
+ cmd->source_buffer = source_buffer;
+ cmd->source_offset = source_offset;
+ cmd->target_buffer = target_buffer;
+ cmd->target_offset = target_offset;
+ cmd->length = length;
+ return iree_ok_status();
+}
+
+static iree_status_t iree_hal_deferred_command_buffer_apply_copy_buffer(
+ iree_hal_command_buffer_t* target_command_buffer,
+ const iree_hal_cmd_copy_buffer_t* cmd) {
+ return iree_hal_command_buffer_copy_buffer(
+ target_command_buffer, cmd->source_buffer, cmd->source_offset,
+ cmd->target_buffer, cmd->target_offset, cmd->length);
+}
+
+//===----------------------------------------------------------------------===//
+// IREE_HAL_CMD_PUSH_CONSTANTS
+//===----------------------------------------------------------------------===//
+
+typedef struct iree_hal_cmd_push_constants_t {
+ iree_hal_cmd_header_t header;
+ iree_hal_executable_layout_t* executable_layout;
+ iree_host_size_t offset;
+ iree_host_size_t values_length;
+ uint8_t values[];
+} iree_hal_cmd_push_constants_t;
+
+static iree_status_t iree_hal_deferred_command_buffer_push_constants(
+ iree_hal_command_buffer_t* base_command_buffer,
+ iree_hal_executable_layout_t* executable_layout, iree_host_size_t offset,
+ const void* values, iree_host_size_t values_length) {
+ iree_hal_cmd_list_t* cmd_list =
+ &iree_hal_deferred_command_buffer_cast(base_command_buffer)->cmd_list;
+ iree_hal_cmd_push_constants_t* cmd = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_cmd_list_append_command(
+ cmd_list, IREE_HAL_CMD_PUSH_CONSTANTS,
+ sizeof(*cmd) + sizeof(cmd->values[0]) * values_length, (void**)&cmd));
+ cmd->executable_layout = executable_layout;
+ cmd->offset = offset;
+ cmd->values_length = values_length;
+ memcpy(cmd->values, values, sizeof(cmd->values[0]) * values_length);
+ return iree_ok_status();
+}
+
+static iree_status_t iree_hal_deferred_command_buffer_apply_push_constants(
+ iree_hal_command_buffer_t* target_command_buffer,
+ const iree_hal_cmd_push_constants_t* cmd) {
+ return iree_hal_command_buffer_push_constants(
+ target_command_buffer, cmd->executable_layout, cmd->offset, cmd->values,
+ cmd->values_length);
+}
+
+//===----------------------------------------------------------------------===//
+// IREE_HAL_CMD_PUSH_DESCRIPTOR_SET
+//===----------------------------------------------------------------------===//
+
+typedef struct iree_hal_cmd_push_descriptor_set_t {
+ iree_hal_cmd_header_t header;
+ iree_hal_executable_layout_t* executable_layout;
+ uint32_t set;
+ iree_host_size_t binding_count;
+ iree_hal_descriptor_set_binding_t bindings[];
+} iree_hal_cmd_push_descriptor_set_t;
+
+static iree_status_t iree_hal_deferred_command_buffer_push_descriptor_set(
+ iree_hal_command_buffer_t* base_command_buffer,
+ iree_hal_executable_layout_t* executable_layout, uint32_t set,
+ iree_host_size_t binding_count,
+ const iree_hal_descriptor_set_binding_t* bindings) {
+ iree_hal_cmd_list_t* cmd_list =
+ &iree_hal_deferred_command_buffer_cast(base_command_buffer)->cmd_list;
+ iree_hal_cmd_push_descriptor_set_t* cmd = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_cmd_list_append_command(
+ cmd_list, IREE_HAL_CMD_PUSH_DESCRIPTOR_SET,
+ sizeof(*cmd) + sizeof(cmd->bindings[0]) * binding_count, (void**)&cmd));
+ cmd->executable_layout = executable_layout;
+ cmd->set = set;
+ cmd->binding_count = binding_count;
+ memcpy(cmd->bindings, bindings, sizeof(cmd->bindings[0]) * binding_count);
+ return iree_ok_status();
+}
+
+static iree_status_t iree_hal_deferred_command_buffer_apply_push_descriptor_set(
+ iree_hal_command_buffer_t* target_command_buffer,
+ const iree_hal_cmd_push_descriptor_set_t* cmd) {
+ return iree_hal_command_buffer_push_descriptor_set(
+ target_command_buffer, cmd->executable_layout, cmd->set,
+ cmd->binding_count, cmd->bindings);
+}
+
+//===----------------------------------------------------------------------===//
+// IREE_HAL_CMD_BIND_DESCRIPTOR_SET
+//===----------------------------------------------------------------------===//
+
+typedef struct iree_hal_cmd_bind_descriptor_set_t {
+ iree_hal_cmd_header_t header;
+ iree_hal_executable_layout_t* executable_layout;
+ uint32_t set;
+ iree_hal_descriptor_set_t* descriptor_set;
+ iree_host_size_t dynamic_offset_count;
+ iree_device_size_t dynamic_offsets[];
+} iree_hal_cmd_bind_descriptor_set_t;
+
+static iree_status_t iree_hal_deferred_command_buffer_bind_descriptor_set(
+ iree_hal_command_buffer_t* base_command_buffer,
+ iree_hal_executable_layout_t* executable_layout, uint32_t set,
+ iree_hal_descriptor_set_t* descriptor_set,
+ iree_host_size_t dynamic_offset_count,
+ const iree_device_size_t* dynamic_offsets) {
+ iree_hal_cmd_list_t* cmd_list =
+ &iree_hal_deferred_command_buffer_cast(base_command_buffer)->cmd_list;
+ iree_hal_cmd_bind_descriptor_set_t* cmd = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_cmd_list_append_command(
+ cmd_list, IREE_HAL_CMD_BIND_DESCRIPTOR_SET,
+ sizeof(*cmd) + sizeof(cmd->dynamic_offsets[0]) * dynamic_offset_count,
+ (void**)&cmd));
+ cmd->executable_layout = executable_layout;
+ cmd->set = set;
+ cmd->descriptor_set = descriptor_set;
+ cmd->dynamic_offset_count = dynamic_offset_count;
+ memcpy(cmd->dynamic_offsets, dynamic_offsets,
+ sizeof(cmd->dynamic_offsets[0]) * dynamic_offset_count);
+ return iree_ok_status();
+}
+
+static iree_status_t iree_hal_deferred_command_buffer_apply_bind_descriptor_set(
+ iree_hal_command_buffer_t* target_command_buffer,
+ const iree_hal_cmd_bind_descriptor_set_t* cmd) {
+ return iree_hal_command_buffer_bind_descriptor_set(
+ target_command_buffer, cmd->executable_layout, cmd->set,
+ cmd->descriptor_set, cmd->dynamic_offset_count, cmd->dynamic_offsets);
+}
+
+//===----------------------------------------------------------------------===//
+// IREE_HAL_CMD_DISPATCH
+//===----------------------------------------------------------------------===//
+
+typedef struct iree_hal_cmd_dispatch_t {
+ iree_hal_cmd_header_t header;
+ iree_hal_executable_t* executable;
+ int32_t entry_point;
+ uint32_t workgroup_x;
+ uint32_t workgroup_y;
+ uint32_t workgroup_z;
+} iree_hal_cmd_dispatch_t;
+
+static iree_status_t iree_hal_deferred_command_buffer_dispatch(
+ iree_hal_command_buffer_t* base_command_buffer,
+ iree_hal_executable_t* executable, int32_t entry_point,
+ uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) {
+ iree_hal_cmd_list_t* cmd_list =
+ &iree_hal_deferred_command_buffer_cast(base_command_buffer)->cmd_list;
+ iree_hal_cmd_dispatch_t* cmd = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_cmd_list_append_command(
+ cmd_list, IREE_HAL_CMD_DISPATCH, sizeof(*cmd), (void**)&cmd));
+ cmd->executable = executable;
+ cmd->entry_point = entry_point;
+ cmd->workgroup_x = workgroup_x;
+ cmd->workgroup_y = workgroup_y;
+ cmd->workgroup_z = workgroup_z;
+ return iree_ok_status();
+}
+
+static iree_status_t iree_hal_deferred_command_buffer_apply_dispatch(
+ iree_hal_command_buffer_t* target_command_buffer,
+ const iree_hal_cmd_dispatch_t* cmd) {
+ return iree_hal_command_buffer_dispatch(
+ target_command_buffer, cmd->executable, cmd->entry_point,
+ cmd->workgroup_x, cmd->workgroup_y, cmd->workgroup_z);
+}
+
+//===----------------------------------------------------------------------===//
+// IREE_HAL_CMD_DISPATCH_INDIRECT
+//===----------------------------------------------------------------------===//
+
+typedef struct iree_hal_cmd_dispatch_indirect_t {
+ iree_hal_cmd_header_t header;
+ iree_hal_executable_t* executable;
+ int32_t entry_point;
+ iree_hal_buffer_t* workgroups_buffer;
+ iree_device_size_t workgroups_offset;
+} iree_hal_cmd_dispatch_indirect_t;
+
+static iree_status_t iree_hal_deferred_command_buffer_dispatch_indirect(
+ iree_hal_command_buffer_t* base_command_buffer,
+ iree_hal_executable_t* executable, int32_t entry_point,
+ iree_hal_buffer_t* workgroups_buffer,
+ iree_device_size_t workgroups_offset) {
+ iree_hal_cmd_list_t* cmd_list =
+ &iree_hal_deferred_command_buffer_cast(base_command_buffer)->cmd_list;
+ iree_hal_cmd_dispatch_indirect_t* cmd = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_cmd_list_append_command(
+ cmd_list, IREE_HAL_CMD_DISPATCH_INDIRECT, sizeof(*cmd), (void**)&cmd));
+ cmd->executable = executable;
+ cmd->entry_point = entry_point;
+ cmd->workgroups_buffer = workgroups_buffer;
+ cmd->workgroups_offset = workgroups_offset;
+ return iree_ok_status();
+}
+
+static iree_status_t iree_hal_deferred_command_buffer_apply_dispatch_indirect(
+ iree_hal_command_buffer_t* target_command_buffer,
+ const iree_hal_cmd_dispatch_indirect_t* cmd) {
+ return iree_hal_command_buffer_dispatch_indirect(
+ target_command_buffer, cmd->executable, cmd->entry_point,
+ cmd->workgroups_buffer, cmd->workgroups_offset);
+}
+
+//===----------------------------------------------------------------------===//
+// Dynamic replay dispatch
+//===----------------------------------------------------------------------===//
+
+static const iree_hal_cmd_apply_fn_t iree_hal_cmd_apply_table[] = {
+ [IREE_HAL_CMD_EXECUTION_BARRIER] = (iree_hal_cmd_apply_fn_t)
+ iree_hal_deferred_command_buffer_apply_execution_barrier,
+ [IREE_HAL_CMD_SIGNAL_EVENT] = (iree_hal_cmd_apply_fn_t)
+ iree_hal_deferred_command_buffer_apply_signal_event,
+ [IREE_HAL_CMD_RESET_EVENT] = (iree_hal_cmd_apply_fn_t)
+ iree_hal_deferred_command_buffer_apply_reset_event,
+ [IREE_HAL_CMD_WAIT_EVENTS] = (iree_hal_cmd_apply_fn_t)
+ iree_hal_deferred_command_buffer_apply_wait_events,
+ [IREE_HAL_CMD_DISCARD_BUFFER] = (iree_hal_cmd_apply_fn_t)
+ iree_hal_deferred_command_buffer_apply_discard_buffer,
+ [IREE_HAL_CMD_FILL_BUFFER] = (iree_hal_cmd_apply_fn_t)
+ iree_hal_deferred_command_buffer_apply_fill_buffer,
+ [IREE_HAL_CMD_UPDATE_BUFFER] = (iree_hal_cmd_apply_fn_t)
+ iree_hal_deferred_command_buffer_apply_update_buffer,
+ [IREE_HAL_CMD_COPY_BUFFER] = (iree_hal_cmd_apply_fn_t)
+ iree_hal_deferred_command_buffer_apply_copy_buffer,
+ [IREE_HAL_CMD_PUSH_CONSTANTS] = (iree_hal_cmd_apply_fn_t)
+ iree_hal_deferred_command_buffer_apply_push_constants,
+ [IREE_HAL_CMD_PUSH_DESCRIPTOR_SET] = (iree_hal_cmd_apply_fn_t)
+ iree_hal_deferred_command_buffer_apply_push_descriptor_set,
+ [IREE_HAL_CMD_BIND_DESCRIPTOR_SET] = (iree_hal_cmd_apply_fn_t)
+ iree_hal_deferred_command_buffer_apply_bind_descriptor_set,
+ [IREE_HAL_CMD_DISPATCH] = (iree_hal_cmd_apply_fn_t)
+ iree_hal_deferred_command_buffer_apply_dispatch,
+ [IREE_HAL_CMD_DISPATCH_INDIRECT] = (iree_hal_cmd_apply_fn_t)
+ iree_hal_deferred_command_buffer_apply_dispatch_indirect,
+};
+
+IREE_API_EXPORT iree_status_t iree_hal_deferred_command_buffer_apply(
+ iree_hal_command_buffer_t* base_command_buffer,
+ iree_hal_command_buffer_t* target_command_buffer) {
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ iree_hal_deferred_command_buffer_t* command_buffer =
+ iree_hal_deferred_command_buffer_cast(base_command_buffer);
+ iree_hal_cmd_list_t* cmd_list = &command_buffer->cmd_list;
+
+ iree_status_t status = iree_hal_command_buffer_begin(target_command_buffer);
+ if (iree_status_is_ok(status)) {
+ for (iree_hal_cmd_header_t* cmd = cmd_list->head; cmd != NULL;
+ cmd = cmd->next) {
+ status = iree_hal_cmd_apply_table[cmd->type](target_command_buffer, cmd);
+ if (!iree_status_is_ok(status)) break;
+ }
+ }
+ if (iree_status_is_ok(status)) {
+ status = iree_hal_command_buffer_end(target_command_buffer);
+ }
+
+ // One-shot command buffers can't be replayed so we can drop the memory
+ // immediately. As command buffers must remain live for the duration of their
+ // execution this prevents us from hanging on to the commands we will never
+ // use again.
+ if (iree_status_is_ok(status) &&
+ iree_all_bits_set(command_buffer->mode,
+ IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT)) {
+ iree_hal_cmd_list_reset(cmd_list);
+ }
+
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+static const iree_hal_command_buffer_vtable_t
+ iree_hal_deferred_command_buffer_vtable = {
+ .destroy = iree_hal_deferred_command_buffer_destroy,
+ .allowed_categories =
+ iree_hal_deferred_command_buffer_allowed_categories,
+ .begin = iree_hal_deferred_command_buffer_begin,
+ .end = iree_hal_deferred_command_buffer_end,
+ .execution_barrier = iree_hal_deferred_command_buffer_execution_barrier,
+ .signal_event = iree_hal_deferred_command_buffer_signal_event,
+ .reset_event = iree_hal_deferred_command_buffer_reset_event,
+ .wait_events = iree_hal_deferred_command_buffer_wait_events,
+ .discard_buffer = iree_hal_deferred_command_buffer_discard_buffer,
+ .fill_buffer = iree_hal_deferred_command_buffer_fill_buffer,
+ .update_buffer = iree_hal_deferred_command_buffer_update_buffer,
+ .copy_buffer = iree_hal_deferred_command_buffer_copy_buffer,
+ .push_constants = iree_hal_deferred_command_buffer_push_constants,
+ .push_descriptor_set =
+ iree_hal_deferred_command_buffer_push_descriptor_set,
+ .bind_descriptor_set =
+ iree_hal_deferred_command_buffer_bind_descriptor_set,
+ .dispatch = iree_hal_deferred_command_buffer_dispatch,
+ .dispatch_indirect = iree_hal_deferred_command_buffer_dispatch_indirect,
+};
diff --git a/iree/hal/utils/deferred_command_buffer.h b/iree/hal/utils/deferred_command_buffer.h
new file mode 100644
index 0000000..ee104e4
--- /dev/null
+++ b/iree/hal/utils/deferred_command_buffer.h
@@ -0,0 +1,62 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_UTILS_DEFERRED_COMMAND_BUFFER_H_
+#define IREE_HAL_UTILS_DEFERRED_COMMAND_BUFFER_H_
+
+#include "iree/base/api.h"
+#include "iree/hal/command_buffer.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+typedef struct iree_arena_block_pool_t iree_arena_block_pool_t;
+
+//===----------------------------------------------------------------------===//
+// iree_hal_command_buffer_t deferred record/replay wrapper
+//===----------------------------------------------------------------------===//
+
+// Records an in-memory command buffer that can be replayed against a target
+// command buffer at a later time.
+//
+// Argument arrays (like push constants) and host buffers (like the source
+// buffer in iree_hal_command_buffer_update_buffer) that usually live on the
+// stack will be cloned. As with all command buffers the resources (buffers,
+// events, etc) referenced will not be retained and the caller must ensure that
+// all resource lifetimes outlive the command buffer.
+//
+// |block_pool| will be used to allocate the underlying storage and the blocks
+// will be retained until the command buffer is reset or released, or if
+// IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT is set after the first time the command
+// buffer is replayed. The block size of the pool can be whatever the caller
+// wants with the caveat being that smaller sizes may result in more oversized
+// allocations from the system. 16KB, 32KB, and 64KB are reasonable starting
+// points based on system availability.
+// NOTE: the |block_pool| must remain live for the lifetime of the command
+// buffers that use it.
+//
+// After recording iree_hal_deferred_command_buffer_apply can be used to replay
+// the sequence of commands against a target command buffer implementation.
+// The command buffer can be replayed multiple times.
+IREE_API_EXPORT iree_status_t iree_hal_deferred_command_buffer_create(
+ iree_hal_command_buffer_mode_t mode,
+ iree_hal_command_category_t command_categories,
+ iree_arena_block_pool_t* block_pool, iree_allocator_t host_allocator,
+ iree_hal_command_buffer_t** out_command_buffer);
+
+// Replays a recorded |command_buffer| against a |target_command_buffer|.
+// If the command buffer was recorded in one-shot mode it will be reset upon
+// return.
+IREE_API_EXPORT iree_status_t iree_hal_deferred_command_buffer_apply(
+ iree_hal_command_buffer_t* command_buffer,
+ iree_hal_command_buffer_t* target_command_buffer);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // IREE_HAL_UTILS_DEFERRED_COMMAND_BUFFER_H_
diff --git a/iree/samples/variables_and_state/main.c b/iree/samples/variables_and_state/main.c
index 9a7a288..9cdef29 100644
--- a/iree/samples/variables_and_state/main.c
+++ b/iree/samples/variables_and_state/main.c
@@ -166,7 +166,7 @@
fprintf(stdout, "Calling functions\n\n");
// 1. get_value() // initial value
- int value;
+ int value = -1;
if (iree_status_is_ok(status)) {
status = counter_get_value(session, &value);
}
diff --git a/iree/tools/BUILD b/iree/tools/BUILD
index 0e729c9..9ba8e4a 100644
--- a/iree/tools/BUILD
+++ b/iree/tools/BUILD
@@ -230,6 +230,17 @@
],
)
+cc_library(
+ name = "iree-mlir-lsp-server",
+ srcs = ["iree-mlir-lsp-server.cc"],
+ deps = [
+ ":init_passes_and_dialects",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:MlirLspServerLib",
+ "@llvm-project//mlir:Support",
+ ],
+)
+
cc_binary(
name = "iree-run-mlir",
srcs = ["iree-run-mlir-main.cc"],
diff --git a/iree/tools/CMakeLists.txt b/iree/tools/CMakeLists.txt
index fb42f8f..f0ae8d2 100644
--- a/iree/tools/CMakeLists.txt
+++ b/iree/tools/CMakeLists.txt
@@ -336,6 +336,20 @@
iree_cc_binary(
NAME
+ iree-mlir-lsp-server
+ SRCS
+ "iree-mlir-lsp-server.cc"
+ DEPS
+ ::init_passes_and_dialects
+ MLIRIR
+ MLIRLspServerLib
+ MLIRSupport
+ ${IREE_OPT_CONDITIONAL_DEPS}
+ PUBLIC
+ )
+
+ iree_cc_binary(
+ NAME
iree-run-mlir
SRCS
"iree-run-mlir-main.cc"
diff --git a/iree/tools/iree-mlir-lsp-server.cc b/iree/tools/iree-mlir-lsp-server.cc
new file mode 100644
index 0000000..0363a68
--- /dev/null
+++ b/iree/tools/iree-mlir-lsp-server.cc
@@ -0,0 +1,20 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+// Main entry function for the IREE variant of mlir-lsp-server.
+//
+// See https://mlir.llvm.org/docs/Tools/MLIRLSP/
+
+#include "iree/tools/init_dialects.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Tools/mlir-lsp-server/MlirLspServerMain.h"
+
+int main(int argc, char **argv) {
+ mlir::DialectRegistry registry;
+ mlir::iree_compiler::registerAllDialects(registry);
+ return failed(mlir::MlirLspServerMain(argc, argv, registry));
+}
diff --git a/iree/vm/bytecode_dispatch.c b/iree/vm/bytecode_dispatch.c
index d7995fd..316c55e 100644
--- a/iree/vm/bytecode_dispatch.c
+++ b/iree/vm/bytecode_dispatch.c
@@ -1340,7 +1340,7 @@
bool rhs_is_move;
iree_vm_ref_t* rhs = VM_DecOperandRegRef("rhs", &rhs_is_move);
int32_t* result = VM_DecResultRegI32("result");
- *result = iree_vm_ref_equal(lhs, rhs);
+ *result = vm_cmp_eq_ref(lhs, rhs);
if (lhs_is_move) iree_vm_ref_release(lhs);
if (rhs_is_move) iree_vm_ref_release(rhs);
});
@@ -1350,7 +1350,7 @@
bool rhs_is_move;
iree_vm_ref_t* rhs = VM_DecOperandRegRef("rhs", &rhs_is_move);
int32_t* result = VM_DecResultRegI32("result");
- *result = !iree_vm_ref_equal(lhs, rhs);
+ *result = vm_cmp_ne_ref(lhs, rhs);
if (lhs_is_move) iree_vm_ref_release(lhs);
if (rhs_is_move) iree_vm_ref_release(rhs);
});
@@ -1358,7 +1358,7 @@
bool operand_is_move;
iree_vm_ref_t* operand = VM_DecOperandRegRef("operand", &operand_is_move);
int32_t* result = VM_DecResultRegI32("result");
- *result = operand->ptr != NULL ? 1 : 0;
+ *result = vm_cmp_nz_ref(operand);
if (operand_is_move) iree_vm_ref_release(operand);
});
diff --git a/iree/vm/ops.h b/iree/vm/ops.h
index b957bdc..2cd3d3b 100644
--- a/iree/vm/ops.h
+++ b/iree/vm/ops.h
@@ -182,6 +182,15 @@
static inline int32_t vm_cmp_nz_i32(int32_t operand) {
return (operand != 0) ? 1 : 0;
}
+static inline int32_t vm_cmp_eq_ref(iree_vm_ref_t* lhs, iree_vm_ref_t* rhs) {
+ return iree_vm_ref_equal(lhs, rhs) ? 1 : 0;
+}
+static inline int32_t vm_cmp_ne_ref(iree_vm_ref_t* lhs, iree_vm_ref_t* rhs) {
+ return (!iree_vm_ref_equal(lhs, rhs)) ? 1 : 0;
+}
+static inline int32_t vm_cmp_nz_ref(iree_vm_ref_t* operand) {
+ return (operand->ptr != NULL) ? 1 : 0;
+}
static inline int32_t vm_cmp_eq_f32o(float lhs, float rhs) {
return (lhs == rhs) ? 1 : 0;
@@ -207,7 +216,9 @@
static inline int32_t vm_cmp_lte_f32u(float lhs, float rhs) {
return (isunordered(lhs, rhs) || islessequal(lhs, rhs)) ? 1 : 0;
}
-static inline int32_t vm_cmp_nan_f32(float operand) { return isnan(operand); }
+static inline int32_t vm_cmp_nan_f32(float operand) {
+ return isnan(operand) ? 1 : 0;
+}
//===------------------------------------------------------------------===//
// Control flow ops
@@ -312,7 +323,7 @@
}
//===------------------------------------------------------------------===//
-// Utility macros (Used for things that EmitC can't hadnle)
+// Utility macros (Used for things that EmitC can't handle)
//===------------------------------------------------------------------===//
// Get the address of an array element
@@ -324,6 +335,13 @@
iree_vm_ref_release(VM_ARRAY_ELEMENT_ADDRESS(array, i)); \
}
+#define VM_REF_RELEASE_IF_TYPE_MISMATCH(ref, type_def) \
+ if (ref->type != IREE_VM_REF_TYPE_NULL && \
+ (iree_vm_type_def_is_value(type_def) || \
+ ref->type != type_def->ref_type)) { \
+ iree_vm_ref_release(ref); \
+ }
+
// TODO(simon-camp): This macro should resemble the error handling part of the
// IREE_RETURN_IF_ERROR macro. There are two different definitions in
// iree/base/api.h depending on a feature flag.
diff --git a/iree/vm/test/BUILD b/iree/vm/test/BUILD
index 998350b..a86337e 100644
--- a/iree/vm/test/BUILD
+++ b/iree/vm/test/BUILD
@@ -45,6 +45,7 @@
":global_ops_i64.vmfb",
":list_ops.vmfb",
":list_variant_ops.vmfb",
+ ":ref_ops.vmfb",
":shift_ops.vmfb",
":shift_ops_i64.vmfb",
],
@@ -168,6 +169,12 @@
)
iree_bytecode_module(
+ name = "ref_ops",
+ src = "ref_ops.mlir",
+ flags = ["-iree-vm-ir-to-bytecode-module"],
+)
+
+iree_bytecode_module(
name = "shift_ops",
src = "shift_ops.mlir",
flags = ["-iree-vm-ir-to-bytecode-module"],
diff --git a/iree/vm/test/CMakeLists.txt b/iree/vm/test/CMakeLists.txt
index 2ffa3c7..5866f70 100644
--- a/iree/vm/test/CMakeLists.txt
+++ b/iree/vm/test/CMakeLists.txt
@@ -37,6 +37,7 @@
"global_ops_i64.vmfb"
"list_ops.vmfb"
"list_variant_ops.vmfb"
+ "ref_ops.vmfb"
"shift_ops.vmfb"
"shift_ops_i64.vmfb"
C_FILE_OUTPUT
@@ -239,6 +240,16 @@
iree_bytecode_module(
NAME
+ ref_ops
+ SRC
+ "ref_ops.mlir"
+ FLAGS
+ "-iree-vm-ir-to-bytecode-module"
+ PUBLIC
+)
+
+iree_bytecode_module(
+ NAME
shift_ops
SRC
"shift_ops.mlir"
diff --git a/iree/vm/test/emitc/CMakeLists.txt b/iree/vm/test/emitc/CMakeLists.txt
index 0cd1c9d..7fe076a 100644
--- a/iree/vm/test/emitc/CMakeLists.txt
+++ b/iree/vm/test/emitc/CMakeLists.txt
@@ -25,7 +25,7 @@
::arithmetic_ops
::arithmetic_ops_f32
::arithmetic_ops_i64
- ::assignment_ops
+ # ::assignment_ops
::assignment_ops_i64
::comparison_ops
::comparison_ops_f32
@@ -35,6 +35,8 @@
::conversion_ops_i64
::global_ops
::list_ops
+ ::list_variant_ops
+ ::ref_ops
::shift_ops
::shift_ops_i64
)
@@ -66,14 +68,15 @@
"arithmetic_ops_i64.h"
)
-iree_c_module(
- NAME
- assignment_ops
- SRC
- "../assignment_ops.mlir"
- H_FILE_OUTPUT
- "assignment_ops.h"
-)
+# TODO(simon-camp): Reenable this test once the 'vm.select.ref' op is supported.
+# iree_c_module(
+# NAME
+# assignment_ops
+# SRC
+# "../assignment_ops.mlir"
+# H_FILE_OUTPUT
+# "assignment_ops.h"
+# )
iree_c_module(
NAME
@@ -158,6 +161,24 @@
iree_c_module(
NAME
+ list_variant_ops
+ SRC
+ "../list_variant_ops.mlir"
+ H_FILE_OUTPUT
+ "list_variant_ops.h"
+)
+
+iree_c_module(
+ NAME
+ ref_ops
+ SRC
+ "../ref_ops.mlir"
+ H_FILE_OUTPUT
+ "ref_ops.h"
+)
+
+iree_c_module(
+ NAME
shift_ops
SRC
"../shift_ops.mlir"
diff --git a/iree/vm/test/emitc/module_test.cc b/iree/vm/test/emitc/module_test.cc
index 71f884e..0c0e1e9 100644
--- a/iree/vm/test/emitc/module_test.cc
+++ b/iree/vm/test/emitc/module_test.cc
@@ -13,7 +13,7 @@
#include "iree/vm/test/emitc/arithmetic_ops.h"
#include "iree/vm/test/emitc/arithmetic_ops_f32.h"
#include "iree/vm/test/emitc/arithmetic_ops_i64.h"
-#include "iree/vm/test/emitc/assignment_ops.h"
+// #include "iree/vm/test/emitc/assignment_ops.h"
#include "iree/vm/test/emitc/assignment_ops_i64.h"
#include "iree/vm/test/emitc/comparison_ops.h"
#include "iree/vm/test/emitc/comparison_ops_f32.h"
@@ -23,6 +23,8 @@
#include "iree/vm/test/emitc/conversion_ops_i64.h"
#include "iree/vm/test/emitc/global_ops.h"
#include "iree/vm/test/emitc/list_ops.h"
+#include "iree/vm/test/emitc/list_variant_ops.h"
+#include "iree/vm/test/emitc/ref_ops.h"
#include "iree/vm/test/emitc/shift_ops.h"
#include "iree/vm/test/emitc/shift_ops_i64.h"
@@ -55,7 +57,7 @@
{arithmetic_ops_descriptor_, arithmetic_ops_create},
{arithmetic_ops_f32_descriptor_, arithmetic_ops_f32_create},
{arithmetic_ops_i64_descriptor_, arithmetic_ops_i64_create},
- {assignment_ops_descriptor_, assignment_ops_create},
+ // {assignment_ops_descriptor_, assignment_ops_create},
{assignment_ops_i64_descriptor_, assignment_ops_i64_create},
{comparison_ops_descriptor_, comparison_ops_create},
{comparison_ops_f32_descriptor_, comparison_ops_f32_create},
@@ -65,6 +67,8 @@
{conversion_ops_i64_descriptor_, conversion_ops_i64_create},
{global_ops_descriptor_, global_ops_create},
{list_ops_descriptor_, list_ops_create},
+ {list_variant_ops_descriptor_, list_variant_ops_create},
+ {ref_ops_descriptor_, ref_ops_create},
{shift_ops_descriptor_, shift_ops_create},
{shift_ops_i64_descriptor_, shift_ops_i64_create}};
@@ -97,8 +101,7 @@
iree_vm_module_t* module_ = nullptr;
IREE_CHECK_OK(
- test_params.create_function(iree_allocator_system(), &module_))
- << "Module failed to load";
+ test_params.create_function(iree_allocator_system(), &module_));
std::vector<iree_vm_module_t*> modules = {module_};
IREE_CHECK_OK(iree_vm_context_create_with_modules(
@@ -119,8 +122,7 @@
IREE_CHECK_OK(iree_vm_context_resolve_function(
context_,
iree_string_view_t{qualified_name.data(), qualified_name.size()},
- &function))
- << "Exported function '" << local_name << "' not found";
+ &function));
return iree_vm_invoke(context_, function,
/*policy=*/nullptr, /*inputs=*/nullptr,
diff --git a/iree/vm/test/list_variant_ops.mlir b/iree/vm/test/list_variant_ops.mlir
index 1a19a5d..bd1a48d 100644
--- a/iree/vm/test/list_variant_ops.mlir
+++ b/iree/vm/test/list_variant_ops.mlir
@@ -89,6 +89,47 @@
vm.return
}
+ //===--------------------------------------------------------------------===//
+ // Failure tests
+ //===--------------------------------------------------------------------===//
+
+ vm.export @fail_uninitialized_access
+ vm.func @fail_uninitialized_access() {
+ %c0 = vm.const.i32 0 : i32
+ %c1 = vm.const.i32 1 : i32
+
+ %ref = vm.const.ref.rodata @byte_buffer : !vm.buffer
+ %list = vm.list.alloc %c1 : (i32) -> !vm.list<?>
+
+ vm.list.set.ref %list, %c0, %ref : (!vm.list<?>, i32, !vm.buffer)
+ vm.return
+ }
+
+ vm.export @fail_out_of_bounds_read
+ vm.func @fail_out_of_bounds_read() {
+ %c1 = vm.const.i32 1 : i32
+
+ %list = vm.list.alloc %c1 : (i32) -> !vm.list<?>
+ vm.list.resize %list, %c1 : (!vm.list<?>, i32)
+
+ %ref = vm.list.get.ref %list, %c1 : (!vm.list<?>, i32) -> !vm.buffer
+ %ref_dno = iree.do_not_optimize(%ref) : !vm.buffer
+ vm.return
+ }
+
+ vm.export @fail_out_of_bounds_write
+ vm.func @fail_out_of_bounds_write() {
+ %c0 = vm.const.i32 0 : i32
+ %c1 = vm.const.i32 1 : i32
+
+ %ref = vm.const.ref.rodata @byte_buffer : !vm.buffer
+ %list = vm.list.alloc %c1 : (i32) -> !vm.list<?>
+ vm.list.resize %list, %c1 : (!vm.list<?>, i32)
+
+ vm.list.set.ref %list, %c1, %ref : (!vm.list<?>, i32, !vm.buffer)
+ vm.return
+ }
+
vm.export @fail_variant_slot_change
vm.func @fail_variant_slot_change() {
%capacity = vm.const.i32 42 : i32
diff --git a/iree/vm/test/ref_ops.mlir b/iree/vm/test/ref_ops.mlir
new file mode 100644
index 0000000..d1ce68e
--- /dev/null
+++ b/iree/vm/test/ref_ops.mlir
@@ -0,0 +1,41 @@
+vm.module @ref_ops {
+ vm.rodata private @buffer_i8 dense<[1, 2, 3]> : tensor<3xi8>
+ vm.rodata private @buffer_i32 dense<[1, 2, 3]> : tensor<3xi32>
+
+ vm.export @test_zero_ref_eq
+ vm.func @test_zero_ref_eq() {
+ %ref = vm.const.ref.zero : !vm.ref<?>
+ %ref_dno = iree.do_not_optimize(%ref) : !vm.ref<?>
+ vm.check.eq %ref_dno, %ref_dno : !vm.ref<?>
+ vm.return
+ }
+
+ vm.export @test_ref_eq
+ vm.func @test_ref_eq() {
+ %ref_1 = vm.const.ref.rodata @buffer_i8 : !vm.buffer
+ %ref_1_dno = iree.do_not_optimize(%ref_1) : !vm.buffer
+ %ref_2 = vm.const.ref.rodata @buffer_i8 : !vm.buffer
+ %ref_2_dno = iree.do_not_optimize(%ref_2) : !vm.buffer
+ vm.check.eq %ref_1_dno, %ref_2_dno : !vm.buffer
+ vm.return
+ }
+
+ vm.export @test_ref_ne
+ vm.func @test_ref_ne() {
+ %ref_i8 = vm.const.ref.rodata @buffer_i8 : !vm.buffer
+ %ref_i8_dno = iree.do_not_optimize(%ref_i8) : !vm.buffer
+ %ref_i32 = vm.const.ref.rodata @buffer_i32 : !vm.buffer
+ %ref_i32_dno = iree.do_not_optimize(%ref_i32) : !vm.buffer
+ vm.check.ne %ref_i8_dno, %ref_i32_dno : !vm.buffer
+ vm.return
+ }
+
+ vm.export @test_ref_nz
+ vm.func @test_ref_nz() {
+ %ref = vm.const.ref.rodata @buffer_i8 : !vm.buffer
+ %ref_dno = iree.do_not_optimize(%ref) : !vm.buffer
+ vm.check.nz %ref_dno : !vm.buffer
+ vm.return
+ }
+
+}