Merge pull request #4696 from google/benvanik-tf-compiler-cleanup

diff --git a/iree/compiler/Conversion/LinalgToLLVM/LinalgTileAndDistributePass.cpp b/iree/compiler/Conversion/LinalgToLLVM/LinalgTileAndDistributePass.cpp
index 4ebc130..307e8bc 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/LinalgTileAndDistributePass.cpp
+++ b/iree/compiler/Conversion/LinalgToLLVM/LinalgTileAndDistributePass.cpp
@@ -18,6 +18,7 @@
 #include "iree/compiler/Conversion/Common/Attributes.h"
 #include "iree/compiler/Conversion/Common/Transforms.h"
 #include "iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.h"
+#include "iree/compiler/Conversion/LinalgToLLVM/Passes.h"
 #include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
 #include "iree/compiler/Dialect/HAL/IR/HALOps.h"
 #include "mlir/Dialect/Linalg/Transforms/CodegenStrategy.h"
diff --git a/iree/compiler/Conversion/LinalgToLLVM/Passes.h b/iree/compiler/Conversion/LinalgToLLVM/Passes.h
index 301eee3..053d84c 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/Passes.h
+++ b/iree/compiler/Conversion/LinalgToLLVM/Passes.h
@@ -15,6 +15,7 @@
 #ifndef IREE_COMPILER_CONVERSION_LINALGTOLLVM_PASSES_H_
 #define IREE_COMPILER_CONVERSION_LINALGTOLLVM_PASSES_H_
 
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
 #include "mlir/Pass/Pass.h"
 
 namespace mlir {
@@ -29,7 +30,8 @@
 std::unique_ptr<FunctionPass> createPlanConvLoopOrderPass();
 
 /// Distributes linalg ops among hal.interface.workgroup logical threads.
-std::unique_ptr<OperationPass<ModuleOp>> createLinalgTileAndDistributePass();
+std::unique_ptr<OperationPass<IREE::HAL::ExecutableTargetOp>>
+createLinalgTileAndDistributePass();
 
 /// Vectorizes linalg ops executed in the same hal.interface.workgroup.
 std::unique_ptr<FunctionPass> createLinalgTileAndVectorizeWorkgroupsPass();
diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
index fffc559..0339231 100644
--- a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
@@ -65,17 +65,18 @@
   //----------------------------------------------------------------------------
   passManager.addPass(createCanonicalizerPass());
 
+  // Flatten structured control flow to our CFG.
+  passManager.addNestedPass<FuncOp>(mhlo::createLegalizeControlFlowPass());
+  passManager.addNestedPass<FuncOp>(createHLOPreprocessingPass());
+
   // Frontload linalg-on-tensors transformations and dispatch region creation.
   if (clEnableLinalgOnTensorsDispatch) {
+    passManager.addNestedPass<FuncOp>(createCanonicalizerPass());
     addHLOToLinalgOnTensorsPasses(passManager);
     passManager.addNestedPass<FuncOp>(createDispatchLinalgOnTensorsPass(
         clLinalgOnTensorsTileSizes, clLinalgOnTensorsEnableFusion));
   }
 
-  // Flatten structured control flow to our CFG.
-  passManager.addNestedPass<FuncOp>(mhlo::createLegalizeControlFlowPass());
-  passManager.addNestedPass<FuncOp>(createHLOPreprocessingPass());
-
   // Convert TOSA ops to Linalg-on-tensor ops.
   passManager.addNestedPass<FuncOp>(tosa::createTosaToLinalgOnTensors());
 
@@ -209,10 +210,8 @@
   // Note that as we are rematerializing things here it's critical we do not run
   // the canonicalizer/CSE between now and when we outline - otherwise it'll
   // undo all of our work!
-  if (!clEnableLinalgOnTensorsDispatch) {
-    passManager.addNestedPass<FuncOp>(
-        IREE::Flow::createRematerializeDispatchConstantsPass());
-  }
+  passManager.addNestedPass<FuncOp>(
+      IREE::Flow::createRematerializeDispatchConstantsPass());
 
   // Outline the dispatch regions into their own functions wrapped in
   // executables. This separates sequencer functions performing dispatches from
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp
index dd09c02..7bb954a 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp
@@ -158,7 +158,12 @@
           funcOp->getAttrOfType<FlatSymbolRefAttr>(
               getNumWorkgroupsFnAttrName());
       if (!numWorkgroupsFnAttr) {
-        return funcOp.emitError("expected llvm.num_workgroups_fn ");
+        auto constantOne = rewriter.createOrFold<mlir::ConstantIndexOp>(loc, 1);
+        rewriter.create<IREE::HAL::CommandBufferDispatchSymbolOp>(
+            loc, commandBuffer, dispatchState.entryPointOp, constantOne,
+            constantOne, constantOne);
+        rewriter.create<IREE::HAL::ReturnOp>(loc);
+        return success();
       }
       std::array<Value, 3> workgroupCount = {nullptr, nullptr, nullptr};
       FuncOp numWorkgroupsFn = dyn_cast<FuncOp>(SymbolTable::lookupSymbolIn(
diff --git a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp
index fb6d657..1bde01a 100644
--- a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp
+++ b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp
@@ -52,6 +52,35 @@
   StringRef funcName;
 };
 
+template <typename SrcOpTy>
+class ShiftArithmeticOpConversion : public OpConversionPattern<SrcOpTy> {
+  using OpConversionPattern<SrcOpTy>::OpConversionPattern;
+
+ public:
+  ShiftArithmeticOpConversion(MLIRContext *context, StringRef funcName)
+      : OpConversionPattern<SrcOpTy>(context), funcName(funcName) {}
+
+ private:
+  LogicalResult matchAndRewrite(
+      SrcOpTy op, ArrayRef<Value> operands,
+      ConversionPatternRewriter &rewriter) const override {
+    typename SrcOpTy::Adaptor srcAdaptor(
+        operands, op.getOperation()->getAttrDictionary());
+
+    StringAttr callee = rewriter.getStringAttr(funcName);
+    ArrayAttr args = rewriter.getArrayAttr(
+        {IntegerAttr::get(rewriter.getIndexType(), 0), srcAdaptor.amount()});
+    ArrayAttr templateArgs;
+
+    rewriter.replaceOpWithNewOp<emitc::CallOp>(op, op.getType(), callee, args,
+                                               templateArgs, operands);
+
+    return success();
+  }
+
+  StringRef funcName;
+};
+
 // TODO(simon-camp): These conversions to macro calls should be deleted once
 // support for control flow ops has landed in the c module target
 template <typename SrcOpTy>
@@ -137,6 +166,14 @@
   patterns.insert<NoAttributeOpConversion<IREE::VM::XorI32Op>>(context,
                                                                "vm_xor_i32");
 
+  // Native bitwise shift and rotate ops
+  patterns.insert<ShiftArithmeticOpConversion<IREE::VM::ShlI32Op>>(
+      context, "vm_shl_i32");
+  patterns.insert<ShiftArithmeticOpConversion<IREE::VM::ShrI32SOp>>(
+      context, "vm_shr_i32s");
+  patterns.insert<ShiftArithmeticOpConversion<IREE::VM::ShrI32UOp>>(
+      context, "vm_shr_i32u");
+
   // Check
   // TODO(simon-camp): These conversions to macro calls should be deleted once
   // support for control flow ops has landed in the c module target
@@ -188,6 +225,11 @@
     target.addIllegalOp<IREE::VM::OrI32Op>();
     target.addIllegalOp<IREE::VM::XorI32Op>();
 
+    // Native bitwise shift and rotate ops
+    target.addIllegalOp<IREE::VM::ShlI32Op>();
+    target.addIllegalOp<IREE::VM::ShrI32SOp>();
+    target.addIllegalOp<IREE::VM::ShrI32UOp>();
+
     // Check ops
     // TODO(simon-camp): These conversions to macro calls should be deleted once
     // support for control flow ops has landed in the c module target
diff --git a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/arithmetic.mlir b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/arithmetic_ops.mlir
similarity index 96%
rename from iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/arithmetic.mlir
rename to iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/arithmetic_ops.mlir
index 16383e1..06435e8 100644
--- a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/arithmetic.mlir
+++ b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/arithmetic_ops.mlir
@@ -1,12 +1,10 @@
 // RUN: iree-opt -split-input-file -pass-pipeline='vm.module(iree-convert-vm-to-emitc)' %s | IreeFileCheck %s
 
-// CHECK: vm.module @my_module {
+// CHECK-LABEL: @add_i32
 vm.module @my_module {
-  // CHECK-LABEL: vm.func @add_i32
   vm.func @add_i32(%arg0: i32, %arg1: i32) {
     // CHECK-NEXT: %0 = emitc.call "vm_add_i32"(%arg0, %arg1) : (i32, i32) -> i32
     %0 = vm.add.i32 %arg0, %arg1 : i32
-    // CHECK-NEXT: vm.return %0 : i32
     vm.return %0 : i32
   }
 }
diff --git a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/shift_ops.mlir b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/shift_ops.mlir
new file mode 100644
index 0000000..0222c95
--- /dev/null
+++ b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/shift_ops.mlir
@@ -0,0 +1,32 @@
+// RUN: iree-opt -split-input-file -pass-pipeline='vm.module(iree-convert-vm-to-emitc)' %s | IreeFileCheck %s
+
+// CHECK-LABEL: @shl_i32
+vm.module @my_module {
+  vm.func @shl_i32(%arg0 : i32) -> i32 {
+    // CHECK: %0 = emitc.call "vm_shl_i32"(%arg0) {args = [0 : index, 2 : i8]} : (i32) -> i32
+    %0 = vm.shl.i32 %arg0, 2 : i32
+    vm.return %0 : i32
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @shr_i32_s
+vm.module @my_module {
+  vm.func @shr_i32_s(%arg0 : i32) -> i32 {
+    // CHECK: %0 = emitc.call "vm_shr_i32s"(%arg0) {args = [0 : index, 2 : i8]} : (i32) -> i32
+    %0 = vm.shr.i32.s %arg0, 2 : i32
+    vm.return %0 : i32
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @shr_i32_u
+vm.module @my_module {
+  vm.func @shr_i32_u(%arg0 : i32) -> i32 {
+    // CHECK: %0 = emitc.call "vm_shr_i32u"(%arg0) {args = [0 : index, 2 : i8]} : (i32) -> i32
+    %0 = vm.shr.i32.u %arg0, 2 : i32
+    vm.return %0 : i32
+  }
+}
diff --git a/iree/compiler/Dialect/VM/IR/test/arithmetic_ops.mlir b/iree/compiler/Dialect/VM/IR/test/arithmetic_ops.mlir
index 4fd7c65..bbb8381 100644
--- a/iree/compiler/Dialect/VM/IR/test/arithmetic_ops.mlir
+++ b/iree/compiler/Dialect/VM/IR/test/arithmetic_ops.mlir
@@ -120,36 +120,3 @@
     vm.return %0 : i32
   }
 }
-
-// -----
-
-// CHECK-LABEL: @shl_i32
-vm.module @my_module {
-  vm.func @shl_i32(%arg0 : i32) -> i32 {
-    // CHECK: %0 = vm.shl.i32 %arg0, 2 : i32
-    %0 = vm.shl.i32 %arg0, 2 : i32
-    vm.return %0 : i32
-  }
-}
-
-// -----
-
-// CHECK-LABEL: @shr_i32_s
-vm.module @my_module {
-  vm.func @shr_i32_s(%arg0 : i32) -> i32 {
-    // CHECK: %0 = vm.shr.i32.s %arg0, 2 : i32
-    %0 = vm.shr.i32.s %arg0, 2 : i32
-    vm.return %0 : i32
-  }
-}
-
-// -----
-
-// CHECK-LABEL: @shr_i32_u
-vm.module @my_module {
-  vm.func @shr_i32_u(%arg0 : i32) -> i32 {
-    // CHECK: %0 = vm.shr.i32.u %arg0, 2 : i32
-    %0 = vm.shr.i32.u %arg0, 2 : i32
-    vm.return %0 : i32
-  }
-}
diff --git a/iree/compiler/Dialect/VM/IR/test/shift_ops.mlir b/iree/compiler/Dialect/VM/IR/test/shift_ops.mlir
new file mode 100644
index 0000000..f6375bc
--- /dev/null
+++ b/iree/compiler/Dialect/VM/IR/test/shift_ops.mlir
@@ -0,0 +1,34 @@
+// Tests printing and parsing of bitwise shift and rotate ops.
+
+// RUN: iree-opt -split-input-file %s | IreeFileCheck %s
+
+// CHECK-LABEL: @shl_i32
+vm.module @my_module {
+  vm.func @shl_i32(%arg0 : i32) -> i32 {
+    // CHECK: %0 = vm.shl.i32 %arg0, 2 : i32
+    %0 = vm.shl.i32 %arg0, 2 : i32
+    vm.return %0 : i32
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @shr_i32_s
+vm.module @my_module {
+  vm.func @shr_i32_s(%arg0 : i32) -> i32 {
+    // CHECK: %0 = vm.shr.i32.s %arg0, 2 : i32
+    %0 = vm.shr.i32.s %arg0, 2 : i32
+    vm.return %0 : i32
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @shr_i32_u
+vm.module @my_module {
+  vm.func @shr_i32_u(%arg0 : i32) -> i32 {
+    // CHECK: %0 = vm.shr.i32.u %arg0, 2 : i32
+    %0 = vm.shr.i32.u %arg0, 2 : i32
+    vm.return %0 : i32
+  }
+}
diff --git a/iree/test/e2e/xla_ops/BUILD b/iree/test/e2e/xla_ops/BUILD
index 2693834..8d13056 100644
--- a/iree/test/e2e/xla_ops/BUILD
+++ b/iree/test/e2e/xla_ops/BUILD
@@ -134,6 +134,70 @@
     target_backend = "dylib-llvm-aot",
 )
 
+iree_check_single_backend_test_suite(
+    name = "check_linalg_on_tensors_dylib-llvm-aot_dylib",
+    srcs = [
+        "abs.mlir",
+        "add.mlir",
+        "batch_norm_inference.mlir",
+        "broadcast.mlir",
+        "broadcast_add.mlir",
+        "broadcast_in_dim.mlir",
+        # https://github.com/google/iree/issues/4692
+        # "clamp.mlir",
+        "compare.mlir",
+        # https://github.com/google/iree/issues/4079
+        # "concatenate.mlir",
+        "constant.mlir",
+        # https://github.com/google/iree/issues/4079
+        # "convolution.mlir",
+        "cosine.mlir",
+        "divide.mlir",
+        # https://github.com/google/iree/issues/4079
+        # "dot.mlir",
+        # "dot_general.mlir",
+        "exponential.mlir",
+        # https://github.com/google/iree/issues/4692
+        # "exponential_minus_one.mlir",
+        "floor.mlir",
+        # https://github.com/google/iree/issues/4692
+        # "gather.mlir",
+        "iota.mlir",
+        "log.mlir",
+        # https://github.com/google/iree/issues/4692
+        # "log_plus_one.mlir",
+        # "maximum.mlir",
+        # "minimum.mlir",
+        "multiply.mlir",
+        "negate.mlir",
+        # https://github.com/google/iree/issues/4079
+        # "pad.mlir",
+        # "reduce.mlir",
+        # "reduce_window.mlir",
+        "remainder.mlir",
+        # "reshape.mlir",
+        # "reverse.mlir",
+        "rsqrt.mlir",
+        "select.mlir",
+        "sine.mlir",
+        # https://github.com/google/iree/issues/4692
+        # "slice.mlir",
+        "sqrt.mlir",
+        "subtract.mlir",
+        "tanh.mlir",
+        # https://github.com/google/iree/issues/4079
+        # "torch_index_select.mlir",
+        "transpose.mlir",
+        # "while.mlir",
+    ],
+    compiler_flags = [
+        "-iree-flow-dispatch-linalg-on-tensors",
+        "-iree-codegen-llvm-experimental-linalg-on-tensors",
+    ],
+    driver = "dylib",
+    target_backend = "dylib-llvm-aot",
+)
+
 test_suite(
     name = "check",
     tests = [
diff --git a/iree/test/e2e/xla_ops/CMakeLists.txt b/iree/test/e2e/xla_ops/CMakeLists.txt
index 65c4d93..9a53940 100644
--- a/iree/test/e2e/xla_ops/CMakeLists.txt
+++ b/iree/test/e2e/xla_ops/CMakeLists.txt
@@ -130,3 +130,40 @@
   DRIVER
     "dylib"
 )
+
+iree_check_single_backend_test_suite(
+  NAME
+    check_linalg_on_tensors_dylib-llvm-aot_dylib
+  SRCS
+    "abs.mlir"
+    "add.mlir"
+    "batch_norm_inference.mlir"
+    "broadcast.mlir"
+    "broadcast_add.mlir"
+    "broadcast_in_dim.mlir"
+    "compare.mlir"
+    "constant.mlir"
+    "cosine.mlir"
+    "divide.mlir"
+    "exponential.mlir"
+    "floor.mlir"
+    "iota.mlir"
+    "log.mlir"
+    "multiply.mlir"
+    "negate.mlir"
+    "remainder.mlir"
+    "rsqrt.mlir"
+    "select.mlir"
+    "sine.mlir"
+    "sqrt.mlir"
+    "subtract.mlir"
+    "tanh.mlir"
+    "transpose.mlir"
+  TARGET_BACKEND
+    "dylib-llvm-aot"
+  DRIVER
+    "dylib"
+  COMPILER_FLAGS
+    "-iree-flow-dispatch-linalg-on-tensors"
+    "-iree-codegen-llvm-experimental-linalg-on-tensors"
+)
diff --git a/iree/vm/bytecode_dispatch.c b/iree/vm/bytecode_dispatch.c
index f1cb0bd..a4e052c 100644
--- a/iree/vm/bytecode_dispatch.c
+++ b/iree/vm/bytecode_dispatch.c
@@ -1025,17 +1025,17 @@
     // Native bitwise shifts and rotates
     //===------------------------------------------------------------------===//
 
-#define DISPATCH_OP_CORE_SHIFT_I32(op_name, type, op) \
-  DISPATCH_OP(CORE, op_name, {                        \
-    int32_t operand = VM_DecOperandRegI32("operand"); \
-    int8_t amount = VM_DecConstI8("amount");          \
-    int32_t* result = VM_DecResultRegI32("result");   \
-    *result = (int32_t)(((type)operand)op amount);    \
+#define DISPATCH_OP_CORE_SHIFT_I32(op_name, type, op_func) \
+  DISPATCH_OP(CORE, op_name, {                             \
+    int32_t operand = VM_DecOperandRegI32("operand");      \
+    int8_t amount = VM_DecConstI8("amount");               \
+    int32_t* result = VM_DecResultRegI32("result");        \
+    *result = op_func(operand, amount);                    \
   });
 
-    DISPATCH_OP_CORE_SHIFT_I32(ShlI32, int32_t, <<);
-    DISPATCH_OP_CORE_SHIFT_I32(ShrI32S, int32_t, >>);
-    DISPATCH_OP_CORE_SHIFT_I32(ShrI32U, uint32_t, >>);
+    DISPATCH_OP_CORE_SHIFT_I32(ShlI32, int32_t, vm_shl_i32);
+    DISPATCH_OP_CORE_SHIFT_I32(ShrI32S, int32_t, vm_shr_i32s);
+    DISPATCH_OP_CORE_SHIFT_I32(ShrI32U, uint32_t, vm_shr_i32u);
 
     //===------------------------------------------------------------------===//
     // Comparison ops
diff --git a/iree/vm/ops.h b/iree/vm/ops.h
index d22d8ea..5e9904a 100644
--- a/iree/vm/ops.h
+++ b/iree/vm/ops.h
@@ -35,6 +35,20 @@
 static inline int32_t vm_or_i32(int32_t lhs, int32_t rhs) { return lhs | rhs; }
 static inline int32_t vm_xor_i32(int32_t lhs, int32_t rhs) { return lhs ^ rhs; }
 
+//===------------------------------------------------------------------===//
+// Native bitwise shifts and rotates
+//===------------------------------------------------------------------===//
+
+static inline int32_t vm_shl_i32(int32_t operand, int8_t amount) {
+  return (int32_t)(operand << amount);
+};
+static inline int32_t vm_shr_i32s(int32_t operand, int8_t amount) {
+  return (int32_t)(operand >> amount);
+};
+static inline int32_t vm_shr_i32u(uint32_t operand, int8_t amount) {
+  return (int32_t)(operand >> amount);
+};
+
 // Check ops
 // TODO(simon-camp): These macros should be removed once control flow ops are
 // supported in the c module target