Introduce pooling ops in VMLA. - Add pooling.max/min/sum ops. - Add support for lowering xla_hlo.reduce_window to VMLA pooling ops. - Enable e2e/xla/reduce_window.mlir test. Fixes https://github.com/google/iree/issues/1225 PiperOrigin-RevId: 305943244
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertReductionOps.cpp b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertReductionOps.cpp index 93f72a3..b36d00a 100644 --- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertReductionOps.cpp +++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertReductionOps.cpp
@@ -224,6 +224,84 @@ TypeConverter &typeConverter; }; +struct BuiltinPoolingOpConversion + : public OpConversionPattern<xla_hlo::ReduceWindowOp> { + BuiltinPoolingOpConversion(MLIRContext *context, TypeConverter &typeConverter) + : OpConversionPattern(context, /*benefit=*/1000), + typeConverter(typeConverter) {} + + LogicalResult matchAndRewrite( + xla_hlo::ReduceWindowOp srcOp, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + if (srcOp.body().getBlocks().size() > 1) { + // Control flow within the computation is not supported; bail to fallback. + return failure(); + } else if (srcOp.body().front().getOperations().size() > 2) { + // Require splitting first. + return failure(); + } + + auto operand = operands[0]; + auto operandShape = VMLAConversionTarget::getTensorShape( + srcOp.getLoc(), srcOp.operand(), typeConverter, rewriter); + auto initValue = operands[1]; + auto initValueShape = VMLAConversionTarget::getTensorShape( + srcOp.getLoc(), srcOp.init_value(), typeConverter, rewriter); + auto dst = VMLAConversionTarget::allocateOutputBuffer( + srcOp.getLoc(), srcOp.getResult(), typeConverter, rewriter); + auto dstShape = VMLAConversionTarget::getTensorShape( + srcOp.getLoc(), srcOp.getResult(), typeConverter, rewriter); + auto elementType = + srcOp.operand().getType().cast<ShapedType>().getElementType(); + + SmallVector<int32_t, 4> windowDimensions; + for (const auto &value : srcOp.window_dimensions().getIntValues()) + windowDimensions.push_back(value.getSExtValue()); + int rank = windowDimensions.size(); + SmallVector<int32_t, 4> windowStrides(rank, 1); + SmallVector<int32_t, 4> padding(rank, 0); + for (unsigned i = 0; i < rank; ++i) { + if (srcOp.window_strides()) + windowStrides[i] = srcOp.window_stridesAttr().getValue<int64_t>(i); + if (srcOp.padding()) + padding[i] = srcOp.paddingAttr().getValue<int64_t>({i, 0}); + } + + auto &computeOp = *srcOp.body().front().begin(); + if (isa<mlir::AddIOp>(computeOp) || isa<mlir::AddFOp>(computeOp) || + isa<xla_hlo::AddOp>(computeOp)) { + rewriter.create<IREE::VMLA::PoolingSumOp>( + srcOp.getLoc(), operand, operandShape, initValue, initValueShape, dst, + dstShape, TypeAttr::get(elementType), + rewriter.getI32VectorAttr(windowDimensions), + rewriter.getI32VectorAttr(windowStrides), + rewriter.getI32VectorAttr(padding)); + } else if (isa<xla_hlo::MinOp>(computeOp)) { + rewriter.create<IREE::VMLA::PoolingMinOp>( + srcOp.getLoc(), operand, operandShape, initValue, initValueShape, dst, + dstShape, TypeAttr::get(elementType), + rewriter.getI32VectorAttr(windowDimensions), + rewriter.getI32VectorAttr(windowStrides), + rewriter.getI32VectorAttr(padding)); + } else if (isa<xla_hlo::MaxOp>(computeOp)) { + rewriter.create<IREE::VMLA::PoolingMaxOp>( + srcOp.getLoc(), operand, operandShape, initValue, initValueShape, dst, + dstShape, TypeAttr::get(elementType), + rewriter.getI32VectorAttr(windowDimensions), + rewriter.getI32VectorAttr(windowStrides), + rewriter.getI32VectorAttr(padding)); + } else { + computeOp.emitRemark() << "unsupported builtin reduction operation"; + return failure(); + } + + rewriter.replaceOp(srcOp, {dst}); + return success(); + } + + TypeConverter &typeConverter; +}; + } // namespace void populateHLOReductionToVMLAPatterns(MLIRContext *context, @@ -232,6 +310,7 @@ patterns.insert<SplitIndependentReductionOpConversion>(context, typeConverter); patterns.insert<BuiltinReduceOpConversion>(context, typeConverter); + patterns.insert<BuiltinPoolingOpConversion>(context, typeConverter); patterns.insert<GenericReduceOpConversion>(context, typeConverter); }
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/reduce_window.mlir b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/reduce_window.mlir new file mode 100644 index 0000000..0e92a70 --- /dev/null +++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/reduce_window.mlir
@@ -0,0 +1,51 @@ +// RUN: iree-opt -split-input-file -iree-vmla-conversion -cse %s | IreeFileCheck %s + +// CHECK-LABEL: @pooling_max +func @pooling_max(%arg0: tensor<1x4x6x1xf32>) -> tensor<1x2x2x1xf32> + attributes { sym_visibility = "private" } { + // CHECK: vmla.pooling.max + %cst = constant dense<0.000000e+00> : tensor<f32> + %0 = "xla_hlo.reduce_window"(%arg0, %cst) ( { + ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>): // no predecessors + %1 = xla_hlo.maximum %arg1, %arg2 : tensor<f32> + "xla_hlo.return"(%1) : (tensor<f32>) -> () + }) {window_dimensions = dense<[1, 2, 3, 1]> : tensor<4xi64>, + window_strides = dense<1> : tensor<4xi64> + } : (tensor<1x4x6x1xf32>, tensor<f32>) -> tensor<1x2x2x1xf32> + return %0 : tensor<1x2x2x1xf32> +} + +// ----- + +// CHECK-LABEL: @pooling_min +func @pooling_min(%arg0: tensor<1x4x6x1xi32>) -> tensor<1x2x2x1xi32> + attributes { sym_visibility = "private" } { + // CHECK: vmla.pooling.min + %cst = constant dense<0> : tensor<i32> + %0 = "xla_hlo.reduce_window"(%arg0, %cst) ( { + ^bb0(%arg1: tensor<i32>, %arg2: tensor<i32>): // no predecessors + %1 = xla_hlo.minimum %arg1, %arg2 : tensor<i32> + "xla_hlo.return"(%1) : (tensor<i32>) -> () + }) {window_dimensions = dense<[1, 2, 3, 1]> : tensor<4xi64>, + window_strides = dense<1> : tensor<4xi64> + } : (tensor<1x4x6x1xi32>, tensor<i32>) -> tensor<1x2x2x1xi32> + return %0 : tensor<1x2x2x1xi32> +} + +// ----- + +// CHECK-LABEL: @pooling_sum +func @pooling_sum(%arg0: tensor<4x6xf32>) -> tensor<3x4xf32> attributes + { sym_visibility = "private" } { + // CHECK: vmla.pooling.sum + %cst = constant dense<0.000000e+00> : tensor<f32> + %0 = "xla_hlo.reduce_window"(%arg0, %cst) ( { + ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>): // no predecessors + %1 = xla_hlo.add %arg1, %arg2 : tensor<f32> + "xla_hlo.return"(%1) : (tensor<f32>) -> () + }) {window_dimensions = dense<[2, 3]> : tensor<2xi64>, + window_strides = dense<1> : tensor<2xi64>, + padding = dense<[[1, 0], [2, 0]]> : tensor<2x2xi64> + } : (tensor<4x6xf32>, tensor<f32>) -> tensor<3x4xf32> + return %0 : tensor<3x4xf32> +}
diff --git a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp index df5ddd5..9dbaf9b 100644 --- a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp +++ b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp
@@ -324,6 +324,10 @@ VMLA_TYPED_IMPORT_OP(IREE::VMLA::ReduceMinOp, "vmla.reduce.min"); VMLA_TYPED_IMPORT_OP(IREE::VMLA::ReduceMaxOp, "vmla.reduce.max"); + VMLA_TYPED_IMPORT_OP(IREE::VMLA::PoolingSumOp, "vmla.pooling.sum"); + VMLA_TYPED_IMPORT_OP(IREE::VMLA::PoolingMinOp, "vmla.pooling.min"); + VMLA_TYPED_IMPORT_OP(IREE::VMLA::PoolingMaxOp, "vmla.pooling.max"); + VMLA_IMPORT_OP(IREE::VMLA::InterfaceConstOp, "vmla.interface.const"); VMLA_IMPORT_OP(IREE::VMLA::InterfaceBindingOp, "vmla.interface.binding"); }
diff --git a/iree/compiler/Dialect/VMLA/IR/VMLAOps.td b/iree/compiler/Dialect/VMLA/IR/VMLAOps.td index d064ea4..ff7e863 100644 --- a/iree/compiler/Dialect/VMLA/IR/VMLAOps.td +++ b/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
@@ -401,6 +401,26 @@ def VMLA_ReduceMinOp : VMLA_ReduceOp<"reduce.min">; def VMLA_ReduceMaxOp : VMLA_ReduceOp<"reduce.max">; +class VMLA_PoolingOp<string mnemonic, list<OpTrait> traits = []> : + VMLA_ElementTypeOp<mnemonic, !listconcat(traits, [VMLA_IncludeShapes])> { + let arguments = (ins + VMLA_Buffer:$src, + VMLA_Shape:$src_shape, + VMLA_Buffer:$init, + VMLA_Shape:$init_shape, + VMLA_Buffer:$dst, + VMLA_Shape:$dst_shape, + VMLA_AnyTypeAttr:$element_type, + I32ElementsAttr:$window_dimensions, + I32ElementsAttr:$window_strides, + I32ElementsAttr:$padding + ); +} + +def VMLA_PoolingSumOp : VMLA_PoolingOp<"pooling.sum">; +def VMLA_PoolingMinOp : VMLA_PoolingOp<"pooling.min">; +def VMLA_PoolingMaxOp : VMLA_PoolingOp<"pooling.max">; + //===----------------------------------------------------------------------===// // VMLA Ops: ABI //===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/VMLA/vmla.imports.mlir b/iree/compiler/Dialect/VMLA/vmla.imports.mlir index a01a097..069fa2e 100644 --- a/iree/compiler/Dialect/VMLA/vmla.imports.mlir +++ b/iree/compiler/Dialect/VMLA/vmla.imports.mlir
@@ -412,4 +412,103 @@ %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ... ) +vm.import @pooling.sum.i8( + %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ..., + %init : !vm.ref<!vmla.buffer>, %init_shape : i32 ..., + %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ..., + %window_dimensions: i32 ..., + %window_strides: i32 ..., + %padding: i32 ... +) +vm.import @pooling.sum.i16( + %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ..., + %init : !vm.ref<!vmla.buffer>, %init_shape : i32 ..., + %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ..., + %window_dimensions: i32 ..., + %window_strides: i32 ..., + %padding: i32 ... +) +vm.import @pooling.sum.i32( + %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ..., + %init : !vm.ref<!vmla.buffer>, %init_shape : i32 ..., + %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ..., + %window_dimensions: i32 ..., + %window_strides: i32 ..., + %padding: i32 ... +) +vm.import @pooling.sum.f32( + %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ..., + %init : !vm.ref<!vmla.buffer>, %init_shape : i32 ..., + %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ..., + %window_dimensions: i32 ..., + %window_strides: i32 ..., + %padding: i32 ... +) + +vm.import @pooling.min.i8( + %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ..., + %init : !vm.ref<!vmla.buffer>, %init_shape : i32 ..., + %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ..., + %window_dimensions: i32 ..., + %window_strides: i32 ..., + %padding: i32 ... +) +vm.import @pooling.min.i16( + %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ..., + %init : !vm.ref<!vmla.buffer>, %init_shape : i32 ..., + %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ..., + %window_dimensions: i32 ..., + %window_strides: i32 ..., + %padding: i32 ... +) +vm.import @pooling.min.i32( + %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ..., + %init : !vm.ref<!vmla.buffer>, %init_shape : i32 ..., + %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ..., + %window_dimensions: i32 ..., + %window_strides: i32 ..., + %padding: i32 ... +) +vm.import @pooling.min.f32( + %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ..., + %init : !vm.ref<!vmla.buffer>, %init_shape : i32 ..., + %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ..., + %window_dimensions: i32 ..., + %window_strides: i32 ..., + %padding: i32 ... +) + +vm.import @pooling.max.i8( + %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ..., + %init : !vm.ref<!vmla.buffer>, %init_shape : i32 ..., + %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ..., + %window_dimensions: i32 ..., + %window_strides: i32 ..., + %padding: i32 ... +) +vm.import @pooling.max.i16( + %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ..., + %init : !vm.ref<!vmla.buffer>, %init_shape : i32 ..., + %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ..., + %window_dimensions: i32 ..., + %window_strides: i32 ..., + %padding: i32 ... +) +vm.import @pooling.max.i32( + %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ..., + %init : !vm.ref<!vmla.buffer>, %init_shape : i32 ..., + %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ..., + %window_dimensions: i32 ..., + %window_strides: i32 ..., + %padding: i32 ... +) +vm.import @pooling.max.f32( + %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ..., + %init : !vm.ref<!vmla.buffer>, %init_shape : i32 ..., + %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ..., + %window_dimensions: i32 ..., + %window_strides: i32 ..., + %padding: i32 ... +) + } // module
diff --git a/iree/hal/vmla/op_kernels.h b/iree/hal/vmla/op_kernels.h index 370b2ad..37eb313 100644 --- a/iree/hal/vmla/op_kernels.h +++ b/iree/hal/vmla/op_kernels.h
@@ -392,6 +392,33 @@ const Shape& src_shape, const Shape& dst_shape); }; +struct PoolingSum { + template <typename T> + static Status Execute(absl::Span<const T> src_buffer, + absl::Span<const T> init_buffer, + absl::Span<T> dst_buffer, const Shape& src_shape, + const Shape& dst_shape, const Shape& window_dimensions, + const Shape& strides, const Shape& pad_low); +}; + +struct PoolingMin { + template <typename T> + static Status Execute(absl::Span<const T> src_buffer, + absl::Span<const T> init_buffer, + absl::Span<T> dst_buffer, const Shape& src_shape, + const Shape& dst_shape, const Shape& window_dimensions, + const Shape& strides, const Shape& pad_low); +}; + +struct PoolingMax { + template <typename T> + static Status Execute(absl::Span<const T> src_buffer, + absl::Span<const T> init_buffer, + absl::Span<T> dst_buffer, const Shape& src_shape, + const Shape& dst_shape, const Shape& window_dimensions, + const Shape& strides, const Shape& pad_low); +}; + } // namespace kernels } // namespace vmla } // namespace hal
diff --git a/iree/hal/vmla/op_kernels_generic.h b/iree/hal/vmla/op_kernels_generic.h index e21bc63..2e90a42 100644 --- a/iree/hal/vmla/op_kernels_generic.h +++ b/iree/hal/vmla/op_kernels_generic.h
@@ -779,6 +779,92 @@ src_buffer, init_buffer, dst_buffer, dimension, src_shape, dst_shape); } +namespace impl { + +template <typename T, typename KernelImpl> +Status ComputePoolingWindow(absl::Span<const T> src_buffer, + absl::Span<const int> src_indices, + const Shape& src_shape, T init_value, + const Shape& window_dimensions, T* dst_value) { + int rank = src_shape.size(); + absl::InlinedVector<int, 8> window_indices(rank, 0); + auto getSrcValue = [&]() -> T { + int flat_idx = 0; + for (int i = 0; i < rank; ++i) { + int idx = src_indices[i] + window_indices[i]; + if (idx < 0 || idx >= src_shape[i]) return init_value; + flat_idx = flat_idx * src_shape[i] + idx; + } + return src_buffer[flat_idx]; + }; + + *dst_value = init_value; + for (int i = 0, e = window_dimensions.element_count(); i < e; ++i) { + KernelImpl()(dst_value, getSrcValue()); + IncrementShapeIndex(absl::MakeSpan(window_indices), window_dimensions); + } + return OkStatus(); +} + +template <typename T, typename KernelImpl> +Status GenericPooling(absl::Span<const T> src_buffer, + absl::Span<const T> init_buffer, absl::Span<T> dst_buffer, + const Shape& src_shape, const Shape& dst_shape, + const Shape& window_dimensions, const Shape& strides, + const Shape& pad_low) { + int rank = src_shape.size(); + absl::InlinedVector<int, 8> src_indices(rank, 0); + absl::InlinedVector<int, 8> dst_indices(rank, 0); + for (int i = 0, e = dst_shape.element_count(); i < e; ++i) { + for (int j = 0; j < rank; ++j) + src_indices[j] = dst_indices[j] * strides[j] - pad_low[j]; + Status status = ComputePoolingWindow<T, KernelImpl>( + src_buffer, src_indices, src_shape, init_buffer[0], window_dimensions, + &dst_buffer[i]); + if (!status.ok()) return status; + IncrementShapeIndex(absl::MakeSpan(dst_indices), dst_shape); + } + return OkStatus(); +} + +} // namespace impl + +template <typename T> +Status PoolingSum::Execute(absl::Span<const T> src_buffer, + absl::Span<const T> init_buffer, + absl::Span<T> dst_buffer, const Shape& src_shape, + const Shape& dst_shape, + const Shape& window_dimensions, const Shape& strides, + const Shape& pad_low) { + return impl::GenericPooling<T, impl::SumKernel>( + src_buffer, init_buffer, dst_buffer, src_shape, dst_shape, + window_dimensions, strides, pad_low); +} + +template <typename T> +Status PoolingMin::Execute(absl::Span<const T> src_buffer, + absl::Span<const T> init_buffer, + absl::Span<T> dst_buffer, const Shape& src_shape, + const Shape& dst_shape, + const Shape& window_dimensions, const Shape& strides, + const Shape& pad_low) { + return impl::GenericPooling<T, impl::MinKernel>( + src_buffer, init_buffer, dst_buffer, src_shape, dst_shape, + window_dimensions, strides, pad_low); +} + +template <typename T> +Status PoolingMax::Execute(absl::Span<const T> src_buffer, + absl::Span<const T> init_buffer, + absl::Span<T> dst_buffer, const Shape& src_shape, + const Shape& dst_shape, + const Shape& window_dimensions, const Shape& strides, + const Shape& pad_low) { + return impl::GenericPooling<T, impl::MaxKernel>( + src_buffer, init_buffer, dst_buffer, src_shape, dst_shape, + window_dimensions, strides, pad_low); +} + } // namespace kernels } // namespace vmla } // namespace hal
diff --git a/iree/hal/vmla/op_kernels_test.cc b/iree/hal/vmla/op_kernels_test.cc index c39215c..b01c9ac 100644 --- a/iree/hal/vmla/op_kernels_test.cc +++ b/iree/hal/vmla/op_kernels_test.cc
@@ -401,6 +401,63 @@ } } +TEST(PoolingMax, NoOverlapping) { + Shape src_shape = {1, 4, 6, 1}; + Shape dst_shape = {1, 2, 2, 1}; + Shape window_sizes = {1, 2, 3, 1}; + Shape strides = {1, 2, 3, 1}; + Shape pad_low = {0, 0, 0, 0}; + std::vector<int> src_buffer = MakeIota<int>(src_shape.element_count()); + std::vector<int> init_buffer(1, 0.0f); + std::vector<int> dst_buffer(dst_shape.element_count(), 0.0f); + std::vector<int> expected_dst = {9, 12, 21, 24}; + + EXPECT_OK(PoolingMax::Execute<int>( + src_buffer, init_buffer, absl::MakeSpan(dst_buffer), src_shape, dst_shape, + window_sizes, strides, pad_low)); + EXPECT_EQ(dst_buffer, expected_dst); +} + +TEST(PoolingMin, Padding) { + // Padded input: + // 100 100 100 100 + // 100 1 2 3 + // 100 4 5 6 + Shape src_shape = {2, 3}; + Shape dst_shape = {2, 3}; + Shape window_sizes = {2, 2}; + Shape strides = {1, 1}; + Shape pad_low = {1, 1}; + std::vector<int> src_buffer = MakeIota<int>(src_shape.element_count()); + std::vector<int> init_buffer(1, 100.0); + std::vector<int> dst_buffer(dst_shape.element_count(), 0.0f); + std::vector<int> expected_dst = {1, 1, 2, 1, 1, 2}; + + EXPECT_OK(PoolingMin::Execute<int>( + src_buffer, init_buffer, absl::MakeSpan(dst_buffer), src_shape, dst_shape, + window_sizes, strides, pad_low)); + EXPECT_EQ(dst_buffer, expected_dst); +} + +TEST(PoolingSum, Overlapping) { + Shape src_shape = {3, 4}; + Shape dst_shape = {2, 2}; + Shape window_sizes = {2, 3}; + Shape strides = {1, 1}; + Shape pad_low = {0, 0}; + std::vector<float> src_buffer = MakeIota<float>(src_shape.element_count()); + std::vector<float> init_buffer(1, 0.0f); + std::vector<float> dst_buffer(dst_shape.element_count(), 0.0f); + std::vector<float> expected_dst = {24, 30, 48, 54}; + + EXPECT_OK(PoolingSum::Execute<float>( + src_buffer, init_buffer, absl::MakeSpan(dst_buffer), src_shape, dst_shape, + window_sizes, strides, pad_low)); + for (int i = 0; i < dst_buffer.size(); ++i) { + EXPECT_NEAR(expected_dst[i], dst_buffer[i], kEpsilon); + } +} + TEST(Conv2d, NoDilation) { Shape input_shape = {4, 5, 2}; Shape filter_shape = {3, 2, 2, 1};
diff --git a/iree/hal/vmla/vmla_module.cc b/iree/hal/vmla/vmla_module.cc index 9878cb3..cda4327 100644 --- a/iree/hal/vmla/vmla_module.cc +++ b/iree/hal/vmla/vmla_module.cc
@@ -744,6 +744,31 @@ IREE_VMLA_REDUCTION_OP(ReduceMaxI32, kernels::ReduceMax, int32_t); IREE_VMLA_REDUCTION_OP(ReduceMaxF32, kernels::ReduceMax, float); +#define IREE_VMLA_POOLING_OP(name, kernel, type) \ + Status name(vm::ref<Buffer> src, iree_vmla_shape_t src_shape, \ + vm::ref<Buffer> init, iree_vmla_shape_t init_shape, \ + vm::ref<Buffer> dst, iree_vmla_shape_t dst_shape, \ + iree_vmla_shape_t window_dimensions, iree_vmla_shape_t strides, \ + iree_vmla_shape_t pad_low) { \ + IREE_TRACE_SCOPE0("VMLAModuleState::" #name); \ + return kernel::Execute<type>(src->As<type>(), init->As<type>(), \ + dst->As<type>(), Shape(src_shape), \ + Shape(dst_shape), Shape(window_dimensions), \ + Shape(strides), Shape(pad_low)); \ + } + IREE_VMLA_POOLING_OP(PoolingSumI8, kernels::PoolingSum, int8_t); + IREE_VMLA_POOLING_OP(PoolingSumI16, kernels::PoolingSum, int16_t); + IREE_VMLA_POOLING_OP(PoolingSumI32, kernels::PoolingSum, int32_t); + IREE_VMLA_POOLING_OP(PoolingSumF32, kernels::PoolingSum, float); + IREE_VMLA_POOLING_OP(PoolingMinI8, kernels::PoolingMin, int8_t); + IREE_VMLA_POOLING_OP(PoolingMinI16, kernels::PoolingMin, int16_t); + IREE_VMLA_POOLING_OP(PoolingMinI32, kernels::PoolingMin, int32_t); + IREE_VMLA_POOLING_OP(PoolingMinF32, kernels::PoolingMin, float); + IREE_VMLA_POOLING_OP(PoolingMaxI8, kernels::PoolingMax, int8_t); + IREE_VMLA_POOLING_OP(PoolingMaxI16, kernels::PoolingMax, int16_t); + IREE_VMLA_POOLING_OP(PoolingMaxI32, kernels::PoolingMax, int32_t); + IREE_VMLA_POOLING_OP(PoolingMaxF32, kernels::PoolingMax, float); + private: iree_allocator_t allocator_; @@ -907,6 +932,19 @@ vm::MakeNativeFunction("reduce.max.i32", &VMLAModuleState::ReduceMaxI32), vm::MakeNativeFunction("reduce.max.f32", &VMLAModuleState::ReduceMaxF32), + vm::MakeNativeFunction("pooling.sum.i8", &VMLAModuleState::PoolingSumI8), + vm::MakeNativeFunction("pooling.sum.i16", &VMLAModuleState::PoolingSumI16), + vm::MakeNativeFunction("pooling.sum.i32", &VMLAModuleState::PoolingSumI32), + vm::MakeNativeFunction("pooling.sum.f32", &VMLAModuleState::PoolingSumF32), + vm::MakeNativeFunction("pooling.min.i8", &VMLAModuleState::PoolingMinI8), + vm::MakeNativeFunction("pooling.min.i16", &VMLAModuleState::PoolingMinI16), + vm::MakeNativeFunction("pooling.min.i32", &VMLAModuleState::PoolingMinI32), + vm::MakeNativeFunction("pooling.min.f32", &VMLAModuleState::PoolingMinF32), + vm::MakeNativeFunction("pooling.max.i8", &VMLAModuleState::PoolingMaxI8), + vm::MakeNativeFunction("pooling.max.i16", &VMLAModuleState::PoolingMaxI16), + vm::MakeNativeFunction("pooling.max.i32", &VMLAModuleState::PoolingMaxI32), + vm::MakeNativeFunction("pooling.max.f32", &VMLAModuleState::PoolingMaxF32), + vm::MakeNativeFunction("batch.matmul.f32f32.f32", &VMLAModuleState::BatchMatMulF32F32F32),
diff --git a/iree/test/e2e/xla/BUILD b/iree/test/e2e/xla/BUILD index efd59d3..c718049 100644 --- a/iree/test/e2e/xla/BUILD +++ b/iree/test/e2e/xla/BUILD
@@ -48,6 +48,7 @@ "pad.mlir", "reduce_float.mlir", "reduce_int.mlir", + "reduce_window.mlir", "rem.mlir", "reshape.mlir", "reshape_adddims.mlir",