Initial commit of bridge to use Linalg Comprehensive Bufferize Pass. (#7509)
Co-authored-by: Nicolas Vasilache <nicolasvasilache@users.noreply.github.com>
This is the initial commit to use Linalg Comprehensive bufferize pass
from within IREE. The main work here is to bridge the memory model gap
of using flow.dispatch.tensor.load/store operations.
diff --git a/iree/compiler/Codegen/Common/BUILD b/iree/compiler/Codegen/Common/BUILD
index 0593b9b..380812c 100644
--- a/iree/compiler/Codegen/Common/BUILD
+++ b/iree/compiler/Codegen/Common/BUILD
@@ -37,6 +37,7 @@
"FlattenMemRefSubspanPass.cpp",
"FoldTensorExtractOpPass.cpp",
"ForOpCanonicalizationPass.cpp",
+ "IREEComprehensiveBufferizePass.cpp",
"LinalgBufferizePass.cpp",
"OptimizeVectorTransferPass.cpp",
"SetNumWorkgroupsPass.cpp",
diff --git a/iree/compiler/Codegen/Common/CMakeLists.txt b/iree/compiler/Codegen/Common/CMakeLists.txt
index ce29221..780f6fd 100644
--- a/iree/compiler/Codegen/Common/CMakeLists.txt
+++ b/iree/compiler/Codegen/Common/CMakeLists.txt
@@ -28,6 +28,7 @@
"FlattenMemRefSubspanPass.cpp"
"FoldTensorExtractOpPass.cpp"
"ForOpCanonicalizationPass.cpp"
+ "IREEComprehensiveBufferizePass.cpp"
"LinalgBufferizePass.cpp"
"OptimizeVectorTransferPass.cpp"
"SetNumWorkgroupsPass.cpp"
diff --git a/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp b/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp
new file mode 100644
index 0000000..210e28c
--- /dev/null
+++ b/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp
@@ -0,0 +1,243 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+//===- IREEComprehensiveBufferizePass.cpp.cpp - -------------------------===//
+//
+// Wrapper pass to use MLIRs ComprehensiveBufferization pass.
+//
+//===----------------------------------------------------------------------===//
+
+#include "iree/compiler/Codegen/PassDetail.h"
+#include "iree/compiler/Codegen/Passes.h"
+#include "iree/compiler/Codegen/Transforms/Transforms.h"
+#include "iree/compiler/Codegen/Utils/Utils.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowTypes.h"
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
+#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
+#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
+#include "llvm/ADT/EquivalenceClasses.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/ScopeExit.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/BufferUtils.h"
+#include "mlir/Transforms/Passes.h"
+
+#define DEBUG_TYPE "iree-codegen-linalg-bufferize"
+
+namespace mlir {
+template <typename TensorType>
+static MemRefType getMemrefTypeForTensor(TensorType tensorType,
+ MemRefLayoutAttrInterface layout = {},
+ Attribute memorySpace = {}) {
+ return MemRefType::get(tensorType.getShape(), tensorType.getElementType(),
+ layout, memorySpace);
+}
+
+namespace iree_compiler {
+
+namespace {
+/// Pass to convert from tensor based ops to memref based ops.
+class IREEComprehensiveBufferizePass
+ : public IREEComprehensiveBufferizeBase<IREEComprehensiveBufferizePass> {
+ public:
+ IREEComprehensiveBufferizePass(
+ linalg::AllocationCallbacks allocationFn = linalg::AllocationCallbacks())
+ : allocationFn(allocationFn) {}
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry
+ .insert<IREE::Util::UtilDialect, linalg::LinalgDialect,
+ memref::MemRefDialect, scf::SCFDialect, StandardOpsDialect>();
+ }
+ void runOnOperation() override;
+
+ private:
+ linalg::AllocationCallbacks allocationFn;
+};
+} // namespace
+
+static bool isaTensor(Type t) { return t.isa<TensorType>(); };
+
+/// Stitch comprehensive bufferization inside of IREE by proceeding as follows:
+/// 1. a. Bufferizes InterfaceBindingSubspanOp optimistically
+/// b. Insert a memref::TensorLoad to serve as glue between the buffer and
+/// tensor worlds.
+/// c. Record aliasInfo of memref::TensorLoad manually
+/// d. Record inplaceability of memref::TensorLoad manually
+/// e. Record the bufferization of memref::TensorLoad manually
+/// 2. Rewrite all Flow::Dispatch::TensorLoad ops as Tensor::ExtractSliceOp
+/// that comprehensive bufferization understands.
+/// 3. Specifically select the ops we want to bufferize / skip. In the future,
+/// this may be better specified with a BufferizationOpInterface.
+/// 4. Perform analysis and bufferization on the ops.
+void IREEComprehensiveBufferizePass::runOnOperation() {
+ ModuleOp moduleOp = getOperation();
+ MLIRContext *context = &getContext();
+
+ for (auto funcOp : moduleOp.getOps<FuncOp>()) {
+ OpBuilder b(context);
+
+ // 1. First go over all hal.interface.binding.subspan ops and create
+ // counterparts working with memrefs.
+ BlockAndValueMapping bvm, tensorLoads;
+ linalg::BufferizationAliasInfo aliasInfo(funcOp);
+ // These are used until late, erase on scoped exit.
+ SmallVector<Operation *> toEraseLate;
+ auto scopeGuard = llvm::make_scope_exit([&]() {
+ for (Operation *op : llvm::reverse(toEraseLate)) op->erase();
+ });
+ funcOp.walk([&](IREE::HAL::InterfaceBindingSubspanOp op) {
+ auto shapedType =
+ op.getResult().getType().dyn_cast<IREE::Flow::DispatchTensorType>();
+ if (!shapedType || !shapedType.hasRank()) return;
+ OpBuilder::InsertionGuard g(b);
+ b.setInsertionPoint(op);
+ // 1.a. Just change the result type of the InterfaceBindingSubspanOp to
+ // from the base buffer.
+ auto memRefType = getMemrefTypeForTensor(shapedType);
+ auto baseBuffer = b.create<IREE::HAL::InterfaceBindingSubspanOp>(
+ op->getLoc(), memRefType, op.binding(), op.byte_offset(),
+ op.byte_length(), op.dynamic_dims());
+ bvm.map(op, baseBuffer);
+
+ // This op does not operate on core tensor types and has half-side
+ // effecting semantics. It cannot be added to BufferizationAliasInfo.
+ // Instead:
+ // 1.b. Insert a memref::TensorLoad to serve as glue between the buffer
+ // and tensor worlds.
+ Value tensor = b.create<memref::TensorLoadOp>(op->getLoc(), baseBuffer);
+ // 1.c. Insert a new entry manually into the existing aliasInfo.
+ aliasInfo.createAliasInfoEntry(op.result());
+ aliasInfo.createAliasInfoEntry(tensor);
+ tensorLoads.map(op.result(), tensor);
+ // 1.d. Mark tensors that bufferize to writeable memory as such.
+ if (shapedType.getAccess() != IREE::Flow::TensorAccess::ReadOnly) {
+ aliasInfo.setBufferizesToWritableMemory(tensor);
+ }
+ // 1.e. Save tensor -> baseBuffer into BVM.
+ bvm.map(tensor, baseBuffer);
+
+ // Drop the original op that is now bufferized.
+ toEraseLate.push_back(op);
+ });
+
+ // 2. Rewrite all Flow::Dispatch::TensorLoad ops as Tensor::ExtractSliceOp.
+ funcOp.walk<WalkOrder::PostOrder>([&](IREE::Flow::DispatchTensorLoadOp op) {
+ OpBuilder b(op);
+ Value v = b.create<tensor::ExtractSliceOp>(
+ op->getLoc(), op.result().getType().cast<RankedTensorType>(),
+ tensorLoads.lookup(op.source()), op.getMixedOffsets(),
+ op.getMixedSizes(), op.getMixedStrides());
+ // Insert a new entry manually into the existing aliasInfo.
+ aliasInfo.createAliasInfoEntry(v);
+ op.result().replaceAllUsesWith(v);
+ toEraseLate.push_back(op);
+ });
+ funcOp.walk<WalkOrder::PostOrder>(
+ [&](IREE::Flow::DispatchTensorStoreOp op) {
+ OpBuilder b(op);
+ Value v = b.create<tensor::InsertSliceOp>(
+ op->getLoc(), op.value(), tensorLoads.lookup(op.target()),
+ op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides());
+ // Insert a new entry manually into the existing aliasInfo.
+ aliasInfo.createAliasInfoEntry(v);
+ toEraseLate.push_back(op);
+ });
+
+ // TODO: Visit all the operations that return `tensor`s that are not handled
+ // by comprehensive bufferize.
+
+ // 3. Specifically select the ops we want to bufferize / skip. In the
+ // future, this may be better specified with a BufferizationOpInterface.
+ DominanceInfo domInfo(funcOp);
+ SmallVector<Operation *> ops;
+ ops.reserve(funcOp.body().front().getOperations().size());
+ WalkResult opsSelected =
+ funcOp.body().walk([&](Operation *op) -> WalkResult {
+ if (isa<IREE::HAL::InterfaceBindingSubspanOp,
+ IREE::Flow::DispatchTensorLoadOp,
+ IREE::Flow::DispatchTensorStoreOp>(op)) {
+ return WalkResult::advance();
+ }
+ if (llvm::none_of(op->getOperandTypes(), isaTensor) &&
+ llvm::none_of(op->getResultTypes(), isaTensor)) {
+ return WalkResult::advance();
+ }
+ if (op->getParentOfType<linalg::LinalgOp>())
+ return WalkResult::advance();
+ // TODO: if we want to bufferize function calls, we need FuncOp
+ // and to pass a proper bufferizedFunctionTypes.
+ if (isa<CallOpInterface>(op)) {
+ return static_cast<LogicalResult>(op->emitError(
+ "CallOpInterface bufferization not supported in IREE"));
+ }
+ ops.push_back(op);
+ return WalkResult::advance();
+ });
+
+ // 4. Perform inplaceability analysis of `ops`.
+ if (opsSelected.wasInterrupted() ||
+ failed(linalg::inPlaceAnalysis(ops, aliasInfo, domInfo))) {
+ return signalPassFailure();
+ }
+
+ // 5. Perform bufferization.
+ for (Operation *op : ops) {
+ if (failed(linalg::bufferizeOp(op, bvm, aliasInfo,
+ linalg::AllocationCallbacks(),
+ /*bufferizedFunctionTypes=*/nullptr))) {
+ return signalPassFailure();
+ }
+ }
+ }
+}
+
+// TODO: pass this to comprehensive bufferize.
+static Value defaultAllocationFn(OpBuilder &builder, Location loc,
+ ArrayRef<int64_t> staticShape,
+ Type elementType,
+ ArrayRef<Value> dynamicSizes) {
+ auto allocationType = MemRefType::get(staticShape, elementType);
+ return builder.create<memref::AllocOp>(loc, allocationType, dynamicSizes);
+}
+
+std::unique_ptr<OperationPass<ModuleOp>> createIREEComprehensiveBufferizePass(
+ linalg::AllocationCallbacks allocationFns) {
+ return std::make_unique<IREEComprehensiveBufferizePass>(allocationFns);
+}
+
+void addIREEComprehensiveBufferizePasses(
+ OpPassManager &passManager, linalg::AllocationCallbacks allocationFns) {
+ passManager.addPass(createIREEComprehensiveBufferizePass(allocationFns));
+ passManager.addPass(memref::createResolveShapedTypeResultDimsPass());
+ passManager.addNestedPass<FuncOp>(createCanonicalizerPass());
+ passManager.addNestedPass<FuncOp>(createCSEPass());
+ passManager.addNestedPass<FuncOp>(createCleanupBufferAllocViewPass());
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Codegen/Common/test/BUILD b/iree/compiler/Codegen/Common/test/BUILD
index c83460a..9475ce3 100644
--- a/iree/compiler/Codegen/Common/test/BUILD
+++ b/iree/compiler/Codegen/Common/test/BUILD
@@ -22,6 +22,7 @@
"affinemin_canonicalization.mlir",
"canonicalize_interface_load_store.mlir",
"dead_alloc.mlir",
+ "iree_comprehensive_bufferize.mlir",
"f32Tof16.mlir",
"flatten_memref_subspan.mlir",
"fold_tensor_extract_op.mlir",
diff --git a/iree/compiler/Codegen/Common/test/CMakeLists.txt b/iree/compiler/Codegen/Common/test/CMakeLists.txt
index 19d4e3f..3336df7 100644
--- a/iree/compiler/Codegen/Common/test/CMakeLists.txt
+++ b/iree/compiler/Codegen/Common/test/CMakeLists.txt
@@ -21,6 +21,7 @@
"flatten_memref_subspan.mlir"
"fold_tensor_extract_op.mlir"
"forop_canonicalization.mlir"
+ "iree_comprehensive_bufferize.mlir"
"linalg_bufferize.mlir"
"remove_dead_allocs.mlir"
"transpose_canonicalization.mlir"
diff --git a/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir b/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir
new file mode 100644
index 0000000..e7973ce
--- /dev/null
+++ b/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir
@@ -0,0 +1,150 @@
+// RUN: iree-opt %s --iree-codegen-iree-comprehensive-bufferize -canonicalize -cse -split-input-file | IreeFileCheck %s
+
+func @matmul() {
+ %c0 = arith.constant 0 : index
+ %m = hal.interface.load.constant offset = 0 : index
+ %n = hal.interface.load.constant offset = 1 : index
+ %k = hal.interface.load.constant offset = 2 : index
+ %lhs = hal.interface.binding.subspan @io::@arg0[%c0] : !flow.dispatch.tensor<readonly:?x?xf32>{%m, %k}
+ %rhs = hal.interface.binding.subspan @io::@arg1[%c0] : !flow.dispatch.tensor<readonly:?x?xf32>{%k, %n}
+ %init = hal.interface.binding.subspan @io::@arg2[%c0] : !flow.dispatch.tensor<readonly:?x?xf32>{%m, %n}
+ %result = hal.interface.binding.subspan @io::@ret0[%c0] : !flow.dispatch.tensor<writeonly:?x?xf32>{%m, %n}
+ %wg_id_y = hal.interface.workgroup.id[1] : index
+ %wg_count_y = hal.interface.workgroup.count[1] : index
+ %wg_size_y = hal.interface.workgroup.size[1] : index
+ %wg_id_x = hal.interface.workgroup.id[0] : index
+ %wg_count_x = hal.interface.workgroup.count[0] : index
+ %wg_size_x = hal.interface.workgroup.size[0] : index
+ %offset_y = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%wg_id_y, %wg_size_y]
+ %step_y = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%wg_count_y, %wg_size_y]
+ %offset_x = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%wg_id_x, %wg_size_x]
+ %step_x = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%wg_count_x, %wg_size_x]
+ scf.for %iv0 = %offset_y to %m step %step_y {
+ %tilesize_y = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%iv0)[%wg_size_y, %m]
+ scf.for %iv1 = %offset_x to %n step %step_x {
+ %tilesize_x = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%iv1)[%wg_size_x, %n]
+ %lhs_tile = flow.dispatch.tensor.load %lhs, offsets = [%iv0, 0], sizes = [%tilesize_y, %k], strides = [1, 1] : !flow.dispatch.tensor<readonly:?x?xf32> -> tensor<?x?xf32>
+ %rhs_tile = flow.dispatch.tensor.load %rhs, offsets = [0, %iv1], sizes = [%k, %tilesize_x], strides = [1, 1] : !flow.dispatch.tensor<readonly:?x?xf32> -> tensor<?x?xf32>
+ %init_tile = flow.dispatch.tensor.load %init, offsets = [%iv0, %iv1], sizes = [%tilesize_y, %tilesize_x], strides = [1, 1] : !flow.dispatch.tensor<readonly:?x?xf32> -> tensor<?x?xf32>
+ %matmul_tile = linalg.matmul ins(%lhs_tile, %rhs_tile : tensor<?x?xf32>, tensor<?x?xf32>) outs(%init_tile : tensor<?x?xf32>) -> tensor<?x?xf32>
+ flow.dispatch.tensor.store %matmul_tile, %result, offsets = [%iv0, %iv1], sizes = [%tilesize_y, %tilesize_x], strides = [1, 1] : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:?x?xf32>
+ }
+ }
+ return
+}
+hal.interface private @io {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg2, set=0, binding=2, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=3, type="StorageBuffer", access="Write|Discard"
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 * s1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>
+// CHECK: func @matmul()
+// CHECK-DAG: %[[M:.+]] = hal.interface.load.constant offset = 0
+// CHECK-DAG: %[[N:.+]] = hal.interface.load.constant offset = 1
+// CHECK-DAG: %[[K:.+]] = hal.interface.load.constant offset = 2
+// CHECK-DAG: %[[LHS:.+]] = hal.interface.binding.subspan @io::@arg0
+// CHECK-DAG: %[[RHS:.+]] = hal.interface.binding.subspan @io::@arg1
+// CHECK-DAG: %[[INIT:.+]] = hal.interface.binding.subspan @io::@arg2
+// CHECK-DAG: %[[RESULT:.+]] = hal.interface.binding.subspan @io::@ret0
+// CHECK-DAG: %[[WG_ID_Y:.+]] = hal.interface.workgroup.id[1]
+// CHECK-DAG: %[[WG_COUNT_Y:.+]] = hal.interface.workgroup.count[1]
+// CHECK-DAG: %[[WG_SIZE_Y:.+]] = hal.interface.workgroup.size[1]
+// CHECK-DAG: %[[WG_ID_X:.+]] = hal.interface.workgroup.id[0]
+// CHECK-DAG: %[[WG_COUNT_X:.+]] = hal.interface.workgroup.count[0]
+// CHECK-DAG: %[[WG_SIZE_X:.+]] = hal.interface.workgroup.size[0]
+// CHECK-DAG: %[[OFFSET_Y:.+]] = affine.apply #[[MAP0]]()[%[[WG_ID_Y]], %[[WG_SIZE_Y]]]
+// CHECK-DAG: %[[STEP_Y:.+]] = affine.apply #[[MAP0]]()[%[[WG_COUNT_Y]], %[[WG_SIZE_Y]]]
+// CHECK-DAG: %[[OFFSET_X:.+]] = affine.apply #[[MAP0]]()[%[[WG_ID_X]], %[[WG_SIZE_X]]]
+// CHECK-DAG: %[[STEP_X:.+]] = affine.apply #[[MAP0]]()[%[[WG_COUNT_X]], %[[WG_SIZE_X]]]
+// CHECK: scf.for %[[IV0:.+]] = %[[OFFSET_Y]] to %[[M]] step %[[STEP_Y]]
+// CHECK: %[[TILESIZE_Y:.+]] = affine.min #[[MAP1]](%[[IV0]])[%[[WG_SIZE_Y]], %[[M]]]
+// CHECK: scf.for %[[IV1:.+]] = %[[OFFSET_X]] to %[[N]] step %[[STEP_X]]
+// CHECK: %[[TILESIZE_X:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[WG_SIZE_X]], %[[N]]]
+// CHECK-DAG: %[[LHS_TILE:.+]] = memref.subview %[[LHS]][%[[IV0]], 0] [%[[TILESIZE_Y]], %[[K]]]
+// CHECK-DAG: %[[RHS_TILE:.+]] = memref.subview %[[RHS]][0, %[[IV1]]] [%[[K]], %[[TILESIZE_X]]]
+// CHECK-DAG: %[[ALLOC:.+]] = memref.alloc(%[[TILESIZE_Y]], %[[TILESIZE_X]]) {alignment = 128 : i64}
+// CHECK-DAG: %[[INIT_TILE:.+]] = memref.subview %[[INIT]][%[[IV0]], %[[IV1]]] [%[[TILESIZE_Y]], %[[TILESIZE_X]]]
+// CHECK: linalg.copy(%[[INIT_TILE]], %[[ALLOC]])
+// CHECK: linalg.matmul
+// CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]]
+// CHECK-SAME: outs(%[[ALLOC]]
+// CHECK: %[[RESULT_TILE:.+]] = memref.subview %[[RESULT]][%[[IV0]], %[[IV1]]] [%[[TILESIZE_Y]], %[[TILESIZE_X]]]
+// CHECK: linalg.copy(%[[ALLOC]], %[[RESULT_TILE]])
+// CHECK: memref.dealloc %[[ALLOC]]
+
+// -----
+
+func @matmul_fill() {
+ %cst = arith.constant 0.0 : f32
+ %c0 = arith.constant 0 : index
+ %m = hal.interface.load.constant offset = 0 : index
+ %n = hal.interface.load.constant offset = 1 : index
+ %k = hal.interface.load.constant offset = 2 : index
+ %lhs = hal.interface.binding.subspan @io::@arg0[%c0] : !flow.dispatch.tensor<readonly:?x?xf32>{%m, %k}
+ %rhs = hal.interface.binding.subspan @io::@arg1[%c0] : !flow.dispatch.tensor<readonly:?x?xf32>{%k, %n}
+ %result = hal.interface.binding.subspan @io::@ret0[%c0] : !flow.dispatch.tensor<writeonly:?x?xf32>{%m, %n}
+ %wg_id_y = hal.interface.workgroup.id[1] : index
+ %wg_count_y = hal.interface.workgroup.count[1] : index
+ %wg_size_y = hal.interface.workgroup.size[1] : index
+ %wg_id_x = hal.interface.workgroup.id[0] : index
+ %wg_count_x = hal.interface.workgroup.count[0] : index
+ %wg_size_x = hal.interface.workgroup.size[0] : index
+ %offset_y = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%wg_id_y, %wg_size_y]
+ %step_y = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%wg_count_y, %wg_size_y]
+ %offset_x = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%wg_id_x, %wg_size_x]
+ %step_x = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%wg_count_x, %wg_size_x]
+ scf.for %iv0 = %offset_y to %m step %step_y {
+ %tilesize_y = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%iv0)[%wg_size_y, %m]
+ scf.for %iv1 = %offset_x to %n step %step_x {
+ %tilesize_x = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%iv1)[%wg_size_x, %n]
+ %lhs_tile = flow.dispatch.tensor.load %lhs, offsets = [%iv0, 0], sizes = [%tilesize_y, %k], strides = [1, 1] : !flow.dispatch.tensor<readonly:?x?xf32> -> tensor<?x?xf32>
+ %rhs_tile = flow.dispatch.tensor.load %rhs, offsets = [0, %iv1], sizes = [%k, %tilesize_x], strides = [1, 1] : !flow.dispatch.tensor<readonly:?x?xf32> -> tensor<?x?xf32>
+ %init_tile = linalg.init_tensor [%tilesize_y, %tilesize_x] : tensor<?x?xf32>
+ %fill_tile = linalg.fill(%cst, %init_tile) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
+ %matmul_tile = linalg.matmul ins(%lhs_tile, %rhs_tile : tensor<?x?xf32>, tensor<?x?xf32>) outs(%fill_tile : tensor<?x?xf32>) -> tensor<?x?xf32>
+ flow.dispatch.tensor.store %matmul_tile, %result, offsets = [%iv0, %iv1], sizes = [%tilesize_y, %tilesize_x], strides = [1, 1] : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:?x?xf32>
+ }
+ }
+ return
+}
+hal.interface private @io {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 * s1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>
+// CHECK: func @matmul_fill()
+// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[M:.+]] = hal.interface.load.constant offset = 0
+// CHECK-DAG: %[[N:.+]] = hal.interface.load.constant offset = 1
+// CHECK-DAG: %[[K:.+]] = hal.interface.load.constant offset = 2
+// CHECK-DAG: %[[LHS:.+]] = hal.interface.binding.subspan @io::@arg0
+// CHECK-DAG: %[[RHS:.+]] = hal.interface.binding.subspan @io::@arg1
+// CHECK-DAG: %[[RESULT:.+]] = hal.interface.binding.subspan @io::@ret0
+// CHECK-DAG: %[[WG_ID_Y:.+]] = hal.interface.workgroup.id[1]
+// CHECK-DAG: %[[WG_COUNT_Y:.+]] = hal.interface.workgroup.count[1]
+// CHECK-DAG: %[[WG_SIZE_Y:.+]] = hal.interface.workgroup.size[1]
+// CHECK-DAG: %[[WG_ID_X:.+]] = hal.interface.workgroup.id[0]
+// CHECK-DAG: %[[WG_COUNT_X:.+]] = hal.interface.workgroup.count[0]
+// CHECK-DAG: %[[WG_SIZE_X:.+]] = hal.interface.workgroup.size[0]
+// CHECK-DAG: %[[OFFSET_Y:.+]] = affine.apply #[[MAP0]]()[%[[WG_ID_Y]], %[[WG_SIZE_Y]]]
+// CHECK-DAG: %[[STEP_Y:.+]] = affine.apply #[[MAP0]]()[%[[WG_COUNT_Y]], %[[WG_SIZE_Y]]]
+// CHECK-DAG: %[[OFFSET_X:.+]] = affine.apply #[[MAP0]]()[%[[WG_ID_X]], %[[WG_SIZE_X]]]
+// CHECK-DAG: %[[STEP_X:.+]] = affine.apply #[[MAP0]]()[%[[WG_COUNT_X]], %[[WG_SIZE_X]]]
+// CHECK: scf.for %[[IV0:.+]] = %[[OFFSET_Y]] to %[[M]] step %[[STEP_Y]]
+// CHECK: %[[TILESIZE_Y:.+]] = affine.min #[[MAP1]](%[[IV0]])[%[[WG_SIZE_Y]], %[[M]]]
+// CHECK: scf.for %[[IV1:.+]] = %[[OFFSET_X]] to %[[N]] step %[[STEP_X]]
+// CHECK: %[[TILESIZE_X:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[WG_SIZE_X]], %[[N]]]
+// CHECK-DAG: %[[LHS_TILE:.+]] = memref.subview %[[LHS]][%[[IV0]], 0] [%[[TILESIZE_Y]], %[[K]]]
+// CHECK-DAG: %[[RHS_TILE:.+]] = memref.subview %[[RHS]][0, %[[IV1]]] [%[[K]], %[[TILESIZE_X]]]
+// CHECK-DAG: %[[ALLOC:.+]] = memref.alloc(%[[TILESIZE_Y]], %[[TILESIZE_X]]) {alignment = 128 : i64}
+// CHECK: linalg.fill(%[[CST]], %[[ALLOC]])
+// CHECK: linalg.matmul
+// CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]]
+// CHECK-SAME: outs(%[[ALLOC]]
+// CHECK-DAG: %[[RESULT_TILE:.+]] = memref.subview %[[RESULT]][%[[IV0]], %[[IV1]]] [%[[TILESIZE_Y]], %[[TILESIZE_X]]]
+// CHECK: linalg.copy(%[[ALLOC]], %[[RESULT_TILE]])
+// CHECK: memref.dealloc %[[ALLOC]]
diff --git a/iree/compiler/Codegen/Passes.h b/iree/compiler/Codegen/Passes.h
index d550517..51795ac 100644
--- a/iree/compiler/Codegen/Passes.h
+++ b/iree/compiler/Codegen/Passes.h
@@ -10,6 +10,7 @@
#include <memory>
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassOptions.h"
@@ -39,6 +40,9 @@
void addLinalgBufferizePasses(
OpPassManager &passManager,
WorkgroupMemoryAllocationFn allocationFn = nullptr);
+void addIREEComprehensiveBufferizePasses(
+ OpPassManager &passManager,
+ linalg::AllocationCallbacks allocationFn = linalg::AllocationCallbacks());
/// Pass to perform canonicalizations/cleanups related to HAL interface/buffer
/// allocations and view operations.
@@ -72,6 +76,8 @@
/// and default memory space.
std::unique_ptr<OperationPass<FuncOp>> createLinalgBufferizePass(
WorkgroupMemoryAllocationFn allocationFn = nullptr);
+std::unique_ptr<OperationPass<ModuleOp>> createIREEComprehensiveBufferizePass(
+ linalg::AllocationCallbacks = linalg::AllocationCallbacks());
/// Creates a pass to vectorize a very specific form of linalg.conv ops.
std::unique_ptr<OperationPass<FuncOp>> createLinalgToVectorVectorizeConvPass();
diff --git a/iree/compiler/Codegen/Passes.td b/iree/compiler/Codegen/Passes.td
index 83e376f..6e650fd 100644
--- a/iree/compiler/Codegen/Passes.td
+++ b/iree/compiler/Codegen/Passes.td
@@ -52,6 +52,12 @@
let constructor = "mlir::iree_compiler::createLinalgBufferizePass(nullptr)";
}
+def IREEComprehensiveBufferize :
+ Pass<"iree-codegen-iree-comprehensive-bufferize", "ModuleOp"> {
+ let summary = "Convert from to Linalg ops on tensors to buffers";
+ let constructor = "mlir::iree_compiler::createIREEComprehensiveBufferizePass()";
+}
+
def OptimizeVectorTransfer :
Pass<"iree-codegen-optimize-vector-transfer", "FuncOp"> {
let summary =