Implementing VMLA reduction conversion.

This allows standard reduce ops via the reduce-in-dispatch path to convert
to the built-in VMLA reduction ops. There's a placeholder for a generic
fallback path that emits a VM loop, however that's not used yet (as we
don't support those in other backends yet either).

WIP #936.

Closes https://github.com/google/iree/pull/955

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/iree/pull/955 from google:benvanik-vmla-reduce-conversion 35b02988e53499848895ba69697292b26019644e
PiperOrigin-RevId: 299015281
diff --git a/build_tools/cmake/iree_copts.cmake b/build_tools/cmake/iree_copts.cmake
index 3ef833e..4ba9497 100644
--- a/build_tools/cmake/iree_copts.cmake
+++ b/build_tools/cmake/iree_copts.cmake
@@ -43,6 +43,7 @@
     "-Wno-undef"
   MSVC_OR_CLANG_CL
     "/DWIN32_LEAN_AND_MEAN"
+    "/wd4624"
     # TODO(benvanik): figure out if really required or accidentally enabled.
     "/EHsc"
 )
diff --git a/iree/compiler/Dialect/Flow/Transforms/OutlineReductionRegions.cpp b/iree/compiler/Dialect/Flow/Transforms/OutlineReductionRegions.cpp
index 2d89be5..bee1284 100644
--- a/iree/compiler/Dialect/Flow/Transforms/OutlineReductionRegions.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/OutlineReductionRegions.cpp
@@ -151,7 +151,8 @@
   for (uint32_t i = 0; i < dimensions.getNumElements(); ++i) {
     sortedDimensions.push_back(dimensions.getValue<IntegerAttr>({i}).getInt());
   }
-  llvm::sort(sortedDimensions, [](int32_t a, int32_t b) { return a - b; });
+  llvm::sort(sortedDimensions,
+             [](int32_t a, int32_t b) { return (a - b) > 0; });
   for (auto dimension : llvm::enumerate(sortedDimensions)) {
     // Create the executable with the region cloned into it.
     ExecutableOp executableOp;
diff --git a/iree/compiler/Dialect/HAL/Target/VMLA/test/smoketest.mlir b/iree/compiler/Dialect/HAL/Target/VMLA/test/smoketest.mlir
index f4fa468..e33d70c 100644
--- a/iree/compiler/Dialect/HAL/Target/VMLA/test/smoketest.mlir
+++ b/iree/compiler/Dialect/HAL/Target/VMLA/test/smoketest.mlir
@@ -42,18 +42,57 @@
 
 // -----
 
-// TODO(benvanik): vmla reduction.
-// flow.executable @reduction_ex_reduce_0_dim_0 {
-//   flow.reduction.entry @reduction_rgn_reduce_0_dim_0_entry apply(@reduction_rgn_reduce_0_dim_0) attributes {
-//     dimension = 1 : i32,
-//     workgroup_size = dense<[32, 1, 1]> : vector<3xi32>,
-//     workload = dense<[4, 1, 1]> : vector<3xi32>
-//   }
-//   module {
-//     func @reduction_rgn_reduce_0_dim_0_entry(tensor<4x8xf32>, tensor<f32>) -> tensor<4xf32>
-//     func @reduction_rgn_reduce_0_dim_0(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
-//       %0 = xla_hlo.add %arg0, %arg1 : tensor<f32>
-//       return %0 : tensor<f32>
-//     }
-//   }
-// }
+flow.executable @reduction_ex_dispatch_0 {
+  flow.dispatch.entry @reduction_ex_dispatch_0 attributes {workload = dense<[4, 1, 1]> : vector<3xi32>}
+  module {
+    func @reduction_ex_dispatch_0(%arg0: tensor<4x8xf32>) -> tensor<4xf32> {
+      %cst = constant dense<0.000000e+00> : tensor<f32>
+      %0 = "xla_hlo.reduce"(%arg0, %cst) ( {
+      ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):	// no predecessors
+        %1 = xla_hlo.add %arg1, %arg2 : tensor<f32>
+        "xla_hlo.return"(%1) : (tensor<f32>) -> ()
+      }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xf32>, tensor<f32>) -> tensor<4xf32>
+      return %0 : tensor<4xf32>
+    }
+  }
+}
+
+// CHECK-LABEL: hal.executable @reduction_ex_dispatch_0 {
+//  CHECK-NEXT:   hal.interface @legacy_io {
+//  CHECK-NEXT:     hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+//  CHECK-NEXT:     hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write|Discard"
+//  CHECK-NEXT:   }
+//  CHECK-NEXT:   hal.executable.entry_point @reduction_ex_dispatch_0 attributes {interface = @legacy_io, ordinal = 0 : i32, signature = (tensor<4x8xf32>) -> tensor<4xf32>, workgroup_size = dense<1> : vector<3xi32>}
+//  CHECK-NEXT:   hal.executable.binary attributes {
+//  CHECK-SAME:       data = dense<
+//  CHECK-SAME:       format = 1447906369 : i32} {
+//  CHECK-NEXT:     vm.module @module {
+//  CHECK-NEXT:       vm.rodata @reduction_ex_dispatch_0_const_0 dense<0.000000e+00> : tensor<f32>
+//  CHECK-NEXT:       vm.func @reduction_ex_dispatch_0(%arg0: !vm.ref<!vmla.interface>) attributes {ordinal = 0 : i32} {
+//  CHECK-NEXT:         %zero = vm.const.i32.zero : i32
+//  CHECK-NEXT:         %c128 = vm.const.i32 128 : i32
+//  CHECK-NEXT:         %c16 = vm.const.i32 16 : i32
+//  CHECK-NEXT:         %c4 = vm.const.i32 4 : i32
+//  CHECK-NEXT:         %c8 = vm.const.i32 8 : i32
+//  CHECK-NEXT:         %c1 = vm.const.i32 1 : i32
+//  CHECK-NEXT:         %reduction_ex_dispatch_0_const_0 = vm.const.ref.rodata @reduction_ex_dispatch_0_const_0 : !vm.ref<!iree.byte_buffer>
+//  CHECK-NEXT:         %ref = vm.call @vmla.buffer.const(%reduction_ex_dispatch_0_const_0) : (!vm.ref<!iree.byte_buffer>) -> !vm.ref<!vmla.buffer>
+//  CHECK-NEXT:         %ref_0 = vm.call @vmla.interface.binding(%arg0, %zero, %zero) : (!vm.ref<!vmla.interface>, i32, i32) -> !vm.ref<!vmla.buffer>
+//  CHECK-NEXT:         %ref_1 = vm.call @vmla.buffer.view(%ref_0, %zero, %c128) : (!vm.ref<!vmla.buffer>, i32, i32) -> !vm.ref<!vmla.buffer>
+//  CHECK-NEXT:         %ref_2 = vm.call @vmla.buffer.alloc(%c16) : (i32) -> !vm.ref<!vmla.buffer>
+//  CHECK-NEXT:         vm.call.variadic @vmla.reduce.sum.f32(%ref_1, [%c4, %c8], %ref, [], %c1, %ref_2, [%c4]) : (!vm.ref<!vmla.buffer>, i32..., !vm.ref<!vmla.buffer>, i32..., i32, !vm.ref<!vmla.buffer>, i32...)
+//  CHECK-NEXT:         %ref_3 = vm.call @vmla.interface.binding(%arg0, %zero, %c1) : (!vm.ref<!vmla.interface>, i32, i32) -> !vm.ref<!vmla.buffer>
+//  CHECK-NEXT:         vm.call @vmla.buffer.copy(%ref_2, %zero, %ref_3, %zero, %c16) : (!vm.ref<!vmla.buffer>, i32, !vm.ref<!vmla.buffer>, i32, i32) -> ()
+//  CHECK-NEXT:         vm.return
+//  CHECK-NEXT:       }
+//  CHECK-NEXT:       vm.export @reduction_ex_dispatch_0 attributes {ordinal = 0 : i32}
+//  CHECK-NEXT:       vm.rodata @reduction_ex_dispatch_0_impl_const_0 dense<0.000000e+00> : tensor<f32>
+//  CHECK-NEXT:       vm.import @vmla.interface.binding(%interface : !vm.ref<!vmla.interface>, %set : i32, %binding : i32) -> !vm.ref<!vmla.buffer> attributes {nosideeffects, ordinal = 0 : i32, sym_visibility = "private"}
+//  CHECK-NEXT:       vm.import @vmla.buffer.const(%value : !vm.ref<!iree.byte_buffer>) -> !vm.ref<!vmla.buffer> attributes {nosideeffects, ordinal = 1 : i32, sym_visibility = "private"}
+//  CHECK-NEXT:       vm.import @vmla.buffer.alloc(%byte_length : i32) -> !vm.ref<!vmla.buffer> attributes {nosideeffects, ordinal = 2 : i32, sym_visibility = "private"}
+//  CHECK-NEXT:       vm.import @vmla.buffer.view(%src : !vm.ref<!vmla.buffer>, %byte_offset : i32, %byte_length : i32) -> !vm.ref<!vmla.buffer> attributes {nosideeffects, ordinal = 3 : i32, sym_visibility = "private"}
+//  CHECK-NEXT:       vm.import @vmla.buffer.copy(%src : !vm.ref<!vmla.buffer>, %src_byte_offset : i32, %dst : !vm.ref<!vmla.buffer>, %dst_byte_offset : i32, %byte_length : i32) attributes {ordinal = 4 : i32, sym_visibility = "private"}
+//  CHECK-NEXT:       vm.import @vmla.reduce.sum.f32(%src : !vm.ref<!vmla.buffer>, %src_shape : i32..., %init : !vm.ref<!vmla.buffer>, %init_shape : i32..., %dimension : i32, %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32...) attributes {ordinal = 5 : i32, sym_visibility = "private"}
+//  CHECK-NEXT:     }
+//  CHECK-NEXT:   }
+//  CHECK-NEXT: }
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/BUILD b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/BUILD
index 3d6de77..52c81cc 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/BUILD
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/BUILD
@@ -21,6 +21,7 @@
     name = "HLOToVMLA",
     srcs = [
         "ConvertHLOToVMLA.cpp",
+        "ConvertReductionOps.cpp",
     ],
     hdrs = [
         "ConvertHLOToVMLA.h",
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/CMakeLists.txt b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/CMakeLists.txt
index a3975e6..9574f1f 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/CMakeLists.txt
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/CMakeLists.txt
@@ -21,6 +21,7 @@
     "ConvertHLOToVMLA.h"
   SRCS
     "ConvertHLOToVMLA.cpp"
+    "ConvertReductionOps.cpp"
   DEPS
     MLIRIR
     MLIRPass
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
index 681a41e..1d50438 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
@@ -35,6 +35,10 @@
 namespace mlir {
 namespace iree_compiler {
 
+void populateHLOReductionToVMLAPatterns(MLIRContext *context,
+                                        OwningRewritePatternList &patterns,
+                                        TypeConverter &typeConverter);
+
 namespace {
 
 // Clones operand[0] and returns the result.
@@ -372,6 +376,9 @@
   xla_hlo::PopulateXlaToStdPatterns(&patterns, context);
   xla_hlo::PopulateUnfuseBatchNormPatterns(context, &patterns);
 
+  // xla_hlo.reduce and xla_hlo.reduce_window.
+  populateHLOReductionToVMLAPatterns(context, patterns, typeConverter);
+
   // Simple 1:1 conversion patterns using the automated trait-based converter.
   // Used for HLO ops that have equivalent VMLA ops such as most arithmetic ops.
   patterns.insert<VMLAOpConversion<xla_hlo::AddOp, IREE::VMLA::AddOp>>(
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertReductionOps.cpp b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertReductionOps.cpp
new file mode 100644
index 0000000..1fdb1fc
--- /dev/null
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertReductionOps.cpp
@@ -0,0 +1,239 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/compiler/Dialect/IREE/IR/IREETypes.h"
+#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
+#include "iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.h"
+#include "iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.h"
+#include "iree/compiler/Dialect/VMLA/IR/VMLADialect.h"
+#include "iree/compiler/Dialect/VMLA/IR/VMLAOps.h"
+#include "iree/compiler/Dialect/VMLA/IR/VMLATypes.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+// Converts a simple xla_hlo.reduce op that performs independent individual
+// computations into a set of xla_hlo.reduce ops. This is an intermediate
+// conversion that may make it possible to use the much faster builtin VMLA
+// reduction ops.
+//
+// Only supports single dimensional reductions and assumes that unrolling has
+// been performed prior to conversion.
+struct SplitIndependentReductionOpConversion
+    : public OpConversionPattern<xla_hlo::ReduceOp> {
+  SplitIndependentReductionOpConversion(MLIRContext *context,
+                                        TypeConverter &typeConverter)
+      : OpConversionPattern(context), typeConverter(typeConverter) {}
+
+  PatternMatchResult matchAndRewrite(
+      xla_hlo::ReduceOp srcOp, ArrayRef<Value> operands,
+      ConversionPatternRewriter &rewriter) const override {
+    if (srcOp.dimensions().getNumElements() > 1) {
+      srcOp.emitOpError() << "multi-dimensional reductions must be unrolled";
+      return matchFailure();
+    } else if (srcOp.body().getBlocks().size() > 1) {
+      // Control flow within the computation is not supported; bail to fallback.
+      return matchFailure();
+    }
+    auto &block = srcOp.body().getBlocks().front();
+    xla_hlo::ReduceOpOperandAdaptor newOperands(operands);
+    SmallVector<Value, 4> setResults;
+    for (auto &op : block) {
+      if (op.isKnownTerminator()) {
+        continue;
+      } else if (op.getOperands().size() != 2) {
+        // Only binary ops are supported for builtins.
+        return matchFailure();
+      }
+
+      // Determine which argument set this op is acting on. For the builtins we
+      // only support ops that act within a single set.
+      // Our arguments are expanded tuples like <lhs0, lhs1>, <rhs0, rhs1>, so
+      // this index gets the set offset.
+      int opSetIndex =
+          std::distance(block.args_begin(),
+                        llvm::find(block.getArguments(), op.getOperand(0)));
+
+      for (auto operand : op.getOperands()) {
+        if (operand.getDefiningOp() != nullptr) {
+          // Operand comes from another op within the block; unsupported.
+          return matchFailure();
+        }
+        int operandSetIndex =
+            std::distance(block.args_begin(),
+                          llvm::find(block.getArguments(), operand)) %
+            newOperands.operands().size();
+        if (operandSetIndex != opSetIndex) {
+          // Operand is not coming from the same set as the other operands of
+          // this op; unsupported.
+          return matchFailure();
+        }
+      }
+      for (auto result : op.getResults()) {
+        for (auto *user : result.getUsers()) {
+          if (!user->isKnownTerminator()) {
+            // Result is not directly returned from the block; unsupported.
+            return matchFailure();
+          }
+        }
+      }
+
+      // Create the new op for this set.
+      Value operandArg = srcOp.operands()[opSetIndex];
+      Value initArg = srcOp.init_values()[opSetIndex];
+      auto splitOp = rewriter.create<xla_hlo::ReduceOp>(
+          op.getLoc(), ValueRange{operandArg}, ValueRange{initArg},
+          srcOp.dimensionsAttr());
+      auto *splitBlock = new Block();
+      splitOp.body().getBlocks().push_back(splitBlock);
+      OpBuilder splitBuilder(splitBlock);
+      BlockAndValueMapping mapping;
+      for (auto operand : op.getOperands()) {
+        mapping.map(operand, splitBlock->addArgument(operand.getType()));
+      }
+      Operation *splitComputeOp = splitBuilder.clone(op, mapping);
+      splitBuilder.create<xla_hlo::ReturnOp>(
+          srcOp.getLoc(), ValueRange{*splitComputeOp->getResults().begin()});
+      setResults.push_back(*splitOp.getResults().begin());
+    }
+
+    rewriter.replaceOp(srcOp, setResults);
+    return matchSuccess();
+  }
+
+  TypeConverter &typeConverter;
+};
+
+// Converts an xla_hlo.reduce with a single op to a builtin reduce op.
+// This is meant to pair with the SplitIndependentReductionOpConversion that
+// tries to unfuse/divide combined reductions. If this cannot match then the
+// fallback path will be used and a VM loop will be emitted (slower, but can
+// perform any reduction).
+//
+// Only supports single dimensional reductions and assumes that unrolling has
+// been performed prior to conversion.
+struct BuiltinReduceOpConversion
+    : public OpConversionPattern<xla_hlo::ReduceOp> {
+  BuiltinReduceOpConversion(MLIRContext *context, TypeConverter &typeConverter)
+      : OpConversionPattern(context, /*benefit=*/1000),
+        typeConverter(typeConverter) {}
+
+  PatternMatchResult matchAndRewrite(
+      xla_hlo::ReduceOp srcOp, ArrayRef<Value> operands,
+      ConversionPatternRewriter &rewriter) const override {
+    if (srcOp.dimensions().getNumElements() > 1) {
+      srcOp.emitOpError() << "multi-dimensional reductions must be unrolled";
+      return matchFailure();
+    } else if (srcOp.body().getBlocks().size() > 1) {
+      // Control flow within the computation is not supported; bail to fallback.
+      return matchFailure();
+    } else if (srcOp.body().front().getOperations().size() > 2) {
+      // Require splitting first.
+      return matchFailure();
+    }
+
+    auto operand = operands[0];
+    auto operandShape = VMLAConversionTarget::getTensorShape(
+        srcOp.getLoc(), srcOp.operands()[0], typeConverter, rewriter);
+    auto initValue = operands[1];
+    auto initValueShape = VMLAConversionTarget::getTensorShape(
+        srcOp.getLoc(), srcOp.init_values()[0], typeConverter, rewriter);
+    int dimension = srcOp.dimensions().getValue<IntegerAttr>({0}).getInt();
+    auto dst = VMLAConversionTarget::allocateOutputBuffer(
+        srcOp.getLoc(), srcOp.getResults()[0], typeConverter, rewriter);
+    auto dstShape = VMLAConversionTarget::getTensorShape(
+        srcOp.getLoc(), srcOp.getResults()[0], typeConverter, rewriter);
+    auto elementType =
+        srcOp.operands()[0].getType().cast<ShapedType>().getElementType();
+
+    auto &computeOp = *srcOp.body().front().begin();
+    if (isa<mlir::AddIOp>(computeOp) || isa<mlir::AddFOp>(computeOp) ||
+        isa<xla_hlo::AddOp>(computeOp)) {
+      rewriter.create<IREE::VMLA::ReduceSumOp>(
+          srcOp.getLoc(), operand, operandShape, initValue, initValueShape,
+          rewriter.getI32IntegerAttr(dimension), dst, dstShape,
+          TypeAttr::get(elementType));
+    } else if (isa<xla_hlo::MinOp>(computeOp)) {
+      rewriter.create<IREE::VMLA::ReduceMinOp>(
+          srcOp.getLoc(), operand, operandShape, initValue, initValueShape,
+          rewriter.getI32IntegerAttr(dimension), dst, dstShape,
+          TypeAttr::get(elementType));
+    } else if (isa<xla_hlo::MaxOp>(computeOp)) {
+      rewriter.create<IREE::VMLA::ReduceMaxOp>(
+          srcOp.getLoc(), operand, operandShape, initValue, initValueShape,
+          rewriter.getI32IntegerAttr(dimension), dst, dstShape,
+          TypeAttr::get(elementType));
+    } else {
+      computeOp.emitRemark() << "unsupported builtin reduction operation";
+      return matchFailure();
+    }
+
+    rewriter.replaceOp(srcOp, {dst});
+    return matchSuccess();
+  }
+
+  TypeConverter &typeConverter;
+};
+
+// Converts a generic xla_hlo.reduce to a VM loop.
+//
+// Only supports single dimensional reductions and assumes that unrolling has
+// been performed prior to conversion.
+struct GenericReduceOpConversion
+    : public OpConversionPattern<xla_hlo::ReduceOp> {
+  GenericReduceOpConversion(MLIRContext *context, TypeConverter &typeConverter)
+      : OpConversionPattern(context), typeConverter(typeConverter) {}
+
+  PatternMatchResult matchAndRewrite(
+      xla_hlo::ReduceOp srcOp, ArrayRef<Value> operands,
+      ConversionPatternRewriter &rewriter) const override {
+    if (srcOp.dimensions().getNumElements() > 1) {
+      srcOp.emitOpError() << "multi-dimensional reductions must be unrolled";
+      return matchFailure();
+    }
+
+    // TODO(benvanik): emit VM loop around computation.
+    srcOp.emitOpError() << "generic reduction lowering not yet implemented";
+    return matchFailure();
+  }
+
+  TypeConverter &typeConverter;
+};
+
+}  // namespace
+
+void populateHLOReductionToVMLAPatterns(MLIRContext *context,
+                                        OwningRewritePatternList &patterns,
+                                        TypeConverter &typeConverter) {
+  patterns.insert<SplitIndependentReductionOpConversion>(context,
+                                                         typeConverter);
+  patterns.insert<BuiltinReduceOpConversion>(context, typeConverter);
+  patterns.insert<GenericReduceOpConversion>(context, typeConverter);
+}
+
+}  // namespace iree_compiler
+}  // namespace mlir
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/reduce.mlir b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/reduce.mlir
new file mode 100644
index 0000000..4923313
--- /dev/null
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/reduce.mlir
@@ -0,0 +1,45 @@
+// RUN: iree-opt -split-input-file -iree-vmla-conversion -cse %s | IreeFileCheck %s
+
+// CHECK-LABEL: @single_reduction
+func @single_reduction(%arg0: tensor<4x8xf32>) -> tensor<4xf32> attributes { sym_visibility = "private" } {
+  // CHECK-DAG: [[INIT:%.+]] = "vmla.constant"() {value = dense<0.000000e+00> : tensor<f32>} : () -> !vmla.buffer
+  %cst = constant dense<0.000000e+00> : tensor<f32>
+  //  CHECK-DAG: [[SRC_SHAPE:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[4,8],i32>
+  //  CHECK-DAG: [[INIT_SHAPE:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[],i32>
+  //  CHECK-DAG: [[DST:%.+]] = "vmla.buffer.alloc"
+  //  CHECK-DAG: [[DST_SHAPE:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[4],i32>
+  // CHECK-NEXT: "vmla.reduce.sum"(%arg0, [[SRC_SHAPE]], [[INIT]], [[INIT_SHAPE]], [[DST]], [[DST_SHAPE]]) {dimension = 1 : i32, element_type = f32} : (!vmla.buffer, !shapex.ranked_shape<[4,8],i32>, !vmla.buffer, !shapex.ranked_shape<[],i32>, !vmla.buffer, !shapex.ranked_shape<[4],i32>) -> ()
+  %0 = "xla_hlo.reduce"(%arg0, %cst) ( {
+  ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):	// no predecessors
+    %1 = xla_hlo.add %arg1, %arg2 : tensor<f32>
+    "xla_hlo.return"(%1) : (tensor<f32>) -> ()
+  }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xf32>, tensor<f32>) -> tensor<4xf32>
+  // CHECK-NEXT: return [[DST]] : !vmla.buffer
+  return %0 : tensor<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @multi_reduction
+func @multi_reduction(%arg0 : tensor<4x8xf32>, %arg1 : tensor<4x8xf32>) -> (tensor<4xf32>, tensor<4xf32>) attributes { sym_visibility = "private" } {
+  //  CHECK-DAG: [[CST0:%.+]] = "vmla.constant"() {value = dense<0.000000e+00> : tensor<f32>} : () -> !vmla.buffer
+  %0 = constant dense<0.000000e+00> : tensor<f32>
+  //  CHECK-DAG: [[CST1:%.+]] = "vmla.constant"() {value = dense<1.000000e+00> : tensor<f32>} : () -> !vmla.buffer
+  %1 = constant dense<1.000000e+00> : tensor<f32>
+  //  CHECK-DAG: [[INPUT_SHAPE:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[4,8],i32>
+  //  CHECK-DAG: [[SCALAR_SHAPE:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[],i32>
+  //  CHECK-DAG: [[RESULT_SHAPE:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[4],i32>
+  //  CHECK-DAG: [[RET_SIZE:%.+]] = muli
+  // CHECK-NEXT: [[RET0:%.+]] = "vmla.buffer.alloc"([[RET_SIZE]]) : (i32) -> !vmla.buffer
+  // CHECK-NEXT: "vmla.reduce.sum"(%arg0, [[INPUT_SHAPE]], [[CST0]], [[SCALAR_SHAPE]], [[RET0]], [[RESULT_SHAPE]]) {dimension = 1 : i32, element_type = f32} : (!vmla.buffer, !shapex.ranked_shape<[4,8],i32>, !vmla.buffer, !shapex.ranked_shape<[],i32>, !vmla.buffer, !shapex.ranked_shape<[4],i32>) -> ()
+  // CHECK-NEXT: [[RET1:%.+]] = "vmla.buffer.alloc"([[RET_SIZE]]) : (i32) -> !vmla.buffer
+  // CHECK-NEXT: "vmla.reduce.sum"(%arg1, [[INPUT_SHAPE]], [[CST1]], [[SCALAR_SHAPE]], [[RET1]], [[RESULT_SHAPE]]) {dimension = 1 : i32, element_type = f32} : (!vmla.buffer, !shapex.ranked_shape<[4,8],i32>, !vmla.buffer, !shapex.ranked_shape<[],i32>, !vmla.buffer, !shapex.ranked_shape<[4],i32>) -> ()
+  %2, %3 = "xla_hlo.reduce"(%arg0, %arg1, %0, %1) ( {
+  ^bb0(%arg0_lhs : tensor<f32>, %arg1_lhs : tensor<f32>, %arg0_rhs : tensor<f32>, %arg1_rhs : tensor<f32>):
+    %4 = xla_hlo.add %arg0_lhs, %arg0_rhs : tensor<f32>
+    %5 = xla_hlo.add %arg1_lhs, %arg1_rhs : tensor<f32>
+    "xla_hlo.return"(%4, %5) : (tensor<f32>, tensor<f32>) -> ()
+  }) {dimensions = dense<[1]> : tensor<1xi64>} : (tensor<4x8xf32>, tensor<4x8xf32>, tensor<f32>, tensor<f32>) -> (tensor<4xf32>, tensor<4xf32>)
+  // CHECK-NEXT: return [[RET0]], [[RET1]] : !vmla.buffer, !vmla.buffer
+  return %2, %3 : tensor<4xf32>, tensor<4xf32>
+}
diff --git a/iree/compiler/Dialect/VMLA/IR/VMLAOps.td b/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
index 4f52fe0..5e796bd 100644
--- a/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
+++ b/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
@@ -56,7 +56,7 @@
   );
 }
 
-def VMLA_BufferAllocOp : VMLA_PureOp<"buffer.alloc"> {
+def VMLA_BufferAllocOp : VMLA_Op<"buffer.alloc"> {
   let arguments = (ins
     VMLA_DeviceSize:$byte_length
   );
@@ -65,7 +65,7 @@
   );
 }
 
-def VMLA_BufferCloneOp : VMLA_PureOp<"buffer.clone"> {
+def VMLA_BufferCloneOp : VMLA_Op<"buffer.clone"> {
   let arguments = (ins
     VMLA_Buffer:$src
   );
diff --git a/test/e2e/xla/reduce_float.mlir b/test/e2e/xla/reduce_float.mlir
index e303014..3612a4e 100644
--- a/test/e2e/xla/reduce_float.mlir
+++ b/test/e2e/xla/reduce_float.mlir
@@ -1,4 +1,5 @@
 // RUN: iree-run-mlir %s -iree-hal-target-backends=interpreter-bytecode | IreeFileCheck %s
+// RUN: iree-run-mlir -iree-hal-target-backends=vmla -iree-flow-experimental-dispatch-reduce %s | IreeFileCheck %s
 // TODO(b/142903911): figure out swiftshader+asan crash:
 // RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir %s -iree-hal-target-backends=vulkan-spirv --run=false)
 
diff --git a/test/e2e/xla/reduce_int.mlir b/test/e2e/xla/reduce_int.mlir
index 51ceeb9..6e8c43d 100644
--- a/test/e2e/xla/reduce_int.mlir
+++ b/test/e2e/xla/reduce_int.mlir
@@ -1,4 +1,5 @@
 // RUN: iree-run-mlir -iree-hal-target-backends=interpreter-bytecode %s | IreeFileCheck %s
+// RUN: iree-run-mlir -iree-hal-target-backends=vmla -iree-flow-experimental-dispatch-reduce %s | IreeFileCheck %s
 // TODO(b/146030213): This test fails cause the initialization isn't done
 // correctly within the vulkan backend. Enable this test once that is done.
 // RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir -iree-hal-target-backends=vulkan-spirv --run=false %s)