Introduce pooling ops in VMLA.

- Add pooling.max/min/sum ops.
- Add support for lowering xla_hlo.reduce_window to VMLA pooling ops.
- Enable e2e/xla/reduce_window.mlir test.

Fixes https://github.com/google/iree/issues/1225

PiperOrigin-RevId: 305943244
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertReductionOps.cpp b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertReductionOps.cpp
index 93f72a3..b36d00a 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertReductionOps.cpp
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertReductionOps.cpp
@@ -224,6 +224,84 @@
   TypeConverter &typeConverter;
 };
 
+struct BuiltinPoolingOpConversion
+    : public OpConversionPattern<xla_hlo::ReduceWindowOp> {
+  BuiltinPoolingOpConversion(MLIRContext *context, TypeConverter &typeConverter)
+      : OpConversionPattern(context, /*benefit=*/1000),
+        typeConverter(typeConverter) {}
+
+  LogicalResult matchAndRewrite(
+      xla_hlo::ReduceWindowOp srcOp, ArrayRef<Value> operands,
+      ConversionPatternRewriter &rewriter) const override {
+    if (srcOp.body().getBlocks().size() > 1) {
+      // Control flow within the computation is not supported; bail to fallback.
+      return failure();
+    } else if (srcOp.body().front().getOperations().size() > 2) {
+      // Require splitting first.
+      return failure();
+    }
+
+    auto operand = operands[0];
+    auto operandShape = VMLAConversionTarget::getTensorShape(
+        srcOp.getLoc(), srcOp.operand(), typeConverter, rewriter);
+    auto initValue = operands[1];
+    auto initValueShape = VMLAConversionTarget::getTensorShape(
+        srcOp.getLoc(), srcOp.init_value(), typeConverter, rewriter);
+    auto dst = VMLAConversionTarget::allocateOutputBuffer(
+        srcOp.getLoc(), srcOp.getResult(), typeConverter, rewriter);
+    auto dstShape = VMLAConversionTarget::getTensorShape(
+        srcOp.getLoc(), srcOp.getResult(), typeConverter, rewriter);
+    auto elementType =
+        srcOp.operand().getType().cast<ShapedType>().getElementType();
+
+    SmallVector<int32_t, 4> windowDimensions;
+    for (const auto &value : srcOp.window_dimensions().getIntValues())
+      windowDimensions.push_back(value.getSExtValue());
+    int rank = windowDimensions.size();
+    SmallVector<int32_t, 4> windowStrides(rank, 1);
+    SmallVector<int32_t, 4> padding(rank, 0);
+    for (unsigned i = 0; i < rank; ++i) {
+      if (srcOp.window_strides())
+        windowStrides[i] = srcOp.window_stridesAttr().getValue<int64_t>(i);
+      if (srcOp.padding())
+        padding[i] = srcOp.paddingAttr().getValue<int64_t>({i, 0});
+    }
+
+    auto &computeOp = *srcOp.body().front().begin();
+    if (isa<mlir::AddIOp>(computeOp) || isa<mlir::AddFOp>(computeOp) ||
+        isa<xla_hlo::AddOp>(computeOp)) {
+      rewriter.create<IREE::VMLA::PoolingSumOp>(
+          srcOp.getLoc(), operand, operandShape, initValue, initValueShape, dst,
+          dstShape, TypeAttr::get(elementType),
+          rewriter.getI32VectorAttr(windowDimensions),
+          rewriter.getI32VectorAttr(windowStrides),
+          rewriter.getI32VectorAttr(padding));
+    } else if (isa<xla_hlo::MinOp>(computeOp)) {
+      rewriter.create<IREE::VMLA::PoolingMinOp>(
+          srcOp.getLoc(), operand, operandShape, initValue, initValueShape, dst,
+          dstShape, TypeAttr::get(elementType),
+          rewriter.getI32VectorAttr(windowDimensions),
+          rewriter.getI32VectorAttr(windowStrides),
+          rewriter.getI32VectorAttr(padding));
+    } else if (isa<xla_hlo::MaxOp>(computeOp)) {
+      rewriter.create<IREE::VMLA::PoolingMaxOp>(
+          srcOp.getLoc(), operand, operandShape, initValue, initValueShape, dst,
+          dstShape, TypeAttr::get(elementType),
+          rewriter.getI32VectorAttr(windowDimensions),
+          rewriter.getI32VectorAttr(windowStrides),
+          rewriter.getI32VectorAttr(padding));
+    } else {
+      computeOp.emitRemark() << "unsupported builtin reduction operation";
+      return failure();
+    }
+
+    rewriter.replaceOp(srcOp, {dst});
+    return success();
+  }
+
+  TypeConverter &typeConverter;
+};
+
 }  // namespace
 
 void populateHLOReductionToVMLAPatterns(MLIRContext *context,
@@ -232,6 +310,7 @@
   patterns.insert<SplitIndependentReductionOpConversion>(context,
                                                          typeConverter);
   patterns.insert<BuiltinReduceOpConversion>(context, typeConverter);
+  patterns.insert<BuiltinPoolingOpConversion>(context, typeConverter);
   patterns.insert<GenericReduceOpConversion>(context, typeConverter);
 }
 
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/reduce_window.mlir b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/reduce_window.mlir
new file mode 100644
index 0000000..0e92a70
--- /dev/null
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/reduce_window.mlir
@@ -0,0 +1,51 @@
+// RUN: iree-opt -split-input-file -iree-vmla-conversion -cse %s | IreeFileCheck %s
+
+// CHECK-LABEL: @pooling_max
+func @pooling_max(%arg0: tensor<1x4x6x1xf32>) -> tensor<1x2x2x1xf32>
+    attributes { sym_visibility = "private" } {
+  // CHECK: vmla.pooling.max
+  %cst = constant dense<0.000000e+00> : tensor<f32>
+  %0 = "xla_hlo.reduce_window"(%arg0, %cst) ( {
+  ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):	// no predecessors
+    %1 = xla_hlo.maximum %arg1, %arg2 : tensor<f32>
+    "xla_hlo.return"(%1) : (tensor<f32>) -> ()
+  }) {window_dimensions = dense<[1, 2, 3, 1]> : tensor<4xi64>,
+      window_strides = dense<1> : tensor<4xi64>
+  } : (tensor<1x4x6x1xf32>, tensor<f32>) -> tensor<1x2x2x1xf32>
+  return %0 : tensor<1x2x2x1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @pooling_min
+func @pooling_min(%arg0: tensor<1x4x6x1xi32>) -> tensor<1x2x2x1xi32>
+    attributes { sym_visibility = "private" } {
+  // CHECK: vmla.pooling.min
+  %cst = constant dense<0> : tensor<i32>
+  %0 = "xla_hlo.reduce_window"(%arg0, %cst) ( {
+  ^bb0(%arg1: tensor<i32>, %arg2: tensor<i32>):	// no predecessors
+    %1 = xla_hlo.minimum %arg1, %arg2 : tensor<i32>
+    "xla_hlo.return"(%1) : (tensor<i32>) -> ()
+  }) {window_dimensions = dense<[1, 2, 3, 1]> : tensor<4xi64>,
+      window_strides = dense<1> : tensor<4xi64>
+  } : (tensor<1x4x6x1xi32>, tensor<i32>) -> tensor<1x2x2x1xi32>
+  return %0 : tensor<1x2x2x1xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @pooling_sum
+func @pooling_sum(%arg0: tensor<4x6xf32>) -> tensor<3x4xf32> attributes
+    { sym_visibility = "private" } {
+  // CHECK: vmla.pooling.sum
+  %cst = constant dense<0.000000e+00> : tensor<f32>
+  %0 = "xla_hlo.reduce_window"(%arg0, %cst) ( {
+  ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):	// no predecessors
+    %1 = xla_hlo.add %arg1, %arg2 : tensor<f32>
+    "xla_hlo.return"(%1) : (tensor<f32>) -> ()
+  }) {window_dimensions = dense<[2, 3]> : tensor<2xi64>,
+      window_strides = dense<1> : tensor<2xi64>,
+      padding = dense<[[1, 0], [2, 0]]> : tensor<2x2xi64>
+  } : (tensor<4x6xf32>, tensor<f32>) -> tensor<3x4xf32>
+  return %0 : tensor<3x4xf32>
+}
diff --git a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp
index df5ddd5..9dbaf9b 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp
+++ b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp
@@ -324,6 +324,10 @@
   VMLA_TYPED_IMPORT_OP(IREE::VMLA::ReduceMinOp, "vmla.reduce.min");
   VMLA_TYPED_IMPORT_OP(IREE::VMLA::ReduceMaxOp, "vmla.reduce.max");
 
+  VMLA_TYPED_IMPORT_OP(IREE::VMLA::PoolingSumOp, "vmla.pooling.sum");
+  VMLA_TYPED_IMPORT_OP(IREE::VMLA::PoolingMinOp, "vmla.pooling.min");
+  VMLA_TYPED_IMPORT_OP(IREE::VMLA::PoolingMaxOp, "vmla.pooling.max");
+
   VMLA_IMPORT_OP(IREE::VMLA::InterfaceConstOp, "vmla.interface.const");
   VMLA_IMPORT_OP(IREE::VMLA::InterfaceBindingOp, "vmla.interface.binding");
 }
diff --git a/iree/compiler/Dialect/VMLA/IR/VMLAOps.td b/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
index d064ea4..ff7e863 100644
--- a/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
+++ b/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
@@ -401,6 +401,26 @@
 def VMLA_ReduceMinOp : VMLA_ReduceOp<"reduce.min">;
 def VMLA_ReduceMaxOp : VMLA_ReduceOp<"reduce.max">;
 
+class VMLA_PoolingOp<string mnemonic, list<OpTrait> traits = []> :
+    VMLA_ElementTypeOp<mnemonic, !listconcat(traits, [VMLA_IncludeShapes])> {
+  let arguments = (ins
+    VMLA_Buffer:$src,
+    VMLA_Shape:$src_shape,
+    VMLA_Buffer:$init,
+    VMLA_Shape:$init_shape,
+    VMLA_Buffer:$dst,
+    VMLA_Shape:$dst_shape,
+    VMLA_AnyTypeAttr:$element_type,
+    I32ElementsAttr:$window_dimensions,
+    I32ElementsAttr:$window_strides,
+    I32ElementsAttr:$padding
+  );
+}
+
+def VMLA_PoolingSumOp : VMLA_PoolingOp<"pooling.sum">;
+def VMLA_PoolingMinOp : VMLA_PoolingOp<"pooling.min">;
+def VMLA_PoolingMaxOp : VMLA_PoolingOp<"pooling.max">;
+
 //===----------------------------------------------------------------------===//
 // VMLA Ops: ABI
 //===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/VMLA/vmla.imports.mlir b/iree/compiler/Dialect/VMLA/vmla.imports.mlir
index a01a097..069fa2e 100644
--- a/iree/compiler/Dialect/VMLA/vmla.imports.mlir
+++ b/iree/compiler/Dialect/VMLA/vmla.imports.mlir
@@ -412,4 +412,103 @@
   %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...
 )
 
+vm.import @pooling.sum.i8(
+  %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
+  %init : !vm.ref<!vmla.buffer>, %init_shape : i32 ...,
+  %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...,
+  %window_dimensions: i32 ...,
+  %window_strides: i32 ...,
+  %padding: i32 ...
+)
+vm.import @pooling.sum.i16(
+  %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
+  %init : !vm.ref<!vmla.buffer>, %init_shape : i32 ...,
+  %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...,
+  %window_dimensions: i32 ...,
+  %window_strides: i32 ...,
+  %padding: i32 ...
+)
+vm.import @pooling.sum.i32(
+  %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
+  %init : !vm.ref<!vmla.buffer>, %init_shape : i32 ...,
+  %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...,
+  %window_dimensions: i32 ...,
+  %window_strides: i32 ...,
+  %padding: i32 ...
+)
+vm.import @pooling.sum.f32(
+  %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
+  %init : !vm.ref<!vmla.buffer>, %init_shape : i32 ...,
+  %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...,
+  %window_dimensions: i32 ...,
+  %window_strides: i32 ...,
+  %padding: i32 ...
+)
+
+vm.import @pooling.min.i8(
+  %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
+  %init : !vm.ref<!vmla.buffer>, %init_shape : i32 ...,
+  %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...,
+  %window_dimensions: i32 ...,
+  %window_strides: i32 ...,
+  %padding: i32 ...
+)
+vm.import @pooling.min.i16(
+  %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
+  %init : !vm.ref<!vmla.buffer>, %init_shape : i32 ...,
+  %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...,
+  %window_dimensions: i32 ...,
+  %window_strides: i32 ...,
+  %padding: i32 ...
+)
+vm.import @pooling.min.i32(
+  %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
+  %init : !vm.ref<!vmla.buffer>, %init_shape : i32 ...,
+  %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...,
+  %window_dimensions: i32 ...,
+  %window_strides: i32 ...,
+  %padding: i32 ...
+)
+vm.import @pooling.min.f32(
+  %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
+  %init : !vm.ref<!vmla.buffer>, %init_shape : i32 ...,
+  %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...,
+  %window_dimensions: i32 ...,
+  %window_strides: i32 ...,
+  %padding: i32 ...
+)
+
+vm.import @pooling.max.i8(
+  %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
+  %init : !vm.ref<!vmla.buffer>, %init_shape : i32 ...,
+  %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...,
+  %window_dimensions: i32 ...,
+  %window_strides: i32 ...,
+  %padding: i32 ...
+)
+vm.import @pooling.max.i16(
+  %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
+  %init : !vm.ref<!vmla.buffer>, %init_shape : i32 ...,
+  %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...,
+  %window_dimensions: i32 ...,
+  %window_strides: i32 ...,
+  %padding: i32 ...
+)
+vm.import @pooling.max.i32(
+  %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
+  %init : !vm.ref<!vmla.buffer>, %init_shape : i32 ...,
+  %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...,
+  %window_dimensions: i32 ...,
+  %window_strides: i32 ...,
+  %padding: i32 ...
+)
+vm.import @pooling.max.f32(
+  %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
+  %init : !vm.ref<!vmla.buffer>, %init_shape : i32 ...,
+  %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...,
+  %window_dimensions: i32 ...,
+  %window_strides: i32 ...,
+  %padding: i32 ...
+)
+
 }  // module
diff --git a/iree/hal/vmla/op_kernels.h b/iree/hal/vmla/op_kernels.h
index 370b2ad..37eb313 100644
--- a/iree/hal/vmla/op_kernels.h
+++ b/iree/hal/vmla/op_kernels.h
@@ -392,6 +392,33 @@
                         const Shape& src_shape, const Shape& dst_shape);
 };
 
+struct PoolingSum {
+  template <typename T>
+  static Status Execute(absl::Span<const T> src_buffer,
+                        absl::Span<const T> init_buffer,
+                        absl::Span<T> dst_buffer, const Shape& src_shape,
+                        const Shape& dst_shape, const Shape& window_dimensions,
+                        const Shape& strides, const Shape& pad_low);
+};
+
+struct PoolingMin {
+  template <typename T>
+  static Status Execute(absl::Span<const T> src_buffer,
+                        absl::Span<const T> init_buffer,
+                        absl::Span<T> dst_buffer, const Shape& src_shape,
+                        const Shape& dst_shape, const Shape& window_dimensions,
+                        const Shape& strides, const Shape& pad_low);
+};
+
+struct PoolingMax {
+  template <typename T>
+  static Status Execute(absl::Span<const T> src_buffer,
+                        absl::Span<const T> init_buffer,
+                        absl::Span<T> dst_buffer, const Shape& src_shape,
+                        const Shape& dst_shape, const Shape& window_dimensions,
+                        const Shape& strides, const Shape& pad_low);
+};
+
 }  // namespace kernels
 }  // namespace vmla
 }  // namespace hal
diff --git a/iree/hal/vmla/op_kernels_generic.h b/iree/hal/vmla/op_kernels_generic.h
index e21bc63..2e90a42 100644
--- a/iree/hal/vmla/op_kernels_generic.h
+++ b/iree/hal/vmla/op_kernels_generic.h
@@ -779,6 +779,92 @@
       src_buffer, init_buffer, dst_buffer, dimension, src_shape, dst_shape);
 }
 
+namespace impl {
+
+template <typename T, typename KernelImpl>
+Status ComputePoolingWindow(absl::Span<const T> src_buffer,
+                            absl::Span<const int> src_indices,
+                            const Shape& src_shape, T init_value,
+                            const Shape& window_dimensions, T* dst_value) {
+  int rank = src_shape.size();
+  absl::InlinedVector<int, 8> window_indices(rank, 0);
+  auto getSrcValue = [&]() -> T {
+    int flat_idx = 0;
+    for (int i = 0; i < rank; ++i) {
+      int idx = src_indices[i] + window_indices[i];
+      if (idx < 0 || idx >= src_shape[i]) return init_value;
+      flat_idx = flat_idx * src_shape[i] + idx;
+    }
+    return src_buffer[flat_idx];
+  };
+
+  *dst_value = init_value;
+  for (int i = 0, e = window_dimensions.element_count(); i < e; ++i) {
+    KernelImpl()(dst_value, getSrcValue());
+    IncrementShapeIndex(absl::MakeSpan(window_indices), window_dimensions);
+  }
+  return OkStatus();
+}
+
+template <typename T, typename KernelImpl>
+Status GenericPooling(absl::Span<const T> src_buffer,
+                      absl::Span<const T> init_buffer, absl::Span<T> dst_buffer,
+                      const Shape& src_shape, const Shape& dst_shape,
+                      const Shape& window_dimensions, const Shape& strides,
+                      const Shape& pad_low) {
+  int rank = src_shape.size();
+  absl::InlinedVector<int, 8> src_indices(rank, 0);
+  absl::InlinedVector<int, 8> dst_indices(rank, 0);
+  for (int i = 0, e = dst_shape.element_count(); i < e; ++i) {
+    for (int j = 0; j < rank; ++j)
+      src_indices[j] = dst_indices[j] * strides[j] - pad_low[j];
+    Status status = ComputePoolingWindow<T, KernelImpl>(
+        src_buffer, src_indices, src_shape, init_buffer[0], window_dimensions,
+        &dst_buffer[i]);
+    if (!status.ok()) return status;
+    IncrementShapeIndex(absl::MakeSpan(dst_indices), dst_shape);
+  }
+  return OkStatus();
+}
+
+}  // namespace impl
+
+template <typename T>
+Status PoolingSum::Execute(absl::Span<const T> src_buffer,
+                           absl::Span<const T> init_buffer,
+                           absl::Span<T> dst_buffer, const Shape& src_shape,
+                           const Shape& dst_shape,
+                           const Shape& window_dimensions, const Shape& strides,
+                           const Shape& pad_low) {
+  return impl::GenericPooling<T, impl::SumKernel>(
+      src_buffer, init_buffer, dst_buffer, src_shape, dst_shape,
+      window_dimensions, strides, pad_low);
+}
+
+template <typename T>
+Status PoolingMin::Execute(absl::Span<const T> src_buffer,
+                           absl::Span<const T> init_buffer,
+                           absl::Span<T> dst_buffer, const Shape& src_shape,
+                           const Shape& dst_shape,
+                           const Shape& window_dimensions, const Shape& strides,
+                           const Shape& pad_low) {
+  return impl::GenericPooling<T, impl::MinKernel>(
+      src_buffer, init_buffer, dst_buffer, src_shape, dst_shape,
+      window_dimensions, strides, pad_low);
+}
+
+template <typename T>
+Status PoolingMax::Execute(absl::Span<const T> src_buffer,
+                           absl::Span<const T> init_buffer,
+                           absl::Span<T> dst_buffer, const Shape& src_shape,
+                           const Shape& dst_shape,
+                           const Shape& window_dimensions, const Shape& strides,
+                           const Shape& pad_low) {
+  return impl::GenericPooling<T, impl::MaxKernel>(
+      src_buffer, init_buffer, dst_buffer, src_shape, dst_shape,
+      window_dimensions, strides, pad_low);
+}
+
 }  // namespace kernels
 }  // namespace vmla
 }  // namespace hal
diff --git a/iree/hal/vmla/op_kernels_test.cc b/iree/hal/vmla/op_kernels_test.cc
index c39215c..b01c9ac 100644
--- a/iree/hal/vmla/op_kernels_test.cc
+++ b/iree/hal/vmla/op_kernels_test.cc
@@ -401,6 +401,63 @@
   }
 }
 
+TEST(PoolingMax, NoOverlapping) {
+  Shape src_shape = {1, 4, 6, 1};
+  Shape dst_shape = {1, 2, 2, 1};
+  Shape window_sizes = {1, 2, 3, 1};
+  Shape strides = {1, 2, 3, 1};
+  Shape pad_low = {0, 0, 0, 0};
+  std::vector<int> src_buffer = MakeIota<int>(src_shape.element_count());
+  std::vector<int> init_buffer(1, 0.0f);
+  std::vector<int> dst_buffer(dst_shape.element_count(), 0.0f);
+  std::vector<int> expected_dst = {9, 12, 21, 24};
+
+  EXPECT_OK(PoolingMax::Execute<int>(
+      src_buffer, init_buffer, absl::MakeSpan(dst_buffer), src_shape, dst_shape,
+      window_sizes, strides, pad_low));
+  EXPECT_EQ(dst_buffer, expected_dst);
+}
+
+TEST(PoolingMin, Padding) {
+  // Padded input:
+  // 100 100 100 100
+  // 100   1   2   3
+  // 100   4   5   6
+  Shape src_shape = {2, 3};
+  Shape dst_shape = {2, 3};
+  Shape window_sizes = {2, 2};
+  Shape strides = {1, 1};
+  Shape pad_low = {1, 1};
+  std::vector<int> src_buffer = MakeIota<int>(src_shape.element_count());
+  std::vector<int> init_buffer(1, 100.0);
+  std::vector<int> dst_buffer(dst_shape.element_count(), 0.0f);
+  std::vector<int> expected_dst = {1, 1, 2, 1, 1, 2};
+
+  EXPECT_OK(PoolingMin::Execute<int>(
+      src_buffer, init_buffer, absl::MakeSpan(dst_buffer), src_shape, dst_shape,
+      window_sizes, strides, pad_low));
+  EXPECT_EQ(dst_buffer, expected_dst);
+}
+
+TEST(PoolingSum, Overlapping) {
+  Shape src_shape = {3, 4};
+  Shape dst_shape = {2, 2};
+  Shape window_sizes = {2, 3};
+  Shape strides = {1, 1};
+  Shape pad_low = {0, 0};
+  std::vector<float> src_buffer = MakeIota<float>(src_shape.element_count());
+  std::vector<float> init_buffer(1, 0.0f);
+  std::vector<float> dst_buffer(dst_shape.element_count(), 0.0f);
+  std::vector<float> expected_dst = {24, 30, 48, 54};
+
+  EXPECT_OK(PoolingSum::Execute<float>(
+      src_buffer, init_buffer, absl::MakeSpan(dst_buffer), src_shape, dst_shape,
+      window_sizes, strides, pad_low));
+  for (int i = 0; i < dst_buffer.size(); ++i) {
+    EXPECT_NEAR(expected_dst[i], dst_buffer[i], kEpsilon);
+  }
+}
+
 TEST(Conv2d, NoDilation) {
   Shape input_shape = {4, 5, 2};
   Shape filter_shape = {3, 2, 2, 1};
diff --git a/iree/hal/vmla/vmla_module.cc b/iree/hal/vmla/vmla_module.cc
index 9878cb3..cda4327 100644
--- a/iree/hal/vmla/vmla_module.cc
+++ b/iree/hal/vmla/vmla_module.cc
@@ -744,6 +744,31 @@
   IREE_VMLA_REDUCTION_OP(ReduceMaxI32, kernels::ReduceMax, int32_t);
   IREE_VMLA_REDUCTION_OP(ReduceMaxF32, kernels::ReduceMax, float);
 
+#define IREE_VMLA_POOLING_OP(name, kernel, type)                              \
+  Status name(vm::ref<Buffer> src, iree_vmla_shape_t src_shape,               \
+              vm::ref<Buffer> init, iree_vmla_shape_t init_shape,             \
+              vm::ref<Buffer> dst, iree_vmla_shape_t dst_shape,               \
+              iree_vmla_shape_t window_dimensions, iree_vmla_shape_t strides, \
+              iree_vmla_shape_t pad_low) {                                    \
+    IREE_TRACE_SCOPE0("VMLAModuleState::" #name);                             \
+    return kernel::Execute<type>(src->As<type>(), init->As<type>(),           \
+                                 dst->As<type>(), Shape(src_shape),           \
+                                 Shape(dst_shape), Shape(window_dimensions),  \
+                                 Shape(strides), Shape(pad_low));             \
+  }
+  IREE_VMLA_POOLING_OP(PoolingSumI8, kernels::PoolingSum, int8_t);
+  IREE_VMLA_POOLING_OP(PoolingSumI16, kernels::PoolingSum, int16_t);
+  IREE_VMLA_POOLING_OP(PoolingSumI32, kernels::PoolingSum, int32_t);
+  IREE_VMLA_POOLING_OP(PoolingSumF32, kernels::PoolingSum, float);
+  IREE_VMLA_POOLING_OP(PoolingMinI8, kernels::PoolingMin, int8_t);
+  IREE_VMLA_POOLING_OP(PoolingMinI16, kernels::PoolingMin, int16_t);
+  IREE_VMLA_POOLING_OP(PoolingMinI32, kernels::PoolingMin, int32_t);
+  IREE_VMLA_POOLING_OP(PoolingMinF32, kernels::PoolingMin, float);
+  IREE_VMLA_POOLING_OP(PoolingMaxI8, kernels::PoolingMax, int8_t);
+  IREE_VMLA_POOLING_OP(PoolingMaxI16, kernels::PoolingMax, int16_t);
+  IREE_VMLA_POOLING_OP(PoolingMaxI32, kernels::PoolingMax, int32_t);
+  IREE_VMLA_POOLING_OP(PoolingMaxF32, kernels::PoolingMax, float);
+
  private:
   iree_allocator_t allocator_;
 
@@ -907,6 +932,19 @@
     vm::MakeNativeFunction("reduce.max.i32", &VMLAModuleState::ReduceMaxI32),
     vm::MakeNativeFunction("reduce.max.f32", &VMLAModuleState::ReduceMaxF32),
 
+    vm::MakeNativeFunction("pooling.sum.i8", &VMLAModuleState::PoolingSumI8),
+    vm::MakeNativeFunction("pooling.sum.i16", &VMLAModuleState::PoolingSumI16),
+    vm::MakeNativeFunction("pooling.sum.i32", &VMLAModuleState::PoolingSumI32),
+    vm::MakeNativeFunction("pooling.sum.f32", &VMLAModuleState::PoolingSumF32),
+    vm::MakeNativeFunction("pooling.min.i8", &VMLAModuleState::PoolingMinI8),
+    vm::MakeNativeFunction("pooling.min.i16", &VMLAModuleState::PoolingMinI16),
+    vm::MakeNativeFunction("pooling.min.i32", &VMLAModuleState::PoolingMinI32),
+    vm::MakeNativeFunction("pooling.min.f32", &VMLAModuleState::PoolingMinF32),
+    vm::MakeNativeFunction("pooling.max.i8", &VMLAModuleState::PoolingMaxI8),
+    vm::MakeNativeFunction("pooling.max.i16", &VMLAModuleState::PoolingMaxI16),
+    vm::MakeNativeFunction("pooling.max.i32", &VMLAModuleState::PoolingMaxI32),
+    vm::MakeNativeFunction("pooling.max.f32", &VMLAModuleState::PoolingMaxF32),
+
     vm::MakeNativeFunction("batch.matmul.f32f32.f32",
                            &VMLAModuleState::BatchMatMulF32F32F32),
 
diff --git a/iree/test/e2e/xla/BUILD b/iree/test/e2e/xla/BUILD
index efd59d3..c718049 100644
--- a/iree/test/e2e/xla/BUILD
+++ b/iree/test/e2e/xla/BUILD
@@ -48,6 +48,7 @@
     "pad.mlir",
     "reduce_float.mlir",
     "reduce_int.mlir",
+    "reduce_window.mlir",
     "rem.mlir",
     "reshape.mlir",
     "reshape_adddims.mlir",