Implementing VMLA reduction conversion.
This allows standard reduce ops via the reduce-in-dispatch path to convert
to the built-in VMLA reduction ops. There's a placeholder for a generic
fallback path that emits a VM loop, however that's not used yet (as we
don't support those in other backends yet either).
WIP #936.
Closes https://github.com/google/iree/pull/955
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/iree/pull/955 from google:benvanik-vmla-reduce-conversion 35b02988e53499848895ba69697292b26019644e
PiperOrigin-RevId: 299015281
diff --git a/build_tools/cmake/iree_copts.cmake b/build_tools/cmake/iree_copts.cmake
index 3ef833e..4ba9497 100644
--- a/build_tools/cmake/iree_copts.cmake
+++ b/build_tools/cmake/iree_copts.cmake
@@ -43,6 +43,7 @@
"-Wno-undef"
MSVC_OR_CLANG_CL
"/DWIN32_LEAN_AND_MEAN"
+ "/wd4624"
# TODO(benvanik): figure out if really required or accidentally enabled.
"/EHsc"
)
diff --git a/iree/compiler/Dialect/Flow/Transforms/OutlineReductionRegions.cpp b/iree/compiler/Dialect/Flow/Transforms/OutlineReductionRegions.cpp
index 2d89be5..bee1284 100644
--- a/iree/compiler/Dialect/Flow/Transforms/OutlineReductionRegions.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/OutlineReductionRegions.cpp
@@ -151,7 +151,8 @@
for (uint32_t i = 0; i < dimensions.getNumElements(); ++i) {
sortedDimensions.push_back(dimensions.getValue<IntegerAttr>({i}).getInt());
}
- llvm::sort(sortedDimensions, [](int32_t a, int32_t b) { return a - b; });
+ llvm::sort(sortedDimensions,
+ [](int32_t a, int32_t b) { return (a - b) > 0; });
for (auto dimension : llvm::enumerate(sortedDimensions)) {
// Create the executable with the region cloned into it.
ExecutableOp executableOp;
diff --git a/iree/compiler/Dialect/HAL/Target/VMLA/test/smoketest.mlir b/iree/compiler/Dialect/HAL/Target/VMLA/test/smoketest.mlir
index f4fa468..e33d70c 100644
--- a/iree/compiler/Dialect/HAL/Target/VMLA/test/smoketest.mlir
+++ b/iree/compiler/Dialect/HAL/Target/VMLA/test/smoketest.mlir
@@ -42,18 +42,57 @@
// -----
-// TODO(benvanik): vmla reduction.
-// flow.executable @reduction_ex_reduce_0_dim_0 {
-// flow.reduction.entry @reduction_rgn_reduce_0_dim_0_entry apply(@reduction_rgn_reduce_0_dim_0) attributes {
-// dimension = 1 : i32,
-// workgroup_size = dense<[32, 1, 1]> : vector<3xi32>,
-// workload = dense<[4, 1, 1]> : vector<3xi32>
-// }
-// module {
-// func @reduction_rgn_reduce_0_dim_0_entry(tensor<4x8xf32>, tensor<f32>) -> tensor<4xf32>
-// func @reduction_rgn_reduce_0_dim_0(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
-// %0 = xla_hlo.add %arg0, %arg1 : tensor<f32>
-// return %0 : tensor<f32>
-// }
-// }
-// }
+flow.executable @reduction_ex_dispatch_0 {
+ flow.dispatch.entry @reduction_ex_dispatch_0 attributes {workload = dense<[4, 1, 1]> : vector<3xi32>}
+ module {
+ func @reduction_ex_dispatch_0(%arg0: tensor<4x8xf32>) -> tensor<4xf32> {
+ %cst = constant dense<0.000000e+00> : tensor<f32>
+ %0 = "xla_hlo.reduce"(%arg0, %cst) ( {
+ ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>): // no predecessors
+ %1 = xla_hlo.add %arg1, %arg2 : tensor<f32>
+ "xla_hlo.return"(%1) : (tensor<f32>) -> ()
+ }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xf32>, tensor<f32>) -> tensor<4xf32>
+ return %0 : tensor<4xf32>
+ }
+ }
+}
+
+// CHECK-LABEL: hal.executable @reduction_ex_dispatch_0 {
+// CHECK-NEXT: hal.interface @legacy_io {
+// CHECK-NEXT: hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+// CHECK-NEXT: hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write|Discard"
+// CHECK-NEXT: }
+// CHECK-NEXT: hal.executable.entry_point @reduction_ex_dispatch_0 attributes {interface = @legacy_io, ordinal = 0 : i32, signature = (tensor<4x8xf32>) -> tensor<4xf32>, workgroup_size = dense<1> : vector<3xi32>}
+// CHECK-NEXT: hal.executable.binary attributes {
+// CHECK-SAME: data = dense<
+// CHECK-SAME: format = 1447906369 : i32} {
+// CHECK-NEXT: vm.module @module {
+// CHECK-NEXT: vm.rodata @reduction_ex_dispatch_0_const_0 dense<0.000000e+00> : tensor<f32>
+// CHECK-NEXT: vm.func @reduction_ex_dispatch_0(%arg0: !vm.ref<!vmla.interface>) attributes {ordinal = 0 : i32} {
+// CHECK-NEXT: %zero = vm.const.i32.zero : i32
+// CHECK-NEXT: %c128 = vm.const.i32 128 : i32
+// CHECK-NEXT: %c16 = vm.const.i32 16 : i32
+// CHECK-NEXT: %c4 = vm.const.i32 4 : i32
+// CHECK-NEXT: %c8 = vm.const.i32 8 : i32
+// CHECK-NEXT: %c1 = vm.const.i32 1 : i32
+// CHECK-NEXT: %reduction_ex_dispatch_0_const_0 = vm.const.ref.rodata @reduction_ex_dispatch_0_const_0 : !vm.ref<!iree.byte_buffer>
+// CHECK-NEXT: %ref = vm.call @vmla.buffer.const(%reduction_ex_dispatch_0_const_0) : (!vm.ref<!iree.byte_buffer>) -> !vm.ref<!vmla.buffer>
+// CHECK-NEXT: %ref_0 = vm.call @vmla.interface.binding(%arg0, %zero, %zero) : (!vm.ref<!vmla.interface>, i32, i32) -> !vm.ref<!vmla.buffer>
+// CHECK-NEXT: %ref_1 = vm.call @vmla.buffer.view(%ref_0, %zero, %c128) : (!vm.ref<!vmla.buffer>, i32, i32) -> !vm.ref<!vmla.buffer>
+// CHECK-NEXT: %ref_2 = vm.call @vmla.buffer.alloc(%c16) : (i32) -> !vm.ref<!vmla.buffer>
+// CHECK-NEXT: vm.call.variadic @vmla.reduce.sum.f32(%ref_1, [%c4, %c8], %ref, [], %c1, %ref_2, [%c4]) : (!vm.ref<!vmla.buffer>, i32..., !vm.ref<!vmla.buffer>, i32..., i32, !vm.ref<!vmla.buffer>, i32...)
+// CHECK-NEXT: %ref_3 = vm.call @vmla.interface.binding(%arg0, %zero, %c1) : (!vm.ref<!vmla.interface>, i32, i32) -> !vm.ref<!vmla.buffer>
+// CHECK-NEXT: vm.call @vmla.buffer.copy(%ref_2, %zero, %ref_3, %zero, %c16) : (!vm.ref<!vmla.buffer>, i32, !vm.ref<!vmla.buffer>, i32, i32) -> ()
+// CHECK-NEXT: vm.return
+// CHECK-NEXT: }
+// CHECK-NEXT: vm.export @reduction_ex_dispatch_0 attributes {ordinal = 0 : i32}
+// CHECK-NEXT: vm.rodata @reduction_ex_dispatch_0_impl_const_0 dense<0.000000e+00> : tensor<f32>
+// CHECK-NEXT: vm.import @vmla.interface.binding(%interface : !vm.ref<!vmla.interface>, %set : i32, %binding : i32) -> !vm.ref<!vmla.buffer> attributes {nosideeffects, ordinal = 0 : i32, sym_visibility = "private"}
+// CHECK-NEXT: vm.import @vmla.buffer.const(%value : !vm.ref<!iree.byte_buffer>) -> !vm.ref<!vmla.buffer> attributes {nosideeffects, ordinal = 1 : i32, sym_visibility = "private"}
+// CHECK-NEXT: vm.import @vmla.buffer.alloc(%byte_length : i32) -> !vm.ref<!vmla.buffer> attributes {nosideeffects, ordinal = 2 : i32, sym_visibility = "private"}
+// CHECK-NEXT: vm.import @vmla.buffer.view(%src : !vm.ref<!vmla.buffer>, %byte_offset : i32, %byte_length : i32) -> !vm.ref<!vmla.buffer> attributes {nosideeffects, ordinal = 3 : i32, sym_visibility = "private"}
+// CHECK-NEXT: vm.import @vmla.buffer.copy(%src : !vm.ref<!vmla.buffer>, %src_byte_offset : i32, %dst : !vm.ref<!vmla.buffer>, %dst_byte_offset : i32, %byte_length : i32) attributes {ordinal = 4 : i32, sym_visibility = "private"}
+// CHECK-NEXT: vm.import @vmla.reduce.sum.f32(%src : !vm.ref<!vmla.buffer>, %src_shape : i32..., %init : !vm.ref<!vmla.buffer>, %init_shape : i32..., %dimension : i32, %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32...) attributes {ordinal = 5 : i32, sym_visibility = "private"}
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+// CHECK-NEXT: }
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/BUILD b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/BUILD
index 3d6de77..52c81cc 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/BUILD
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/BUILD
@@ -21,6 +21,7 @@
name = "HLOToVMLA",
srcs = [
"ConvertHLOToVMLA.cpp",
+ "ConvertReductionOps.cpp",
],
hdrs = [
"ConvertHLOToVMLA.h",
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/CMakeLists.txt b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/CMakeLists.txt
index a3975e6..9574f1f 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/CMakeLists.txt
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/CMakeLists.txt
@@ -21,6 +21,7 @@
"ConvertHLOToVMLA.h"
SRCS
"ConvertHLOToVMLA.cpp"
+ "ConvertReductionOps.cpp"
DEPS
MLIRIR
MLIRPass
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
index 681a41e..1d50438 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
@@ -35,6 +35,10 @@
namespace mlir {
namespace iree_compiler {
+void populateHLOReductionToVMLAPatterns(MLIRContext *context,
+ OwningRewritePatternList &patterns,
+ TypeConverter &typeConverter);
+
namespace {
// Clones operand[0] and returns the result.
@@ -372,6 +376,9 @@
xla_hlo::PopulateXlaToStdPatterns(&patterns, context);
xla_hlo::PopulateUnfuseBatchNormPatterns(context, &patterns);
+ // xla_hlo.reduce and xla_hlo.reduce_window.
+ populateHLOReductionToVMLAPatterns(context, patterns, typeConverter);
+
// Simple 1:1 conversion patterns using the automated trait-based converter.
// Used for HLO ops that have equivalent VMLA ops such as most arithmetic ops.
patterns.insert<VMLAOpConversion<xla_hlo::AddOp, IREE::VMLA::AddOp>>(
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertReductionOps.cpp b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertReductionOps.cpp
new file mode 100644
index 0000000..1fdb1fc
--- /dev/null
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertReductionOps.cpp
@@ -0,0 +1,239 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/compiler/Dialect/IREE/IR/IREETypes.h"
+#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
+#include "iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.h"
+#include "iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.h"
+#include "iree/compiler/Dialect/VMLA/IR/VMLADialect.h"
+#include "iree/compiler/Dialect/VMLA/IR/VMLAOps.h"
+#include "iree/compiler/Dialect/VMLA/IR/VMLATypes.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+// Converts a simple xla_hlo.reduce op that performs independent individual
+// computations into a set of xla_hlo.reduce ops. This is an intermediate
+// conversion that may make it possible to use the much faster builtin VMLA
+// reduction ops.
+//
+// Only supports single dimensional reductions and assumes that unrolling has
+// been performed prior to conversion.
+struct SplitIndependentReductionOpConversion
+ : public OpConversionPattern<xla_hlo::ReduceOp> {
+ SplitIndependentReductionOpConversion(MLIRContext *context,
+ TypeConverter &typeConverter)
+ : OpConversionPattern(context), typeConverter(typeConverter) {}
+
+ PatternMatchResult matchAndRewrite(
+ xla_hlo::ReduceOp srcOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ if (srcOp.dimensions().getNumElements() > 1) {
+ srcOp.emitOpError() << "multi-dimensional reductions must be unrolled";
+ return matchFailure();
+ } else if (srcOp.body().getBlocks().size() > 1) {
+ // Control flow within the computation is not supported; bail to fallback.
+ return matchFailure();
+ }
+ auto &block = srcOp.body().getBlocks().front();
+ xla_hlo::ReduceOpOperandAdaptor newOperands(operands);
+ SmallVector<Value, 4> setResults;
+ for (auto &op : block) {
+ if (op.isKnownTerminator()) {
+ continue;
+ } else if (op.getOperands().size() != 2) {
+ // Only binary ops are supported for builtins.
+ return matchFailure();
+ }
+
+ // Determine which argument set this op is acting on. For the builtins we
+ // only support ops that act within a single set.
+ // Our arguments are expanded tuples like <lhs0, lhs1>, <rhs0, rhs1>, so
+ // this index gets the set offset.
+ int opSetIndex =
+ std::distance(block.args_begin(),
+ llvm::find(block.getArguments(), op.getOperand(0)));
+
+ for (auto operand : op.getOperands()) {
+ if (operand.getDefiningOp() != nullptr) {
+ // Operand comes from another op within the block; unsupported.
+ return matchFailure();
+ }
+ int operandSetIndex =
+ std::distance(block.args_begin(),
+ llvm::find(block.getArguments(), operand)) %
+ newOperands.operands().size();
+ if (operandSetIndex != opSetIndex) {
+ // Operand is not coming from the same set as the other operands of
+ // this op; unsupported.
+ return matchFailure();
+ }
+ }
+ for (auto result : op.getResults()) {
+ for (auto *user : result.getUsers()) {
+ if (!user->isKnownTerminator()) {
+ // Result is not directly returned from the block; unsupported.
+ return matchFailure();
+ }
+ }
+ }
+
+ // Create the new op for this set.
+ Value operandArg = srcOp.operands()[opSetIndex];
+ Value initArg = srcOp.init_values()[opSetIndex];
+ auto splitOp = rewriter.create<xla_hlo::ReduceOp>(
+ op.getLoc(), ValueRange{operandArg}, ValueRange{initArg},
+ srcOp.dimensionsAttr());
+ auto *splitBlock = new Block();
+ splitOp.body().getBlocks().push_back(splitBlock);
+ OpBuilder splitBuilder(splitBlock);
+ BlockAndValueMapping mapping;
+ for (auto operand : op.getOperands()) {
+ mapping.map(operand, splitBlock->addArgument(operand.getType()));
+ }
+ Operation *splitComputeOp = splitBuilder.clone(op, mapping);
+ splitBuilder.create<xla_hlo::ReturnOp>(
+ srcOp.getLoc(), ValueRange{*splitComputeOp->getResults().begin()});
+ setResults.push_back(*splitOp.getResults().begin());
+ }
+
+ rewriter.replaceOp(srcOp, setResults);
+ return matchSuccess();
+ }
+
+ TypeConverter &typeConverter;
+};
+
+// Converts an xla_hlo.reduce with a single op to a builtin reduce op.
+// This is meant to pair with the SplitIndependentReductionOpConversion that
+// tries to unfuse/divide combined reductions. If this cannot match then the
+// fallback path will be used and a VM loop will be emitted (slower, but can
+// perform any reduction).
+//
+// Only supports single dimensional reductions and assumes that unrolling has
+// been performed prior to conversion.
+struct BuiltinReduceOpConversion
+ : public OpConversionPattern<xla_hlo::ReduceOp> {
+ BuiltinReduceOpConversion(MLIRContext *context, TypeConverter &typeConverter)
+ : OpConversionPattern(context, /*benefit=*/1000),
+ typeConverter(typeConverter) {}
+
+ PatternMatchResult matchAndRewrite(
+ xla_hlo::ReduceOp srcOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ if (srcOp.dimensions().getNumElements() > 1) {
+ srcOp.emitOpError() << "multi-dimensional reductions must be unrolled";
+ return matchFailure();
+ } else if (srcOp.body().getBlocks().size() > 1) {
+ // Control flow within the computation is not supported; bail to fallback.
+ return matchFailure();
+ } else if (srcOp.body().front().getOperations().size() > 2) {
+ // Require splitting first.
+ return matchFailure();
+ }
+
+ auto operand = operands[0];
+ auto operandShape = VMLAConversionTarget::getTensorShape(
+ srcOp.getLoc(), srcOp.operands()[0], typeConverter, rewriter);
+ auto initValue = operands[1];
+ auto initValueShape = VMLAConversionTarget::getTensorShape(
+ srcOp.getLoc(), srcOp.init_values()[0], typeConverter, rewriter);
+ int dimension = srcOp.dimensions().getValue<IntegerAttr>({0}).getInt();
+ auto dst = VMLAConversionTarget::allocateOutputBuffer(
+ srcOp.getLoc(), srcOp.getResults()[0], typeConverter, rewriter);
+ auto dstShape = VMLAConversionTarget::getTensorShape(
+ srcOp.getLoc(), srcOp.getResults()[0], typeConverter, rewriter);
+ auto elementType =
+ srcOp.operands()[0].getType().cast<ShapedType>().getElementType();
+
+ auto &computeOp = *srcOp.body().front().begin();
+ if (isa<mlir::AddIOp>(computeOp) || isa<mlir::AddFOp>(computeOp) ||
+ isa<xla_hlo::AddOp>(computeOp)) {
+ rewriter.create<IREE::VMLA::ReduceSumOp>(
+ srcOp.getLoc(), operand, operandShape, initValue, initValueShape,
+ rewriter.getI32IntegerAttr(dimension), dst, dstShape,
+ TypeAttr::get(elementType));
+ } else if (isa<xla_hlo::MinOp>(computeOp)) {
+ rewriter.create<IREE::VMLA::ReduceMinOp>(
+ srcOp.getLoc(), operand, operandShape, initValue, initValueShape,
+ rewriter.getI32IntegerAttr(dimension), dst, dstShape,
+ TypeAttr::get(elementType));
+ } else if (isa<xla_hlo::MaxOp>(computeOp)) {
+ rewriter.create<IREE::VMLA::ReduceMaxOp>(
+ srcOp.getLoc(), operand, operandShape, initValue, initValueShape,
+ rewriter.getI32IntegerAttr(dimension), dst, dstShape,
+ TypeAttr::get(elementType));
+ } else {
+ computeOp.emitRemark() << "unsupported builtin reduction operation";
+ return matchFailure();
+ }
+
+ rewriter.replaceOp(srcOp, {dst});
+ return matchSuccess();
+ }
+
+ TypeConverter &typeConverter;
+};
+
+// Converts a generic xla_hlo.reduce to a VM loop.
+//
+// Only supports single dimensional reductions and assumes that unrolling has
+// been performed prior to conversion.
+struct GenericReduceOpConversion
+ : public OpConversionPattern<xla_hlo::ReduceOp> {
+ GenericReduceOpConversion(MLIRContext *context, TypeConverter &typeConverter)
+ : OpConversionPattern(context), typeConverter(typeConverter) {}
+
+ PatternMatchResult matchAndRewrite(
+ xla_hlo::ReduceOp srcOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ if (srcOp.dimensions().getNumElements() > 1) {
+ srcOp.emitOpError() << "multi-dimensional reductions must be unrolled";
+ return matchFailure();
+ }
+
+ // TODO(benvanik): emit VM loop around computation.
+ srcOp.emitOpError() << "generic reduction lowering not yet implemented";
+ return matchFailure();
+ }
+
+ TypeConverter &typeConverter;
+};
+
+} // namespace
+
+void populateHLOReductionToVMLAPatterns(MLIRContext *context,
+ OwningRewritePatternList &patterns,
+ TypeConverter &typeConverter) {
+ patterns.insert<SplitIndependentReductionOpConversion>(context,
+ typeConverter);
+ patterns.insert<BuiltinReduceOpConversion>(context, typeConverter);
+ patterns.insert<GenericReduceOpConversion>(context, typeConverter);
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/reduce.mlir b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/reduce.mlir
new file mode 100644
index 0000000..4923313
--- /dev/null
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/reduce.mlir
@@ -0,0 +1,45 @@
+// RUN: iree-opt -split-input-file -iree-vmla-conversion -cse %s | IreeFileCheck %s
+
+// CHECK-LABEL: @single_reduction
+func @single_reduction(%arg0: tensor<4x8xf32>) -> tensor<4xf32> attributes { sym_visibility = "private" } {
+ // CHECK-DAG: [[INIT:%.+]] = "vmla.constant"() {value = dense<0.000000e+00> : tensor<f32>} : () -> !vmla.buffer
+ %cst = constant dense<0.000000e+00> : tensor<f32>
+ // CHECK-DAG: [[SRC_SHAPE:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[4,8],i32>
+ // CHECK-DAG: [[INIT_SHAPE:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[],i32>
+ // CHECK-DAG: [[DST:%.+]] = "vmla.buffer.alloc"
+ // CHECK-DAG: [[DST_SHAPE:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[4],i32>
+ // CHECK-NEXT: "vmla.reduce.sum"(%arg0, [[SRC_SHAPE]], [[INIT]], [[INIT_SHAPE]], [[DST]], [[DST_SHAPE]]) {dimension = 1 : i32, element_type = f32} : (!vmla.buffer, !shapex.ranked_shape<[4,8],i32>, !vmla.buffer, !shapex.ranked_shape<[],i32>, !vmla.buffer, !shapex.ranked_shape<[4],i32>) -> ()
+ %0 = "xla_hlo.reduce"(%arg0, %cst) ( {
+ ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>): // no predecessors
+ %1 = xla_hlo.add %arg1, %arg2 : tensor<f32>
+ "xla_hlo.return"(%1) : (tensor<f32>) -> ()
+ }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xf32>, tensor<f32>) -> tensor<4xf32>
+ // CHECK-NEXT: return [[DST]] : !vmla.buffer
+ return %0 : tensor<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @multi_reduction
+func @multi_reduction(%arg0 : tensor<4x8xf32>, %arg1 : tensor<4x8xf32>) -> (tensor<4xf32>, tensor<4xf32>) attributes { sym_visibility = "private" } {
+ // CHECK-DAG: [[CST0:%.+]] = "vmla.constant"() {value = dense<0.000000e+00> : tensor<f32>} : () -> !vmla.buffer
+ %0 = constant dense<0.000000e+00> : tensor<f32>
+ // CHECK-DAG: [[CST1:%.+]] = "vmla.constant"() {value = dense<1.000000e+00> : tensor<f32>} : () -> !vmla.buffer
+ %1 = constant dense<1.000000e+00> : tensor<f32>
+ // CHECK-DAG: [[INPUT_SHAPE:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[4,8],i32>
+ // CHECK-DAG: [[SCALAR_SHAPE:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[],i32>
+ // CHECK-DAG: [[RESULT_SHAPE:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[4],i32>
+ // CHECK-DAG: [[RET_SIZE:%.+]] = muli
+ // CHECK-NEXT: [[RET0:%.+]] = "vmla.buffer.alloc"([[RET_SIZE]]) : (i32) -> !vmla.buffer
+ // CHECK-NEXT: "vmla.reduce.sum"(%arg0, [[INPUT_SHAPE]], [[CST0]], [[SCALAR_SHAPE]], [[RET0]], [[RESULT_SHAPE]]) {dimension = 1 : i32, element_type = f32} : (!vmla.buffer, !shapex.ranked_shape<[4,8],i32>, !vmla.buffer, !shapex.ranked_shape<[],i32>, !vmla.buffer, !shapex.ranked_shape<[4],i32>) -> ()
+ // CHECK-NEXT: [[RET1:%.+]] = "vmla.buffer.alloc"([[RET_SIZE]]) : (i32) -> !vmla.buffer
+ // CHECK-NEXT: "vmla.reduce.sum"(%arg1, [[INPUT_SHAPE]], [[CST1]], [[SCALAR_SHAPE]], [[RET1]], [[RESULT_SHAPE]]) {dimension = 1 : i32, element_type = f32} : (!vmla.buffer, !shapex.ranked_shape<[4,8],i32>, !vmla.buffer, !shapex.ranked_shape<[],i32>, !vmla.buffer, !shapex.ranked_shape<[4],i32>) -> ()
+ %2, %3 = "xla_hlo.reduce"(%arg0, %arg1, %0, %1) ( {
+ ^bb0(%arg0_lhs : tensor<f32>, %arg1_lhs : tensor<f32>, %arg0_rhs : tensor<f32>, %arg1_rhs : tensor<f32>):
+ %4 = xla_hlo.add %arg0_lhs, %arg0_rhs : tensor<f32>
+ %5 = xla_hlo.add %arg1_lhs, %arg1_rhs : tensor<f32>
+ "xla_hlo.return"(%4, %5) : (tensor<f32>, tensor<f32>) -> ()
+ }) {dimensions = dense<[1]> : tensor<1xi64>} : (tensor<4x8xf32>, tensor<4x8xf32>, tensor<f32>, tensor<f32>) -> (tensor<4xf32>, tensor<4xf32>)
+ // CHECK-NEXT: return [[RET0]], [[RET1]] : !vmla.buffer, !vmla.buffer
+ return %2, %3 : tensor<4xf32>, tensor<4xf32>
+}
diff --git a/iree/compiler/Dialect/VMLA/IR/VMLAOps.td b/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
index 4f52fe0..5e796bd 100644
--- a/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
+++ b/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
@@ -56,7 +56,7 @@
);
}
-def VMLA_BufferAllocOp : VMLA_PureOp<"buffer.alloc"> {
+def VMLA_BufferAllocOp : VMLA_Op<"buffer.alloc"> {
let arguments = (ins
VMLA_DeviceSize:$byte_length
);
@@ -65,7 +65,7 @@
);
}
-def VMLA_BufferCloneOp : VMLA_PureOp<"buffer.clone"> {
+def VMLA_BufferCloneOp : VMLA_Op<"buffer.clone"> {
let arguments = (ins
VMLA_Buffer:$src
);
diff --git a/test/e2e/xla/reduce_float.mlir b/test/e2e/xla/reduce_float.mlir
index e303014..3612a4e 100644
--- a/test/e2e/xla/reduce_float.mlir
+++ b/test/e2e/xla/reduce_float.mlir
@@ -1,4 +1,5 @@
// RUN: iree-run-mlir %s -iree-hal-target-backends=interpreter-bytecode | IreeFileCheck %s
+// RUN: iree-run-mlir -iree-hal-target-backends=vmla -iree-flow-experimental-dispatch-reduce %s | IreeFileCheck %s
// TODO(b/142903911): figure out swiftshader+asan crash:
// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir %s -iree-hal-target-backends=vulkan-spirv --run=false)
diff --git a/test/e2e/xla/reduce_int.mlir b/test/e2e/xla/reduce_int.mlir
index 51ceeb9..6e8c43d 100644
--- a/test/e2e/xla/reduce_int.mlir
+++ b/test/e2e/xla/reduce_int.mlir
@@ -1,4 +1,5 @@
// RUN: iree-run-mlir -iree-hal-target-backends=interpreter-bytecode %s | IreeFileCheck %s
+// RUN: iree-run-mlir -iree-hal-target-backends=vmla -iree-flow-experimental-dispatch-reduce %s | IreeFileCheck %s
// TODO(b/146030213): This test fails cause the initialization isn't done
// correctly within the vulkan backend. Enable this test once that is done.
// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir -iree-hal-target-backends=vulkan-spirv --run=false %s)