Initial commit of bridge to use Linalg Comprehensive Bufferize Pass. (#7509)

Co-authored-by: Nicolas Vasilache <nicolasvasilache@users.noreply.github.com>

This is the initial commit to use Linalg Comprehensive bufferize pass
from within IREE. The main work here is to bridge the memory model gap
of using flow.dispatch.tensor.load/store operations.
diff --git a/iree/compiler/Codegen/Common/BUILD b/iree/compiler/Codegen/Common/BUILD
index 0593b9b..380812c 100644
--- a/iree/compiler/Codegen/Common/BUILD
+++ b/iree/compiler/Codegen/Common/BUILD
@@ -37,6 +37,7 @@
         "FlattenMemRefSubspanPass.cpp",
         "FoldTensorExtractOpPass.cpp",
         "ForOpCanonicalizationPass.cpp",
+        "IREEComprehensiveBufferizePass.cpp",
         "LinalgBufferizePass.cpp",
         "OptimizeVectorTransferPass.cpp",
         "SetNumWorkgroupsPass.cpp",
diff --git a/iree/compiler/Codegen/Common/CMakeLists.txt b/iree/compiler/Codegen/Common/CMakeLists.txt
index ce29221..780f6fd 100644
--- a/iree/compiler/Codegen/Common/CMakeLists.txt
+++ b/iree/compiler/Codegen/Common/CMakeLists.txt
@@ -28,6 +28,7 @@
     "FlattenMemRefSubspanPass.cpp"
     "FoldTensorExtractOpPass.cpp"
     "ForOpCanonicalizationPass.cpp"
+    "IREEComprehensiveBufferizePass.cpp"
     "LinalgBufferizePass.cpp"
     "OptimizeVectorTransferPass.cpp"
     "SetNumWorkgroupsPass.cpp"
diff --git a/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp b/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp
new file mode 100644
index 0000000..210e28c
--- /dev/null
+++ b/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp
@@ -0,0 +1,243 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+//===- IREEComprehensiveBufferizePass.cpp.cpp - -------------------------===//
+//
+// Wrapper pass to use MLIRs ComprehensiveBufferization pass.
+//
+//===----------------------------------------------------------------------===//
+
+#include "iree/compiler/Codegen/PassDetail.h"
+#include "iree/compiler/Codegen/Passes.h"
+#include "iree/compiler/Codegen/Transforms/Transforms.h"
+#include "iree/compiler/Codegen/Utils/Utils.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowTypes.h"
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
+#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
+#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
+#include "llvm/ADT/EquivalenceClasses.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/ScopeExit.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/BufferUtils.h"
+#include "mlir/Transforms/Passes.h"
+
+#define DEBUG_TYPE "iree-codegen-linalg-bufferize"
+
+namespace mlir {
+template <typename TensorType>
+static MemRefType getMemrefTypeForTensor(TensorType tensorType,
+                                         MemRefLayoutAttrInterface layout = {},
+                                         Attribute memorySpace = {}) {
+  return MemRefType::get(tensorType.getShape(), tensorType.getElementType(),
+                         layout, memorySpace);
+}
+
+namespace iree_compiler {
+
+namespace {
+/// Pass to convert from tensor based ops to memref based ops.
+class IREEComprehensiveBufferizePass
+    : public IREEComprehensiveBufferizeBase<IREEComprehensiveBufferizePass> {
+ public:
+  IREEComprehensiveBufferizePass(
+      linalg::AllocationCallbacks allocationFn = linalg::AllocationCallbacks())
+      : allocationFn(allocationFn) {}
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry
+        .insert<IREE::Util::UtilDialect, linalg::LinalgDialect,
+                memref::MemRefDialect, scf::SCFDialect, StandardOpsDialect>();
+  }
+  void runOnOperation() override;
+
+ private:
+  linalg::AllocationCallbacks allocationFn;
+};
+}  // namespace
+
+static bool isaTensor(Type t) { return t.isa<TensorType>(); };
+
+/// Stitch comprehensive bufferization inside of IREE by proceeding as follows:
+///   1. a. Bufferizes InterfaceBindingSubspanOp optimistically
+///      b. Insert a memref::TensorLoad to serve as glue between the buffer and
+///         tensor worlds.
+///      c. Record aliasInfo of memref::TensorLoad manually
+///      d. Record inplaceability of memref::TensorLoad manually
+///      e. Record the bufferization of memref::TensorLoad manually
+///   2. Rewrite all Flow::Dispatch::TensorLoad ops as Tensor::ExtractSliceOp
+///      that comprehensive bufferization understands.
+///   3. Specifically select the ops we want to bufferize / skip. In the future,
+///      this may be better specified with a BufferizationOpInterface.
+///   4. Perform analysis and bufferization on the ops.
+void IREEComprehensiveBufferizePass::runOnOperation() {
+  ModuleOp moduleOp = getOperation();
+  MLIRContext *context = &getContext();
+
+  for (auto funcOp : moduleOp.getOps<FuncOp>()) {
+    OpBuilder b(context);
+
+    // 1. First go over all hal.interface.binding.subspan ops and create
+    // counterparts working with memrefs.
+    BlockAndValueMapping bvm, tensorLoads;
+    linalg::BufferizationAliasInfo aliasInfo(funcOp);
+    // These are used until late, erase on scoped exit.
+    SmallVector<Operation *> toEraseLate;
+    auto scopeGuard = llvm::make_scope_exit([&]() {
+      for (Operation *op : llvm::reverse(toEraseLate)) op->erase();
+    });
+    funcOp.walk([&](IREE::HAL::InterfaceBindingSubspanOp op) {
+      auto shapedType =
+          op.getResult().getType().dyn_cast<IREE::Flow::DispatchTensorType>();
+      if (!shapedType || !shapedType.hasRank()) return;
+      OpBuilder::InsertionGuard g(b);
+      b.setInsertionPoint(op);
+      // 1.a. Just change the result type of the InterfaceBindingSubspanOp to
+      // from the base buffer.
+      auto memRefType = getMemrefTypeForTensor(shapedType);
+      auto baseBuffer = b.create<IREE::HAL::InterfaceBindingSubspanOp>(
+          op->getLoc(), memRefType, op.binding(), op.byte_offset(),
+          op.byte_length(), op.dynamic_dims());
+      bvm.map(op, baseBuffer);
+
+      // This op does not operate on core tensor types and has half-side
+      // effecting semantics. It cannot be added to BufferizationAliasInfo.
+      // Instead:
+      // 1.b. Insert a memref::TensorLoad to serve as glue between the buffer
+      // and tensor worlds.
+      Value tensor = b.create<memref::TensorLoadOp>(op->getLoc(), baseBuffer);
+      // 1.c. Insert a new entry manually into the existing aliasInfo.
+      aliasInfo.createAliasInfoEntry(op.result());
+      aliasInfo.createAliasInfoEntry(tensor);
+      tensorLoads.map(op.result(), tensor);
+      // 1.d. Mark tensors that bufferize to writeable memory as such.
+      if (shapedType.getAccess() != IREE::Flow::TensorAccess::ReadOnly) {
+        aliasInfo.setBufferizesToWritableMemory(tensor);
+      }
+      // 1.e. Save tensor -> baseBuffer into BVM.
+      bvm.map(tensor, baseBuffer);
+
+      // Drop the original op that is now bufferized.
+      toEraseLate.push_back(op);
+    });
+
+    // 2. Rewrite all Flow::Dispatch::TensorLoad ops as Tensor::ExtractSliceOp.
+    funcOp.walk<WalkOrder::PostOrder>([&](IREE::Flow::DispatchTensorLoadOp op) {
+      OpBuilder b(op);
+      Value v = b.create<tensor::ExtractSliceOp>(
+          op->getLoc(), op.result().getType().cast<RankedTensorType>(),
+          tensorLoads.lookup(op.source()), op.getMixedOffsets(),
+          op.getMixedSizes(), op.getMixedStrides());
+      // Insert a new entry manually into the existing aliasInfo.
+      aliasInfo.createAliasInfoEntry(v);
+      op.result().replaceAllUsesWith(v);
+      toEraseLate.push_back(op);
+    });
+    funcOp.walk<WalkOrder::PostOrder>(
+        [&](IREE::Flow::DispatchTensorStoreOp op) {
+          OpBuilder b(op);
+          Value v = b.create<tensor::InsertSliceOp>(
+              op->getLoc(), op.value(), tensorLoads.lookup(op.target()),
+              op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides());
+          // Insert a new entry manually into the existing aliasInfo.
+          aliasInfo.createAliasInfoEntry(v);
+          toEraseLate.push_back(op);
+        });
+
+    // TODO: Visit all the operations that return `tensor`s that are not handled
+    // by comprehensive bufferize.
+
+    // 3. Specifically select the ops we want to bufferize / skip. In the
+    // future, this may be better specified with a BufferizationOpInterface.
+    DominanceInfo domInfo(funcOp);
+    SmallVector<Operation *> ops;
+    ops.reserve(funcOp.body().front().getOperations().size());
+    WalkResult opsSelected =
+        funcOp.body().walk([&](Operation *op) -> WalkResult {
+          if (isa<IREE::HAL::InterfaceBindingSubspanOp,
+                  IREE::Flow::DispatchTensorLoadOp,
+                  IREE::Flow::DispatchTensorStoreOp>(op)) {
+            return WalkResult::advance();
+          }
+          if (llvm::none_of(op->getOperandTypes(), isaTensor) &&
+              llvm::none_of(op->getResultTypes(), isaTensor)) {
+            return WalkResult::advance();
+          }
+          if (op->getParentOfType<linalg::LinalgOp>())
+            return WalkResult::advance();
+          // TODO: if we want to bufferize function calls, we need FuncOp
+          // and to pass a proper bufferizedFunctionTypes.
+          if (isa<CallOpInterface>(op)) {
+            return static_cast<LogicalResult>(op->emitError(
+                "CallOpInterface bufferization not supported in IREE"));
+          }
+          ops.push_back(op);
+          return WalkResult::advance();
+        });
+
+    // 4. Perform inplaceability analysis of `ops`.
+    if (opsSelected.wasInterrupted() ||
+        failed(linalg::inPlaceAnalysis(ops, aliasInfo, domInfo))) {
+      return signalPassFailure();
+    }
+
+    // 5. Perform bufferization.
+    for (Operation *op : ops) {
+      if (failed(linalg::bufferizeOp(op, bvm, aliasInfo,
+                                     linalg::AllocationCallbacks(),
+                                     /*bufferizedFunctionTypes=*/nullptr))) {
+        return signalPassFailure();
+      }
+    }
+  }
+}
+
+// TODO: pass this to comprehensive bufferize.
+static Value defaultAllocationFn(OpBuilder &builder, Location loc,
+                                 ArrayRef<int64_t> staticShape,
+                                 Type elementType,
+                                 ArrayRef<Value> dynamicSizes) {
+  auto allocationType = MemRefType::get(staticShape, elementType);
+  return builder.create<memref::AllocOp>(loc, allocationType, dynamicSizes);
+}
+
+std::unique_ptr<OperationPass<ModuleOp>> createIREEComprehensiveBufferizePass(
+    linalg::AllocationCallbacks allocationFns) {
+  return std::make_unique<IREEComprehensiveBufferizePass>(allocationFns);
+}
+
+void addIREEComprehensiveBufferizePasses(
+    OpPassManager &passManager, linalg::AllocationCallbacks allocationFns) {
+  passManager.addPass(createIREEComprehensiveBufferizePass(allocationFns));
+  passManager.addPass(memref::createResolveShapedTypeResultDimsPass());
+  passManager.addNestedPass<FuncOp>(createCanonicalizerPass());
+  passManager.addNestedPass<FuncOp>(createCSEPass());
+  passManager.addNestedPass<FuncOp>(createCleanupBufferAllocViewPass());
+}
+
+}  // namespace iree_compiler
+}  // namespace mlir
diff --git a/iree/compiler/Codegen/Common/test/BUILD b/iree/compiler/Codegen/Common/test/BUILD
index c83460a..9475ce3 100644
--- a/iree/compiler/Codegen/Common/test/BUILD
+++ b/iree/compiler/Codegen/Common/test/BUILD
@@ -22,6 +22,7 @@
             "affinemin_canonicalization.mlir",
             "canonicalize_interface_load_store.mlir",
             "dead_alloc.mlir",
+            "iree_comprehensive_bufferize.mlir",
             "f32Tof16.mlir",
             "flatten_memref_subspan.mlir",
             "fold_tensor_extract_op.mlir",
diff --git a/iree/compiler/Codegen/Common/test/CMakeLists.txt b/iree/compiler/Codegen/Common/test/CMakeLists.txt
index 19d4e3f..3336df7 100644
--- a/iree/compiler/Codegen/Common/test/CMakeLists.txt
+++ b/iree/compiler/Codegen/Common/test/CMakeLists.txt
@@ -21,6 +21,7 @@
     "flatten_memref_subspan.mlir"
     "fold_tensor_extract_op.mlir"
     "forop_canonicalization.mlir"
+    "iree_comprehensive_bufferize.mlir"
     "linalg_bufferize.mlir"
     "remove_dead_allocs.mlir"
     "transpose_canonicalization.mlir"
diff --git a/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir b/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir
new file mode 100644
index 0000000..e7973ce
--- /dev/null
+++ b/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir
@@ -0,0 +1,150 @@
+// RUN: iree-opt %s --iree-codegen-iree-comprehensive-bufferize -canonicalize -cse -split-input-file | IreeFileCheck %s
+
+func @matmul() {
+  %c0 = arith.constant 0 : index
+  %m = hal.interface.load.constant offset = 0 : index
+  %n = hal.interface.load.constant offset = 1 : index
+  %k = hal.interface.load.constant offset = 2 : index
+  %lhs = hal.interface.binding.subspan @io::@arg0[%c0] : !flow.dispatch.tensor<readonly:?x?xf32>{%m, %k}
+  %rhs = hal.interface.binding.subspan @io::@arg1[%c0] : !flow.dispatch.tensor<readonly:?x?xf32>{%k, %n}
+  %init = hal.interface.binding.subspan @io::@arg2[%c0] : !flow.dispatch.tensor<readonly:?x?xf32>{%m, %n}
+  %result = hal.interface.binding.subspan @io::@ret0[%c0] : !flow.dispatch.tensor<writeonly:?x?xf32>{%m, %n}
+  %wg_id_y = hal.interface.workgroup.id[1] : index
+  %wg_count_y = hal.interface.workgroup.count[1] : index
+  %wg_size_y = hal.interface.workgroup.size[1] : index
+  %wg_id_x = hal.interface.workgroup.id[0] : index
+  %wg_count_x = hal.interface.workgroup.count[0] : index
+  %wg_size_x = hal.interface.workgroup.size[0] : index
+  %offset_y = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%wg_id_y, %wg_size_y]
+  %step_y = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%wg_count_y, %wg_size_y]
+  %offset_x = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%wg_id_x, %wg_size_x]
+  %step_x = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%wg_count_x, %wg_size_x]
+  scf.for %iv0 = %offset_y to %m step %step_y {
+    %tilesize_y = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%iv0)[%wg_size_y, %m]
+    scf.for %iv1 = %offset_x to %n step %step_x {
+      %tilesize_x = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%iv1)[%wg_size_x, %n]
+      %lhs_tile = flow.dispatch.tensor.load %lhs, offsets = [%iv0, 0], sizes = [%tilesize_y, %k], strides = [1, 1] : !flow.dispatch.tensor<readonly:?x?xf32> -> tensor<?x?xf32>
+      %rhs_tile = flow.dispatch.tensor.load %rhs, offsets = [0, %iv1], sizes = [%k, %tilesize_x], strides = [1, 1] : !flow.dispatch.tensor<readonly:?x?xf32> -> tensor<?x?xf32>
+      %init_tile = flow.dispatch.tensor.load %init, offsets = [%iv0, %iv1], sizes = [%tilesize_y, %tilesize_x], strides = [1, 1] : !flow.dispatch.tensor<readonly:?x?xf32> -> tensor<?x?xf32>
+      %matmul_tile = linalg.matmul ins(%lhs_tile, %rhs_tile : tensor<?x?xf32>, tensor<?x?xf32>) outs(%init_tile : tensor<?x?xf32>) -> tensor<?x?xf32>
+      flow.dispatch.tensor.store %matmul_tile, %result, offsets = [%iv0, %iv1], sizes = [%tilesize_y, %tilesize_x], strides = [1, 1] : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:?x?xf32>
+    }
+  }
+  return
+}
+hal.interface private @io  {
+  hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+  hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+  hal.interface.binding @arg2, set=0, binding=2, type="StorageBuffer", access="Read"
+  hal.interface.binding @ret0, set=0, binding=3, type="StorageBuffer", access="Write|Discard"
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 * s1)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>
+//      CHECK: func @matmul()
+//  CHECK-DAG:   %[[M:.+]] = hal.interface.load.constant offset = 0
+//  CHECK-DAG:   %[[N:.+]] = hal.interface.load.constant offset = 1
+//  CHECK-DAG:   %[[K:.+]] = hal.interface.load.constant offset = 2
+//  CHECK-DAG:   %[[LHS:.+]] = hal.interface.binding.subspan @io::@arg0
+//  CHECK-DAG:   %[[RHS:.+]] = hal.interface.binding.subspan @io::@arg1
+//  CHECK-DAG:   %[[INIT:.+]] = hal.interface.binding.subspan @io::@arg2
+//  CHECK-DAG:   %[[RESULT:.+]] = hal.interface.binding.subspan @io::@ret0
+//  CHECK-DAG:   %[[WG_ID_Y:.+]] = hal.interface.workgroup.id[1]
+//  CHECK-DAG:   %[[WG_COUNT_Y:.+]] = hal.interface.workgroup.count[1]
+//  CHECK-DAG:   %[[WG_SIZE_Y:.+]] = hal.interface.workgroup.size[1]
+//  CHECK-DAG:   %[[WG_ID_X:.+]] = hal.interface.workgroup.id[0]
+//  CHECK-DAG:   %[[WG_COUNT_X:.+]] = hal.interface.workgroup.count[0]
+//  CHECK-DAG:   %[[WG_SIZE_X:.+]] = hal.interface.workgroup.size[0]
+//  CHECK-DAG:   %[[OFFSET_Y:.+]] = affine.apply #[[MAP0]]()[%[[WG_ID_Y]], %[[WG_SIZE_Y]]]
+//  CHECK-DAG:   %[[STEP_Y:.+]] = affine.apply #[[MAP0]]()[%[[WG_COUNT_Y]], %[[WG_SIZE_Y]]]
+//  CHECK-DAG:   %[[OFFSET_X:.+]] = affine.apply #[[MAP0]]()[%[[WG_ID_X]], %[[WG_SIZE_X]]]
+//  CHECK-DAG:   %[[STEP_X:.+]] = affine.apply #[[MAP0]]()[%[[WG_COUNT_X]], %[[WG_SIZE_X]]]
+//      CHECK:   scf.for %[[IV0:.+]] = %[[OFFSET_Y]] to %[[M]] step %[[STEP_Y]]
+//      CHECK:     %[[TILESIZE_Y:.+]] = affine.min #[[MAP1]](%[[IV0]])[%[[WG_SIZE_Y]], %[[M]]]
+//      CHECK:     scf.for %[[IV1:.+]] = %[[OFFSET_X]] to %[[N]] step %[[STEP_X]]
+//      CHECK:       %[[TILESIZE_X:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[WG_SIZE_X]], %[[N]]]
+//  CHECK-DAG:       %[[LHS_TILE:.+]] = memref.subview %[[LHS]][%[[IV0]], 0] [%[[TILESIZE_Y]], %[[K]]]
+//  CHECK-DAG:       %[[RHS_TILE:.+]] = memref.subview %[[RHS]][0, %[[IV1]]] [%[[K]], %[[TILESIZE_X]]]
+//  CHECK-DAG:       %[[ALLOC:.+]] = memref.alloc(%[[TILESIZE_Y]], %[[TILESIZE_X]]) {alignment = 128 : i64}
+//  CHECK-DAG:       %[[INIT_TILE:.+]] = memref.subview %[[INIT]][%[[IV0]], %[[IV1]]] [%[[TILESIZE_Y]], %[[TILESIZE_X]]]
+//      CHECK:       linalg.copy(%[[INIT_TILE]], %[[ALLOC]])
+//      CHECK:       linalg.matmul
+// CHECK-SAME:           ins(%[[LHS_TILE]], %[[RHS_TILE]]
+// CHECK-SAME:           outs(%[[ALLOC]]
+//      CHECK:       %[[RESULT_TILE:.+]] = memref.subview %[[RESULT]][%[[IV0]], %[[IV1]]] [%[[TILESIZE_Y]], %[[TILESIZE_X]]]
+//      CHECK:       linalg.copy(%[[ALLOC]], %[[RESULT_TILE]])
+//      CHECK:       memref.dealloc %[[ALLOC]]
+
+// -----
+
+func @matmul_fill() {
+  %cst = arith.constant 0.0 : f32
+  %c0 = arith.constant 0 : index
+  %m = hal.interface.load.constant offset = 0 : index
+  %n = hal.interface.load.constant offset = 1 : index
+  %k = hal.interface.load.constant offset = 2 : index
+  %lhs = hal.interface.binding.subspan @io::@arg0[%c0] : !flow.dispatch.tensor<readonly:?x?xf32>{%m, %k}
+  %rhs = hal.interface.binding.subspan @io::@arg1[%c0] : !flow.dispatch.tensor<readonly:?x?xf32>{%k, %n}
+  %result = hal.interface.binding.subspan @io::@ret0[%c0] : !flow.dispatch.tensor<writeonly:?x?xf32>{%m, %n}
+  %wg_id_y = hal.interface.workgroup.id[1] : index
+  %wg_count_y = hal.interface.workgroup.count[1] : index
+  %wg_size_y = hal.interface.workgroup.size[1] : index
+  %wg_id_x = hal.interface.workgroup.id[0] : index
+  %wg_count_x = hal.interface.workgroup.count[0] : index
+  %wg_size_x = hal.interface.workgroup.size[0] : index
+  %offset_y = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%wg_id_y, %wg_size_y]
+  %step_y = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%wg_count_y, %wg_size_y]
+  %offset_x = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%wg_id_x, %wg_size_x]
+  %step_x = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%wg_count_x, %wg_size_x]
+  scf.for %iv0 = %offset_y to %m step %step_y {
+    %tilesize_y = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%iv0)[%wg_size_y, %m]
+    scf.for %iv1 = %offset_x to %n step %step_x {
+      %tilesize_x = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%iv1)[%wg_size_x, %n]
+      %lhs_tile = flow.dispatch.tensor.load %lhs, offsets = [%iv0, 0], sizes = [%tilesize_y, %k], strides = [1, 1] : !flow.dispatch.tensor<readonly:?x?xf32> -> tensor<?x?xf32>
+      %rhs_tile = flow.dispatch.tensor.load %rhs, offsets = [0, %iv1], sizes = [%k, %tilesize_x], strides = [1, 1] : !flow.dispatch.tensor<readonly:?x?xf32> -> tensor<?x?xf32>
+      %init_tile = linalg.init_tensor [%tilesize_y, %tilesize_x] : tensor<?x?xf32>
+      %fill_tile = linalg.fill(%cst, %init_tile) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
+      %matmul_tile = linalg.matmul ins(%lhs_tile, %rhs_tile : tensor<?x?xf32>, tensor<?x?xf32>) outs(%fill_tile : tensor<?x?xf32>) -> tensor<?x?xf32>
+      flow.dispatch.tensor.store %matmul_tile, %result, offsets = [%iv0, %iv1], sizes = [%tilesize_y, %tilesize_x], strides = [1, 1] : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:?x?xf32>
+    }
+  }
+  return
+}
+hal.interface private @io  {
+  hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+  hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+  hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 * s1)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>
+//      CHECK: func @matmul_fill()
+//  CHECK-DAG:   %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+//  CHECK-DAG:   %[[M:.+]] = hal.interface.load.constant offset = 0
+//  CHECK-DAG:   %[[N:.+]] = hal.interface.load.constant offset = 1
+//  CHECK-DAG:   %[[K:.+]] = hal.interface.load.constant offset = 2
+//  CHECK-DAG:   %[[LHS:.+]] = hal.interface.binding.subspan @io::@arg0
+//  CHECK-DAG:   %[[RHS:.+]] = hal.interface.binding.subspan @io::@arg1
+//  CHECK-DAG:   %[[RESULT:.+]] = hal.interface.binding.subspan @io::@ret0
+//  CHECK-DAG:   %[[WG_ID_Y:.+]] = hal.interface.workgroup.id[1]
+//  CHECK-DAG:   %[[WG_COUNT_Y:.+]] = hal.interface.workgroup.count[1]
+//  CHECK-DAG:   %[[WG_SIZE_Y:.+]] = hal.interface.workgroup.size[1]
+//  CHECK-DAG:   %[[WG_ID_X:.+]] = hal.interface.workgroup.id[0]
+//  CHECK-DAG:   %[[WG_COUNT_X:.+]] = hal.interface.workgroup.count[0]
+//  CHECK-DAG:   %[[WG_SIZE_X:.+]] = hal.interface.workgroup.size[0]
+//  CHECK-DAG:   %[[OFFSET_Y:.+]] = affine.apply #[[MAP0]]()[%[[WG_ID_Y]], %[[WG_SIZE_Y]]]
+//  CHECK-DAG:   %[[STEP_Y:.+]] = affine.apply #[[MAP0]]()[%[[WG_COUNT_Y]], %[[WG_SIZE_Y]]]
+//  CHECK-DAG:   %[[OFFSET_X:.+]] = affine.apply #[[MAP0]]()[%[[WG_ID_X]], %[[WG_SIZE_X]]]
+//  CHECK-DAG:   %[[STEP_X:.+]] = affine.apply #[[MAP0]]()[%[[WG_COUNT_X]], %[[WG_SIZE_X]]]
+//      CHECK:   scf.for %[[IV0:.+]] = %[[OFFSET_Y]] to %[[M]] step %[[STEP_Y]]
+//      CHECK:     %[[TILESIZE_Y:.+]] = affine.min #[[MAP1]](%[[IV0]])[%[[WG_SIZE_Y]], %[[M]]]
+//      CHECK:     scf.for %[[IV1:.+]] = %[[OFFSET_X]] to %[[N]] step %[[STEP_X]]
+//      CHECK:       %[[TILESIZE_X:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[WG_SIZE_X]], %[[N]]]
+//  CHECK-DAG:       %[[LHS_TILE:.+]] = memref.subview %[[LHS]][%[[IV0]], 0] [%[[TILESIZE_Y]], %[[K]]]
+//  CHECK-DAG:       %[[RHS_TILE:.+]] = memref.subview %[[RHS]][0, %[[IV1]]] [%[[K]], %[[TILESIZE_X]]]
+//  CHECK-DAG:       %[[ALLOC:.+]] = memref.alloc(%[[TILESIZE_Y]], %[[TILESIZE_X]]) {alignment = 128 : i64}
+//      CHECK:       linalg.fill(%[[CST]], %[[ALLOC]])
+//      CHECK:       linalg.matmul
+// CHECK-SAME:           ins(%[[LHS_TILE]], %[[RHS_TILE]]
+// CHECK-SAME:           outs(%[[ALLOC]]
+//  CHECK-DAG:       %[[RESULT_TILE:.+]] = memref.subview %[[RESULT]][%[[IV0]], %[[IV1]]] [%[[TILESIZE_Y]], %[[TILESIZE_X]]]
+//      CHECK:       linalg.copy(%[[ALLOC]], %[[RESULT_TILE]])
+//      CHECK:       memref.dealloc %[[ALLOC]]
diff --git a/iree/compiler/Codegen/Passes.h b/iree/compiler/Codegen/Passes.h
index d550517..51795ac 100644
--- a/iree/compiler/Codegen/Passes.h
+++ b/iree/compiler/Codegen/Passes.h
@@ -10,6 +10,7 @@
 #include <memory>
 
 #include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassOptions.h"
@@ -39,6 +40,9 @@
 void addLinalgBufferizePasses(
     OpPassManager &passManager,
     WorkgroupMemoryAllocationFn allocationFn = nullptr);
+void addIREEComprehensiveBufferizePasses(
+    OpPassManager &passManager,
+    linalg::AllocationCallbacks allocationFn = linalg::AllocationCallbacks());
 
 /// Pass to perform canonicalizations/cleanups related to HAL interface/buffer
 /// allocations and view operations.
@@ -72,6 +76,8 @@
 /// and default memory space.
 std::unique_ptr<OperationPass<FuncOp>> createLinalgBufferizePass(
     WorkgroupMemoryAllocationFn allocationFn = nullptr);
+std::unique_ptr<OperationPass<ModuleOp>> createIREEComprehensiveBufferizePass(
+    linalg::AllocationCallbacks = linalg::AllocationCallbacks());
 
 /// Creates a pass to vectorize a very specific form of linalg.conv ops.
 std::unique_ptr<OperationPass<FuncOp>> createLinalgToVectorVectorizeConvPass();
diff --git a/iree/compiler/Codegen/Passes.td b/iree/compiler/Codegen/Passes.td
index 83e376f..6e650fd 100644
--- a/iree/compiler/Codegen/Passes.td
+++ b/iree/compiler/Codegen/Passes.td
@@ -52,6 +52,12 @@
   let constructor = "mlir::iree_compiler::createLinalgBufferizePass(nullptr)";
 }
 
+def IREEComprehensiveBufferize :
+    Pass<"iree-codegen-iree-comprehensive-bufferize", "ModuleOp"> {
+  let summary = "Convert from to Linalg ops on tensors to buffers";
+  let constructor = "mlir::iree_compiler::createIREEComprehensiveBufferizePass()";
+}
+
 def OptimizeVectorTransfer :
     Pass<"iree-codegen-optimize-vector-transfer", "FuncOp"> {
   let summary =