Add LinalgOnTensors xla_ops tests (#4693)
- Move HLO -> HLO & Legalize workflow before LinalgOnTensors -> HLO
- Run canonicalization after HLO -> HLO.
- Create xla_ops tests.
diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
index fffc559..0339231 100644
--- a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
@@ -65,17 +65,18 @@
//----------------------------------------------------------------------------
passManager.addPass(createCanonicalizerPass());
+ // Flatten structured control flow to our CFG.
+ passManager.addNestedPass<FuncOp>(mhlo::createLegalizeControlFlowPass());
+ passManager.addNestedPass<FuncOp>(createHLOPreprocessingPass());
+
// Frontload linalg-on-tensors transformations and dispatch region creation.
if (clEnableLinalgOnTensorsDispatch) {
+ passManager.addNestedPass<FuncOp>(createCanonicalizerPass());
addHLOToLinalgOnTensorsPasses(passManager);
passManager.addNestedPass<FuncOp>(createDispatchLinalgOnTensorsPass(
clLinalgOnTensorsTileSizes, clLinalgOnTensorsEnableFusion));
}
- // Flatten structured control flow to our CFG.
- passManager.addNestedPass<FuncOp>(mhlo::createLegalizeControlFlowPass());
- passManager.addNestedPass<FuncOp>(createHLOPreprocessingPass());
-
// Convert TOSA ops to Linalg-on-tensor ops.
passManager.addNestedPass<FuncOp>(tosa::createTosaToLinalgOnTensors());
@@ -209,10 +210,8 @@
// Note that as we are rematerializing things here it's critical we do not run
// the canonicalizer/CSE between now and when we outline - otherwise it'll
// undo all of our work!
- if (!clEnableLinalgOnTensorsDispatch) {
- passManager.addNestedPass<FuncOp>(
- IREE::Flow::createRematerializeDispatchConstantsPass());
- }
+ passManager.addNestedPass<FuncOp>(
+ IREE::Flow::createRematerializeDispatchConstantsPass());
// Outline the dispatch regions into their own functions wrapped in
// executables. This separates sequencer functions performing dispatches from
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp
index dd09c02..7bb954a 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp
@@ -158,7 +158,12 @@
funcOp->getAttrOfType<FlatSymbolRefAttr>(
getNumWorkgroupsFnAttrName());
if (!numWorkgroupsFnAttr) {
- return funcOp.emitError("expected llvm.num_workgroups_fn ");
+ auto constantOne = rewriter.createOrFold<mlir::ConstantIndexOp>(loc, 1);
+ rewriter.create<IREE::HAL::CommandBufferDispatchSymbolOp>(
+ loc, commandBuffer, dispatchState.entryPointOp, constantOne,
+ constantOne, constantOne);
+ rewriter.create<IREE::HAL::ReturnOp>(loc);
+ return success();
}
std::array<Value, 3> workgroupCount = {nullptr, nullptr, nullptr};
FuncOp numWorkgroupsFn = dyn_cast<FuncOp>(SymbolTable::lookupSymbolIn(
diff --git a/iree/test/e2e/xla_ops/BUILD b/iree/test/e2e/xla_ops/BUILD
index 2693834..8d13056 100644
--- a/iree/test/e2e/xla_ops/BUILD
+++ b/iree/test/e2e/xla_ops/BUILD
@@ -134,6 +134,70 @@
target_backend = "dylib-llvm-aot",
)
+iree_check_single_backend_test_suite(
+ name = "check_linalg_on_tensors_dylib-llvm-aot_dylib",
+ srcs = [
+ "abs.mlir",
+ "add.mlir",
+ "batch_norm_inference.mlir",
+ "broadcast.mlir",
+ "broadcast_add.mlir",
+ "broadcast_in_dim.mlir",
+ # https://github.com/google/iree/issues/4692
+ # "clamp.mlir",
+ "compare.mlir",
+ # https://github.com/google/iree/issues/4079
+ # "concatenate.mlir",
+ "constant.mlir",
+ # https://github.com/google/iree/issues/4079
+ # "convolution.mlir",
+ "cosine.mlir",
+ "divide.mlir",
+ # https://github.com/google/iree/issues/4079
+ # "dot.mlir",
+ # "dot_general.mlir",
+ "exponential.mlir",
+ # https://github.com/google/iree/issues/4692
+ # "exponential_minus_one.mlir",
+ "floor.mlir",
+ # https://github.com/google/iree/issues/4692
+ # "gather.mlir",
+ "iota.mlir",
+ "log.mlir",
+ # https://github.com/google/iree/issues/4692
+ # "log_plus_one.mlir",
+ # "maximum.mlir",
+ # "minimum.mlir",
+ "multiply.mlir",
+ "negate.mlir",
+ # https://github.com/google/iree/issues/4079
+ # "pad.mlir",
+ # "reduce.mlir",
+ # "reduce_window.mlir",
+ "remainder.mlir",
+ # "reshape.mlir",
+ # "reverse.mlir",
+ "rsqrt.mlir",
+ "select.mlir",
+ "sine.mlir",
+ # https://github.com/google/iree/issues/4692
+ # "slice.mlir",
+ "sqrt.mlir",
+ "subtract.mlir",
+ "tanh.mlir",
+ # https://github.com/google/iree/issues/4079
+ # "torch_index_select.mlir",
+ "transpose.mlir",
+ # "while.mlir",
+ ],
+ compiler_flags = [
+ "-iree-flow-dispatch-linalg-on-tensors",
+ "-iree-codegen-llvm-experimental-linalg-on-tensors",
+ ],
+ driver = "dylib",
+ target_backend = "dylib-llvm-aot",
+)
+
test_suite(
name = "check",
tests = [
diff --git a/iree/test/e2e/xla_ops/CMakeLists.txt b/iree/test/e2e/xla_ops/CMakeLists.txt
index 65c4d93..9a53940 100644
--- a/iree/test/e2e/xla_ops/CMakeLists.txt
+++ b/iree/test/e2e/xla_ops/CMakeLists.txt
@@ -130,3 +130,40 @@
DRIVER
"dylib"
)
+
+iree_check_single_backend_test_suite(
+ NAME
+ check_linalg_on_tensors_dylib-llvm-aot_dylib
+ SRCS
+ "abs.mlir"
+ "add.mlir"
+ "batch_norm_inference.mlir"
+ "broadcast.mlir"
+ "broadcast_add.mlir"
+ "broadcast_in_dim.mlir"
+ "compare.mlir"
+ "constant.mlir"
+ "cosine.mlir"
+ "divide.mlir"
+ "exponential.mlir"
+ "floor.mlir"
+ "iota.mlir"
+ "log.mlir"
+ "multiply.mlir"
+ "negate.mlir"
+ "remainder.mlir"
+ "rsqrt.mlir"
+ "select.mlir"
+ "sine.mlir"
+ "sqrt.mlir"
+ "subtract.mlir"
+ "tanh.mlir"
+ "transpose.mlir"
+ TARGET_BACKEND
+ "dylib-llvm-aot"
+ DRIVER
+ "dylib"
+ COMPILER_FLAGS
+ "-iree-flow-dispatch-linalg-on-tensors"
+ "-iree-codegen-llvm-experimental-linalg-on-tensors"
+)