First pass at linalg_ext.scan op (#8020)

This pass adds the first version of the scan op in
the linalg_ext dialect. Currently, the op requires specifying
the dimension, whether the scan is inclusive or not as well
as body specifying the operator for scan.

The current patch adds code to lower the op to loops
using the naive sequential implementation. The tiled implementation
is borrowed from the sort op.

Finally, some unit tests have also been added to test the passes.
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.td
index fa75ea3..7c729e1 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.td
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.td
@@ -247,6 +247,56 @@
   }];
 }
 
+def IREELinalgExt_ScanOp : IREELinalgExt_Op<"scan",
+    [DeclareOpInterfaceMethods<TiledOpInterface,
+      ["getPartitionableLoops", "generateScalarImplementation",
+       "getTiledImplementation"]>]> {
+  let summary = "Scan operator";
+  let description = [{
+    Computes the inclusive/exclusive scan along a given dimension.
+  }];
+
+  let arguments = (ins Variadic<AnyShaped>:$inputs,
+                       Variadic<AnyShaped>:$outputs,
+                       AnyType:$identity,
+                       I64Attr:$dimension,
+                       BoolAttr:$inclusive
+  );
+
+  let builders = [
+    OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs,
+      "Value":$identity, CArg<"int64_t", "0">:$dimension,
+      CArg<"bool", "true">:$inclusive)>
+  ];
+
+  let results = (outs Variadic<AnyRankedTensor>:$results);
+  let regions = (region AnyRegion:$region);
+  let assemblyFormat = [{
+    `dimension` `(` $dimension `)`
+    `inclusive` `(` $inclusive `)`
+    attr-dict
+    `identity` `(` $identity `:` type($identity) `)`
+    `ins` `(` $inputs `:` type($inputs) `)`
+    (`outs` `(` $outputs^ `:` type($outputs) `)`)?
+    $region (`->` type($results)^)?
+  }];
+
+  let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{
+    Value input() {
+      return getInputOperand(0)->get();
+    }
+    Value output() {
+      return getOutputOperand(0)->get();
+    }
+    ShapedType getOperandType() {
+      return input().getType().cast<ShapedType>();
+    }
+    int64_t getOperandRank() {
+      return getOperandType().getRank();
+    }
+  }];
+}
+
 def IREELinalgExt_ReverseOp : IREELinalgExt_Op<"reverse", [
   DeclareOpInterfaceMethods<
       TiledOpInterface,
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
index abe37c0..688cdfe 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -768,6 +768,172 @@
 }
 
 //===----------------------------------------------------------------------===//
+// ScanOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyScanOp(ScanOp op) {
+  if (op.getNumInputs() != 1) {
+    return op.emitOpError("expected one input operands");
+  }
+  if (op.getNumOutputs() != 1) {
+    return op.emitOpError("expected one output operand");
+  }
+  auto identityElementType = op.identity().getType();
+  auto inputType = op.input().getType().cast<ShapedType>();
+  auto outputType = op.output().getType().cast<ShapedType>();
+  if (identityElementType != inputType.getElementType()) {
+    return op.emitOpError(
+        "expected input/identity element types to be identical");
+  }
+  if (inputType.getElementType() != outputType.getElementType()) {
+    return op.emitOpError(
+        "expected input/output element types to be identical");
+  }
+  ArrayRef<int64_t> inputShapes = inputType.getShape();
+  ArrayRef<int64_t> outputShapes = outputType.getShape();
+  if (inputShapes.size() != outputShapes.size()) {
+    return op.emitOpError("expected input/output to have identical ranks");
+  }
+  if (llvm::any_of(llvm::zip(inputShapes, outputShapes),
+                   [](std::tuple<int64_t, int64_t> s) {
+                     return std::get<0>(s) != ShapedType::kDynamicSize &&
+                            std::get<1>(s) != ShapedType::kDynamicSize &&
+                            std::get<0>(s) != std::get<1>(s);
+                   })) {
+    return op.emitOpError("incompatible input/output shapes");
+  }
+  return success();
+}
+
+SmallVector<Range> ScanOp::getIterationDomain(OpBuilder &builder) {
+  int64_t operandRank = getOperandRank();
+  SmallVector<Range> loopBounds(operandRank);
+  Location loc = getLoc();
+  Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+  Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
+  Value source = input();
+  for (auto dim : llvm::seq<int64_t>(0, operandRank)) {
+    loopBounds[dim].offset = zero;
+    loopBounds[dim].size = getDimValue(builder, loc, source, dim);
+    loopBounds[dim].stride = one;
+  }
+  return loopBounds;
+}
+
+SmallVector<StringRef> ScanOp::getLoopIteratorTypes() {
+  SmallVector<StringRef> iteratorTypes(getOperandRank(),
+                                       getParallelIteratorTypeName());
+  iteratorTypes[dimension()] = getReductionIteratorTypeName();
+  return iteratorTypes;
+}
+
+SmallVector<unsigned> ScanOp::getPartitionableLoops(
+    unsigned maxNumParallelDims) {
+  auto range = llvm::seq<unsigned>(0, getOperandRank());
+  SmallVector<unsigned> partitionableLoops(range.begin(), range.end());
+  partitionableLoops.erase(std::next(partitionableLoops.begin(), dimension()));
+  return partitionableLoops;
+}
+
+// Generates naive scalar implementation of scan for a given operator f.
+// For inclusive,
+//     output[0] = input[0]
+//     output[i] = f(output[i-1], input[i])
+//
+// For exclusive,
+//     output[0] = 0
+//     output[i] = f(output[i-1], input[i-1])
+
+LogicalResult ScanOp::generateScalarImplementation(OpBuilder &b, Location loc,
+                                                   ValueRange ivs) {
+  SmallVector<Value> indices, scanBlkArgs;
+  indices.append(ivs.begin(), ivs.end());
+  Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
+  Value one = b.create<arith::ConstantIndexOp>(loc, 1);
+  auto scanDim = dimension();
+  auto cond = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
+                                      indices[scanDim], zero);
+  bool isInclusive = inclusive();
+  auto scfIf = b.create<scf::IfOp>(
+      loc, TypeRange{}, cond,
+      [&](OpBuilder &b, Location loc) {
+        if (isInclusive) {
+          auto value = b.create<memref::LoadOp>(loc, input(), indices);
+          b.create<memref::StoreOp>(loc, value, output(), indices);
+        } else {
+          b.create<memref::StoreOp>(loc, identity(), output(), indices);
+        }
+        b.create<scf::YieldOp>(loc);
+      },
+      [&](OpBuilder &b, Location loc) {
+        SmallVector<Value> indices(ivs.begin(), ivs.end());
+        Value iv = indices[scanDim];
+        Value ivMinusOne = b.create<arith::SubIOp>(loc, iv, one);
+        indices[scanDim] = ivMinusOne;
+        scanBlkArgs.push_back(b.create<memref::LoadOp>(loc, output(), indices));
+        Value i0;
+        if (!isInclusive) i0 = b.create<memref::LoadOp>(loc, input(), indices);
+        indices[scanDim] = iv;
+        if (isInclusive) i0 = b.create<memref::LoadOp>(loc, input(), indices);
+        scanBlkArgs.push_back(i0);
+      });
+
+  auto &srcBlock = region().front();
+  Region &region = scfIf.getElseRegion();
+  BlockAndValueMapping bvm;
+  {
+    OpBuilder::InsertionGuard guard(b);
+    auto &block = region.front();
+    b.setInsertionPointToEnd(&block);
+    for (auto it : llvm::zip(srcBlock.getArguments(), scanBlkArgs)) {
+      bvm.map(std::get<0>(it), std::get<1>(it));
+    }
+    for (auto &blockOp : srcBlock.without_terminator()) {
+      b.clone(blockOp, bvm);
+    }
+    b.create<memref::StoreOp>(
+        loc, bvm.lookupOrDefault(srcBlock.getTerminator()->getOperand(0)),
+        output(), indices);
+    b.create<scf::YieldOp>(loc);
+  }
+  return success();
+}
+
+Operation *ScanOp::getTiledImplementation(OpBuilder &builder,
+                                          ValueRange outputs,
+                                          ArrayRef<OpFoldResult> offsets,
+                                          ArrayRef<OpFoldResult> sizes,
+                                          SmallVectorImpl<Value> &results) {
+  assert(outputs.size() == this->outputs().size());
+  int64_t rank = getOperandRank();
+  assert(offsets.size() == static_cast<size_t>(rank) &&
+         sizes.size() == static_cast<size_t>(rank));
+  auto oneAttr = builder.getI64IntegerAttr(1);
+  SmallVector<OpFoldResult> strides(rank, oneAttr);
+  Location loc = getLoc();
+  SmallVector<Value> tiledOperands;
+  tiledOperands.emplace_back(
+      getSlice(builder, getLoc(), input(), offsets, sizes, strides));
+  tiledOperands.emplace_back(
+      getSlice(builder, getLoc(), output(), offsets, sizes, strides));
+  tiledOperands.emplace_back(identity());
+
+  SmallVector<Type, 4> resultTypes;
+  if (hasTensorSemantics()) {
+    resultTypes.push_back(tiledOperands[1].getType());
+  }
+
+  Operation *tiledScanOp = cast<LinalgExtOp>(getOperation())
+                               .clone(builder, loc, resultTypes, tiledOperands);
+  for (auto result : llvm::enumerate(tiledScanOp->getResults())) {
+    auto insertSliceOp = builder.create<tensor::InsertSliceOp>(
+        loc, result.value(), outputs[result.index()], offsets, sizes, strides);
+    results.push_back(insertSliceOp.getResult());
+  }
+  return tiledScanOp;
+}
+
+//===----------------------------------------------------------------------===//
 // ReverseOp
 //===----------------------------------------------------------------------===//
 
@@ -913,6 +1079,7 @@
 DEFINE_OP_GET_EFFECTS(SortOp)
 DEFINE_OP_GET_EFFECTS(FftOp)
 DEFINE_OP_GET_EFFECTS(ReverseOp)
+DEFINE_OP_GET_EFFECTS(ScanOp)
 
 namespace {
 /// This is derived from mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp without any
diff --git a/llvm-external-projects/iree-dialects/test/iree_linalgext/convert_to_loops.mlir b/llvm-external-projects/iree-dialects/test/iree_linalgext/convert_to_loops.mlir
index f96850f..c49c7b4 100644
--- a/llvm-external-projects/iree-dialects/test/iree_linalgext/convert_to_loops.mlir
+++ b/llvm-external-projects/iree-dialects/test/iree_linalgext/convert_to_loops.mlir
@@ -505,3 +505,96 @@
 // CHECK:             %[[T2:.+]] = arith.subi %[[T1]], %[[I]] : index
 // CHECK:             %[[V0:.+]] = memref.load %[[IN]][%[[I]], %[[J]]]
 // CHECK:             memref.store %[[V0]], %[[OUT]][%[[T2]], %[[J]]] : memref<?x?xi32>
+
+func @scan_1d_inclusive(%0: memref<128xi32>, %1: memref<128xi32>) {
+  %c0 = arith.constant 0 : i32
+  iree_linalg_ext.scan dimension(0) inclusive(true) identity(%c0 : i32)
+    ins(%0 : memref<128xi32>) outs(%1 : memref<128xi32>) {
+    ^bb0(%arg0 : i32, %arg1 : i32):
+      %sum = arith.addi %arg0, %arg1 : i32
+      iree_linalg_ext.yield %sum : i32
+  }
+  return
+}
+// CHECK-LABEL: func @scan_1d_inclusive
+// CHECK-SAME:    %[[BUFI:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[BUFO:[a-zA-Z0-9]+]]
+// CHECK-DAG:     %[[C128:.+]] = arith.constant 128 : index
+// CHECK-DAG:     %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C1:.+]] = arith.constant 1 : index
+// CHECK:         scf.for %[[ARG1:.+]] = %[[C0]] to %[[C128]] step %[[C1]]
+// CHECK:           %[[COND:.+]] = arith.cmpi eq, %[[ARG1]], %[[C0]] : index
+// CHECK:           scf.if %[[COND]] {
+// CHECK:             %[[V1:.+]] = memref.load %[[BUFI]][%[[ARG1]]]
+// CHECK:             memref.store %[[V1]], %[[BUFO]][%[[ARG1]]]
+// CHECK:           } else {
+// CHECK:             %[[T1:.+]] = arith.subi %[[ARG1]], %[[C1]] : index
+// CHECK:             %[[V2:.+]] = memref.load %[[BUFO]][%[[T1]]]
+// CHECK:             %[[V3:.+]] = memref.load %[[BUFI]][%[[ARG1]]]
+// CHECK:             %[[V4:.+]] = arith.addi %[[V2]], %[[V3]] : i32
+// CHECK:             memref.store %[[V4]], %[[BUFO]][%[[ARG1]]]
+// CHECK:           }
+
+// -----
+
+func @scan_1d_exclusive(%0: memref<128xi32>, %1: memref<128xi32>) {
+  %c0 = arith.constant 0 : i32
+  iree_linalg_ext.scan dimension(0) inclusive(false) identity(%c0 : i32)
+    ins(%0 : memref<128xi32>) outs(%1 : memref<128xi32>) {
+    ^bb0(%arg0 : i32, %arg1 : i32):
+      %sum = arith.addi %arg0, %arg1 : i32
+      iree_linalg_ext.yield %sum : i32
+  }
+  return
+}
+// CHECK-LABEL: func @scan_1d_exclusive
+// CHECK-SAME:    %[[BUFI:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[BUFO:[a-zA-Z0-9]+]]
+// CHECK-DAG:     %[[C128:.+]] = arith.constant 128 : index
+// CHECK-DAG:     %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG:     %[[C0_I32:.+]] = arith.constant 0 : i32
+// CHECK:         scf.for %[[ARG1:.+]] = %[[C0]] to %[[C128]] step %[[C1]]
+// CHECK:           %[[COND:.+]] = arith.cmpi eq, %[[ARG1]], %[[C0]] : index
+// CHECK:           scf.if %[[COND]] {
+// CHECK:             memref.store %[[C0_I32]], %[[BUFO]][%[[ARG1]]]
+// CHECK:           } else {
+// CHECK:             %[[T1:.+]] = arith.subi %[[ARG1]], %[[C1]] : index
+// CHECK:             %[[V2:.+]] = memref.load %[[BUFO]][%[[T1]]]
+// CHECK:             %[[V3:.+]] = memref.load %[[BUFI]][%[[T1]]]
+// CHECK:             %[[V4:.+]] = arith.addi %[[V2]], %[[V3]] : i32
+// CHECK:             memref.store %[[V4]], %[[BUFO]][%[[ARG1]]]
+// CHECK:           }
+
+// -----
+
+func @scan_2d(%0: memref<16x32xi32>, %1: memref<16x32xi32>) {
+  %c0 = arith.constant 0 : i32
+  iree_linalg_ext.scan dimension(0) inclusive(true) identity(%c0 : i32)
+    ins(%0 : memref<16x32xi32>) outs(%1 : memref<16x32xi32>) {
+    ^bb0(%arg0 : i32, %arg1 : i32):
+      %sum = arith.addi %arg0, %arg1 : i32
+      iree_linalg_ext.yield %sum : i32
+  }
+  return
+}
+// CHECK-LABEL: func @scan_2d
+// CHECK-SAME:    %[[BUFI:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[BUFO:[a-zA-Z0-9]+]]
+// CHECK-DAG:     %[[C16:.+]] = arith.constant 16 : index
+// CHECK-DAG:     %[[C32:.+]] = arith.constant 32 : index
+// CHECK-DAG:     %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C1:.+]] = arith.constant 1 : index
+// CHECK:         scf.for %[[ARG1:.+]] = %[[C0]] to %[[C16]] step %[[C1]]
+// CHECK:           scf.for %[[ARG2:.+]] = %[[C0]] to %[[C32]] step %[[C1]]
+// CHECK:             %[[COND:.+]] = arith.cmpi eq, %[[ARG1]], %[[C0]] : index
+// CHECK:             scf.if %[[COND]] {
+// CHECK:               %[[V1:.+]] = memref.load %[[BUFI]][%[[ARG1]], %[[ARG2]]]
+// CHECK:               memref.store %[[V1]], %[[BUFO]][%[[ARG1]], %[[ARG2]]]
+// CHECK:             } else {
+// CHECK:               %[[T1:.+]] = arith.subi %[[ARG1]], %[[C1]] : index
+// CHECK:               %[[V2:.+]] = memref.load %[[BUFO]][%[[T1]], %[[ARG2]]]
+// CHECK:               %[[V3:.+]] = memref.load %[[BUFI]][%[[ARG1]], %[[ARG2]]]
+// CHECK:               %[[V4:.+]] = arith.addi %[[V2]], %[[V3]] : i32
+// CHECK:               memref.store %[[V4]], %[[BUFO]][%[[ARG1]], %[[ARG2]]]
+// CHECK:             }
diff --git a/llvm-external-projects/iree-dialects/test/iree_linalgext/tiling.mlir b/llvm-external-projects/iree-dialects/test/iree_linalgext/tiling.mlir
index 219ee59..afaf909 100644
--- a/llvm-external-projects/iree-dialects/test/iree_linalgext/tiling.mlir
+++ b/llvm-external-projects/iree-dialects/test/iree_linalgext/tiling.mlir
@@ -1206,3 +1206,106 @@
 //      CHECK:       scf.yield %[[INSERT]]
 //      CHECK:     scf.yield %[[YIELD]]
 //      CHECK:   return %[[RESULT]]
+
+// -----
+
+func @scan_1d(%0: tensor<128xi32>) -> tensor<128xi32> {
+  %c0 = arith.constant 0 : i32
+  %1 = linalg.init_tensor [128] : tensor<128xi32>
+  %2 = iree_linalg_ext.scan
+    dimension(0) inclusive(true)
+    {__internal_linalg_transform__ = "outer_reduce_input"}
+    identity(%c0 : i32)
+    ins(%0 : tensor<128xi32>) outs(%1 : tensor<128xi32>) {
+    ^bb0(%arg0 : i32, %arg1 : i32):
+      %sum = arith.addi %arg0, %arg1 : i32
+      iree_linalg_ext.yield %sum : i32
+  } -> tensor<128xi32>
+  return %2 : tensor<128xi32>
+}
+//      CHECK: func @scan_1d(
+// CHECK-SAME:   %[[OPERAND:.+]]: tensor<128xi32>
+//      CHECK:   %[[IDENTITY:.+]] = arith.constant 0 : i32
+//      CHECK:   %[[OUTPUT:.+]] = linalg.init_tensor [128] : tensor<128xi32>
+//      CHECK:   %[[RESULT:.+]] = iree_linalg_ext.scan
+// CHECK-SAME:           __internal_linalg_transform__ = "outer_reduce_output"
+// CHECK-SAME:       identity(%[[IDENTITY]] :
+// CHECK-SAME:       ins(%[[OPERAND]] :
+// CHECK-SAME:       outs(%[[OUTPUT]] :
+//      CHECK:   return %[[RESULT]]
+
+// -----
+
+func @scan_2d(%0: tensor<16x32xi32>) -> tensor<16x32xi32> {
+  %c0 = arith.constant 0 : i32
+  %1 = linalg.init_tensor [16, 32] : tensor<16x32xi32>
+  %2 = iree_linalg_ext.scan
+    dimension(0) inclusive(true)
+    {__internal_linalg_transform__ = "outer_reduce_input"}
+    identity(%c0 : i32)
+    ins(%0 : tensor<16x32xi32>) outs(%1 : tensor<16x32xi32>) {
+    ^bb0(%arg0 : i32, %arg1 : i32):
+      %sum = arith.addi %arg0, %arg1 : i32
+      iree_linalg_ext.yield %sum : i32
+  } -> tensor<16x32xi32>
+  return %2 : tensor<16x32xi32>
+}
+//  CHECK-DAG:  #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)>
+//      CHECK:  func @scan_2d(
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9_]+]]
+//  CHECK-DAG:    %[[IDENTITY:.+]] = arith.constant 0 : i32
+//      CHECK:    %[[C0:.+]] = arith.constant 0 : index
+//      CHECK:    %[[C16:.+]] = arith.constant 16 : index
+//      CHECK:    %[[C32:.+]] = arith.constant 32 : index
+//      CHECK:    %[[C20:.+]] = arith.constant 20 : index
+//      CHECK:    %[[OUTPUT:.+]] = linalg.init_tensor [16, 32] : tensor<16x32xi32>
+//      CHECK:    %[[RESULT:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C32]] step %[[C20]] 
+// CHECK-SAME:      iter_args(%[[ARG2:.+]] = %[[OUTPUT]])
+//      CHECK:      %[[SIZE:.+]] = affine.min #[[MAP0]](%[[I]])[%[[C20]], %[[C32]]]
+//      CHECK:      %[[UPDATE_SLICE_IN:.+]] = tensor.extract_slice %[[ARG0]][0, %[[I]]] [%[[C16]], %[[SIZE]]]
+//      CHECK:      %[[UPDATE_SLICE_OUT:.+]] = tensor.extract_slice %[[OUTPUT]][0, %[[I]]] [%[[C16]], %[[SIZE]]]
+//      CHECK:      %[[SCAN_TILE:.+]] = iree_linalg_ext.scan
+// CHECK-SAME:       dimension(0) inclusive(true)
+// CHECK-SAME:       {__internal_linalg_transform__ = "outer_reduce_output"}
+// CHECK-SAME:       ins(%[[UPDATE_SLICE_IN]]
+// CHECK-SAME:       outs(%[[UPDATE_SLICE_OUT]]
+//      CHECK:       %[[YIELD:.+]] = tensor.insert_slice %[[SCAN_TILE]] into %[[ARG2]][0, %[[I]]]
+// CHECK-SAME:           [%[[C16]], %[[SIZE]]]
+//      CHECK:       scf.yield %[[YIELD]]
+//      CHECK:   return %[[RESULT]]
+
+// -----
+
+func @scan_2d_memref(%0: memref<16x32xi32>, %1: memref<16x32xi32>) {
+  %c0 = arith.constant 0 : i32
+  iree_linalg_ext.scan
+    dimension(0) inclusive(true)
+    {__internal_linalg_transform__ = "outer_reduce_input"}
+    identity(%c0 : i32)
+    ins(%0 : memref<16x32xi32>) outs(%1 : memref<16x32xi32>) {
+    ^bb0(%arg0 : i32, %arg1 : i32):
+      %sum = arith.addi %arg0, %arg1 : i32
+      iree_linalg_ext.yield %sum : i32
+  }
+  return
+}
+//  CHECK-DAG:  #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)>
+//  CHECK-DAG:  #[[MAP1:.+]] = affine_map<(d0, d1)[s0] -> (d0 * 32 + s0 + d1)>
+//      CHECK:  func @scan_2d_memref(
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9_]+]]
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9_]+]]
+//  CHECK-DAG:    %[[IDENTITY:.+]] = arith.constant 0 : i32
+//      CHECK:    %[[C0:.+]] = arith.constant 0 : index
+//      CHECK:    %[[C16:.+]] = arith.constant 16 : index
+//      CHECK:    %[[C32:.+]] = arith.constant 32 : index
+//      CHECK:    %[[C20:.+]] = arith.constant 20 : index
+//      CHECK:    scf.for %[[I:.+]] = %[[C0]] to %[[C32]] step %[[C20]]
+//      CHECK:      %[[SIZE:.+]] = affine.min #[[MAP0]](%[[I]])[%[[C20]], %[[C32]]]
+//      CHECK:      %[[UPDATE_SLICE_IN:.+]] = memref.subview %[[ARG0]][0, %[[I]]] [%[[C16]], %[[SIZE]]]
+//      CHECK:      %[[UPDATE_SLICE_OUT:.+]] = memref.subview %[[ARG1]][0, %[[I]]] [%[[C16]], %[[SIZE]]]
+//      CHECK:      iree_linalg_ext.scan
+// CHECK-SAME:       dimension(0) inclusive(true)
+// CHECK-SAME:       {__internal_linalg_transform__ = "outer_reduce_output"}
+// CHECK-SAME:       ins(%[[UPDATE_SLICE_IN]]
+// CHECK-SAME:       outs(%[[UPDATE_SLICE_OUT]]
+//      CHECK:   return