Merge pull request #4696 from google/benvanik-tf-compiler-cleanup
diff --git a/iree/compiler/Conversion/LinalgToLLVM/LinalgTileAndDistributePass.cpp b/iree/compiler/Conversion/LinalgToLLVM/LinalgTileAndDistributePass.cpp index 4ebc130..307e8bc 100644 --- a/iree/compiler/Conversion/LinalgToLLVM/LinalgTileAndDistributePass.cpp +++ b/iree/compiler/Conversion/LinalgToLLVM/LinalgTileAndDistributePass.cpp
@@ -18,6 +18,7 @@ #include "iree/compiler/Conversion/Common/Attributes.h" #include "iree/compiler/Conversion/Common/Transforms.h" #include "iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.h" +#include "iree/compiler/Conversion/LinalgToLLVM/Passes.h" #include "iree/compiler/Dialect/HAL/IR/HALDialect.h" #include "iree/compiler/Dialect/HAL/IR/HALOps.h" #include "mlir/Dialect/Linalg/Transforms/CodegenStrategy.h"
diff --git a/iree/compiler/Conversion/LinalgToLLVM/Passes.h b/iree/compiler/Conversion/LinalgToLLVM/Passes.h index 301eee3..053d84c 100644 --- a/iree/compiler/Conversion/LinalgToLLVM/Passes.h +++ b/iree/compiler/Conversion/LinalgToLLVM/Passes.h
@@ -15,6 +15,7 @@ #ifndef IREE_COMPILER_CONVERSION_LINALGTOLLVM_PASSES_H_ #define IREE_COMPILER_CONVERSION_LINALGTOLLVM_PASSES_H_ +#include "iree/compiler/Dialect/HAL/IR/HALOps.h" #include "mlir/Pass/Pass.h" namespace mlir { @@ -29,7 +30,8 @@ std::unique_ptr<FunctionPass> createPlanConvLoopOrderPass(); /// Distributes linalg ops among hal.interface.workgroup logical threads. -std::unique_ptr<OperationPass<ModuleOp>> createLinalgTileAndDistributePass(); +std::unique_ptr<OperationPass<IREE::HAL::ExecutableTargetOp>> +createLinalgTileAndDistributePass(); /// Vectorizes linalg ops executed in the same hal.interface.workgroup. std::unique_ptr<FunctionPass> createLinalgTileAndVectorizeWorkgroupsPass();
diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp index fffc559..0339231 100644 --- a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp +++ b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
@@ -65,17 +65,18 @@ //---------------------------------------------------------------------------- passManager.addPass(createCanonicalizerPass()); + // Flatten structured control flow to our CFG. + passManager.addNestedPass<FuncOp>(mhlo::createLegalizeControlFlowPass()); + passManager.addNestedPass<FuncOp>(createHLOPreprocessingPass()); + // Frontload linalg-on-tensors transformations and dispatch region creation. if (clEnableLinalgOnTensorsDispatch) { + passManager.addNestedPass<FuncOp>(createCanonicalizerPass()); addHLOToLinalgOnTensorsPasses(passManager); passManager.addNestedPass<FuncOp>(createDispatchLinalgOnTensorsPass( clLinalgOnTensorsTileSizes, clLinalgOnTensorsEnableFusion)); } - // Flatten structured control flow to our CFG. - passManager.addNestedPass<FuncOp>(mhlo::createLegalizeControlFlowPass()); - passManager.addNestedPass<FuncOp>(createHLOPreprocessingPass()); - // Convert TOSA ops to Linalg-on-tensor ops. passManager.addNestedPass<FuncOp>(tosa::createTosaToLinalgOnTensors()); @@ -209,10 +210,8 @@ // Note that as we are rematerializing things here it's critical we do not run // the canonicalizer/CSE between now and when we outline - otherwise it'll // undo all of our work! - if (!clEnableLinalgOnTensorsDispatch) { - passManager.addNestedPass<FuncOp>( - IREE::Flow::createRematerializeDispatchConstantsPass()); - } + passManager.addNestedPass<FuncOp>( + IREE::Flow::createRematerializeDispatchConstantsPass()); // Outline the dispatch regions into their own functions wrapped in // executables. This separates sequencer functions performing dispatches from
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp index dd09c02..7bb954a 100644 --- a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp +++ b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp
@@ -158,7 +158,12 @@ funcOp->getAttrOfType<FlatSymbolRefAttr>( getNumWorkgroupsFnAttrName()); if (!numWorkgroupsFnAttr) { - return funcOp.emitError("expected llvm.num_workgroups_fn "); + auto constantOne = rewriter.createOrFold<mlir::ConstantIndexOp>(loc, 1); + rewriter.create<IREE::HAL::CommandBufferDispatchSymbolOp>( + loc, commandBuffer, dispatchState.entryPointOp, constantOne, + constantOne, constantOne); + rewriter.create<IREE::HAL::ReturnOp>(loc); + return success(); } std::array<Value, 3> workgroupCount = {nullptr, nullptr, nullptr}; FuncOp numWorkgroupsFn = dyn_cast<FuncOp>(SymbolTable::lookupSymbolIn(
diff --git a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp index fb6d657..1bde01a 100644 --- a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp +++ b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp
@@ -52,6 +52,35 @@ StringRef funcName; }; +template <typename SrcOpTy> +class ShiftArithmeticOpConversion : public OpConversionPattern<SrcOpTy> { + using OpConversionPattern<SrcOpTy>::OpConversionPattern; + + public: + ShiftArithmeticOpConversion(MLIRContext *context, StringRef funcName) + : OpConversionPattern<SrcOpTy>(context), funcName(funcName) {} + + private: + LogicalResult matchAndRewrite( + SrcOpTy op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + typename SrcOpTy::Adaptor srcAdaptor( + operands, op.getOperation()->getAttrDictionary()); + + StringAttr callee = rewriter.getStringAttr(funcName); + ArrayAttr args = rewriter.getArrayAttr( + {IntegerAttr::get(rewriter.getIndexType(), 0), srcAdaptor.amount()}); + ArrayAttr templateArgs; + + rewriter.replaceOpWithNewOp<emitc::CallOp>(op, op.getType(), callee, args, + templateArgs, operands); + + return success(); + } + + StringRef funcName; +}; + // TODO(simon-camp): These conversions to macro calls should be deleted once // support for control flow ops has landed in the c module target template <typename SrcOpTy> @@ -137,6 +166,14 @@ patterns.insert<NoAttributeOpConversion<IREE::VM::XorI32Op>>(context, "vm_xor_i32"); + // Native bitwise shift and rotate ops + patterns.insert<ShiftArithmeticOpConversion<IREE::VM::ShlI32Op>>( + context, "vm_shl_i32"); + patterns.insert<ShiftArithmeticOpConversion<IREE::VM::ShrI32SOp>>( + context, "vm_shr_i32s"); + patterns.insert<ShiftArithmeticOpConversion<IREE::VM::ShrI32UOp>>( + context, "vm_shr_i32u"); + // Check // TODO(simon-camp): These conversions to macro calls should be deleted once // support for control flow ops has landed in the c module target @@ -188,6 +225,11 @@ target.addIllegalOp<IREE::VM::OrI32Op>(); target.addIllegalOp<IREE::VM::XorI32Op>(); + // Native bitwise shift and rotate ops + target.addIllegalOp<IREE::VM::ShlI32Op>(); + target.addIllegalOp<IREE::VM::ShrI32SOp>(); + target.addIllegalOp<IREE::VM::ShrI32UOp>(); + // Check ops // TODO(simon-camp): These conversions to macro calls should be deleted once // support for control flow ops has landed in the c module target
diff --git a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/arithmetic.mlir b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/arithmetic_ops.mlir similarity index 96% rename from iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/arithmetic.mlir rename to iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/arithmetic_ops.mlir index 16383e1..06435e8 100644 --- a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/arithmetic.mlir +++ b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/arithmetic_ops.mlir
@@ -1,12 +1,10 @@ // RUN: iree-opt -split-input-file -pass-pipeline='vm.module(iree-convert-vm-to-emitc)' %s | IreeFileCheck %s -// CHECK: vm.module @my_module { +// CHECK-LABEL: @add_i32 vm.module @my_module { - // CHECK-LABEL: vm.func @add_i32 vm.func @add_i32(%arg0: i32, %arg1: i32) { // CHECK-NEXT: %0 = emitc.call "vm_add_i32"(%arg0, %arg1) : (i32, i32) -> i32 %0 = vm.add.i32 %arg0, %arg1 : i32 - // CHECK-NEXT: vm.return %0 : i32 vm.return %0 : i32 } }
diff --git a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/shift_ops.mlir b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/shift_ops.mlir new file mode 100644 index 0000000..0222c95 --- /dev/null +++ b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/shift_ops.mlir
@@ -0,0 +1,32 @@ +// RUN: iree-opt -split-input-file -pass-pipeline='vm.module(iree-convert-vm-to-emitc)' %s | IreeFileCheck %s + +// CHECK-LABEL: @shl_i32 +vm.module @my_module { + vm.func @shl_i32(%arg0 : i32) -> i32 { + // CHECK: %0 = emitc.call "vm_shl_i32"(%arg0) {args = [0 : index, 2 : i8]} : (i32) -> i32 + %0 = vm.shl.i32 %arg0, 2 : i32 + vm.return %0 : i32 + } +} + +// ----- + +// CHECK-LABEL: @shr_i32_s +vm.module @my_module { + vm.func @shr_i32_s(%arg0 : i32) -> i32 { + // CHECK: %0 = emitc.call "vm_shr_i32s"(%arg0) {args = [0 : index, 2 : i8]} : (i32) -> i32 + %0 = vm.shr.i32.s %arg0, 2 : i32 + vm.return %0 : i32 + } +} + +// ----- + +// CHECK-LABEL: @shr_i32_u +vm.module @my_module { + vm.func @shr_i32_u(%arg0 : i32) -> i32 { + // CHECK: %0 = emitc.call "vm_shr_i32u"(%arg0) {args = [0 : index, 2 : i8]} : (i32) -> i32 + %0 = vm.shr.i32.u %arg0, 2 : i32 + vm.return %0 : i32 + } +}
diff --git a/iree/compiler/Dialect/VM/IR/test/arithmetic_ops.mlir b/iree/compiler/Dialect/VM/IR/test/arithmetic_ops.mlir index 4fd7c65..bbb8381 100644 --- a/iree/compiler/Dialect/VM/IR/test/arithmetic_ops.mlir +++ b/iree/compiler/Dialect/VM/IR/test/arithmetic_ops.mlir
@@ -120,36 +120,3 @@ vm.return %0 : i32 } } - -// ----- - -// CHECK-LABEL: @shl_i32 -vm.module @my_module { - vm.func @shl_i32(%arg0 : i32) -> i32 { - // CHECK: %0 = vm.shl.i32 %arg0, 2 : i32 - %0 = vm.shl.i32 %arg0, 2 : i32 - vm.return %0 : i32 - } -} - -// ----- - -// CHECK-LABEL: @shr_i32_s -vm.module @my_module { - vm.func @shr_i32_s(%arg0 : i32) -> i32 { - // CHECK: %0 = vm.shr.i32.s %arg0, 2 : i32 - %0 = vm.shr.i32.s %arg0, 2 : i32 - vm.return %0 : i32 - } -} - -// ----- - -// CHECK-LABEL: @shr_i32_u -vm.module @my_module { - vm.func @shr_i32_u(%arg0 : i32) -> i32 { - // CHECK: %0 = vm.shr.i32.u %arg0, 2 : i32 - %0 = vm.shr.i32.u %arg0, 2 : i32 - vm.return %0 : i32 - } -}
diff --git a/iree/compiler/Dialect/VM/IR/test/shift_ops.mlir b/iree/compiler/Dialect/VM/IR/test/shift_ops.mlir new file mode 100644 index 0000000..f6375bc --- /dev/null +++ b/iree/compiler/Dialect/VM/IR/test/shift_ops.mlir
@@ -0,0 +1,34 @@ +// Tests printing and parsing of bitwise shift and rotate ops. + +// RUN: iree-opt -split-input-file %s | IreeFileCheck %s + +// CHECK-LABEL: @shl_i32 +vm.module @my_module { + vm.func @shl_i32(%arg0 : i32) -> i32 { + // CHECK: %0 = vm.shl.i32 %arg0, 2 : i32 + %0 = vm.shl.i32 %arg0, 2 : i32 + vm.return %0 : i32 + } +} + +// ----- + +// CHECK-LABEL: @shr_i32_s +vm.module @my_module { + vm.func @shr_i32_s(%arg0 : i32) -> i32 { + // CHECK: %0 = vm.shr.i32.s %arg0, 2 : i32 + %0 = vm.shr.i32.s %arg0, 2 : i32 + vm.return %0 : i32 + } +} + +// ----- + +// CHECK-LABEL: @shr_i32_u +vm.module @my_module { + vm.func @shr_i32_u(%arg0 : i32) -> i32 { + // CHECK: %0 = vm.shr.i32.u %arg0, 2 : i32 + %0 = vm.shr.i32.u %arg0, 2 : i32 + vm.return %0 : i32 + } +}
diff --git a/iree/test/e2e/xla_ops/BUILD b/iree/test/e2e/xla_ops/BUILD index 2693834..8d13056 100644 --- a/iree/test/e2e/xla_ops/BUILD +++ b/iree/test/e2e/xla_ops/BUILD
@@ -134,6 +134,70 @@ target_backend = "dylib-llvm-aot", ) +iree_check_single_backend_test_suite( + name = "check_linalg_on_tensors_dylib-llvm-aot_dylib", + srcs = [ + "abs.mlir", + "add.mlir", + "batch_norm_inference.mlir", + "broadcast.mlir", + "broadcast_add.mlir", + "broadcast_in_dim.mlir", + # https://github.com/google/iree/issues/4692 + # "clamp.mlir", + "compare.mlir", + # https://github.com/google/iree/issues/4079 + # "concatenate.mlir", + "constant.mlir", + # https://github.com/google/iree/issues/4079 + # "convolution.mlir", + "cosine.mlir", + "divide.mlir", + # https://github.com/google/iree/issues/4079 + # "dot.mlir", + # "dot_general.mlir", + "exponential.mlir", + # https://github.com/google/iree/issues/4692 + # "exponential_minus_one.mlir", + "floor.mlir", + # https://github.com/google/iree/issues/4692 + # "gather.mlir", + "iota.mlir", + "log.mlir", + # https://github.com/google/iree/issues/4692 + # "log_plus_one.mlir", + # "maximum.mlir", + # "minimum.mlir", + "multiply.mlir", + "negate.mlir", + # https://github.com/google/iree/issues/4079 + # "pad.mlir", + # "reduce.mlir", + # "reduce_window.mlir", + "remainder.mlir", + # "reshape.mlir", + # "reverse.mlir", + "rsqrt.mlir", + "select.mlir", + "sine.mlir", + # https://github.com/google/iree/issues/4692 + # "slice.mlir", + "sqrt.mlir", + "subtract.mlir", + "tanh.mlir", + # https://github.com/google/iree/issues/4079 + # "torch_index_select.mlir", + "transpose.mlir", + # "while.mlir", + ], + compiler_flags = [ + "-iree-flow-dispatch-linalg-on-tensors", + "-iree-codegen-llvm-experimental-linalg-on-tensors", + ], + driver = "dylib", + target_backend = "dylib-llvm-aot", +) + test_suite( name = "check", tests = [
diff --git a/iree/test/e2e/xla_ops/CMakeLists.txt b/iree/test/e2e/xla_ops/CMakeLists.txt index 65c4d93..9a53940 100644 --- a/iree/test/e2e/xla_ops/CMakeLists.txt +++ b/iree/test/e2e/xla_ops/CMakeLists.txt
@@ -130,3 +130,40 @@ DRIVER "dylib" ) + +iree_check_single_backend_test_suite( + NAME + check_linalg_on_tensors_dylib-llvm-aot_dylib + SRCS + "abs.mlir" + "add.mlir" + "batch_norm_inference.mlir" + "broadcast.mlir" + "broadcast_add.mlir" + "broadcast_in_dim.mlir" + "compare.mlir" + "constant.mlir" + "cosine.mlir" + "divide.mlir" + "exponential.mlir" + "floor.mlir" + "iota.mlir" + "log.mlir" + "multiply.mlir" + "negate.mlir" + "remainder.mlir" + "rsqrt.mlir" + "select.mlir" + "sine.mlir" + "sqrt.mlir" + "subtract.mlir" + "tanh.mlir" + "transpose.mlir" + TARGET_BACKEND + "dylib-llvm-aot" + DRIVER + "dylib" + COMPILER_FLAGS + "-iree-flow-dispatch-linalg-on-tensors" + "-iree-codegen-llvm-experimental-linalg-on-tensors" +)
diff --git a/iree/vm/bytecode_dispatch.c b/iree/vm/bytecode_dispatch.c index f1cb0bd..a4e052c 100644 --- a/iree/vm/bytecode_dispatch.c +++ b/iree/vm/bytecode_dispatch.c
@@ -1025,17 +1025,17 @@ // Native bitwise shifts and rotates //===------------------------------------------------------------------===// -#define DISPATCH_OP_CORE_SHIFT_I32(op_name, type, op) \ - DISPATCH_OP(CORE, op_name, { \ - int32_t operand = VM_DecOperandRegI32("operand"); \ - int8_t amount = VM_DecConstI8("amount"); \ - int32_t* result = VM_DecResultRegI32("result"); \ - *result = (int32_t)(((type)operand)op amount); \ +#define DISPATCH_OP_CORE_SHIFT_I32(op_name, type, op_func) \ + DISPATCH_OP(CORE, op_name, { \ + int32_t operand = VM_DecOperandRegI32("operand"); \ + int8_t amount = VM_DecConstI8("amount"); \ + int32_t* result = VM_DecResultRegI32("result"); \ + *result = op_func(operand, amount); \ }); - DISPATCH_OP_CORE_SHIFT_I32(ShlI32, int32_t, <<); - DISPATCH_OP_CORE_SHIFT_I32(ShrI32S, int32_t, >>); - DISPATCH_OP_CORE_SHIFT_I32(ShrI32U, uint32_t, >>); + DISPATCH_OP_CORE_SHIFT_I32(ShlI32, int32_t, vm_shl_i32); + DISPATCH_OP_CORE_SHIFT_I32(ShrI32S, int32_t, vm_shr_i32s); + DISPATCH_OP_CORE_SHIFT_I32(ShrI32U, uint32_t, vm_shr_i32u); //===------------------------------------------------------------------===// // Comparison ops
diff --git a/iree/vm/ops.h b/iree/vm/ops.h index d22d8ea..5e9904a 100644 --- a/iree/vm/ops.h +++ b/iree/vm/ops.h
@@ -35,6 +35,20 @@ static inline int32_t vm_or_i32(int32_t lhs, int32_t rhs) { return lhs | rhs; } static inline int32_t vm_xor_i32(int32_t lhs, int32_t rhs) { return lhs ^ rhs; } +//===------------------------------------------------------------------===// +// Native bitwise shifts and rotates +//===------------------------------------------------------------------===// + +static inline int32_t vm_shl_i32(int32_t operand, int8_t amount) { + return (int32_t)(operand << amount); +}; +static inline int32_t vm_shr_i32s(int32_t operand, int8_t amount) { + return (int32_t)(operand >> amount); +}; +static inline int32_t vm_shr_i32u(uint32_t operand, int8_t amount) { + return (int32_t)(operand >> amount); +}; + // Check ops // TODO(simon-camp): These macros should be removed once control flow ops are // supported in the c module target