Merge pull request #9817 from iree-org/benvanik-util-buffer
Implementing `util.buffer` conversions and lowering to `vm.buffer`.
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/BUILD b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/BUILD
index 538a160..6fb092a 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/BUILD
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/BUILD
@@ -17,7 +17,6 @@
srcs = [
"ConvertShapeOps.cpp",
"ConvertStandardToHAL.cpp",
- "ConvertStructuralOps.cpp",
],
hdrs = [
"ConvertStandardToHAL.h",
@@ -30,11 +29,9 @@
"//compiler/src/iree/compiler/Dialect/Util/Conversion",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithmeticDialect",
- "@llvm-project//mlir:ControlFlowDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
- "@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:ShapeDialect",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:Transforms",
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/CMakeLists.txt
index a3e0ed1..996a111 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/CMakeLists.txt
@@ -18,15 +18,12 @@
SRCS
"ConvertShapeOps.cpp"
"ConvertStandardToHAL.cpp"
- "ConvertStructuralOps.cpp"
DEPS
LLVMSupport
MLIRArithmeticDialect
- MLIRControlFlowDialect
MLIRFuncDialect
MLIRIR
MLIRPass
- MLIRSCFDialect
MLIRShapeDialect
MLIRTensorDialect
MLIRTransforms
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/ConvertStandardToHAL.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/ConvertStandardToHAL.cpp
index 37f65ef..e968c7b 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/ConvertStandardToHAL.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/ConvertStandardToHAL.cpp
@@ -22,19 +22,12 @@
RewritePatternSet &patterns,
TypeConverter &converter);
-void populateStandardStructuralToHALPatterns(MLIRContext *context,
- ConversionTarget &conversionTarget,
- RewritePatternSet &patterns,
- TypeConverter &converter);
-
void populateStandardToHALPatterns(MLIRContext *context,
ConversionTarget &conversionTarget,
TypeConverter &typeConverter,
RewritePatternSet &patterns) {
populateStandardShapeToHALPatterns(context, conversionTarget, patterns,
typeConverter);
- populateStandardStructuralToHALPatterns(context, conversionTarget, patterns,
- typeConverter);
}
} // namespace iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/ConvertStructuralOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/ConvertStructuralOps.cpp
deleted file mode 100644
index c000956..0000000
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/ConvertStructuralOps.cpp
+++ /dev/null
@@ -1,214 +0,0 @@
-// Copyright 2020 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree/compiler/Dialect/HAL/Conversion/StandardToHAL/ConvertStandardToHAL.h"
-#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
-#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
-#include "iree/compiler/Dialect/Util/Conversion/ConversionPatterns.h"
-#include "llvm/ADT/DenseMap.h"
-#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
-#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/BlockAndValueMapping.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/Matchers.h"
-#include "mlir/Transforms/DialectConversion.h"
-
-namespace mlir {
-namespace iree_compiler {
-namespace {
-
-class FuncOpSignatureConversion
- : public OpConversionPattern<mlir::func::FuncOp> {
- public:
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult matchAndRewrite(
- mlir::func::FuncOp funcOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto &typeConverter = *getTypeConverter();
-
- // Convert the input signature types.
- // TODO(benvanik): dynamic shapes by passing in tensor dynamic dims.
- auto originalType = funcOp.getFunctionType();
- TypeConverter::SignatureConversion newSignature(
- originalType.getNumInputs());
- for (auto argType : llvm::enumerate(originalType.getInputs())) {
- if (failed(typeConverter.convertSignatureArg(
- argType.index(), argType.value(), newSignature))) {
- return rewriter.notifyMatchFailure(funcOp,
- "failed to convert arg type");
- }
- }
- SmallVector<Type, 4> newResultTypes;
- if (failed(typeConverter.convertTypes(originalType.getResults(),
- newResultTypes))) {
- return rewriter.notifyMatchFailure(funcOp,
- "failed to convert result type");
- }
-
- // Replace function.
- auto newFuncOp = rewriter.cloneWithoutRegions(funcOp);
- newFuncOp.getBlocks().clear();
- rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
- newFuncOp.end());
- newFuncOp.setType(rewriter.getFunctionType(newSignature.getConvertedTypes(),
- newResultTypes));
- if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter,
- &newSignature))) {
- return failure();
- }
-
- rewriter.eraseOp(funcOp);
- return success();
- }
-};
-
-class CallOpConversion : public OpConversionPattern<mlir::func::CallOp> {
- public:
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult matchAndRewrite(
- mlir::func::CallOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- SmallVector<Type, 4> resultTypes;
- if (failed(getTypeConverter()->convertTypes(op.getResultTypes(),
- resultTypes))) {
- return rewriter.notifyMatchFailure(op, "unable to convert result types");
- }
- rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
- op, resultTypes, op.getCallee(), adaptor.operands());
- return success();
- }
-};
-
-class ReturnOpConversion : public OpConversionPattern<mlir::func::ReturnOp> {
- public:
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult matchAndRewrite(
- mlir::func::ReturnOp returnOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(returnOp,
- adaptor.operands());
- return success();
- }
-};
-
-class BranchOpConversion : public OpConversionPattern<mlir::cf::BranchOp> {
- public:
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult matchAndRewrite(
- mlir::cf::BranchOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<mlir::cf::BranchOp>(op, op.getDest(),
- adaptor.getDestOperands());
- return success();
- }
-};
-
-class CondBranchOpConversion
- : public OpConversionPattern<mlir::cf::CondBranchOp> {
- public:
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult matchAndRewrite(
- mlir::cf::CondBranchOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<mlir::cf::CondBranchOp>(
- op, adaptor.getCondition(), op.getTrueDest(),
- adaptor.getTrueDestOperands(), op.getFalseDest(),
- adaptor.getFalseDestOperands());
- return success();
- }
-};
-
-class SelectOpConversion : public OpConversionPattern<mlir::arith::SelectOp> {
- public:
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult matchAndRewrite(
- mlir::arith::SelectOp selectOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<mlir::arith::SelectOp>(
- selectOp, adaptor.getCondition(), adaptor.getTrueValue(),
- adaptor.getFalseValue());
- return success();
- }
-};
-
-struct ConvertIfOp : public OpConversionPattern<scf::IfOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult matchAndRewrite(
- scf::IfOp ifOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto resultTypes = llvm::to_vector<4>(llvm::map_range(
- ifOp.getResultTypes(),
- [&](Type type) { return getTypeConverter()->convertType(type); }));
- auto newOp = rewriter.create<scf::IfOp>(ifOp.getLoc(), resultTypes,
- adaptor.getCondition(),
- ifOp.elseBlock() != nullptr);
- rewriter.inlineRegionBefore(ifOp.getThenRegion(), newOp.getThenRegion(),
- newOp.getThenRegion().end());
- rewriter.eraseBlock(&newOp.getThenRegion().front());
- if (ifOp.elseBlock()) {
- rewriter.inlineRegionBefore(ifOp.getElseRegion(), newOp.getElseRegion(),
- newOp.getElseRegion().end());
- rewriter.eraseBlock(&newOp.getElseRegion().front());
- }
- rewriter.replaceOp(ifOp, newOp.getResults());
- return success();
- }
-};
-
-struct ConvertYieldOp : public OpConversionPattern<scf::YieldOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult matchAndRewrite(
- scf::YieldOp yieldOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, adaptor.getResults());
- return success();
- }
-};
-
-} // namespace
-
-void populateStandardStructuralToHALPatterns(MLIRContext *context,
- ConversionTarget &conversionTarget,
- RewritePatternSet &patterns,
- TypeConverter &typeConverter) {
- conversionTarget.addLegalOp<mlir::ModuleOp>();
-
- // We need to rewrite certain types on operands/results so use the default
- // dynamic legality checker to force any ops using such types to run through
- // our patterns.
- conversionTarget.addDynamicallyLegalOp<mlir::func::FuncOp>(
- [&](mlir::func::FuncOp op) {
- return typeConverter.isSignatureLegal(op.getFunctionType()) &&
- typeConverter.isLegal(&op.getBody());
- });
- addGenericLegalOp<func::CallOp>(conversionTarget, typeConverter);
- addGenericLegalOp<func::ReturnOp>(conversionTarget, typeConverter);
- addGenericLegalOp<cf::BranchOp>(conversionTarget, typeConverter);
- addGenericLegalOp<cf::CondBranchOp>(conversionTarget, typeConverter);
- addGenericLegalOp<arith::SelectOp>(conversionTarget, typeConverter);
- patterns
- .insert<FuncOpSignatureConversion, CallOpConversion, ReturnOpConversion,
- BranchOpConversion, CondBranchOpConversion, SelectOpConversion>(
- typeConverter, context);
-
- // TODO(benvanik): move to general utils conversion.
- addGenericLegalOp<scf::IfOp>(conversionTarget, typeConverter);
- addGenericLegalOp<scf::YieldOp>(conversionTarget, typeConverter);
- patterns.insert<ConvertIfOp, ConvertYieldOp>(typeConverter, context);
-}
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/test/BUILD b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/test/BUILD
index 06cbeb3..f62c215 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/test/BUILD
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/test/BUILD
@@ -18,7 +18,6 @@
srcs = enforce_glob(
[
"shape_ops.mlir",
- "structural_ops.mlir",
],
include = ["*.mlir"],
),
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/test/CMakeLists.txt
index e1f8b88..28444ed 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/test/CMakeLists.txt
@@ -15,7 +15,6 @@
lit
SRCS
"shape_ops.mlir"
- "structural_ops.mlir"
TOOLS
FileCheck
iree-opt
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/test/structural_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/test/structural_ops.mlir
deleted file mode 100644
index 5b5307e..0000000
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/test/structural_ops.mlir
+++ /dev/null
@@ -1,80 +0,0 @@
-// RUN: iree-opt --split-input-file --iree-hal-conversion %s | FileCheck %s
-
-// These patterns are not doing anything HAL-specific and instead just allowing
-// for the ops to update their types during dialect conversions. These should be
-// moved to a general utility location or really become something upstream that
-// can be reused.
-
-// CHECK-LABEL: @funcOp
-// CHECK-SAME: (%[[ARG0:.+]]: !hal.buffer_view) -> !hal.buffer_view
-func.func @funcOp(%arg0: tensor<4x2xf32>) -> tensor<4x2xf32> {
- // CHECK: return %[[ARG0]] : !hal.buffer_view
- return %arg0 : tensor<4x2xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @callOp
-// CHECK-SAME: (%[[ARG0:.+]]: !hal.buffer_view) -> !hal.buffer_view
-func.func @callOp(%arg0: tensor<4x2xf32>) -> tensor<4x2xf32> {
- // CHECK: %[[RET0:.+]] = call @extern(%[[ARG0]]) : (!hal.buffer_view) -> !hal.buffer_view
- %ret0 = call @extern(%arg0) : (tensor<4x2xf32>) -> tensor<4x2xf32>
- // CHECK: return %[[RET0]] : !hal.buffer_view
- return %ret0 : tensor<4x2xf32>
-}
-// CHECK: func.func private @extern(!hal.buffer_view) -> !hal.buffer_view
-func.func private @extern(tensor<4x2xf32>) -> tensor<4x2xf32>
-
-// -----
-
-// CHECK-LABEL: @brOp
-// CHECK-SAME: (%[[ARG0:.+]]: !hal.buffer_view) -> !hal.buffer_view
-func.func @brOp(%arg0: tensor<4x2xf32>) -> tensor<4x2xf32> {
- // CHECK: cf.br ^bb1(%[[ARG0]] : !hal.buffer_view)
- cf.br ^bb1(%arg0 : tensor<4x2xf32>)
-// CHECK: ^bb1(%[[BB1_ARG0:.+]]: !hal.buffer_view):
-^bb1(%bb1_arg0: tensor<4x2xf32>):
- // CHECK: return %[[BB1_ARG0]] : !hal.buffer_view
- return %bb1_arg0 : tensor<4x2xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @condBrOp
-// CHECK-SAME: (%[[COND:.+]]: i1, %[[ARG0:.+]]: !hal.buffer_view, %[[ARG1:.+]]: !hal.buffer_view) -> !hal.buffer_view
-func.func @condBrOp(%cond: i1, %arg0: tensor<4x2xf32>, %arg1: tensor<4x2xf32>) -> tensor<4x2xf32> {
- // CHECK: cf.cond_br %[[COND]], ^bb1(%[[ARG0]] : !hal.buffer_view), ^bb1(%[[ARG1]] : !hal.buffer_view)
- cf.cond_br %cond, ^bb1(%arg0 : tensor<4x2xf32>), ^bb1(%arg1 : tensor<4x2xf32>)
-// CHECK: ^bb1(%[[BB1_ARG0:.+]]: !hal.buffer_view):
-^bb1(%bb1_arg0 : tensor<4x2xf32>):
- // CHECK: return %[[BB1_ARG0]] : !hal.buffer_view
- return %bb1_arg0 : tensor<4x2xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @selectOp
-// CHECK-SAME: (%[[COND:.+]]: i1, %[[ARG0:.+]]: !hal.buffer_view, %[[ARG1:.+]]: !hal.buffer_view) -> !hal.buffer_view
-func.func @selectOp(%cond: i1, %arg0: tensor<4x2xf32>, %arg1: tensor<4x2xf32>) -> tensor<4x2xf32> {
- // CHECK: %[[RET0:.+]] = arith.select %[[COND]], %[[ARG0]], %[[ARG1]] : !hal.buffer_view
- %ret0 = arith.select %cond, %arg0, %arg1 : tensor<4x2xf32>
- // CHECK: return %[[RET0]] : !hal.buffer_view
- return %ret0 : tensor<4x2xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @ifOp
-// CHECK-SAME: (%[[COND:.+]]: i1, %[[ARG0:.+]]: !hal.buffer_view, %[[ARG1:.+]]: !hal.buffer_view) -> !hal.buffer_view
-func.func @ifOp(%cond: i1, %arg0: tensor<4x2xf32>, %arg1: tensor<4x2xf32>) -> tensor<4x2xf32> {
- // CHECK: %[[RET0:.+]] = scf.if %[[COND]] -> (!hal.buffer_view)
- %ret0 = scf.if %cond -> (tensor<4x2xf32>) {
- // CHECK: scf.yield %[[ARG0]] : !hal.buffer_view
- scf.yield %arg0 : tensor<4x2xf32>
- } else {
- // CHECK: scf.yield %[[ARG1]] : !hal.buffer_view
- scf.yield %arg1 : tensor<4x2xf32>
- }
- // CHECK: return %[[RET0]] : !hal.buffer_view
- return %ret0 : tensor<4x2xf32>
-}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/UtilToHAL/ConvertUtilToHAL.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/UtilToHAL/ConvertUtilToHAL.cpp
index d645c31..f6b6e15 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/UtilToHAL/ConvertUtilToHAL.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/UtilToHAL/ConvertUtilToHAL.cpp
@@ -58,6 +58,8 @@
populateUtilConversionPatterns(context, conversionTarget, typeConverter,
patterns);
+ populateGenericStructuralConversionPatterns(context, conversionTarget,
+ typeConverter, patterns);
}
} // namespace iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/ConvertToHAL.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/ConvertToHAL.cpp
index 7d3300e..dbb12a2 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/ConvertToHAL.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/ConvertToHAL.cpp
@@ -77,8 +77,6 @@
populateUtilToHALPatterns(context, conversionTarget, typeConverter,
patterns);
- populateUtilConversionPatterns(context, conversionTarget, typeConverter,
- patterns);
populateStandardToHALPatterns(context, conversionTarget, typeConverter,
patterns);
populateStreamToHALPatterns(context, conversionTarget, typeConverter,
diff --git a/compiler/src/iree/compiler/Dialect/Modules/VMVX/Transforms/BUILD b/compiler/src/iree/compiler/Dialect/Modules/VMVX/Transforms/BUILD
index 5380d2a..091714e 100644
--- a/compiler/src/iree/compiler/Dialect/Modules/VMVX/Transforms/BUILD
+++ b/compiler/src/iree/compiler/Dialect/Modules/VMVX/Transforms/BUILD
@@ -59,6 +59,7 @@
"//compiler/src/iree/compiler/Dialect/Modules/VMVX/Conversion/StandardToVMVX",
"//compiler/src/iree/compiler/Dialect/Modules/VMVX/IR",
"//compiler/src/iree/compiler/Dialect/Modules/VMVX/IR:VMVXDialect",
+ "//compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil",
"//compiler/src/iree/compiler/Dialect/Util/IR",
"//compiler/src/iree/compiler/Dialect/Util/Transforms",
"//compiler/src/iree/compiler/Dialect/VM/IR",
diff --git a/compiler/src/iree/compiler/Dialect/Modules/VMVX/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Modules/VMVX/Transforms/CMakeLists.txt
index 1b86e47..e5f121c 100644
--- a/compiler/src/iree/compiler/Dialect/Modules/VMVX/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Modules/VMVX/Transforms/CMakeLists.txt
@@ -79,6 +79,7 @@
iree::compiler::Dialect::Modules::VMVX::Conversion::StandardToVMVX
iree::compiler::Dialect::Modules::VMVX::IR
iree::compiler::Dialect::Modules::VMVX::IR::VMVXDialect
+ iree::compiler::Dialect::Util::Conversion::MemRefToUtil
iree::compiler::Dialect::Util::IR
iree::compiler::Dialect::Util::Transforms
iree::compiler::Dialect::VM::IR
diff --git a/compiler/src/iree/compiler/Dialect/Modules/VMVX/Transforms/Conversion.cpp b/compiler/src/iree/compiler/Dialect/Modules/VMVX/Transforms/Conversion.cpp
index a64e58a..53d6769 100644
--- a/compiler/src/iree/compiler/Dialect/Modules/VMVX/Transforms/Conversion.cpp
+++ b/compiler/src/iree/compiler/Dialect/Modules/VMVX/Transforms/Conversion.cpp
@@ -70,9 +70,9 @@
conversionTarget.addLegalDialect<memref::MemRefDialect>();
conversionTarget.addLegalOp<mlir::UnrealizedConversionCastOp>();
- RewritePatternSet conversionPatterns(&getContext());
- populateHALToVMVXPatterns(context, conversionPatterns, typeConverter);
- populateStandardToVMVXPatterns(context, conversionPatterns, typeConverter);
+ RewritePatternSet patterns(&getContext());
+ populateHALToVMVXPatterns(context, patterns, typeConverter);
+ populateStandardToVMVXPatterns(context, patterns, typeConverter);
// Use the default 64-bit lowering for TOSA's ApplyScale operator:
// This lowering widens integer types to 64-bit an performs the non-fused
@@ -82,10 +82,10 @@
//
// TODO(suderman): remove the TOSA layering violation and lower to standard/
// math ops instead.
- tosa::populateTosaRescaleToArithConversionPatterns(&conversionPatterns);
+ tosa::populateTosaRescaleToArithConversionPatterns(&patterns);
if (failed(applyPartialConversion(getOperation(), conversionTarget,
- std::move(conversionPatterns)))) {
+ std::move(patterns)))) {
getOperation().emitError() << "conversion to the VMVX dialect failed";
return signalPassFailure();
}
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamBase.td b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamBase.td
index a6d3bee..04c3911 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamBase.td
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamBase.td
@@ -303,7 +303,7 @@
//===----------------------------------------------------------------------===//
def Stream_Timepoint : TypeDef<Stream_Dialect, "Timepoint", [
- Util_GlobalTypeInterface,
+ Util_GlobalType,
]> {
let mnemonic = "timepoint";
@@ -376,12 +376,15 @@
def Stream_Resource : TypeDef<Stream_Dialect, "Resource", [
Util_ReferenceType,
Util_SizeAwareType,
- DeclareTypeInterfaceMethods<Util_GlobalTypeInterface, [
+ DeclareTypeInterfaceMethods<Util_GlobalType, [
"isAccessStorageCompatible",
]>,
DeclareTypeInterfaceMethods<Util_InferTypeSize, [
"inferSizeFromValue",
]>,
+ DeclareTypeInterfaceMethods<Util_SubrangeType, [
+ "createSubrangeOp",
+ ]>,
]> {
let mnemonic = "resource";
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td
index d19d596..6a4bccc 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td
@@ -541,6 +541,7 @@
"isMetadata",
]>,
Util_SizeAwareOp,
+ Util_SubrangeOp,
DeclareOpInterfaceMethods<Util_TiedOpInterface, [
"getTiedResult",
"getTiedResultOperandIndex",
@@ -573,6 +574,11 @@
Value getOperandSize(unsigned idx) { return getSourceSize(); }
Value getResultSize(unsigned idx) { return getResultSize(); }
+ Value getSubrangeResource() { return getSource(); }
+ Value getSubrangeResourceSize() { return getSourceSize(); }
+ Value getSubrangeOffset() { return getSourceOffset(); }
+ Value getSubrangeLength() { return getResultSize(); }
+
// Walks up the use-def chain to find a subview op that feeds into |value|.
static IREE::Stream::ResourceSubviewOp findSubviewOp(Value value);
}];
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.cpp
index ed0c2ae..9daa48d 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.cpp
@@ -336,6 +336,14 @@
loc, builder.getIndexType(), value);
}
+Value ResourceType::createSubrangeOp(Location loc, Value resource,
+ Value resourceSize, Value subrangeOffset,
+ Value subrangeLength,
+ OpBuilder &builder) const {
+ return builder.create<IREE::Stream::ResourceSubviewOp>(
+ loc, resource, resourceSize, subrangeOffset, subrangeLength);
+}
+
//===----------------------------------------------------------------------===//
// Dialect registration
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/test/resource_folding.mlir b/compiler/src/iree/compiler/Dialect/Stream/IR/test/resource_folding.mlir
index 0bfed35..aab5d30 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/test/resource_folding.mlir
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/test/resource_folding.mlir
@@ -172,8 +172,8 @@
%c500 = arith.constant 500 : index
// CHECK: %[[RET:.+]] = stream.resource.subview %arg0[%c300] : !stream.resource<*>{%arg1} -> !stream.resource<*>{%c300}
%0 = stream.resource.subview %arg0[%c100] : !stream.resource<*>{%arg1} -> !stream.resource<*>{%c500}
- %1 = stream.resource.subview %0[%c100] : !stream.resource<*>{%arg1} -> !stream.resource<*>{%c400}
- %2 = stream.resource.subview %1[%c100] : !stream.resource<*>{%arg1} -> !stream.resource<*>{%c300}
+ %1 = stream.resource.subview %0[%c100] : !stream.resource<*>{%c500} -> !stream.resource<*>{%c400}
+ %2 = stream.resource.subview %1[%c100] : !stream.resource<*>{%c400} -> !stream.resource<*>{%c300}
// CHECK-NEXT: return %[[RET]]
return %2 : !stream.resource<*>
}
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/BUILD b/compiler/src/iree/compiler/Dialect/Stream/Transforms/BUILD
index 88b0c58..86c84da 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/BUILD
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/BUILD
@@ -32,7 +32,6 @@
"PackDispatchOperands.cpp",
"PassDetail.h",
"Passes.cpp",
- "PropagateSubviews.cpp",
"PropagateTimepoints.cpp",
"RefineUsage.cpp",
"ScheduleAllocation.cpp",
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt
index 769082e..66f2882 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt
@@ -33,7 +33,6 @@
"PackDispatchOperands.cpp"
"PassDetail.h"
"Passes.cpp"
- "PropagateSubviews.cpp"
"PropagateTimepoints.cpp"
"RefineUsage.cpp"
"ScheduleAllocation.cpp"
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/PackConstants.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/PackConstants.cpp
index 1076b73..098495f 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/PackConstants.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/PackConstants.cpp
@@ -502,8 +502,7 @@
auto rodataOp = builder.create<IREE::Util::BufferConstantOp>(
storageResource.loc, builder.getType<IREE::Util::BufferType>(),
storageResource.data,
- builder.getI64IntegerAttr(
- resourceConfig.getMinBufferOffsetAlignment()));
+ builder.getIndexAttr(resourceConfig.getMinBufferOffsetAlignment()));
storageBuffers.push_back(rodataOp);
}
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp
index 50dc695..ab42cd7 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp
@@ -187,7 +187,7 @@
// Propagate subviews throughout the program to unify resource storage access.
// After propagation many resource SSA values can be deduped or folded by the
// cleanup patterns.
- passManager.addPass(IREE::Stream::createPropagateSubviewsPass());
+ passManager.addPass(IREE::Util::createPropagateSubrangesPass());
addCleanupPatterns(passManager);
// TODO(benvanik): outline streams (ala dispatch regions). Note that we may
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.h
index 1f3187a..4a22df1 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.h
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.h
@@ -136,8 +136,6 @@
std::unique_ptr<InterfacePass<CallableOpInterface>> createPackAllocationsPass();
std::unique_ptr<InterfacePass<CallableOpInterface>> createLayoutSlicesPass();
-std::unique_ptr<OperationPass<mlir::ModuleOp>> createPropagateSubviewsPass();
-
//===----------------------------------------------------------------------===//
// Stream memoization
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td
index b5e104e..96a8134 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td
@@ -149,14 +149,6 @@
}];
}
-def PropagateSubviews :
- Pass<"iree-stream-propagate-subviews", "mlir::ModuleOp"> {
- let summary = "Propagates resource subviews throughout the whole program.";
- let constructor = [{
- mlir::iree_compiler::IREE::Stream::createPropagateSubviewsPass()
- }];
-}
-
//===----------------------------------------------------------------------===//
// Stream memoization
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/dump_statistics.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/dump_statistics.mlir
index 70c47d4..1bec315 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/dump_statistics.mlir
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/dump_statistics.mlir
@@ -25,7 +25,7 @@
%c0 = arith.constant 0 : index
%c192 = arith.constant 192 : index
%0 = stream.timepoint.immediate => !stream.timepoint
- %1 = util.buffer.constant {alignment = 32 : i64} : !util.buffer = #util.composite<192xi8, [
+ %1 = util.buffer.constant {alignment = 32 : index} : !util.buffer = #util.composite<192xi8, [
dense<[5, 6, 7, 8]> : tensor<4xi32>,
dense<0> : vector<16xi8>,
dense<[5, 6, 3, 8]> : tensor<4xi32>,
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/pack_constants.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/pack_constants.mlir
index 500d9d5..ee1efc5 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/pack_constants.mlir
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/pack_constants.mlir
@@ -17,7 +17,7 @@
%c8 = arith.constant 8 : index
// Fetch the read-only host data containing the constants.
- // CHECK: %[[RODATA:.+]] = util.buffer.constant {alignment = 64 : i64} : !util.buffer = #composite_of_128b
+ // CHECK: %[[RODATA:.+]] = util.buffer.constant {alignment = 64 : index} : !util.buffer = #composite_of_128b
%0:3 = stream.resource.constants :
!stream.resource<constant>{%c4} = dense<100> : tensor<1xi32>,
!stream.resource<constant>{%c8} = dense<[101, 102]> : tensor<2xi32>
@@ -79,8 +79,8 @@
%c4 = arith.constant 4 : index
%c8 = arith.constant 8 : index
- // CHECK: %[[RODATA0:.+]] = util.buffer.constant {alignment = 16 : i64} : !util.buffer = #composite_of_16b0
- // CHECK: %[[RODATA1:.+]] = util.buffer.constant {alignment = 16 : i64} : !util.buffer = #composite_of_16b1
+ // CHECK: %[[RODATA0:.+]] = util.buffer.constant {alignment = 16 : index} : !util.buffer = #composite_of_16b0
+ // CHECK: %[[RODATA1:.+]] = util.buffer.constant {alignment = 16 : index} : !util.buffer = #composite_of_16b1
%0:3 = stream.resource.constants :
!stream.resource<constant>{%c4} = dense<100> : tensor<1xi32>,
!stream.resource<constant>{%c8} = dense<[101, 102]> : tensor<2xi32>
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/propagate_subviews.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/propagate_subviews.mlir
index 19b9cbc..9c39c7f 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/propagate_subviews.mlir
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/propagate_subviews.mlir
@@ -1,4 +1,6 @@
-// RUN: iree-opt --split-input-file --iree-stream-propagate-subviews %s | FileCheck %s
+// RUN: iree-opt --split-input-file --iree-util-propagate-subranges %s | FileCheck %s
+
+// NOTE: this only tests how the common pass handles !stream.resource types.
// Tests that resource global loads also load all the subview params.
//
diff --git a/compiler/src/iree/compiler/Dialect/Util/Conversion/BUILD b/compiler/src/iree/compiler/Dialect/Util/Conversion/BUILD
index 9dbd5a5..a399b26 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Conversion/BUILD
+++ b/compiler/src/iree/compiler/Dialect/Util/Conversion/BUILD
@@ -23,8 +23,11 @@
deps = [
"//compiler/src/iree/compiler/Dialect/Util/IR",
"@llvm-project//llvm:Support",
+ "@llvm-project//mlir:ArithmeticDialect",
+ "@llvm-project//mlir:ControlFlowDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
+ "@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
],
diff --git a/compiler/src/iree/compiler/Dialect/Util/Conversion/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Util/Conversion/CMakeLists.txt
index c79cb02..94eb865 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Conversion/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Util/Conversion/CMakeLists.txt
@@ -19,8 +19,11 @@
"ConversionPatterns.cpp"
DEPS
LLVMSupport
+ MLIRArithmeticDialect
+ MLIRControlFlowDialect
MLIRFuncDialect
MLIRIR
+ MLIRSCFDialect
MLIRSupport
MLIRTransforms
iree::compiler::Dialect::Util::IR
diff --git a/compiler/src/iree/compiler/Dialect/Util/Conversion/ConversionPatterns.cpp b/compiler/src/iree/compiler/Dialect/Util/Conversion/ConversionPatterns.cpp
index 5eb16a6..061d8d1 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Conversion/ConversionPatterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/Conversion/ConversionPatterns.cpp
@@ -9,7 +9,10 @@
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "llvm/ADT/DenseMap.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
@@ -66,5 +69,179 @@
populateUtilConversionPatterns(context, typeConverter, patterns);
}
+//===----------------------------------------------------------------------===//
+// Structural op conversion
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+struct ConvertFuncOp : public OpConversionPattern<mlir::func::FuncOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ mlir::func::FuncOp funcOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto &typeConverter = *getTypeConverter();
+
+ // Convert the input signature types.
+ // TODO(benvanik): dynamic shapes by passing in tensor dynamic dims.
+ auto originalType = funcOp.getFunctionType();
+ TypeConverter::SignatureConversion newSignature(
+ originalType.getNumInputs());
+ for (auto argType : llvm::enumerate(originalType.getInputs())) {
+ if (failed(typeConverter.convertSignatureArg(
+ argType.index(), argType.value(), newSignature))) {
+ return rewriter.notifyMatchFailure(funcOp,
+ "failed to convert arg type");
+ }
+ }
+ SmallVector<Type, 4> newResultTypes;
+ if (failed(typeConverter.convertTypes(originalType.getResults(),
+ newResultTypes))) {
+ return rewriter.notifyMatchFailure(funcOp,
+ "failed to convert result type");
+ }
+
+ // Replace function.
+ auto newFuncOp = rewriter.cloneWithoutRegions(funcOp);
+ newFuncOp.getBlocks().clear();
+ rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
+ newFuncOp.end());
+ newFuncOp.setType(rewriter.getFunctionType(newSignature.getConvertedTypes(),
+ newResultTypes));
+ if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter,
+ &newSignature))) {
+ return failure();
+ }
+
+ rewriter.eraseOp(funcOp);
+ return success();
+ }
+};
+
+struct ConvertCallOp : public OpConversionPattern<mlir::func::CallOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ mlir::func::CallOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ SmallVector<Type, 4> resultTypes;
+ if (failed(getTypeConverter()->convertTypes(op.getResultTypes(),
+ resultTypes))) {
+ return rewriter.notifyMatchFailure(op, "unable to convert result types");
+ }
+ rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
+ op, resultTypes, op.getCallee(), adaptor.operands());
+ return success();
+ }
+};
+
+struct ConvertReturnOp : public OpConversionPattern<mlir::func::ReturnOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ mlir::func::ReturnOp returnOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(returnOp,
+ adaptor.operands());
+ return success();
+ }
+};
+
+struct ConvertBranchOp : public OpConversionPattern<mlir::cf::BranchOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ mlir::cf::BranchOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<mlir::cf::BranchOp>(op, op.getDest(),
+ adaptor.getDestOperands());
+ return success();
+ }
+};
+
+struct ConvertCondBranchOp
+ : public OpConversionPattern<mlir::cf::CondBranchOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ mlir::cf::CondBranchOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<mlir::cf::CondBranchOp>(
+ op, adaptor.getCondition(), op.getTrueDest(),
+ adaptor.getTrueDestOperands(), op.getFalseDest(),
+ adaptor.getFalseDestOperands());
+ return success();
+ }
+};
+
+struct ConvertSelectOp : public OpConversionPattern<mlir::arith::SelectOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ mlir::arith::SelectOp selectOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<mlir::arith::SelectOp>(
+ selectOp, adaptor.getCondition(), adaptor.getTrueValue(),
+ adaptor.getFalseValue());
+ return success();
+ }
+};
+
+struct ConvertIfOp : public OpConversionPattern<scf::IfOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ scf::IfOp ifOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto resultTypes = llvm::to_vector<4>(llvm::map_range(
+ ifOp.getResultTypes(),
+ [&](Type type) { return getTypeConverter()->convertType(type); }));
+ auto newOp = rewriter.create<scf::IfOp>(ifOp.getLoc(), resultTypes,
+ adaptor.getCondition(),
+ ifOp.elseBlock() != nullptr);
+ rewriter.inlineRegionBefore(ifOp.getThenRegion(), newOp.getThenRegion(),
+ newOp.getThenRegion().end());
+ rewriter.eraseBlock(&newOp.getThenRegion().front());
+ if (ifOp.elseBlock()) {
+ rewriter.inlineRegionBefore(ifOp.getElseRegion(), newOp.getElseRegion(),
+ newOp.getElseRegion().end());
+ rewriter.eraseBlock(&newOp.getElseRegion().front());
+ }
+ rewriter.replaceOp(ifOp, newOp.getResults());
+ return success();
+ }
+};
+
+struct ConvertYieldOp : public OpConversionPattern<scf::YieldOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ scf::YieldOp yieldOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, adaptor.getResults());
+ return success();
+ }
+};
+
+} // namespace
+
+void populateGenericStructuralConversionPatterns(
+ MLIRContext *context, ConversionTarget &conversionTarget,
+ TypeConverter &typeConverter, RewritePatternSet &patterns) {
+ conversionTarget.addLegalOp<mlir::ModuleOp>();
+
+ // We need to rewrite certain types on operands/results so use the default
+ // dynamic legality checker to force any ops using such types to run through
+ // our patterns.
+ conversionTarget.addDynamicallyLegalOp<mlir::func::FuncOp>(
+ [&](mlir::func::FuncOp op) {
+ return typeConverter.isSignatureLegal(op.getFunctionType()) &&
+ typeConverter.isLegal(&op.getBody());
+ });
+ addGenericLegalOp<func::CallOp>(conversionTarget, typeConverter);
+ addGenericLegalOp<func::ReturnOp>(conversionTarget, typeConverter);
+ addGenericLegalOp<cf::BranchOp>(conversionTarget, typeConverter);
+ addGenericLegalOp<cf::CondBranchOp>(conversionTarget, typeConverter);
+ addGenericLegalOp<arith::SelectOp>(conversionTarget, typeConverter);
+ addGenericLegalOp<scf::IfOp>(conversionTarget, typeConverter);
+ addGenericLegalOp<scf::YieldOp>(conversionTarget, typeConverter);
+ patterns.insert<ConvertFuncOp, ConvertCallOp, ConvertReturnOp,
+ ConvertBranchOp, ConvertCondBranchOp, ConvertSelectOp,
+ ConvertIfOp, ConvertYieldOp>(typeConverter, context);
+}
+
} // namespace iree_compiler
} // namespace mlir
diff --git a/compiler/src/iree/compiler/Dialect/Util/Conversion/ConversionPatterns.h b/compiler/src/iree/compiler/Dialect/Util/Conversion/ConversionPatterns.h
index aada228..71763b8 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Conversion/ConversionPatterns.h
+++ b/compiler/src/iree/compiler/Dialect/Util/Conversion/ConversionPatterns.h
@@ -60,6 +60,13 @@
TypeConverter &typeConverter,
RewritePatternSet &patterns);
+// Populates conversion patterns for generic structural ops (func, scf, etc).
+// The ops will be made dynamically legal based on whether all types can be
+// converted using the provided |typeConverter|.
+void populateGenericStructuralConversionPatterns(
+ MLIRContext *context, ConversionTarget &conversionTarget,
+ TypeConverter &typeConverter, RewritePatternSet &patterns);
+
} // namespace iree_compiler
} // namespace mlir
diff --git a/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/BUILD b/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/BUILD
new file mode 100644
index 0000000..070b616
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/BUILD
@@ -0,0 +1,35 @@
+# Copyright 2022 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_library")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_compiler_cc_library(
+ name = "MemRefToUtil",
+ srcs = [
+ "ConvertMemRefToUtil.cpp",
+ ],
+ hdrs = [
+ "ConvertMemRefToUtil.h",
+ ],
+ deps = [
+ "//compiler/src/iree/compiler/Dialect/Util/IR",
+ "@llvm-project//mlir:AffineDialect",
+ "@llvm-project//mlir:ArithmeticDialect",
+ "@llvm-project//mlir:BufferizationDialect",
+ "@llvm-project//mlir:FuncDialect",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:MemRefDialect",
+ "@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:TransformUtils",
+ "@llvm-project//mlir:Transforms",
+ ],
+)
diff --git a/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/CMakeLists.txt
new file mode 100644
index 0000000..f1db35b
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/CMakeLists.txt
@@ -0,0 +1,34 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/BUILD #
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_cc_library(
+ NAME
+ MemRefToUtil
+ HDRS
+ "ConvertMemRefToUtil.h"
+ SRCS
+ "ConvertMemRefToUtil.cpp"
+ DEPS
+ MLIRAffineDialect
+ MLIRArithmeticDialect
+ MLIRBufferizationDialect
+ MLIRFuncDialect
+ MLIRIR
+ MLIRMemRefDialect
+ MLIRPass
+ MLIRTransformUtils
+ MLIRTransforms
+ iree::compiler::Dialect::Util::IR
+ PUBLIC
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/ConvertMemRefToUtil.cpp b/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/ConvertMemRefToUtil.cpp
new file mode 100644
index 0000000..4de1861
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/ConvertMemRefToUtil.cpp
@@ -0,0 +1,277 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/Util/Conversion/MemRefToUtil/ConvertMemRefToUtil.h"
+
+#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
+#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/OperationSupport.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+/// Returns true if the given `type` is a MemRef of rank 0 or 1.
+static bool isRankZeroOrOneMemRef(Type type) {
+ if (auto memrefType = type.dyn_cast<MemRefType>()) {
+ return memrefType.hasRank() && memrefType.getRank() <= 1 &&
+ memrefType.getLayout().isIdentity();
+ }
+ return false;
+}
+
+/// Returns the offset, in bytes, of an index within a linearized dense buffer.
+/// Expects that the |memrefValue| has been linearized already.
+static Value getBufferOffset(Location loc, Value memrefValue,
+ ValueRange indices, Type elementType,
+ ConversionPatternRewriter &rewriter) {
+ auto memrefType = memrefValue.getType().cast<ShapedType>();
+ if (memrefType.getRank() == 0) {
+ // Rank 0 buffers (like memref<i32>) have only a single valid offset at 0.
+ return rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
+ }
+ assert(memrefType.getRank() == 1 && "memrefs should have been flattened");
+
+ // Element type byte length as the base.
+ auto elementSize = rewriter.createOrFold<arith::ConstantIndexOp>(
+ loc, IREE::Util::getRoundedElementByteWidth(elementType));
+
+ // Rank 1 memrefs are just offset by their element width by the offset.
+ auto elementCount = indices.front();
+ return rewriter.create<arith::MulIOp>(loc, elementSize, elementCount);
+}
+
+/// Pattern to lower operations that become a no-ops at this level.
+/// Passes through operands to results.
+template <typename OpTy>
+struct FoldAsNoOp final : public OpConversionPattern<OpTy> {
+ using OpConversionPattern<OpTy>::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ OpTy op, typename OpTy::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOp(op, adaptor.getOperands());
+ return success();
+ }
+};
+
+/// Pattern to lower operations that become a no-ops at this level.
+/// Erases the op entirely.
+template <typename OpTy>
+struct ElideNoOp final : public OpConversionPattern<OpTy> {
+ using OpConversionPattern<OpTy>::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ OpTy op, typename OpTy::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
+struct ConvertMemRefGlobalOp : public OpConversionPattern<memref::GlobalOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ memref::GlobalOp globalOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (!isRankZeroOrOneMemRef(globalOp.getType())) {
+ return rewriter.notifyMatchFailure(
+ globalOp,
+ "only rank-0 and rank-1 memrefs are supported; flatten first");
+ }
+
+ // For mutable values we'd want to either have a RwdataOp or a global
+ // !vm.buffer that we initialized with rodata.
+ if (!globalOp.getConstant()) {
+ return rewriter.notifyMatchFailure(
+ globalOp, "mutable global memrefs not yet implemented");
+ }
+
+ auto newOp = rewriter.replaceOpWithNewOp<IREE::Util::GlobalOp>(
+ globalOp, globalOp.getSymName(), /*isMutable=*/false,
+ rewriter.getType<IREE::Util::BufferType>());
+ newOp.setPrivate();
+
+ auto initializerOp =
+ rewriter.create<IREE::Util::InitializerOp>(globalOp.getLoc());
+ auto initializerBuilder =
+ OpBuilder::atBlockBegin(initializerOp.addEntryBlock());
+ auto alignmentAttr = globalOp.getAlignmentAttr()
+ ? initializerBuilder.getIndexAttr(
+ globalOp.getAlignmentAttr().getInt())
+ : IntegerAttr{};
+ auto constantOp = initializerBuilder.create<IREE::Util::BufferConstantOp>(
+ globalOp.getLoc(), initializerBuilder.getType<IREE::Util::BufferType>(),
+ globalOp.getInitialValueAttr(), alignmentAttr);
+ initializerBuilder.create<IREE::Util::GlobalStoreOp>(
+ globalOp.getLoc(), constantOp.getResult(), newOp.getName());
+ initializerBuilder.create<IREE::Util::InitializerReturnOp>(
+ globalOp.getLoc());
+
+ return success();
+ }
+};
+
+struct ConvertMemRefGetGlobalOp
+ : public OpConversionPattern<memref::GetGlobalOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ memref::GetGlobalOp getOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (!isRankZeroOrOneMemRef(getOp.getResult().getType())) {
+ return rewriter.notifyMatchFailure(
+ getOp, "only rank-0 and rank-1 memrefs are supported; flatten first");
+ }
+ rewriter.replaceOpWithNewOp<IREE::Util::GlobalLoadOp>(
+ getOp, rewriter.getType<IREE::Util::BufferType>(), getOp.getName());
+ return success();
+ }
+};
+
+struct ConvertMemRefAllocaOp : public OpConversionPattern<memref::AllocaOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ memref::AllocaOp allocaOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto type = allocaOp.getType().cast<ShapedType>();
+ if (!type.hasStaticShape()) {
+ return rewriter.notifyMatchFailure(
+ allocaOp, "unable to create buffers for dynamic shapes");
+ }
+ int64_t memRefLength =
+ type.getNumElements() *
+ IREE::Util::getRoundedElementByteWidth(type.getElementType());
+ Value allocationSize = rewriter.create<arith::ConstantIndexOp>(
+ allocaOp.getLoc(), memRefLength);
+ rewriter.replaceOpWithNewOp<IREE::Util::BufferAllocOp>(
+ allocaOp, rewriter.getType<IREE::Util::BufferType>(), allocationSize);
+ return success();
+ }
+};
+
+struct ConvertMemRefDimOp : public OpConversionPattern<memref::DimOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ memref::DimOp dimOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (!isRankZeroOrOneMemRef(dimOp.getSource().getType())) {
+ return rewriter.notifyMatchFailure(
+ dimOp, "only rank-0 and rank-1 memrefs are supported; flatten first");
+ }
+ auto newElementType = getTypeConverter()->convertType(
+ dimOp.getSource().getType().cast<MemRefType>().getElementType());
+ if (!newElementType) {
+ return rewriter.notifyMatchFailure(dimOp, "unsupported element type");
+ }
+ Value elementSize = rewriter.create<arith::ConstantIndexOp>(
+ dimOp.getLoc(), IREE::Util::getRoundedElementByteWidth(newElementType));
+ Value bufferSize = rewriter.create<IREE::Util::BufferSizeOp>(
+ dimOp.getLoc(), rewriter.getIndexType(), adaptor.getSource());
+ rewriter.replaceOpWithNewOp<arith::FloorDivSIOp>(dimOp, bufferSize,
+ elementSize);
+ return success();
+ }
+};
+
+struct ConvertMemRefLoadOp : public OpConversionPattern<memref::LoadOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ memref::LoadOp loadOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (!isRankZeroOrOneMemRef(loadOp.getMemref().getType())) {
+ return rewriter.notifyMatchFailure(
+ loadOp,
+ "only rank-0 and rank-1 memrefs are supported; flatten first");
+ }
+ auto oldType = loadOp.getResult().getType();
+ auto newType = getTypeConverter()->convertType(oldType);
+ auto newElementType = getTypeConverter()->convertType(
+ loadOp.getMemRef().getType().cast<MemRefType>().getElementType());
+ if (!newElementType) {
+ return rewriter.notifyMatchFailure(loadOp, "unsupported element type");
+ }
+ auto memRefSize = rewriter.createOrFold<IREE::Util::BufferSizeOp>(
+ loadOp.getLoc(), rewriter.getIndexType(), adaptor.getMemref());
+ auto byteOffset =
+ getBufferOffset(loadOp.getLoc(), loadOp.getMemref(),
+ loadOp.getIndices(), newElementType, rewriter);
+ rewriter.replaceOpWithNewOp<IREE::Util::BufferLoadOp>(
+ loadOp, newType, adaptor.getMemref(), memRefSize, byteOffset);
+ return success();
+ }
+};
+
+struct ConvertMemRefStoreOp : public OpConversionPattern<memref::StoreOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ memref::StoreOp storeOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (!isRankZeroOrOneMemRef(storeOp.getMemref().getType())) {
+ return rewriter.notifyMatchFailure(
+ storeOp,
+ "only rank-0 and rank-1 memrefs are supported; flatten first");
+ }
+ auto newElementType = getTypeConverter()->convertType(
+ storeOp.getMemRef().getType().cast<MemRefType>().getElementType());
+ if (!newElementType) {
+ return rewriter.notifyMatchFailure(storeOp, "unsupported element type");
+ }
+ auto memRefSize = rewriter.createOrFold<IREE::Util::BufferSizeOp>(
+ storeOp.getLoc(), rewriter.getIndexType(), adaptor.getMemref());
+ auto byteOffset =
+ getBufferOffset(storeOp.getLoc(), storeOp.getMemref(),
+ storeOp.getIndices(), newElementType, rewriter);
+ rewriter.replaceOpWithNewOp<IREE::Util::BufferStoreOp>(
+ storeOp, adaptor.getValue(), adaptor.getMemref(), memRefSize,
+ byteOffset);
+ return success();
+ }
+};
+
+} // namespace
+
+void populateMemRefToUtilPatterns(MLIRContext *context,
+ ConversionTarget &conversionTarget,
+ TypeConverter &typeConverter,
+ RewritePatternSet &patterns) {
+ conversionTarget.addIllegalDialect<memref::MemRefDialect>();
+
+ typeConverter.addConversion([&](MemRefType type) -> llvm::Optional<Type> {
+ if (isRankZeroOrOneMemRef(type)) {
+ return IREE::Util::BufferType::get(type.getContext());
+ }
+ return llvm::None;
+ });
+
+ // Unranked memrefs are emitted for library call integration when we just
+ // need void* semantics. An unranked memref is basically just a (pointer,
+ // memory-space, element-type).
+ typeConverter.addConversion(
+ [&](UnrankedMemRefType type) -> llvm::Optional<Type> {
+ return IREE::Util::BufferType::get(type.getContext());
+ });
+
+ patterns
+ .insert<FoldAsNoOp<bufferization::ToMemrefOp>,
+ ElideNoOp<memref::AssumeAlignmentOp>, FoldAsNoOp<memref::CastOp>>(
+ typeConverter, context);
+ patterns.insert<ConvertMemRefGlobalOp, ConvertMemRefGetGlobalOp,
+ ConvertMemRefAllocaOp, ConvertMemRefDimOp,
+ ConvertMemRefLoadOp, ConvertMemRefStoreOp>(typeConverter,
+ context);
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/ConvertMemRefToUtil.h b/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/ConvertMemRefToUtil.h
new file mode 100644
index 0000000..b7d991b
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/ConvertMemRefToUtil.h
@@ -0,0 +1,25 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_DIALECT_UTIL_CONVERSION_MEMREFTOUTIL_CONVERTMEMREFTOUTIL_H_
+#define IREE_COMPILER_DIALECT_UTIL_CONVERSION_MEMREFTOUTIL_CONVERTMEMREFTOUTIL_H_
+
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+// Appends memref dialect to vm dialect patterns to the given pattern list.
+void populateMemRefToUtilPatterns(MLIRContext *context,
+ ConversionTarget &conversionTarget,
+ TypeConverter &typeConverter,
+ RewritePatternSet &patterns);
+
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_DIALECT_UTIL_CONVERSION_MEMREFTOUTIL_CONVERTMEMREFTOUTIL_H_
diff --git a/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/test/BUILD b/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/test/BUILD
new file mode 100644
index 0000000..d022190
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/test/BUILD
@@ -0,0 +1,29 @@
+# Copyright 2022 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+load("//build_tools/bazel:iree_lit_test.bzl", "iree_lit_test_suite")
+load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_lit_test_suite(
+ name = "lit",
+ srcs = enforce_glob(
+ [
+ "memref_ops.mlir",
+ ],
+ include = ["*.mlir"],
+ ),
+ cfg = "//compiler:lit.cfg.py",
+ tools = [
+ "//tools:iree-opt",
+ "@llvm-project//llvm:FileCheck",
+ ],
+)
diff --git a/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/test/CMakeLists.txt
new file mode 100644
index 0000000..f18b4d2
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/test/CMakeLists.txt
@@ -0,0 +1,23 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/test/BUILD #
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_lit_test_suite(
+ NAME
+ lit
+ SRCS
+ "memref_ops.mlir"
+ TOOLS
+ FileCheck
+ iree-opt
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/test/memref_ops.mlir b/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/test/memref_ops.mlir
new file mode 100644
index 0000000..0d68e1f
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/test/memref_ops.mlir
@@ -0,0 +1,89 @@
+// RUN: iree-opt --split-input-file --iree-util-test-conversion --cse --canonicalize --verify-diagnostics %s | FileCheck %s
+
+// Must be rank-0 or rank-1.
+
+// expected-error @-5 {{conversion to util failed}}
+func.func @verify_invalid_rank_2(%buffer: memref<4x2xf32>, %idx: index) {
+ // expected-error @below {{failed to legalize operation 'memref.load'}}
+ memref.load %buffer[%idx, %idx] : memref<4x2xf32>
+ return
+}
+
+// -----
+
+// Must have an identity map.
+
+#map = affine_map<(d0)[s0] -> (d0 * s0)>
+// expected-error @-6 {{conversion to util failed}}
+func.func @verify_invalid_non_identity_map(%buffer: memref<4xf32, #map>, %idx: index) {
+ // expected-error @below {{failed to legalize operation 'memref.load'}}
+ memref.load %buffer[%idx] : memref<4xf32, #map>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @assume_alignment
+func.func @assume_alignment(%buffer: memref<?xf32>) {
+ // CHECK-NOT: assume_alignment
+ memref.assume_alignment %buffer, 64 : memref<?xf32>
+ func.return
+}
+
+// -----
+
+// CHECK-LABEL: @cast
+func.func @cast(%buffer: memref<?xf32>) -> memref<5xf32> {
+ // CHECK-NOT: memref.cast
+ %0 = memref.cast %buffer : memref<?xf32> to memref<5xf32>
+ // CHECK: return %arg0 : !util.buffer
+ func.return %0 : memref<5xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @alloca() -> !util.buffer
+func.func @alloca() -> memref<16xi32> {
+ // CHECK: %[[ALLOCATION_SIZE:.+]] = arith.constant 64 : index
+ // CHECK: %[[BUFFER:.+]] = util.buffer.alloc uninitialized : !util.buffer{%[[ALLOCATION_SIZE]]}
+ %0 = memref.alloca() : memref<16xi32>
+ // CHECK: return %[[BUFFER]]
+ return %0 : memref<16xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @load_store_f32
+// CHECK-SAME: (%[[BUFFER:.+]]: !util.buffer, %[[IDX0:.+]]: index, %[[IDX1:.+]]: index) -> f32 {
+func.func @load_store_f32(%buffer: memref<?xf32>, %idx0: index, %idx1: index) -> f32 {
+ // CHECK: %[[BUFFER_SIZE:.+]] = util.buffer.size %[[BUFFER]]
+ // CHECK: %[[IDX0_BYTES:.+]] = arith.muli %[[IDX0]], %c4
+ // CHECK: %[[VALUE:.+]] = util.buffer.load %[[BUFFER]][%[[IDX0_BYTES]]] : !util.buffer{%[[BUFFER_SIZE]]} -> f32
+ %0 = memref.load %buffer[%idx0] : memref<?xf32>
+ // CHECK: %[[IDX1_BYTES:.+]] = arith.muli %[[IDX1]], %c4
+ // CHECK: util.buffer.store %[[VALUE]], %[[BUFFER]][%[[IDX1_BYTES]]] : f32 -> !util.buffer{%[[BUFFER_SIZE]]}
+ memref.store %0, %buffer[%idx1] : memref<?xf32>
+ // CHECK: return %[[VALUE]] : f32
+ return %0 : f32
+}
+
+// -----
+
+// CHECK: util.global private @__constant_f32 : !util.buffer
+// CHECK: util.initializer {
+// CHECK: %[[BUFFER:.+]] = util.buffer.constant : !util.buffer = dense<[0.0287729427, 0.0297581609]> : tensor<2xf32>
+// CHECK: util.global.store %[[BUFFER]], @__constant_f32 : !util.buffer
+memref.global "private" constant @__constant_f32 : memref<2xf32> = dense<[0.0287729427, 0.0297581609]>
+
+// CHECK-LABEL: @constant_global_f32
+// CHECK-SAME: (%[[IDX:.+]]: index) -> f32 {
+func.func @constant_global_f32(%idx: index) -> f32 {
+ // CHECK: %[[BUFFER:.+]] = util.global.load @__constant_f32 : !util.buffer
+ %0 = memref.get_global @__constant_f32 : memref<2xf32>
+ // CHECK: %[[BUFFER_SIZE:.+]] = util.buffer.size %[[BUFFER]]
+ // CHECK: %[[IDX_BYTES:.+]] = arith.muli %[[IDX]], %c4
+ // CHECK: %[[VALUE:.+]] = util.buffer.load %[[BUFFER]][%[[IDX_BYTES]]] : !util.buffer{%[[BUFFER_SIZE]]} -> f32
+ %1 = memref.load %0[%idx] : memref<2xf32>
+ // CHECK: return %[[VALUE]] : f32
+ return %1 : f32
+}
diff --git a/compiler/src/iree/compiler/Dialect/Util/Conversion/test/BUILD b/compiler/src/iree/compiler/Dialect/Util/Conversion/test/BUILD
index b038a21..345c131 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Conversion/test/BUILD
+++ b/compiler/src/iree/compiler/Dialect/Util/Conversion/test/BUILD
@@ -16,7 +16,7 @@
iree_lit_test_suite(
name = "lit",
srcs = enforce_glob(
- ["hint_ops.mlir"],
+ ["structural_ops.mlir"],
include = ["*.mlir"],
),
cfg = "//compiler:lit.cfg.py",
diff --git a/compiler/src/iree/compiler/Dialect/Util/Conversion/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Util/Conversion/test/CMakeLists.txt
index 5ab00f7..dd67129 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Conversion/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Util/Conversion/test/CMakeLists.txt
@@ -14,7 +14,7 @@
NAME
lit
SRCS
- "hint_ops.mlir"
+ "structural_ops.mlir"
TOOLS
FileCheck
iree-opt
diff --git a/compiler/src/iree/compiler/Dialect/Util/Conversion/test/hint_ops.mlir b/compiler/src/iree/compiler/Dialect/Util/Conversion/test/hint_ops.mlir
deleted file mode 100644
index dcfa28f..0000000
--- a/compiler/src/iree/compiler/Dialect/Util/Conversion/test/hint_ops.mlir
+++ /dev/null
@@ -1,10 +0,0 @@
-// RUN: iree-opt %s --split-input-file | FileCheck %s
-
-// CHECK-LABEL: @preserve_compiler_hints
-func.func @preserve_compiler_hints() {
- // CHECK: %[[C:.+]] = arith.constant 2
- %c = arith.constant 2 : i32
- // CHECK: util.do_not_optimize(%[[C]])
- util.do_not_optimize(%c) : i32
- return
-}
diff --git a/compiler/src/iree/compiler/Dialect/Util/Conversion/test/structural_ops.mlir b/compiler/src/iree/compiler/Dialect/Util/Conversion/test/structural_ops.mlir
new file mode 100644
index 0000000..23a5933
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Util/Conversion/test/structural_ops.mlir
@@ -0,0 +1,78 @@
+// RUN: iree-opt --split-input-file --iree-util-test-conversion %s | FileCheck %s
+
+// These patterns are not doing anything dialect-specific and instead just
+// allowing for the ops to update their types during dialect conversions.
+
+// CHECK-LABEL: @funcOp
+// CHECK-SAME: (%[[ARG0:.+]]: !util.buffer) -> !util.buffer
+func.func @funcOp(%arg0: memref<?xi8>) -> memref<?xi8> {
+ // CHECK: return %[[ARG0]] : !util.buffer
+ return %arg0 : memref<?xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @callOp
+// CHECK-SAME: (%[[ARG0:.+]]: !util.buffer) -> !util.buffer
+func.func @callOp(%arg0: memref<?xi8>) -> memref<?xi8> {
+ // CHECK: %[[RET0:.+]] = call @extern(%[[ARG0]]) : (!util.buffer) -> !util.buffer
+ %ret0 = call @extern(%arg0) : (memref<?xi8>) -> memref<?xi8>
+ // CHECK: return %[[RET0]] : !util.buffer
+ return %ret0 : memref<?xi8>
+}
+// CHECK: func.func private @extern(!util.buffer) -> !util.buffer
+func.func private @extern(memref<?xi8>) -> memref<?xi8>
+
+// -----
+
+// CHECK-LABEL: @brOp
+// CHECK-SAME: (%[[ARG0:.+]]: !util.buffer) -> !util.buffer
+func.func @brOp(%arg0: memref<?xi8>) -> memref<?xi8> {
+ // CHECK: cf.br ^bb1(%[[ARG0]] : !util.buffer)
+ cf.br ^bb1(%arg0 : memref<?xi8>)
+// CHECK: ^bb1(%[[BB1_ARG0:.+]]: !util.buffer):
+^bb1(%bb1_arg0: memref<?xi8>):
+ // CHECK: return %[[BB1_ARG0]] : !util.buffer
+ return %bb1_arg0 : memref<?xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @condBrOp
+// CHECK-SAME: (%[[COND:.+]]: i1, %[[ARG0:.+]]: !util.buffer, %[[ARG1:.+]]: !util.buffer) -> !util.buffer
+func.func @condBrOp(%cond: i1, %arg0: memref<?xi8>, %arg1: memref<?xi8>) -> memref<?xi8> {
+ // CHECK: cf.cond_br %[[COND]], ^bb1(%[[ARG0]] : !util.buffer), ^bb1(%[[ARG1]] : !util.buffer)
+ cf.cond_br %cond, ^bb1(%arg0 : memref<?xi8>), ^bb1(%arg1 : memref<?xi8>)
+// CHECK: ^bb1(%[[BB1_ARG0:.+]]: !util.buffer):
+^bb1(%bb1_arg0 : memref<?xi8>):
+ // CHECK: return %[[BB1_ARG0]] : !util.buffer
+ return %bb1_arg0 : memref<?xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @selectOp
+// CHECK-SAME: (%[[COND:.+]]: i1, %[[ARG0:.+]]: !util.buffer, %[[ARG1:.+]]: !util.buffer) -> !util.buffer
+func.func @selectOp(%cond: i1, %arg0: memref<?xi8>, %arg1: memref<?xi8>) -> memref<?xi8> {
+ // CHECK: %[[RET0:.+]] = arith.select %[[COND]], %[[ARG0]], %[[ARG1]] : !util.buffer
+ %ret0 = arith.select %cond, %arg0, %arg1 : memref<?xi8>
+ // CHECK: return %[[RET0]] : !util.buffer
+ return %ret0 : memref<?xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @ifOp
+// CHECK-SAME: (%[[COND:.+]]: i1, %[[ARG0:.+]]: !util.buffer, %[[ARG1:.+]]: !util.buffer) -> !util.buffer
+func.func @ifOp(%cond: i1, %arg0: memref<?xi8>, %arg1: memref<?xi8>) -> memref<?xi8> {
+ // CHECK: %[[RET0:.+]] = scf.if %[[COND]] -> (!util.buffer)
+ %ret0 = scf.if %cond -> (memref<?xi8>) {
+ // CHECK: scf.yield %[[ARG0]] : !util.buffer
+ scf.yield %arg0 : memref<?xi8>
+ } else {
+ // CHECK: scf.yield %[[ARG1]] : !util.buffer
+ scf.yield %arg1 : memref<?xi8>
+ }
+ // CHECK: return %[[RET0]] : !util.buffer
+ return %ret0 : memref<?xi8>
+}
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/BUILD b/compiler/src/iree/compiler/Dialect/Util/IR/BUILD
index 2e930db..ca36433 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/BUILD
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/BUILD
@@ -37,6 +37,7 @@
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:SideEffectInterfacesTdFiles",
"@llvm-project//mlir:SubElementInterfacesTdFiles",
+ "@llvm-project//mlir:ViewLikeInterfaceTdFiles",
],
)
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilInterfaces.td b/compiler/src/iree/compiler/Dialect/Util/IR/UtilInterfaces.td
index 58b8e50..af67f0f 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilInterfaces.td
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilInterfaces.td
@@ -590,7 +590,7 @@
// IREE::Util::GlobalTypeInterface
//===----------------------------------------------------------------------===//
-def Util_GlobalTypeInterface : TypeInterface<"GlobalTypeInterface"> {
+def Util_GlobalType : TypeInterface<"GlobalTypeInterface"> {
let cppNamespace = "::mlir::iree_compiler::IREE::Util";
let description = [{
@@ -622,6 +622,85 @@
}
//===----------------------------------------------------------------------===//
+// IREE::Util::Subrange*Interface
+//===----------------------------------------------------------------------===//
+
+def Util_SubrangeType : TypeInterface<"SubrangeTypeInterface"> {
+ let cppNamespace = "::mlir::iree_compiler::IREE::Util";
+
+ let description = [{
+ Interface used on size-aware types that can represent linear subranges as an
+ (offset, length).
+ }];
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Creates an op returning an (offset, length) subrange of a resource.
+ }],
+ /*retTy=*/"Value",
+ /*methodName=*/"createSubrangeOp",
+ /*args=*/(ins
+ "Location":$loc,
+ "Value":$resource,
+ "Value":$resourceSize,
+ "Value":$subrangeOffset,
+ "Value":$subrangeLength,
+ "OpBuilder &":$builder
+ ),
+ /*methodBody=*/[{}]
+ >,
+ ];
+}
+
+def Util_SubrangeOp : OpInterface<"SubrangeOpInterface"> {
+ let cppNamespace = "::mlir::iree_compiler::IREE::Util";
+
+ let description = [{
+ A size-aware operation taking a subrange on a resource.
+ }];
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Returns the size-aware source resource the subrange is taken from.
+ }],
+ /*retTy=*/"Value",
+ /*methodName=*/"getSubrangeResource",
+ /*args=*/(ins),
+ /*methodBody=*/[{}]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Returns the total size of the source resource.
+ }],
+ /*retTy=*/"Value",
+ /*methodName=*/"getSubrangeResourceSize",
+ /*args=*/(ins),
+ /*methodBody=*/[{}]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Returns the offset of the taken subrange into the resource.
+ }],
+ /*retTy=*/"Value",
+ /*methodName=*/"getSubrangeOffset",
+ /*args=*/(ins),
+ /*methodBody=*/[{}]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Returns the length of the taken subrange.
+ }],
+ /*retTy=*/"Value",
+ /*methodName=*/"getSubrangeLength",
+ /*args=*/(ins),
+ /*methodBody=*/[{}]
+ >,
+ ];
+}
+
+//===----------------------------------------------------------------------===//
// IREE::Util::SerializableAttrInterface
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp
index e80b7fc..955be21 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp
@@ -7,6 +7,7 @@
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/EquivalenceClasses.h"
+#include "llvm/ADT/SmallPtrSet.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Attributes.h"
@@ -561,6 +562,474 @@
results.insert<PropagateGlobalStoreAddress>(context);
}
+//===----------------------------------------------------------------------===//
+// util.buffer.alloc
+//===----------------------------------------------------------------------===//
+
+void BufferAllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ // TODO(benvanik): elide if only users are writes and dealloc.
+}
+
+//===----------------------------------------------------------------------===//
+// util.buffer.slice
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+// Folds subspan ranges into slice ranges.
+//
+// Example:
+// %0 = util.buffer.subspan %src[%subspan_offset] ... -> {%subspan_length}
+// %1 = util.buffer.slice %0[%slice_offset] ... -> {%slice_length}
+// ->
+// %new_offset = arith.addi %slice_offset, %subspan_offset
+// %1 = util.buffer.slice %src[%new_offset] ... -> {%slice_length}
+struct FoldSubspansIntoSliceOp : public OpRewritePattern<BufferSliceOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(BufferSliceOp op,
+ PatternRewriter &rewriter) const override {
+ auto subspanOp = BufferSubspanOp::findSubspanOp(op.getSource());
+ if (!subspanOp) return failure();
+ auto fusedLoc = rewriter.getFusedLoc({subspanOp.getLoc(), op.getLoc()});
+ auto newOffset = rewriter.createOrFold<arith::AddIOp>(
+ fusedLoc, subspanOp.getSourceOffset(), op.getSourceOffset());
+ rewriter.updateRootInPlace(op, [&]() {
+ op.getSourceMutable().assign(subspanOp.getSource());
+ op.getSourceSizeMutable().assign(subspanOp.getSourceSize());
+ op.getSourceOffsetMutable().assign(newOffset);
+ });
+ return success();
+ }
+};
+
+} // namespace
+
+void BufferSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.insert<FoldSubspansIntoSliceOp>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// util.buffer.subspan
+//===----------------------------------------------------------------------===//
+
+OpFoldResult BufferSubspanOp::fold(ArrayRef<Attribute> operands) {
+ if (getSourceSize() == getResultSize()) {
+ // Entire range is covered; return it all.
+ return getSource();
+ }
+ return {};
+}
+
+namespace {
+
+// Folds subspan -> subspan to point at the original source buffer with an
+// updated range.
+struct FoldBufferSubspanOps : public OpRewritePattern<BufferSubspanOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(BufferSubspanOp op,
+ PatternRewriter &rewriter) const override {
+ auto parentOp = BufferSubspanOp::findSubspanOp(op.getSource());
+ if (!parentOp) return failure();
+ auto fusedLoc = rewriter.getFusedLoc({parentOp.getLoc(), op.getLoc()});
+ auto newOffset = rewriter.createOrFold<arith::AddIOp>(
+ fusedLoc, parentOp.getSourceOffset(), op.getSourceOffset());
+ auto newOp = rewriter.create<BufferSubspanOp>(
+ fusedLoc, parentOp.getSource(), parentOp.getSourceSize(), newOffset,
+ op.getResultSize());
+ rewriter.replaceOp(op, newOp.getResult());
+ return success();
+ }
+};
+
+// Turns selects of subspans of a buffer into selects of the offset.
+// This only works if the subspan sizes match.
+//
+// Example:
+// %subspan0 = util.buffer.subspan %src[%offset0]
+// %subspan1 = util.buffer.subspan %src[%offset1]
+// %subspan = select %cond, %subspan0, %subspan1 : !util.buffer
+// ->
+// %offset = select %cond, %offset0, %offset1 : index
+// %subspan = util.buffer.subspan %src[%offset]
+struct SinkSubspanAcrossSelectOps
+ : public OpRewritePattern<mlir::arith::SelectOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(mlir::arith::SelectOp op,
+ PatternRewriter &rewriter) const override {
+ if (!op.getType().isa<IREE::Util::BufferType>()) return failure();
+ auto trueSubspan = dyn_cast_or_null<IREE::Util::BufferSubspanOp>(
+ op.getTrueValue().getDefiningOp());
+ auto falseSubspan = dyn_cast_or_null<IREE::Util::BufferSubspanOp>(
+ op.getFalseValue().getDefiningOp());
+ if (!trueSubspan || !falseSubspan) return failure();
+ if (trueSubspan.getSource() != falseSubspan.getSource() ||
+ trueSubspan.getResultSize() != falseSubspan.getResultSize()) {
+ return failure();
+ }
+ auto offsetSelectOp = rewriter.create<mlir::arith::SelectOp>(
+ op.getLoc(), op.getCondition(), trueSubspan.getSourceOffset(),
+ falseSubspan.getSourceOffset());
+ rewriter.replaceOpWithNewOp<IREE::Util::BufferSubspanOp>(
+ op, op.getResult().getType(), trueSubspan.getSource(),
+ trueSubspan.getSourceSize(), offsetSelectOp.getResult(),
+ trueSubspan.getResultSize());
+ return success();
+ }
+};
+
+} // namespace
+
+void BufferSubspanOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.insert<FoldBufferSubspanOps>(context);
+ results.insert<SinkSubspanAcrossSelectOps>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// util.buffer.size
+//===----------------------------------------------------------------------===//
+
+OpFoldResult BufferSizeOp::fold(ArrayRef<Attribute> operands) {
+ // Try to find the size in the use-def chain.
+ // If it's out of the local scope we'll need IPO to help out.
+ auto operand = getOperand();
+ auto sizeAwareType =
+ operand.getType().cast<IREE::Util::SizeAwareTypeInterface>();
+ Operation *op = this->getOperation();
+ if (auto sizeValue = sizeAwareType.findSizeValue(operand, op->getBlock(),
+ Block::iterator(op))) {
+ return sizeValue;
+ }
+
+ // If the source is a constant then we can calculate that immediately.
+ if (auto constantOp = dyn_cast_or_null<IREE::Util::BufferConstantOp>(
+ operand.getDefiningOp())) {
+ if (auto attr =
+ constantOp.getValue()
+ .dyn_cast_or_null<IREE::Util::SerializableAttrInterface>()) {
+ return IntegerAttr::get(IndexType::get(attr.getContext()),
+ attr.getStorageSize());
+ }
+ }
+
+ return {};
+}
+
+namespace {
+
+// Propagates buffer sizes through select ops by selecting on the sizes of the
+// select operands.
+//
+// Example:
+// %a = util.buffer... : !util.buffer{%a_sz}
+// %b = util.buffer... : !util.buffer{%b_sz}
+// %c = select %cond, %a, %b : !util.buffer
+// %c_sz = util.buffer.size %c : !util.buffer
+// ->
+// %c = select %cond, %a, %b : !util.buffer
+// %c_sz = select %cond, %a_sz, %b_sz : index
+struct SelectBufferSizeOp : public OpRewritePattern<BufferSizeOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(BufferSizeOp op,
+ PatternRewriter &rewriter) const override {
+ auto selectOp = op.getOperand().getDefiningOp<mlir::arith::SelectOp>();
+ if (!selectOp) return failure();
+ auto trueSize = rewriter.createOrFold<IREE::Util::BufferSizeOp>(
+ op.getLoc(), selectOp.getTrueValue());
+ auto falseSize = rewriter.createOrFold<IREE::Util::BufferSizeOp>(
+ op.getLoc(), selectOp.getFalseValue());
+ rewriter.replaceOpWithNewOp<mlir::arith::SelectOp>(
+ op, selectOp.getCondition(), trueSize, falseSize);
+ return success();
+ }
+};
+
+} // namespace
+
+void BufferSizeOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.insert<SelectBufferSizeOp>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// util.buffer.storage
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+// Folds subspan ranges into storage ranges.
+//
+// Example:
+// %0 = util.buffer.subspan %src[%subspan_offset] ... -> {%subspan_length}
+// %storage, %offset = util.buffer.storage %0
+// ->
+// %storage, %raw_offset = util.buffer.storage %src
+// %offset = arith.addi %raw_offset, %subspan_offset
+struct FoldSubspansIntoStorageOp : public OpRewritePattern<BufferStorageOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(BufferStorageOp op,
+ PatternRewriter &rewriter) const override {
+ auto subspanOp = BufferSubspanOp::findSubspanOp(op.getOperand());
+ if (!subspanOp) return failure();
+ auto fusedLoc = rewriter.getFusedLoc({subspanOp.getLoc(), op.getLoc()});
+ rewriter.setInsertionPointAfter(op);
+ auto newOffset = rewriter.createOrFold<arith::AddIOp>(
+ fusedLoc, subspanOp.getSourceOffset(), op.getOffset());
+ rewriter.updateRootInPlace(op, [&]() {
+ op.getOperandMutable().assign(subspanOp.getSource());
+ op.getOperandSizeMutable().assign(subspanOp.getSourceSize());
+ SmallPtrSet<Operation *, 2> exceptions;
+ exceptions.insert(op);
+ if (auto newOffsetOp = newOffset.getDefiningOp()) {
+ exceptions.insert(newOffsetOp);
+ }
+ op.getOffset().replaceAllUsesExcept(newOffset, exceptions);
+ });
+ return success();
+ }
+};
+
+} // namespace
+
+void BufferStorageOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.insert<FoldSubspansIntoStorageOp>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// util.buffer.copy
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+// Folds subspan ranges into copy ranges.
+//
+// Example:
+// %0 = util.buffer.subspan %src[%subspan_offset] ... -> {%subspan_length}
+// %1 = util.buffer.subspan %dst[%subspan_offset] ... -> {%subspan_length}
+// util.buffer.copy %0[%offset], %1[%offset], %length
+// ->
+// %new_offset = arith.addi %offset, %subspan_offset
+// util.buffer.copy %src[%new_offset], %dst[%new_offset], %subspan_length
+struct FoldSubspansIntoCopyOp : public OpRewritePattern<BufferCopyOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(BufferCopyOp op,
+ PatternRewriter &rewriter) const override {
+ auto sourceSubspanOp = BufferSubspanOp::findSubspanOp(op.getSource());
+ auto targetSubspanOp = BufferSubspanOp::findSubspanOp(op.getTarget());
+ if (!sourceSubspanOp && !targetSubspanOp) return failure();
+ if (sourceSubspanOp) {
+ auto fusedLoc =
+ rewriter.getFusedLoc({sourceSubspanOp.getLoc(), op.getLoc()});
+ auto newOffset = rewriter.createOrFold<arith::AddIOp>(
+ fusedLoc, sourceSubspanOp.getSourceOffset(), op.getSourceOffset());
+ rewriter.updateRootInPlace(op, [&]() {
+ op.getSourceMutable().assign(sourceSubspanOp.getSource());
+ op.getSourceSizeMutable().assign(sourceSubspanOp.getSourceSize());
+ op.getSourceOffsetMutable().assign(newOffset);
+ });
+ }
+ if (targetSubspanOp) {
+ auto fusedLoc =
+ rewriter.getFusedLoc({targetSubspanOp.getLoc(), op.getLoc()});
+ auto newOffset = rewriter.createOrFold<arith::AddIOp>(
+ fusedLoc, targetSubspanOp.getSourceOffset(), op.getTargetOffset());
+ rewriter.updateRootInPlace(op, [&]() {
+ op.getTargetMutable().assign(targetSubspanOp.getSource());
+ op.getTargetSizeMutable().assign(targetSubspanOp.getSourceSize());
+ op.getTargetOffsetMutable().assign(newOffset);
+ });
+ }
+ return success();
+ }
+};
+
+} // namespace
+
+void BufferCopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.insert<FoldSubspansIntoCopyOp>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// util.buffer.compare
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+// Folds subspan ranges into copy ranges.
+//
+// Example:
+// %0 = util.buffer.subspan %src[%subspan_offset] ... -> {%subspan_length}
+// %1 = util.buffer.subspan %dst[%subspan_offset] ... -> {%subspan_length}
+// util.buffer.copy %0[%offset], %1[%offset], %length
+// ->
+// %new_offset = arith.addi %offset, %subspan_offset
+// util.buffer.copy %src[%new_offset], %dst[%new_offset], %subspan_length
+struct FoldSubspansIntoCompareOp : public OpRewritePattern<BufferCompareOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(BufferCompareOp op,
+ PatternRewriter &rewriter) const override {
+ auto sourceSubspanOp = BufferSubspanOp::findSubspanOp(op.getLhs());
+ auto targetSubspanOp = BufferSubspanOp::findSubspanOp(op.getRhs());
+ if (!sourceSubspanOp && !targetSubspanOp) return failure();
+ if (sourceSubspanOp) {
+ auto fusedLoc =
+ rewriter.getFusedLoc({sourceSubspanOp.getLoc(), op.getLoc()});
+ auto newOffset = rewriter.createOrFold<arith::AddIOp>(
+ fusedLoc, sourceSubspanOp.getSourceOffset(), op.getLhsOffset());
+ rewriter.updateRootInPlace(op, [&]() {
+ op.getLhsMutable().assign(sourceSubspanOp.getSource());
+ op.getLhsSizeMutable().assign(sourceSubspanOp.getSourceSize());
+ op.getLhsOffsetMutable().assign(newOffset);
+ });
+ }
+ if (targetSubspanOp) {
+ auto fusedLoc =
+ rewriter.getFusedLoc({targetSubspanOp.getLoc(), op.getLoc()});
+ auto newOffset = rewriter.createOrFold<arith::AddIOp>(
+ fusedLoc, targetSubspanOp.getSourceOffset(), op.getRhsOffset());
+ rewriter.updateRootInPlace(op, [&]() {
+ op.getRhsMutable().assign(targetSubspanOp.getSource());
+ op.getRhsSizeMutable().assign(targetSubspanOp.getSourceSize());
+ op.getRhsOffsetMutable().assign(newOffset);
+ });
+ }
+ return success();
+ }
+};
+
+} // namespace
+
+void BufferCompareOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.insert<FoldSubspansIntoCompareOp>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// util.buffer.fill
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+// Folds subspan ranges into fill ranges.
+//
+// Example:
+// %0 = util.buffer.subspan %dst[%subspan_offset] ... -> {%subspan_length}
+// util.buffer.fill %cst, %0[%offset for %length]
+// ->
+// %new_offset = arith.addi %offset, %subspan_offset
+// util.buffer.fill %cst, %dst[%new_offset for %subspan_length]
+struct FoldSubspansIntoFillOp : public OpRewritePattern<BufferFillOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(BufferFillOp op,
+ PatternRewriter &rewriter) const override {
+ auto subspanOp = BufferSubspanOp::findSubspanOp(op.getTarget());
+ if (!subspanOp) return failure();
+ auto fusedLoc = rewriter.getFusedLoc({subspanOp.getLoc(), op.getLoc()});
+ auto newOffset = rewriter.createOrFold<arith::AddIOp>(
+ fusedLoc, subspanOp.getSourceOffset(), op.getTargetOffset());
+ rewriter.updateRootInPlace(op, [&]() {
+ op.getTargetMutable().assign(subspanOp.getSource());
+ op.getTargetSizeMutable().assign(subspanOp.getSourceSize());
+ op.getTargetOffsetMutable().assign(newOffset);
+ });
+ return success();
+ }
+};
+
+} // namespace
+
+void BufferFillOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.insert<FoldSubspansIntoFillOp>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// util.buffer.load
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+// Folds subspan offsets into loads.
+//
+// Example:
+// %0 = util.buffer.subspan %src[%subspan_offset] ... -> {%subspan_length}
+// %1 = util.buffer.load %0[%offset]
+// ->
+// %new_offset = arith.addi %offset, %subspan_offset
+// %1 = util.buffer.load %src[%new_offset]
+struct FoldSubspanIntoLoadOp : public OpRewritePattern<BufferLoadOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(BufferLoadOp op,
+ PatternRewriter &rewriter) const override {
+ auto subspanOp = BufferSubspanOp::findSubspanOp(op.getSource());
+ if (!subspanOp) return failure();
+ auto fusedLoc = rewriter.getFusedLoc({subspanOp.getLoc(), op.getLoc()});
+ auto newOffset = rewriter.createOrFold<arith::AddIOp>(
+ fusedLoc, subspanOp.getSourceOffset(), op.getSourceOffset());
+ rewriter.updateRootInPlace(op, [&]() {
+ op.getSourceMutable().assign(subspanOp.getSource());
+ op.getSourceSizeMutable().assign(subspanOp.getSourceSize());
+ op.getSourceOffsetMutable().assign(newOffset);
+ });
+ return success();
+ }
+};
+
+} // namespace
+
+OpFoldResult BufferLoadOp::fold(ArrayRef<Attribute> operands) {
+ // TODO(benvanik): if source is a constant then perform the load.
+ return {};
+}
+
+void BufferLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.insert<FoldSubspanIntoLoadOp>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// util.buffer.store
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+// Folds subspan offsets into stores.
+//
+// Example:
+// %0 = util.buffer.subspan %dst[%subspan_offset] ... -> {%subspan_length}
+// util.buffer.store %c123_i32, %0[%offset]
+// ->
+// %new_offset = arith.addi %offset, %subspan_offset
+// util.buffer.store %c123_i32, %dst[%new_offset]
+struct FoldSubspanIntoStoreOp : public OpRewritePattern<BufferStoreOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(BufferStoreOp op,
+ PatternRewriter &rewriter) const override {
+ auto subspanOp = BufferSubspanOp::findSubspanOp(op.getTarget());
+ if (!subspanOp) return failure();
+ auto fusedLoc = rewriter.getFusedLoc({subspanOp.getLoc(), op.getLoc()});
+ auto newOffset = rewriter.createOrFold<arith::AddIOp>(
+ fusedLoc, subspanOp.getSourceOffset(), op.getTargetOffset());
+ rewriter.updateRootInPlace(op, [&]() {
+ op.getTargetMutable().assign(subspanOp.getSource());
+ op.getTargetSizeMutable().assign(subspanOp.getSourceSize());
+ op.getTargetOffsetMutable().assign(newOffset);
+ });
+ return success();
+ }
+};
+
+} // namespace
+
+void BufferStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.insert<FoldSubspanIntoStoreOp>(context);
+}
+
} // namespace Util
} // namespace IREE
} // namespace iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp
index 6f85b82..77fa2d7 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp
@@ -718,7 +718,7 @@
}
//===----------------------------------------------------------------------===//
-// Structural ops
+// util.initializer
//===----------------------------------------------------------------------===//
void InitializerOp::build(OpBuilder &builder, OperationState &result,
@@ -764,7 +764,7 @@
}
//===----------------------------------------------------------------------===//
-// Globals
+// util.global
//===----------------------------------------------------------------------===//
// Returns true if the given |accessType| is compatible with the |globalType|.
@@ -950,7 +950,7 @@
}
//===----------------------------------------------------------------------===//
-// Lists
+// !util.list<T>
//===----------------------------------------------------------------------===//
static ParseResult parseListTypeGet(OpAsmParser &parser, Type &listType,
@@ -1048,6 +1048,101 @@
return success();
}
+//===----------------------------------------------------------------------===//
+// !util.buffer
+//===----------------------------------------------------------------------===//
+
+void BufferConstantOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "buffer_cst");
+}
+
+LogicalResult BufferConstantOp::verify() {
+ if (!getValue().isa<IREE::Util::SerializableAttrInterface>()) {
+ return emitOpError("unsupported non-serializable constant attribute type");
+ }
+ if (auto minAlignmentAttr = getAlignmentAttr()) {
+ int64_t minAlignment = minAlignmentAttr.getInt();
+ if (minAlignment > 0 && !llvm::isPowerOf2_64(minAlignment)) {
+ return emitOpError("invalid alignment; must be a power of two");
+ }
+ }
+ return success();
+}
+
+void BufferAllocOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "buffer");
+}
+
+LogicalResult BufferAllocOp::verify() {
+ if (auto minAlignmentAttr = getAlignmentAttr()) {
+ int64_t minAlignment = minAlignmentAttr.getInt();
+ if (minAlignment > 0 && !llvm::isPowerOf2_64(minAlignment)) {
+ return emitOpError("invalid alignment; must be a power of two");
+ }
+ }
+ return success();
+}
+
+void BufferSliceOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "buffer");
+}
+
+void BufferSubspanOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "buffer_span");
+}
+
+Value BufferSubspanOp::getViewSource() { return getSource(); }
+
+Value BufferSubspanOp::getTiedResult(unsigned resultIndex) {
+ return IREE::Util::TiedOpInterface::findTiedBaseValue(getSource());
+}
+
+::llvm::Optional<unsigned> BufferSubspanOp::getTiedResultOperandIndex(
+ unsigned resultIndex) {
+ return {0}; // source
+}
+
+SmallVector<int64_t, 4> BufferSubspanOp::getTiedResultOperandIndices() {
+ return {0}; // source
+}
+
+// static
+IREE::Util::BufferSubspanOp BufferSubspanOp::findSubspanOp(Value value) {
+ while (value) {
+ auto *definingOp = value.getDefiningOp();
+ if (!definingOp) {
+ // Defined as a block argument - stop walk.
+ break;
+ } else if (auto subviewOp =
+ dyn_cast<IREE::Util::BufferSubspanOp>(definingOp)) {
+ // Found!
+ return subviewOp;
+ } else if (auto tiedOp =
+ dyn_cast<IREE::Util::TiedOpInterface>(definingOp)) {
+ // Continue walking up through the tied operand.
+ value = tiedOp.getTiedResultOperand(value);
+ } else {
+ break;
+ }
+ }
+ return {};
+}
+
+void BufferSizeOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "buffer_size");
+}
+
+void BufferStorageOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "buffer_storage");
+ setNameFn(getOffset(), "buffer_offset");
+}
+
} // namespace Util
} // namespace IREE
} // namespace iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.h b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.h
index e1128e9..6678386 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.h
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.h
@@ -20,6 +20,7 @@
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Interfaces/ViewLikeInterface.h"
#include "mlir/Transforms/DialectConversion.h"
#define GET_OP_CLASSES
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td
index 852fa74..26815e2 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td
@@ -17,6 +17,7 @@
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/ViewLikeInterface.td"
//===----------------------------------------------------------------------===//
// Op types
@@ -689,8 +690,10 @@
// !util.buffer
//===----------------------------------------------------------------------===//
-def Util_BufferConstantOp : Util_PureOp<"buffer.constant"> {
- let summary = "constant host-side byte buffer";
+def Util_BufferConstantOp : Util_PureOp<"buffer.constant", [
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+]> {
+ let summary = [{constant host-side byte buffer}];
let description = [{
Defines a compile-time byte buffer based on the given attribute value.
The attribute will be serialized into the canonical IREE format for the
@@ -699,15 +702,408 @@
let arguments = (ins
AnyAttr:$value,
- OptionalAttr<I64Attr>:$alignment
+ OptionalAttr<IndexAttr>:$alignment
);
let results = (outs
Util_BufferType:$result
);
let assemblyFormat = [{
- attr-dict `:` type($result) `=` $value
+ attr-dict `:`
+ type($result)
+ `=` $value
}];
+
+ let hasVerifier = 1;
+}
+
+def Util_BufferAllocOp : Util_PureOp<"buffer.alloc", [
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+ MemoryEffects<[MemAlloc]>,
+ Util_SizeAwareOp,
+]> {
+ let summary = [{allocates a buffer with undefined contents}];
+ let description = [{
+ Allocates a buffer with undefined contents. Consumers of the allocated
+ result must assume nothing of the contents.
+ }];
+
+ let arguments = (ins
+ Util_Size:$storage_size,
+ OptionalAttr<IndexAttr>:$alignment
+ );
+ let results = (outs
+ Util_BufferType:$result
+ );
+
+ let assemblyFormat = [{
+ `uninitialized`
+ attr-dict
+ `:`
+ type($result) `` `{` $storage_size `}`
+ }];
+
+ let extraClassDeclaration = [{
+ Value getOperandSize(unsigned idx) { return {}; }
+ Value getResultSize(unsigned idx) { return getStorageSize(); }
+ }];
+
+ let hasVerifier = 1;
+ let hasCanonicalizer = 1;
+}
+
+def Util_BufferDeallocOp : Util_PureOp<"buffer.dealloc", [
+ MemoryEffects<[MemFree]>,
+ Util_SizeAwareOp,
+]> {
+ let summary = [{deallocates a buffer}];
+ let description = [{
+ Hints that the buffer contents can be discarded. Buffers are reference
+ counted and other owners may keep it live beyond the dealloc.
+ }];
+
+ let arguments = (ins
+ Util_BufferType:$operand,
+ Util_Size:$operand_size
+ );
+
+ let assemblyFormat = [{
+ $operand `:` type($operand) `{` $operand_size `}`
+ attr-dict-with-keyword
+ }];
+
+ let extraClassDeclaration = [{
+ Value getOperandSize(unsigned idx) { return getOperandSize(); }
+ Value getResultSize(unsigned idx) { return {}; }
+ }];
+}
+
+def Util_BufferSliceOp : Util_PureOp<"buffer.slice", [
+ AllTypesMatch<["source", "result"]>,
+ MemoryEffects<[MemAlloc, MemRead]>,
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+ Util_SizeAwareOp,
+]> {
+ let summary = [{clones a subregion of a buffer}];
+ let description = [{
+ Returns a copy of the contents from the source buffer.
+ }];
+
+ let arguments = (ins
+ Util_BufferType:$source,
+ Util_Size:$source_size,
+ Util_Offset:$source_offset,
+ Util_Size:$result_size,
+ OptionalAttr<IndexAttr>:$alignment
+ );
+ let results = (outs
+ Util_BufferType:$result
+ );
+
+ let assemblyFormat = [{
+ $source `[` $source_offset `]` attr-dict `:`
+ type($source) `` `{` $source_size `}` `->`
+ type($result) `` `{` $result_size `}`
+ }];
+
+ let extraClassDeclaration = [{
+ Value getOperandSize(unsigned idx) { return getSourceSize(); }
+ Value getResultSize(unsigned idx) { return getResultSize(); }
+ }];
+
+ let hasCanonicalizer = 1;
+}
+
+def Util_BufferSubspanOp : Util_PureOp<"buffer.subspan", [
+ AllTypesMatch<["source", "result"]>,
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+ DeclareOpInterfaceMethods<ViewLikeOpInterface>,
+ Util_SizeAwareOp,
+ Util_SubrangeOp,
+ DeclareOpInterfaceMethods<Util_TiedOpInterface, [
+ "getTiedResult",
+ "getTiedResultOperandIndex",
+ "getTiedResultOperandIndices",
+ ]>,
+]> {
+ let summary = [{returns the buffer storage size in bytes}];
+ let description = [{
+ Returns the length of the buffer in bytes from its base offset.
+ }];
+
+ let arguments = (ins
+ Util_BufferType:$source,
+ Util_Size:$source_size,
+ Util_Offset:$source_offset,
+ Util_Size:$result_size
+ );
+ let results = (outs
+ Util_BufferType:$result
+ );
+
+ let assemblyFormat = [{
+ $source `[` $source_offset `]` `:`
+ type($source) `` `{` $source_size `}` `->`
+ type($result) `` `{` $result_size `}`
+ attr-dict-with-keyword
+ }];
+
+ let extraClassDeclaration = [{
+ Value getOperandSize(unsigned idx) { return getSourceSize(); }
+ Value getResultSize(unsigned idx) { return getResultSize(); }
+
+ Value getSubrangeResource() { return getSource(); }
+ Value getSubrangeResourceSize() { return getSourceSize(); }
+ Value getSubrangeOffset() { return getSourceOffset(); }
+ Value getSubrangeLength() { return getResultSize(); }
+
+ // Walks up the use-def chain to find a subspan op that feeds into |value|.
+ static IREE::Util::BufferSubspanOp findSubspanOp(Value value);
+ }];
+
+ let hasCanonicalizer = 1;
+ let hasFolder = 1;
+}
+
+def Util_BufferSizeOp : Util_PureOp<"buffer.size", [
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+ Util_SizeAwareOp,
+]> {
+ let summary = [{returns the total buffer storage size in bytes}];
+ let description = [{
+ Returns the total length of the buffer in bytes from its base offset.
+ }];
+
+ let arguments = (ins
+ Util_BufferType:$operand
+ );
+ let results = (outs
+ Util_Size:$result
+ );
+
+ let assemblyFormat = [{
+ $operand
+ `:` type($operand)
+ attr-dict-with-keyword
+ }];
+
+ let extraClassDeclaration = [{
+ Value getOperandSize(unsigned idx) { return getResult(); }
+ Value getResultSize(unsigned idx) { return {}; }
+ }];
+
+ let hasCanonicalizer = 1;
+ let hasFolder = 1;
+}
+
+def Util_BufferStorageOp : Util_PureOp<"buffer.storage", [
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+ Util_SizeAwareOp,
+]> {
+ let summary = [{returns the underlying buffer storage range}];
+ let description = [{
+ Returns the buffer storage as a memref that must be offset and restricted to
+ the returned range. The memref may be of any type and the user is
+ responsible for ensuring that the reinterpret_cast-like behavior makes sense
+ for the data they are accessing.
+ }];
+
+ let arguments = (ins
+ Util_BufferType:$operand,
+ Util_Size:$operand_size
+ );
+ let results = (outs
+ AnyMemRef:$result,
+ Util_Offset:$offset
+ );
+
+ let assemblyFormat = [{
+ $operand
+ `:` type($operand) `` `{` $operand_size `}` `->` `(` type($result) `,` type($offset) `)`
+ attr-dict-with-keyword
+ }];
+
+ let extraClassDeclaration = [{
+ Value getOperandSize(unsigned idx) { return getOperandSize(); }
+ Value getResultSize(unsigned idx) { return {}; }
+ }];
+
+ let hasCanonicalizer = 1;
+}
+
+def Util_BufferCopyOp : Util_Op<"buffer.copy", [
+ MemoryEffects<[MemRead, MemWrite]>,
+ Util_SizeAwareOp,
+]> {
+ let summary = [{copies a range of bytes between buffers}];
+ let description = [{
+ Copies a range of bytes as with memcpy (no overlapping).
+ }];
+
+ let arguments = (ins
+ Util_BufferType:$source,
+ Util_Size:$source_size,
+ Util_Offset:$source_offset,
+ Util_BufferType:$target,
+ Util_Size:$target_size,
+ Util_Offset:$target_offset,
+ Util_Size:$length
+ );
+
+ let assemblyFormat = [{
+ $source `[` $source_offset `]` `,`
+ $target `[` $target_offset `]` `,`
+ $length `:`
+ type($source) `` `{` $source_size `}` `->`
+ type($target) `` `{` $target_size `}`
+ attr-dict-with-keyword
+ }];
+
+ let extraClassDeclaration = [{
+ Value getOperandSize(unsigned idx) { return idx == 0 ? getSourceSize() : getTargetSize(); }
+ Value getResultSize(unsigned idx) { return {}; }
+ }];
+
+ let hasCanonicalizer = 1;
+}
+
+def Util_BufferCompareOp : Util_PureOp<"buffer.compare", [
+ MemoryEffects<[MemRead]>,
+ Util_SizeAwareOp,
+]> {
+ let summary = [{compares a range of two buffers}];
+ let description = [{
+ Returns true if the two ranges are bitwise equivalent, somewhat like memcmp.
+ }];
+
+ let arguments = (ins
+ Util_BufferType:$lhs,
+ Util_Size:$lhs_size,
+ Util_Offset:$lhs_offset,
+ Util_BufferType:$rhs,
+ Util_Size:$rhs_size,
+ Util_Offset:$rhs_offset,
+ Util_Size:$length
+ );
+ let results = (outs
+ I1:$result
+ );
+
+ let assemblyFormat = [{
+ $lhs `[` $lhs_offset `]` `,`
+ $rhs `[` $rhs_offset `]` `,`
+ $length `:`
+ type($lhs) `` `{` $lhs_size `}` `,`
+ type($rhs) `` `{` $rhs_size `}`
+ attr-dict-with-keyword
+ }];
+
+ let extraClassDeclaration = [{
+ Value getOperandSize(unsigned idx) { return idx == 0 ? getLhsSize() : getRhsSize(); }
+ Value getResultSize(unsigned idx) { return {}; }
+ }];
+
+ let hasCanonicalizer = 1;
+}
+
+def Util_BufferFillOp : Util_Op<"buffer.fill", [
+ MemoryEffects<[MemWrite]>,
+ Util_SizeAwareOp,
+]> {
+ let summary = [{fills a range of bytes with a value}];
+ let description = [{
+ Fills the contents of the buffer in the given byte range with a pattern.
+ The offset and length must match the natural alignment of the pattern type.
+ }];
+
+ let arguments = (ins
+ Util_FillPattern:$pattern,
+ Util_BufferType:$target,
+ Util_Size:$target_size,
+ Util_Offset:$target_offset,
+ Util_Size:$length
+ );
+
+ let assemblyFormat = [{
+ $pattern `,`
+ $target `[` $target_offset `for` $length `]` `:`
+ type($pattern) `->`
+ type($target) `` `{` $target_size `}`
+ attr-dict-with-keyword
+ }];
+
+ let extraClassDeclaration = [{
+ Value getOperandSize(unsigned idx) { return getTargetSize(); }
+ Value getResultSize(unsigned idx) { return {}; }
+ }];
+
+ let hasCanonicalizer = 1;
+}
+
+def Util_BufferLoadOp : Util_Op<"buffer.load", [
+ MemoryEffects<[MemRead]>,
+ Util_SizeAwareOp,
+]> {
+ let summary = [{loads a value from a buffer}];
+ let description = [{
+ Loads a value at a byte offset. Must be aligned to the natural size of the
+ result type.
+ }];
+
+ let arguments = (ins
+ Util_BufferType:$source,
+ Util_Size:$source_size,
+ Util_Offset:$source_offset
+ );
+ let results = (outs
+ Util_Primitive:$result
+ );
+
+ let assemblyFormat = [{
+ $source `[` $source_offset `]`
+ `:` type($source) `` `{` $source_size `}` `->` type($result)
+ attr-dict-with-keyword
+ }];
+
+ let extraClassDeclaration = [{
+ Value getOperandSize(unsigned idx) { return getSourceSize(); }
+ Value getResultSize(unsigned idx) { return {}; }
+ }];
+
+ let hasCanonicalizer = 1;
+ let hasFolder = 1;
+}
+
+def Util_BufferStoreOp : Util_Op<"buffer.store", [
+ MemoryEffects<[MemWrite]>,
+ Util_SizeAwareOp,
+]> {
+ let summary = [{stores a value into a buffer}];
+ let description = [{
+ Stores a value at a byte offset. Must be aligned to the natural size of the
+ source type.
+ }];
+
+ let arguments = (ins
+ Util_Primitive:$source,
+ Util_BufferType:$target,
+ Util_Size:$target_size,
+ Util_Offset:$target_offset
+ );
+
+ let assemblyFormat = [{
+ $source `,`
+ $target `[` $target_offset `]`
+ `:` type($source) `->` type($target) `` `{` $target_size `}`
+ attr-dict-with-keyword
+ }];
+
+ let extraClassDeclaration = [{
+ Value getOperandSize(unsigned idx) { return getTargetSize(); }
+ Value getResultSize(unsigned idx) { return {}; }
+ }];
+
+ let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.cpp
index c159014..b7c2410 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.cpp
@@ -7,6 +7,7 @@
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
+#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "llvm/ADT/BitVector.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -29,7 +30,29 @@
namespace Util {
//===----------------------------------------------------------------------===//
-// ListType
+// !util.buffer
+//===----------------------------------------------------------------------===//
+
+bool BufferType::isAccessStorageCompatible(Type accessType) const {
+ return accessType.isa<IREE::Util::BufferType>();
+}
+
+Value BufferType::inferSizeFromValue(Location loc, Value value,
+ OpBuilder &builder) const {
+ return builder.createOrFold<IREE::Util::BufferSizeOp>(
+ loc, builder.getIndexType(), value);
+}
+
+Value BufferType::createSubrangeOp(Location loc, Value resource,
+ Value resourceSize, Value subrangeOffset,
+ Value subrangeLength,
+ OpBuilder &builder) const {
+ return builder.create<IREE::Util::BufferSubspanOp>(
+ loc, resource, resourceSize, subrangeOffset, subrangeLength);
+}
+
+//===----------------------------------------------------------------------===//
+// !util.list<T>
//===----------------------------------------------------------------------===//
static LogicalResult parseListElementType(AsmParser &parser,
@@ -77,7 +100,7 @@
}
//===----------------------------------------------------------------------===//
-// PtrType
+// !util.ptr<T>
//===----------------------------------------------------------------------===//
// static
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.td b/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.td
index 539d9dc..e8895db 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.td
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.td
@@ -16,14 +16,16 @@
def Util_BufferType : TypeDef<Util_Dialect, "Buffer", [
Util_ReferenceType,
- // TODO(benvanik): make buffer size-aware.
- // Util_SizeAwareType,
- // DeclareTypeInterfaceMethods<Util_GlobalTypeInterface, [
- // "isAccessStorageCompatible",
- // ]>,
- // DeclareTypeInterfaceMethods<Util_InferTypeSize, [
- // "inferSizeFromValue",
- // ]>,
+ Util_SizeAwareType,
+ DeclareTypeInterfaceMethods<Util_GlobalType, [
+ "isAccessStorageCompatible",
+ ]>,
+ DeclareTypeInterfaceMethods<Util_InferTypeSize, [
+ "inferSizeFromValue",
+ ]>,
+ DeclareTypeInterfaceMethods<Util_SubrangeType, [
+ "createSubrangeOp",
+ ]>,
]> {
let mnemonic = "buffer";
@@ -32,8 +34,6 @@
A reference counted byte buffer that models a pointer, offset, and length.
}];
- let parameters = (ins);
-
let builders = [
TypeBuilder<(ins), [{
return $_get($_ctxt);
@@ -48,9 +48,9 @@
def Util_ListType : TypeDef<Util_Dialect, "List"> {
let mnemonic = "list";
- let summary = [{a pointer-like reference}];
+ let summary = [{dense list container type}];
let description = [{
- // DO NOT SUBMIT
+ Typed container supporting variant storage.
}];
let parameters = (ins
@@ -101,7 +101,8 @@
let summary = [{a pointer-like reference}];
let description = [{
- // DO NOT SUBMIT
+ A typed indirect reference to a value. These define a runtime addressable
+ value that is strongly referenced.
}];
let parameters = (ins
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/test/BUILD b/compiler/src/iree/compiler/Dialect/Util/IR/test/BUILD
index a907e2d..19f1d7b 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/test/BUILD
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/test/BUILD
@@ -20,6 +20,7 @@
"alignment_folding.mlir",
"alignment_ops.mlir",
"attributes.mlir",
+ "buffer_folding.mlir",
"buffer_ops.mlir",
"global_folding.mlir",
"global_ops.mlir",
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Util/IR/test/CMakeLists.txt
index b9afc4e..966b885 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/test/CMakeLists.txt
@@ -17,6 +17,7 @@
"alignment_folding.mlir"
"alignment_ops.mlir"
"attributes.mlir"
+ "buffer_folding.mlir"
"buffer_ops.mlir"
"global_folding.mlir"
"global_ops.mlir"
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/test/buffer_folding.mlir b/compiler/src/iree/compiler/Dialect/Util/IR/test/buffer_folding.mlir
new file mode 100644
index 0000000..5014d31
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/test/buffer_folding.mlir
@@ -0,0 +1,186 @@
+// RUN: iree-opt --split-input-file --canonicalize %s | iree-opt --split-input-file | FileCheck %s
+
+// CHECK-LABEL: @FoldSubspansIntoSliceOp
+func.func @FoldSubspansIntoSliceOp(%arg0: !util.buffer, %arg1: index, %arg2: index, %arg3: index) -> !util.buffer {
+ %c100 = arith.constant 100 : index
+ %c200 = arith.constant 200 : index
+ // CHECK: %[[OFFSET:.+]] = arith.addi %arg2, %c100
+ %0 = util.buffer.subspan %arg0[%arg2] : !util.buffer{%arg1} -> !util.buffer{%arg3}
+ // CHECK: util.buffer.slice %arg0[%[[OFFSET]]] : !util.buffer{%arg1} -> !util.buffer{%c200}
+ %1 = util.buffer.slice %0[%c100] : !util.buffer{%arg3} -> !util.buffer{%c200}
+ return %1 : !util.buffer
+}
+
+// -----
+
+// CHECK-LABEL: @FoldBufferSubspanOp
+func.func @FoldBufferSubspanOp(%arg0: !util.buffer, %arg1: index, %arg2: index) -> !util.buffer {
+ // CHECK-NOT: util.buffer.subspan
+ %0 = util.buffer.subspan %arg0[%arg1] : !util.buffer{%arg2} -> !util.buffer{%arg2}
+ // CHECK: return %arg0
+ return %0 : !util.buffer
+}
+
+// -----
+
+// CHECK-LABEL: @FoldBufferSubspanOps
+func.func @FoldBufferSubspanOps(%arg0: !util.buffer, %arg1: index) -> !util.buffer {
+ %c100 = arith.constant 100 : index
+ %c300 = arith.constant 300 : index
+ %c400 = arith.constant 400 : index
+ %c500 = arith.constant 500 : index
+ // CHECK: %[[RET:.+]] = util.buffer.subspan %arg0[%c300] : !util.buffer{%arg1} -> !util.buffer{%c300}
+ %0 = util.buffer.subspan %arg0[%c100] : !util.buffer{%arg1} -> !util.buffer{%c500}
+ %1 = util.buffer.subspan %0[%c100] : !util.buffer{%c500} -> !util.buffer{%c400}
+ %2 = util.buffer.subspan %1[%c100] : !util.buffer{%c400} -> !util.buffer{%c300}
+ // CHECK: return %[[RET]]
+ return %2 : !util.buffer
+}
+
+// -----
+
+// CHECK-LABEL: @SinkSubspanAcrossSelectOps
+func.func @SinkSubspanAcrossSelectOps(%arg0: !util.buffer, %arg1: i1) -> !util.buffer {
+ %c0 = arith.constant 0 : index
+ %c128 = arith.constant 128 : index
+ %c256 = arith.constant 256 : index
+ // CHECK-NOT: util.buffer.subspan
+ %0 = util.buffer.subspan %arg0[%c0] : !util.buffer{%c256} -> !util.buffer{%c128}
+ // CHECK-NOT: util.buffer.subspan
+ %1 = util.buffer.subspan %arg0[%c128] : !util.buffer{%c256} -> !util.buffer{%c128}
+ // CHECK: %[[OFFSET:.+]] = arith.select %arg1, %c0, %c128 : index
+ %2 = arith.select %arg1, %0, %1 : !util.buffer
+ // CHECK-NEXT: %[[SUBSPAN:.+]] = util.buffer.subspan %arg0[%[[OFFSET]]] : !util.buffer{%c256} -> !util.buffer{%c128}
+ // CHECK-NEXT: return %[[SUBSPAN]]
+ return %2 : !util.buffer
+}
+
+// -----
+
+// CHECK-LABEL: @FoldBufferSizeOp
+func.func @FoldBufferSizeOp(%arg0: !util.buffer, %arg1: index) -> (index, i32) {
+ %c0 = arith.constant 0 : index
+ // CHECK-NOT: util.buffer.size
+ %0 = util.buffer.size %arg0 : !util.buffer
+ // CHECK: %[[LOAD:.+]] = util.buffer.load
+ %1 = util.buffer.load %arg0[%c0] : !util.buffer{%arg1} -> i32
+ // CHECK: return %arg1, %[[LOAD]]
+ return %0, %1 : index, i32
+}
+
+// -----
+
+// CHECK-LABEL: @FoldConstantBufferSizeOp
+func.func @FoldConstantBufferSizeOp() -> index {
+ // CHECK-NOT: util.buffer.constant
+ %0 = util.buffer.constant : !util.buffer = dense<[1, 2, 3]> : tensor<3xi32>
+ // CHECK-NOT: util.buffer.size
+ %1 = util.buffer.size %0 : !util.buffer
+ // CHECK: return %c12
+ return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: @SelectBufferSizeOp
+func.func @SelectBufferSizeOp(%arg0: !util.buffer, %arg1: index, %arg2: !util.buffer, %arg3: index, %arg4: i1) -> (!util.buffer, index) {
+ %c0 = arith.constant 0 : index
+ // CHECK: %[[ARG0_T:.+]] = util.buffer.slice %arg0[%c0] : !util.buffer{%[[ARG0_SZ:.+]]} ->
+ %0 = util.buffer.slice %arg0[%c0] : !util.buffer{%arg1} -> !util.buffer{%arg1}
+ // CHECK: %[[ARG2_T:.+]] = util.buffer.slice %arg2[%c0] : !util.buffer{%[[ARG2_SZ:.+]]} ->
+ %1 = util.buffer.slice %arg2[%c0] : !util.buffer{%arg3} -> !util.buffer{%arg3}
+ // CHECK: %[[RET_T:.+]] = arith.select %arg4, %[[ARG0_T]], %[[ARG2_T]] : !util.buffer
+ %2 = arith.select %arg4, %0, %1 : !util.buffer
+ // CHECK: %[[RET_SIZE:.+]] = arith.select %arg4, %[[ARG0_SZ]], %[[ARG2_SZ]] : index
+ %3 = util.buffer.size %2 : !util.buffer
+ // CHECK: = util.buffer.slice %[[RET_T]][%c0] : !util.buffer{%[[RET_SIZE]]} ->
+ %4 = util.buffer.slice %2[%c0] : !util.buffer{%3} -> !util.buffer{%3}
+ return %4, %3 : !util.buffer, index
+}
+
+// -----
+
+// CHECK-LABEL: @FoldSubspansIntoStorageOp
+func.func @FoldSubspansIntoStorageOp(%arg0: !util.buffer, %arg1: index, %arg2: index, %arg3: index) -> (memref<?xi8>, index) {
+ // CHECK-NOT: util.buffer.subspan
+ %0 = util.buffer.subspan %arg0[%arg2] : !util.buffer{%arg1} -> !util.buffer{%arg3}
+ // CHECK: %[[STORAGE:.+]], %[[OFFSET:.+]] = util.buffer.storage %arg0 : !util.buffer{%arg1} -> (memref<?xi8>, index)
+ %1:2 = util.buffer.storage %0 : !util.buffer{%arg3} -> (memref<?xi8>, index)
+ // CHECK: %[[ADJUSTED_OFFSET:.+]] = arith.addi %arg2, %[[OFFSET]]
+ // CHECK: return %[[STORAGE]], %[[ADJUSTED_OFFSET]]
+ return %1#0, %1#1 : memref<?xi8>, index
+}
+
+// -----
+
+// CHECK-LABEL: @FoldSubspansIntoCopyOp
+func.func @FoldSubspansIntoCopyOp(%arg0: !util.buffer, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index) {
+ %c1 = arith.constant 1 : index
+ %c100 = arith.constant 100 : index
+ %c200 = arith.constant 200 : index
+ // CHECK: %[[OFFSET_SRC:.+]] = arith.addi %arg2, %c100
+ %0 = util.buffer.subspan %arg0[%arg2] : !util.buffer{%arg1} -> !util.buffer{%arg3}
+ // CHECK: %[[OFFSET_DST:.+]] = arith.addi %arg4, %c200
+ %1 = util.buffer.subspan %arg0[%arg4] : !util.buffer{%arg1} -> !util.buffer{%arg5}
+ // CHECK: util.buffer.copy %arg0[%[[OFFSET_SRC]]], %arg0[%[[OFFSET_DST]]], %c1 : !util.buffer{%arg1} -> !util.buffer{%arg1}
+ util.buffer.copy %0[%c100], %1[%c200], %c1 : !util.buffer{%arg3} -> !util.buffer{%arg5}
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @FoldSubspansIntoCompareOp
+func.func @FoldSubspansIntoCompareOp(%arg0: !util.buffer, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index) -> i1 {
+ %c1 = arith.constant 1 : index
+ %c100 = arith.constant 100 : index
+ %c200 = arith.constant 200 : index
+ // CHECK: %[[OFFSET_LHS:.+]] = arith.addi %arg2, %c100
+ %0 = util.buffer.subspan %arg0[%arg2] : !util.buffer{%arg1} -> !util.buffer{%arg3}
+ // CHECK: %[[OFFSET_RHS:.+]] = arith.addi %arg4, %c200
+ %1 = util.buffer.subspan %arg0[%arg4] : !util.buffer{%arg1} -> !util.buffer{%arg5}
+ // CHECK: = util.buffer.compare %arg0[%[[OFFSET_LHS]]], %arg0[%[[OFFSET_RHS]]], %c1 : !util.buffer{%arg1}, !util.buffer{%arg1}
+ %2 = util.buffer.compare %0[%c100], %1[%c200], %c1 : !util.buffer{%arg3}, !util.buffer{%arg5}
+ return %2 : i1
+}
+
+// -----
+
+// CHECK-LABEL: @FoldSubspansIntoFillOp
+func.func @FoldSubspansIntoFillOp(%arg0: !util.buffer, %arg1: index, %arg2: i32, %arg3: index, %arg4: index) {
+ %c100 = arith.constant 100 : index
+ %c200 = arith.constant 200 : index
+ // CHECK: %[[OFFSET:.+]] = arith.addi %arg3, %c100
+ %0 = util.buffer.subspan %arg0[%arg3] : !util.buffer{%arg1} -> !util.buffer{%arg4}
+ // CHECK: util.buffer.fill %arg2, %arg0[%[[OFFSET]] for %c200] : i32 -> !util.buffer{%arg1}
+ util.buffer.fill %arg2, %0[%c100 for %c200] : i32 -> !util.buffer{%arg4}
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @FoldSubspanIntoLoadOp
+func.func @FoldSubspanIntoLoadOp(%arg0: !util.buffer, %arg1: index) -> i32 {
+ %c64 = arith.constant 64 : index
+ %c128 = arith.constant 128 : index
+ %c256 = arith.constant 256 : index
+ // CHECK-NOT: util.buffer.subspan
+ %0 = util.buffer.subspan %arg0[%c128] : !util.buffer{%arg1} -> !util.buffer{%c256}
+ // CHECK: = util.buffer.load %arg0[%c192] : !util.buffer{%arg1} -> i32
+ %1 = util.buffer.load %0[%c64] : !util.buffer{%c256} -> i32
+ return %1 : i32
+}
+
+// -----
+
+// CHECK-LABEL: @FoldSubspanIntoStoreOp
+func.func @FoldSubspanIntoStoreOp(%arg0: !util.buffer, %arg1: index) {
+ %c64 = arith.constant 64 : index
+ %c128 = arith.constant 128 : index
+ %c256 = arith.constant 256 : index
+ %c123_i32 = arith.constant 123 : i32
+ // CHECK-NOT: util.buffer.subspan
+ %0 = util.buffer.subspan %arg0[%c128] : !util.buffer{%arg1} -> !util.buffer{%c256}
+ // CHECK: util.buffer.store %c123_i32, %arg0[%c192] : i32 -> !util.buffer{%arg1}
+ util.buffer.store %c123_i32, %0[%c64] : i32 -> !util.buffer{%c256}
+ return
+}
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/test/buffer_ops.mlir b/compiler/src/iree/compiler/Dialect/Util/IR/test/buffer_ops.mlir
index 7c463b8..b98c7df 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/test/buffer_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/test/buffer_ops.mlir
@@ -6,3 +6,112 @@
%0 = util.buffer.constant : !util.buffer = dense<[1, 2, 3]> : tensor<3xi32>
return %0 : !util.buffer
}
+
+// -----
+
+// CHECK-LABEL: @buffer_alloc
+func.func @buffer_alloc(%arg0: index) -> !util.buffer {
+ // CHECK: = util.buffer.alloc uninitialized {alignment = 16 : index} : !util.buffer{%arg0}
+ %0 = util.buffer.alloc uninitialized {alignment = 16 : index} : !util.buffer{%arg0}
+ return %0 : !util.buffer
+}
+
+// -----
+
+// CHECK-LABEL: @buffer_dealloc
+func.func @buffer_dealloc(%arg0: !util.buffer, %arg1: index) {
+ // CHECK: util.buffer.dealloc %arg0 : !util.buffer{%arg1}
+ util.buffer.dealloc %arg0 : !util.buffer{%arg1}
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @buffer_slice
+func.func @buffer_slice(%arg0: !util.buffer, %arg1: index, %arg2: index, %arg3: index) -> !util.buffer {
+ // CHECK: = util.buffer.slice %arg0[%arg1] : !util.buffer{%arg2} -> !util.buffer{%arg3}
+ %0 = util.buffer.slice %arg0[%arg1] : !util.buffer{%arg2} -> !util.buffer{%arg3}
+ return %0 : !util.buffer
+}
+
+// -----
+
+// CHECK-LABEL: @buffer_subspan
+func.func @buffer_subspan(%arg0: !util.buffer, %arg1: index, %arg2: index, %arg3: index) -> !util.buffer {
+ // CHECK: = util.buffer.subspan %arg0[%arg1] : !util.buffer{%arg2} -> !util.buffer{%arg3}
+ %0 = util.buffer.subspan %arg0[%arg1] : !util.buffer{%arg2} -> !util.buffer{%arg3}
+ return %0 : !util.buffer
+}
+
+// -----
+
+// CHECK-LABEL: @buffer_size
+func.func @buffer_size(%arg0: !util.buffer) -> index {
+ // CHECK: = util.buffer.size %arg0 : !util.buffer
+ %0 = util.buffer.size %arg0 : !util.buffer
+ return %0 : index
+}
+
+// -----
+
+// CHECK-LABEL: @buffer_storage
+func.func @buffer_storage(%arg0: !util.buffer, %arg1: index) -> (memref<?xi8>, index) {
+ // CHECK: = util.buffer.storage %arg0 : !util.buffer{%arg1} -> (memref<?xi8>, index)
+ %0, %1 = util.buffer.storage %arg0 : !util.buffer{%arg1} -> (memref<?xi8>, index)
+ return %0, %1 : memref<?xi8>, index
+}
+
+// -----
+
+// CHECK-LABEL: @buffer_copy
+func.func @buffer_copy(%arg0: !util.buffer, %arg1: index) {
+ %c1 = arith.constant 1 : index
+ %c100 = arith.constant 100 : index
+ %c200 = arith.constant 200 : index
+ // CHECK: util.buffer.copy %arg0[%c100], %arg0[%c200], %c1 : !util.buffer{%arg1} -> !util.buffer{%arg1}
+ util.buffer.copy %arg0[%c100], %arg0[%c200], %c1 : !util.buffer{%arg1} -> !util.buffer{%arg1}
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @buffer_compare
+func.func @buffer_compare(%arg0: !util.buffer, %arg1: index) -> i1 {
+ %c1 = arith.constant 1 : index
+ %c100 = arith.constant 100 : index
+ %c200 = arith.constant 200 : index
+ // CHECK: = util.buffer.compare %arg0[%c100], %arg0[%c200], %c1 : !util.buffer{%arg1}, !util.buffer{%arg1}
+ %0 = util.buffer.compare %arg0[%c100], %arg0[%c200], %c1 : !util.buffer{%arg1}, !util.buffer{%arg1}
+ return %0 : i1
+}
+
+// -----
+
+// CHECK-LABEL: @buffer_fill
+func.func @buffer_fill(%arg0: !util.buffer, %arg1: index, %arg2: i32) {
+ %c100 = arith.constant 100 : index
+ %c200 = arith.constant 200 : index
+ // CHECK: util.buffer.fill %arg2, %arg0[%c100 for %c200] : i32 -> !util.buffer{%arg1}
+ util.buffer.fill %arg2, %arg0[%c100 for %c200] : i32 -> !util.buffer{%arg1}
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @buffer_load
+func.func @buffer_load(%arg0: !util.buffer, %arg1: index) -> i32 {
+ %c100 = arith.constant 100 : index
+ // CHECK: = util.buffer.load %arg0[%c100] : !util.buffer{%arg1} -> i32
+ %0 = util.buffer.load %arg0[%c100] : !util.buffer{%arg1} -> i32
+ return %0 : i32
+}
+
+// -----
+
+// CHECK-LABEL: @buffer_store
+func.func @buffer_store(%arg0: !util.buffer, %arg1: index, %arg2: i32) {
+ %c100 = arith.constant 100 : index
+ // CHECK: util.buffer.store %arg2, %arg0[%c100] : i32 -> !util.buffer{%arg1}
+ util.buffer.store %arg2, %arg0[%c100] : i32 -> !util.buffer{%arg1}
+ return
+}
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/BUILD b/compiler/src/iree/compiler/Dialect/Util/Transforms/BUILD
index f99209c..01954ef 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/BUILD
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/BUILD
@@ -24,8 +24,10 @@
"FuseGlobals.cpp",
"HoistIntoGlobals.cpp",
"Patterns.cpp",
+ "PropagateSubrange.cpp",
"SimplifyGlobalAccesses.cpp",
"StripDebugOps.cpp",
+ "TestConversion.cpp",
"TestFloatRangeAnalysis.cpp",
],
hdrs = [
@@ -37,13 +39,20 @@
"//compiler/src/iree/compiler/Dialect/Util/Analysis/Attributes",
"//compiler/src/iree/compiler/Dialect/Util/Analysis/Constant",
"//compiler/src/iree/compiler/Dialect/Util/Analysis/DFX",
+ "//compiler/src/iree/compiler/Dialect/Util/Conversion",
+ "//compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil",
"//compiler/src/iree/compiler/Dialect/Util/IR",
+ "//compiler/src/iree/compiler/Utils",
"@llvm-project//llvm:Support",
+ "@llvm-project//mlir:AffineDialect",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:ArithmeticDialect",
+ "@llvm-project//mlir:ArithmeticTransforms",
"@llvm-project//mlir:ControlFlowDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
+ "@llvm-project//mlir:MathDialect",
+ "@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Util/Transforms/CMakeLists.txt
index f34d173..8c53010 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/CMakeLists.txt
@@ -26,16 +26,22 @@
"FuseGlobals.cpp"
"HoistIntoGlobals.cpp"
"Patterns.cpp"
+ "PropagateSubrange.cpp"
"SimplifyGlobalAccesses.cpp"
"StripDebugOps.cpp"
+ "TestConversion.cpp"
"TestFloatRangeAnalysis.cpp"
DEPS
LLVMSupport
+ MLIRAffineDialect
MLIRAnalysis
MLIRArithmeticDialect
+ MLIRArithmeticTransforms
MLIRControlFlowDialect
MLIRFuncDialect
MLIRIR
+ MLIRMathDialect
+ MLIRMemRefDialect
MLIRPass
MLIRSupport
MLIRTransforms
@@ -43,7 +49,10 @@
iree::compiler::Dialect::Util::Analysis::Attributes
iree::compiler::Dialect::Util::Analysis::Constant
iree::compiler::Dialect::Util::Analysis::DFX
+ iree::compiler::Dialect::Util::Conversion
+ iree::compiler::Dialect::Util::Conversion::MemRefToUtil
iree::compiler::Dialect::Util::IR
+ iree::compiler::Utils
PUBLIC
)
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.h
index c3b37e4..7fbf271 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.h
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.h
@@ -23,6 +23,7 @@
std::unique_ptr<OperationPass<mlir::ModuleOp>> createFoldGlobalsPass();
std::unique_ptr<OperationPass<mlir::ModuleOp>> createFuseGlobalsPass();
std::unique_ptr<OperationPass<mlir::ModuleOp>> createHoistIntoGlobalsPass();
+std::unique_ptr<OperationPass<mlir::ModuleOp>> createPropagateSubrangesPass();
std::unique_ptr<OperationPass<void>> createSimplifyGlobalAccessesPass();
std::unique_ptr<OperationPass<void>> createStripDebugOpsPass();
@@ -33,7 +34,8 @@
std::unique_ptr<OperationPass<mlir::ModuleOp>> createPromoteF16ToF32Pass();
// Test passes.
-std::unique_ptr<OperationPass<void>> createTestFloatRangeAnalysis();
+std::unique_ptr<OperationPass<void>> createTestConversionPass();
+std::unique_ptr<OperationPass<void>> createTestFloatRangeAnalysisPass();
// Register all Passes
// TODO: Switch this directory to declarative registration.
@@ -45,6 +47,7 @@
createFoldGlobalsPass();
createFuseGlobalsPass();
createHoistIntoGlobalsPass();
+ createPropagateSubrangesPass();
createSimplifyGlobalAccessesPass();
createStripDebugOpsPass();
@@ -53,7 +56,8 @@
createDemoteF64ToF32Pass();
createPromoteF16ToF32Pass();
- createTestFloatRangeAnalysis();
+ createTestConversionPass();
+ createTestFloatRangeAnalysisPass();
}
} // namespace Util
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/PropagateSubviews.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/PropagateSubrange.cpp
similarity index 65%
rename from compiler/src/iree/compiler/Dialect/Stream/Transforms/PropagateSubviews.cpp
rename to compiler/src/iree/compiler/Dialect/Util/Transforms/PropagateSubrange.cpp
index 26c1198..47046f0 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/PropagateSubviews.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/PropagateSubrange.cpp
@@ -6,12 +6,9 @@
#include <utility>
-#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h"
-#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
-#include "iree/compiler/Dialect/Stream/Transforms/PassDetail.h"
-#include "iree/compiler/Dialect/Stream/Transforms/Passes.h"
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
+#include "iree/compiler/Dialect/Util/Transforms/Passes.h"
#include "iree/compiler/Dialect/Util/Transforms/Patterns.h"
#include "iree/compiler/Utils/IndexSet.h"
#include "llvm/ADT/BreadthFirstIterator.h"
@@ -27,18 +24,19 @@
#include "mlir/IR/Dominance.h"
#include "mlir/Pass/Pass.h"
-#define DEBUG_TYPE "iree-stream-propagate-subviews"
+#define DEBUG_TYPE "iree-util-propagate-subranges"
namespace mlir {
namespace iree_compiler {
namespace IREE {
-namespace Stream {
+namespace Util {
namespace {
-// TODO(benvanik): factor out into a generic util pass base that lets us share
-// with other expanded type propagation passes. The walking of
-// functions/blocks/globals/etc are the same across all of them and only the
-// exact type expansion and consumption/query ops differ.
+// This pass is paired with the subrange type. Any type implementing the
+// interface can be used.
+static bool isResourceType(Type type) {
+ return type.isa<IREE::Util::SubrangeTypeInterface>();
+}
//===----------------------------------------------------------------------===//
// Global handling
@@ -47,13 +45,13 @@
struct ExpandedGlobal {
IREE::Util::GlobalOp resourceOp;
IREE::Util::GlobalOp resourceSizeOp;
- IREE::Util::GlobalOp subviewOffsetOp;
- IREE::Util::GlobalOp subviewLengthOp;
+ IREE::Util::GlobalOp subrangeOffsetOp;
+ IREE::Util::GlobalOp subrangeLengthOp;
};
using ExpandedGlobalMap = DenseMap<StringRef, ExpandedGlobal>;
// Expands each !stream.resource global in |rootOp| to have a matching
-// parent resource size and subview range. Does not behave optimally if there
+// parent resource size and subrange range. Does not behave optimally if there
// already exist offset globals as duplicates will get added and we'll need to
// rely on global fusion to get rid of them. Note that this only expands globals
// and does not yet update use sites - we just need the ops to reference.
@@ -63,7 +61,7 @@
// Gather all of the resource globals in the root.
for (auto ®ion : rootOp->getRegions()) {
for (auto globalOp : region.getOps<IREE::Util::GlobalOp>()) {
- if (!globalOp.getType().isa<IREE::Stream::ResourceType>()) continue;
+ if (!isResourceType(globalOp.getType())) continue;
expandedGlobals[globalOp.getName()].resourceOp = globalOp;
}
}
@@ -90,7 +88,7 @@
/*isMutable=*/true, indexType);
offsetOp.setVisibility(global.resourceOp.getVisibility());
symbolTable.insert(offsetOp);
- global.subviewOffsetOp = offsetOp;
+ global.subrangeOffsetOp = offsetOp;
auto lengthName = (global.resourceOp.getName() + "__length").str();
auto lengthOp = builder.create<IREE::Util::GlobalOp>(
@@ -98,7 +96,7 @@
/*isMutable=*/true, indexType);
lengthOp.setVisibility(global.resourceOp.getVisibility());
symbolTable.insert(lengthOp);
- global.subviewLengthOp = lengthOp;
+ global.subrangeLengthOp = lengthOp;
}
return expandedGlobals;
@@ -108,10 +106,6 @@
// Structural IR rewriting patterns
//===----------------------------------------------------------------------===//
-static bool isResourceType(Type type) {
- return type.isa<IREE::Stream::ResourceType>();
-}
-
// Returns true if an operands or results of |op| use !stream.resources.
static bool usesResources(Operation *op) {
return llvm::any_of(op->getOperandTypes(), isResourceType) ||
@@ -129,69 +123,73 @@
newTypes.push_back(type);
if (isResourceType(type)) {
newTypes.push_back(indexType); // resource size
- newTypes.push_back(indexType); // subview offset
- newTypes.push_back(indexType); // subview length
+ newTypes.push_back(indexType); // subrange offset
+ newTypes.push_back(indexType); // subrange length
}
}
return newTypes;
}
-struct Subview {
+struct Subrange {
Value resource;
Value resourceSize;
- Value subviewOffset;
- Value subviewLength;
+ Value subrangeOffset;
+ Value subrangeLength;
+ IREE::Util::SubrangeTypeInterface getResourceType() {
+ return resource.getType().cast<IREE::Util::SubrangeTypeInterface>();
+ }
};
-using SubviewMap = llvm::DenseMap<Value, Subview>;
+using SubrangeMap = llvm::DenseMap<Value, Subrange>;
-// Attempts to find and consume a subview associated with |value|.
-// Returns the subview - which may point at a different resource than |value|.
-// In cases where no associated subview was found the subview will cover the
+// Attempts to find and consume a subrange associated with |value|.
+// Returns the subrange - which may point at a different resource than |value|.
+// In cases where no associated subrange was found the subrange will cover the
// entire resource (offset at 0, length at size).
-static Subview consumeSubview(Location loc, Value value, SubviewMap &subviewMap,
- IndexSet &indexSet, OpBuilder &builder) {
+static Subrange consumeSubrange(Location loc, Value value,
+ SubrangeMap &subrangeMap, IndexSet &indexSet,
+ OpBuilder &builder) {
// TODO(benvanik): follow ties on value to try to consume there; there are a
// few other ops we could look through as well (such as select, where we could
// join). For now we just look at immediate defining ops.
- auto mapIt = subviewMap.find(value);
- if (mapIt != subviewMap.end()) {
+ auto mapIt = subrangeMap.find(value);
+ if (mapIt != subrangeMap.end()) {
return mapIt->second;
}
- if (auto subviewOp = dyn_cast_or_null<IREE::Stream::ResourceSubviewOp>(
+ if (auto subrangeOp = dyn_cast_or_null<IREE::Util::SubrangeOpInterface>(
value.getDefiningOp())) {
- Subview subview;
- subview.resource = subviewOp.getSource();
- subview.resourceSize = subviewOp.getSourceSize();
- subview.subviewOffset = subviewOp.getSourceOffset();
- subview.subviewLength = subviewOp.getResultSize();
- return subview;
+ Subrange subrange;
+ subrange.resource = subrangeOp.getSubrangeResource();
+ subrange.resourceSize = subrangeOp.getSubrangeResourceSize();
+ subrange.subrangeOffset = subrangeOp.getSubrangeOffset();
+ subrange.subrangeLength = subrangeOp.getSubrangeLength();
+ return subrange;
} else {
- Subview subview;
- subview.resource = value;
- subview.resourceSize =
+ Subrange subrange;
+ subrange.resource = value;
+ subrange.resourceSize =
IREE::Util::SizeAwareTypeInterface::queryValueSize(loc, value, builder);
- subview.subviewOffset = indexSet.get(0);
- subview.subviewLength = subview.resourceSize;
- return subview;
+ subrange.subrangeOffset = indexSet.get(0);
+ subrange.subrangeLength = subrange.resourceSize;
+ return subrange;
}
}
// Expands resources in |operands| into (resource, size, offset, length) tuples.
static SmallVector<Value> expandOperands(Location loc, ValueRange operands,
- SubviewMap &subviewMap,
+ SubrangeMap &subrangeMap,
IndexSet &indexSet,
OpBuilder &builder) {
SmallVector<Value> result;
result.reserve(operands.size() * 2);
for (auto operand : operands) {
if (isResourceType(operand.getType())) {
- auto subview =
- consumeSubview(loc, operand, subviewMap, indexSet, builder);
- result.push_back(subview.resource);
- result.push_back(subview.resourceSize);
- result.push_back(subview.subviewOffset);
- result.push_back(subview.subviewLength);
+ auto subrange =
+ consumeSubrange(loc, operand, subrangeMap, indexSet, builder);
+ result.push_back(subrange.resource);
+ result.push_back(subrange.resourceSize);
+ result.push_back(subrange.subrangeOffset);
+ result.push_back(subrange.subrangeLength);
} else {
result.push_back(operand);
}
@@ -199,14 +197,14 @@
return result;
}
-static void expandSubviews(Operation *op, ExpandedGlobalMap &globalMap,
- IndexSet &indexSet, SubviewMap &subviewMap);
+static void expandSubranges(Operation *op, ExpandedGlobalMap &globalMap,
+ IndexSet &indexSet, SubrangeMap &subrangeMap);
// Recursively expands resources into (resource, size, offset, length) tuples
// within the given |region|. All branches, ops, and nested regions will be
// processed.
static void expandRegion(Region ®ion, ExpandedGlobalMap &globalMap,
- IndexSet &indexSet, SubviewMap subviewMap) {
+ IndexSet &indexSet, SubrangeMap subrangeMap) {
if (region.empty()) return;
// Update all block arguments.
@@ -215,29 +213,30 @@
if (!llvm::any_of(block.getArgumentTypes(), isResourceType)) continue;
// Insert and build a list of expanded (resource, size, offset) tuples.
- SmallVector<Subview> expansions;
+ SmallVector<Subrange> expansions;
for (int i = block.getNumArguments() - 1; i >= 0; --i) {
auto arg = block.getArgument(i);
if (!isResourceType(arg.getType())) continue;
- Subview subview;
- subview.resource = arg;
- subview.resourceSize =
+ Subrange subrange;
+ subrange.resource = arg;
+ subrange.resourceSize =
block.insertArgument(i + 1, indexType, arg.getLoc());
- subview.subviewOffset =
+ subrange.subrangeOffset =
block.insertArgument(i + 2, indexType, arg.getLoc());
- subview.subviewLength =
+ subrange.subrangeLength =
block.insertArgument(i + 3, indexType, arg.getLoc());
- expansions.push_back(subview);
- subviewMap[arg] = subview;
+ expansions.push_back(subrange);
+ subrangeMap[arg] = subrange;
}
- // Insert subviews that we've sunk from callers.
+ // Insert subranges that we've sunk from callers.
auto builder = OpBuilder::atBlockBegin(&block);
for (auto &expansion : llvm::reverse(expansions)) {
- auto subviewOp = builder.create<IREE::Stream::ResourceSubviewOp>(
+ auto newSubrange = expansion.getResourceType().createSubrangeOp(
region.getLoc(), expansion.resource, expansion.resourceSize,
- expansion.subviewOffset, expansion.subviewLength);
- expansion.resource.replaceAllUsesExcept(subviewOp.getResult(), subviewOp);
+ expansion.subrangeOffset, expansion.subrangeLength, builder);
+ expansion.resource.replaceAllUsesExcept(newSubrange,
+ newSubrange.getDefiningOp());
}
}
@@ -247,20 +246,20 @@
if (region.hasOneBlock()) {
for (auto &op :
llvm::make_early_inc_range(region.front().getOperations())) {
- expandSubviews(&op, globalMap, indexSet, subviewMap);
+ expandSubranges(&op, globalMap, indexSet, subrangeMap);
}
} else {
DominanceInfo domInfo(region.getParentOp());
for (auto *blockInfo : llvm::breadth_first(domInfo.getRootNode(®ion))) {
auto *block = blockInfo->getBlock();
for (auto &op : llvm::make_early_inc_range(block->getOperations())) {
- expandSubviews(&op, globalMap, indexSet, subviewMap);
+ expandSubranges(&op, globalMap, indexSet, subrangeMap);
}
}
}
}
-// Moves resource subviews from global stores to loads.
+// Moves resource subranges from global stores to loads.
// Requires that the ExpandGlobalStoreOp pattern elides the await.
//
// Example:
@@ -274,37 +273,37 @@
// !stream.resource<*>{%s} -> !stream.resource<*>{%l}
static void expandGlobalLoadOp(IREE::Util::GlobalLoadOp op,
ExpandedGlobalMap &globalMap, IndexSet &indexSet,
- SubviewMap &subviewMap) {
+ SubrangeMap &subrangeMap) {
if (!usesResources(op)) return;
OpBuilder builder(op);
builder.setInsertionPointAfter(op);
auto indexType = builder.getIndexType();
auto &expandedGlobal = globalMap[op.getGlobal()];
- Subview subview;
- subview.resource = op.getResult();
- subview.resourceSize =
+ Subrange subrange;
+ subrange.resource = op.getResult();
+ subrange.resourceSize =
builder
.create<IREE::Util::GlobalLoadOp>(
op.getLoc(), indexType, expandedGlobal.resourceSizeOp.getName())
.getResult();
- subview.subviewOffset =
+ subrange.subrangeOffset =
builder
.create<IREE::Util::GlobalLoadOp>(
- op.getLoc(), indexType, expandedGlobal.subviewOffsetOp.getName())
+ op.getLoc(), indexType, expandedGlobal.subrangeOffsetOp.getName())
.getResult();
- subview.subviewLength =
+ subrange.subrangeLength =
builder
.create<IREE::Util::GlobalLoadOp>(
- op.getLoc(), indexType, expandedGlobal.subviewLengthOp.getName())
+ op.getLoc(), indexType, expandedGlobal.subrangeLengthOp.getName())
.getResult();
- subviewMap[op.getResult()] = subview;
- auto subviewOp = builder.create<IREE::Stream::ResourceSubviewOp>(
- op.getLoc(), subview.resource, subview.resourceSize,
- subview.subviewOffset, subview.subviewLength);
- op.getResult().replaceAllUsesExcept(subviewOp.getResult(), subviewOp);
+ subrangeMap[op.getResult()] = subrange;
+ auto newSubrange = subrange.getResourceType().createSubrangeOp(
+ op.getLoc(), subrange.resource, subrange.resourceSize,
+ subrange.subrangeOffset, subrange.subrangeLength, builder);
+ op.getResult().replaceAllUsesExcept(newSubrange, newSubrange.getDefiningOp());
}
-// Moves resource subviews from global stores to loads.
+// Moves resource subranges from global stores to loads.
// Requires that the ExpandGlobalLoadOp pattern inserts the await.
//
// Example:
@@ -318,38 +317,38 @@
// util.global.store %l, @foo_length : index
static void expandGlobalStoreOp(IREE::Util::GlobalStoreOp op,
ExpandedGlobalMap &globalMap,
- IndexSet &indexSet, SubviewMap &subviewMap) {
+ IndexSet &indexSet, SubrangeMap &subrangeMap) {
if (!usesResources(op)) return;
OpBuilder builder(op);
builder.setInsertionPointAfter(op);
- auto subview =
- consumeSubview(op.getLoc(), op.getValue(), subviewMap, indexSet, builder);
+ auto subrange = consumeSubrange(op.getLoc(), op.getValue(), subrangeMap,
+ indexSet, builder);
auto &expandedGlobal = globalMap[op.getGlobal()];
builder.create<IREE::Util::GlobalStoreOp>(
- op.getLoc(), subview.resource, expandedGlobal.resourceOp.getName());
+ op.getLoc(), subrange.resource, expandedGlobal.resourceOp.getName());
builder.create<IREE::Util::GlobalStoreOp>(
- op.getLoc(), subview.resourceSize,
+ op.getLoc(), subrange.resourceSize,
expandedGlobal.resourceSizeOp.getName());
builder.create<IREE::Util::GlobalStoreOp>(
- op.getLoc(), subview.subviewOffset,
- expandedGlobal.subviewOffsetOp.getName());
+ op.getLoc(), subrange.subrangeOffset,
+ expandedGlobal.subrangeOffsetOp.getName());
builder.create<IREE::Util::GlobalStoreOp>(
- op.getLoc(), subview.subviewLength,
- expandedGlobal.subviewLengthOp.getName());
+ op.getLoc(), subrange.subrangeLength,
+ expandedGlobal.subrangeLengthOp.getName());
op.erase();
}
static void expandInitializerOp(IREE::Util::InitializerOp op,
ExpandedGlobalMap &globalMap,
- IndexSet &indexSet, SubviewMap &subviewMap) {
- expandRegion(op.getRegion(), globalMap, indexSet, subviewMap);
+ IndexSet &indexSet, SubrangeMap &subrangeMap) {
+ expandRegion(op.getRegion(), globalMap, indexSet, subrangeMap);
}
-// Inserts subviews on resource arguments.
+// Inserts subranges on resource arguments.
// Requires that the ExpandCallOp/ExpandReturnOp patterns handle migrating the
// await.
//
-// NOTE: this needs IPO to remove redundant subviews in cases where the call
+// NOTE: this needs IPO to remove redundant subranges in cases where the call
// sites don't need a wait.
//
// Example:
@@ -358,7 +357,7 @@
// func.func @foo(%0: !stream.resource, %sz: index, %o: index, %l: index) {
// %1 = stream.resource.subview %0[%o] : {%sz} -> {%l}
static void expandFuncOp(mlir::func::FuncOp op, ExpandedGlobalMap &globalMap,
- IndexSet &indexSet, SubviewMap &subviewMap) {
+ IndexSet &indexSet, SubrangeMap &subrangeMap) {
auto oldType = op.getFunctionType();
auto inputTypes = expandTypes(oldType.getInputs());
auto resultTypes = expandTypes(oldType.getResults());
@@ -366,16 +365,16 @@
if (newType != oldType) {
op.setType(newType);
}
- expandRegion(op.getRegion(), globalMap, indexSet, subviewMap);
+ expandRegion(op.getRegion(), globalMap, indexSet, subrangeMap);
}
// Splits resource operands and results into (resource, resourceSize,
-// subviewOffset, subviewLength).
+// subrangeOffset, subrangeLength).
// Requires that the ExpandFuncOp/ExpandReturnOp patterns handle migrating the
// await.
//
// NOTE: this needs IPO to remove redundant values in cases where the call sites
-// don't need a subview.
+// don't need a subrange.
//
// Example:
// %1 = stream.resource.subview %0[%o] : {%sz} -> {%l}
@@ -384,19 +383,19 @@
// %r, %rsz, %ro, %rl = call @foo(%0, %sz, %o, %l)
// %2 = stream.resource.subview %r[%ro] : {%rsz} -> {%rl}
static void expandCallOp(mlir::func::CallOp op, IndexSet &indexSet,
- SubviewMap &subviewMap) {
+ SubrangeMap &subrangeMap) {
if (!usesResources(op)) return;
// Build the new call op with expanded operands and results.
OpBuilder builder(op);
- auto operands =
- expandOperands(op.getLoc(), op.operands(), subviewMap, indexSet, builder);
+ auto operands = expandOperands(op.getLoc(), op.operands(), subrangeMap,
+ indexSet, builder);
auto resultTypes = expandTypes(op.getResultTypes());
auto newOp = builder.create<mlir::func::CallOp>(op.getLoc(), op.getCallee(),
resultTypes, operands);
- // Insert subviews on results that we are sinking across the call edge.
- // The hope is that by moving the subviews here we can fold with uses inside
+ // Insert subranges on results that we are sinking across the call edge.
+ // The hope is that by moving the subranges here we can fold with uses inside
// of this function.
builder.setInsertionPointAfter(newOp);
unsigned newIdx = 0;
@@ -407,22 +406,22 @@
oldResult.replaceAllUsesWith(newResult);
continue;
}
- Subview subview;
- subview.resource = newOp.getResult(newIdx++);
- subview.resourceSize = newOp.getResult(newIdx++);
- subview.subviewOffset = newOp.getResult(newIdx++);
- subview.subviewLength = newOp.getResult(newIdx++);
- subviewMap[subview.resource] = subview;
- auto subviewOp = builder.create<IREE::Stream::ResourceSubviewOp>(
- op.getLoc(), subview.resource, subview.resourceSize,
- subview.subviewOffset, subview.subviewLength);
- oldResult.replaceAllUsesWith(subviewOp.getResult());
+ Subrange subrange;
+ subrange.resource = newOp.getResult(newIdx++);
+ subrange.resourceSize = newOp.getResult(newIdx++);
+ subrange.subrangeOffset = newOp.getResult(newIdx++);
+ subrange.subrangeLength = newOp.getResult(newIdx++);
+ subrangeMap[subrange.resource] = subrange;
+ auto newSubrange = subrange.getResourceType().createSubrangeOp(
+ op.getLoc(), subrange.resource, subrange.resourceSize,
+ subrange.subrangeOffset, subrange.subrangeLength, builder);
+ oldResult.replaceAllUsesWith(newSubrange);
}
op.erase();
}
-// Moves subviews to callers upon return.
+// Moves subranges to callers upon return.
// Requires that the ExpandFuncOp/ExpandCallOp patterns handle migrating the
// await.
//
@@ -432,16 +431,16 @@
// ->
// return %0, %sz, %o, %l
static void expandReturnOp(mlir::func::ReturnOp op, IndexSet &indexSet,
- SubviewMap &subviewMap) {
+ SubrangeMap &subrangeMap) {
if (!usesResources(op)) return;
OpBuilder builder(op);
- auto operands =
- expandOperands(op.getLoc(), op.operands(), subviewMap, indexSet, builder);
+ auto operands = expandOperands(op.getLoc(), op.operands(), subrangeMap,
+ indexSet, builder);
builder.create<mlir::func::ReturnOp>(op.getLoc(), operands);
op.erase();
}
-// Moves subviews across branches.
+// Moves subranges across branches.
// Requires that the ExpandFuncOp pattern handles modifying the block args.
//
// Example:
@@ -453,52 +452,52 @@
// ^bb1(%a, %b, %c, %d):
// %1 = stream.resource.subview %a[%b] : {%c} -> {%d}
static void expandBranchOp(mlir::cf::BranchOp op, IndexSet &indexSet,
- SubviewMap &subviewMap) {
+ SubrangeMap &subrangeMap) {
OpBuilder builder(op);
- auto operands = expandOperands(op.getLoc(), op.getDestOperands(), subviewMap,
+ auto operands = expandOperands(op.getLoc(), op.getDestOperands(), subrangeMap,
indexSet, builder);
builder.create<mlir::cf::BranchOp>(op.getLoc(), op.getDest(), operands);
op.erase();
}
static void expandCondBranchOp(mlir::cf::CondBranchOp op, IndexSet &indexSet,
- SubviewMap &subviewMap) {
+ SubrangeMap &subrangeMap) {
if (!usesResources(op)) return;
OpBuilder builder(op);
builder.create<mlir::cf::CondBranchOp>(
op.getLoc(), op.getCondition(), op.getTrueDest(),
- expandOperands(op.getLoc(), op.getTrueDestOperands(), subviewMap,
+ expandOperands(op.getLoc(), op.getTrueDestOperands(), subrangeMap,
indexSet, builder),
op.getFalseDest(),
- expandOperands(op.getLoc(), op.getFalseDestOperands(), subviewMap,
+ expandOperands(op.getLoc(), op.getFalseDestOperands(), subrangeMap,
indexSet, builder));
op.erase();
}
// Recursively expands resources into (resource, size, offset, length) in |op|.
-static void expandSubviews(Operation *op, ExpandedGlobalMap &globalMap,
- IndexSet &indexSet, SubviewMap &subviewMap) {
+static void expandSubranges(Operation *op, ExpandedGlobalMap &globalMap,
+ IndexSet &indexSet, SubrangeMap &subrangeMap) {
if (auto loadOp = dyn_cast<IREE::Util::GlobalLoadOp>(op)) {
- expandGlobalLoadOp(loadOp, globalMap, indexSet, subviewMap);
+ expandGlobalLoadOp(loadOp, globalMap, indexSet, subrangeMap);
} else if (auto storeOp = dyn_cast<IREE::Util::GlobalStoreOp>(op)) {
- expandGlobalStoreOp(storeOp, globalMap, indexSet, subviewMap);
+ expandGlobalStoreOp(storeOp, globalMap, indexSet, subrangeMap);
} else if (auto initializerOp = dyn_cast<IREE::Util::InitializerOp>(op)) {
- expandInitializerOp(initializerOp, globalMap, indexSet, subviewMap);
+ expandInitializerOp(initializerOp, globalMap, indexSet, subrangeMap);
} else if (auto funcOp = dyn_cast<mlir::func::FuncOp>(op)) {
- expandFuncOp(funcOp, globalMap, indexSet, subviewMap);
+ expandFuncOp(funcOp, globalMap, indexSet, subrangeMap);
} else if (auto callOp = dyn_cast<mlir::func::CallOp>(op)) {
- expandCallOp(callOp, indexSet, subviewMap);
+ expandCallOp(callOp, indexSet, subrangeMap);
} else if (auto returnOp = dyn_cast<mlir::func::ReturnOp>(op)) {
- expandReturnOp(returnOp, indexSet, subviewMap);
+ expandReturnOp(returnOp, indexSet, subrangeMap);
} else if (auto branchOp = dyn_cast<mlir::cf::BranchOp>(op)) {
- expandBranchOp(branchOp, indexSet, subviewMap);
+ expandBranchOp(branchOp, indexSet, subrangeMap);
} else if (auto condBranchOp = dyn_cast<mlir::cf::CondBranchOp>(op)) {
- expandCondBranchOp(condBranchOp, indexSet, subviewMap);
+ expandCondBranchOp(condBranchOp, indexSet, subrangeMap);
}
}
//===----------------------------------------------------------------------===//
-// -iree-stream-propagate-subviews
+// -iree-util-propagate-subranges
//===----------------------------------------------------------------------===//
// This does a relatively mechanical transformation of a module to expand all
@@ -506,25 +505,33 @@
//
// This is designed to be composed with generic optimization passes like global
// fusion/folding and IPO and as such performs all transformations locally. For
-// example, calls are always updated to take/return subview ranges and results
-// are always wrapped in a stream.resource.subview, with the
-// elision/deduplication/etc left until cleanup.
-class PropagateSubviewsPass
- : public PropagateSubviewsBase<PropagateSubviewsPass> {
+// example, calls are always updated to take/return subrange ranges and results
+// are always wrapped in a subrange op, with the elision/deduplication/etc left
+// until cleanup.
+class PropagateSubrangesPass
+ : public PassWrapper<PropagateSubrangesPass,
+ OperationPass<mlir::ModuleOp>> {
public:
- PropagateSubviewsPass() = default;
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PropagateSubrangesPass)
+
+ StringRef getArgument() const override {
+ return "iree-util-propagate-subranges";
+ }
+
+ StringRef getDescription() const override {
+ return "Propagates resource subranges across the program.";
+ }
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<mlir::func::FuncDialect>();
registry.insert<mlir::arith::ArithmeticDialect>();
- registry.insert<IREE::Stream::StreamDialect>();
registry.insert<IREE::Util::UtilDialect>();
}
void runOnOperation() override {
auto rootOp = getOperation();
- // Expand all util.global ops holding resources into resource and subview.
+ // Expand all util.global ops holding resources into resource and subrange.
auto globalMap = expandResourceGlobals(rootOp);
// Walk the entire IR tree and expand the globals.
@@ -539,19 +546,21 @@
!region || region->empty()
? OpBuilder(callableOp)
: OpBuilder::atBlockBegin(®ion->front()));
- SubviewMap subviewMap;
- expandSubviews(callableOp, globalMap, indexSet, subviewMap);
+ SubrangeMap subrangeMap;
+ expandSubranges(callableOp, globalMap, indexSet, subrangeMap);
}
}
};
} // namespace
-std::unique_ptr<OperationPass<mlir::ModuleOp>> createPropagateSubviewsPass() {
- return std::make_unique<PropagateSubviewsPass>();
+std::unique_ptr<OperationPass<mlir::ModuleOp>> createPropagateSubrangesPass() {
+ return std::make_unique<PropagateSubrangesPass>();
}
-} // namespace Stream
+static PassRegistration<PropagateSubrangesPass> pass;
+
+} // namespace Util
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/TestConversion.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/TestConversion.cpp
new file mode 100644
index 0000000..30fe43a
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/TestConversion.cpp
@@ -0,0 +1,80 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/Util/Conversion/ConversionPatterns.h"
+#include "iree/compiler/Dialect/Util/Conversion/MemRefToUtil/ConvertMemRefToUtil.h"
+#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
+#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
+#include "iree/compiler/Dialect/Util/Transforms/Passes.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace Util {
+
+namespace {
+
+class TestConversionPass
+ : public PassWrapper<TestConversionPass, OperationPass<void>> {
+ public:
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestConversionPass)
+
+ StringRef getArgument() const override { return "iree-util-test-conversion"; }
+
+ StringRef getDescription() const override {
+ return "Tests util dialect conversion patterns";
+ }
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<IREE::Util::UtilDialect, func::FuncDialect,
+ mlir::arith::ArithmeticDialect, math::MathDialect,
+ mlir::AffineDialect, memref::MemRefDialect>();
+ }
+
+ void runOnOperation() override {
+ auto *context = &getContext();
+
+ ConversionTarget conversionTarget(*context);
+ conversionTarget.addLegalDialect<arith::ArithmeticDialect>();
+ conversionTarget.addLegalDialect<IREE::Util::UtilDialect>();
+
+ TypeConverter typeConverter;
+ typeConverter.addConversion([](Type type) { return type; });
+
+ RewritePatternSet patterns(&getContext());
+ populateUtilConversionPatterns(context, conversionTarget, typeConverter,
+ patterns);
+ populateGenericStructuralConversionPatterns(context, conversionTarget,
+ typeConverter, patterns);
+ populateMemRefToUtilPatterns(context, conversionTarget, typeConverter,
+ patterns);
+
+ if (failed(applyPartialConversion(getOperation(), conversionTarget,
+ std::move(patterns)))) {
+ getOperation()->emitError() << "conversion to util failed";
+ return signalPassFailure();
+ }
+ }
+};
+
+} // namespace
+
+std::unique_ptr<OperationPass<void>> createTestConversionPass() {
+ return std::make_unique<TestConversionPass>();
+}
+
+static PassRegistration<TestConversionPass> pass;
+
+} // namespace Util
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/TestFloatRangeAnalysis.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/TestFloatRangeAnalysis.cpp
index 1158570..297c28f 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/TestFloatRangeAnalysis.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/TestFloatRangeAnalysis.cpp
@@ -66,7 +66,7 @@
} // namespace
-std::unique_ptr<OperationPass<void>> createTestFloatRangeAnalysis() {
+std::unique_ptr<OperationPass<void>> createTestFloatRangeAnalysisPass() {
return std::make_unique<TestFloatRangeAnalysisPass>();
}
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/BUILD b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/BUILD
index 81de2f1..2ac99b9 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/BUILD
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/BUILD
@@ -28,6 +28,7 @@
"hoist_into_globals.mlir",
"hoist_into_globals_linalg.mlir",
"promote_f16_to_f32.mlir",
+ "propagate_subranges.mlir",
"simplify_global_accesses.mlir",
"strip_debug_ops.mlir",
"test_float_range_analysis.mlir",
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/CMakeLists.txt
index a6f269d..9180723 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/CMakeLists.txt
@@ -25,6 +25,7 @@
"hoist_into_globals.mlir"
"hoist_into_globals_linalg.mlir"
"promote_f16_to_f32.mlir"
+ "propagate_subranges.mlir"
"simplify_global_accesses.mlir"
"strip_debug_ops.mlir"
"test_float_range_analysis.mlir"
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/propagate_subranges.mlir b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/propagate_subranges.mlir
new file mode 100644
index 0000000..1653269
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/propagate_subranges.mlir
@@ -0,0 +1,149 @@
+// RUN: iree-opt --split-input-file --iree-util-propagate-subranges %s | FileCheck %s
+
+// Tests that resource global loads also load all the subrange params.
+//
+// This rotates subranges through stores and into loads.
+
+// CHECK: util.global private mutable @constantGlobal : !util.buffer
+// CHECK-NEXT: util.global private mutable @constantGlobal__storage_size : index
+// CHECK-NEXT: util.global private mutable @constantGlobal__offset : index
+// CHECK-NEXT: util.global private mutable @constantGlobal__length : index
+util.global private mutable @constantGlobal : !util.buffer
+
+// CHECK-LABEL: @globalLoad
+func.func @globalLoad() {
+ // CHECK-NEXT: %[[RESOURCE:.+]] = util.global.load @constantGlobal : !util.buffer
+ // CHECK-NEXT: %[[STORAGE_SIZE:.+]] = util.global.load @constantGlobal__storage_size : index
+ // CHECK-NEXT: %[[OFFSET:.+]] = util.global.load @constantGlobal__offset : index
+ // CHECK-NEXT: %[[LENGTH:.+]] = util.global.load @constantGlobal__length : index
+ // CHECK: %[[SUBRANGE:.+]] = util.buffer.subspan %[[RESOURCE]][%[[OFFSET]]] : !util.buffer{%[[STORAGE_SIZE]]} -> !util.buffer{%[[LENGTH]]}
+ %0 = util.global.load @constantGlobal : !util.buffer
+ // CHECK-NEXT: util.do_not_optimize(%[[SUBRANGE]])
+ util.do_not_optimize(%0) : !util.buffer
+ return
+}
+
+// -----
+
+// Tests that resource global stores consume their incoming subranges.
+//
+// This rotates subranges through stores and into loads.
+
+// CHECK: util.global private mutable @mutableGlobal : !util.buffer
+// CHECK-NEXT: util.global private mutable @mutableGlobal__storage_size : index
+// CHECK-NEXT: util.global private mutable @mutableGlobal__offset : index
+// CHECK-NEXT: util.global private mutable @mutableGlobal__length : index
+util.global private mutable @mutableGlobal : !util.buffer
+
+// CHECK-LABEL: @globalStore
+// CHECK-SAME: (%[[RESOURCE:.+]]: !util.buffer, %[[STORAGE_SIZE:.+]]: index, %[[OFFSET:.+]]: index, %[[LENGTH:.+]]: index)
+func.func @globalStore(%resource: !util.buffer) {
+ // CHECK: util.global.store %[[RESOURCE]], @mutableGlobal : !util.buffer
+ // CHECK: util.global.store %[[STORAGE_SIZE]], @mutableGlobal__storage_size : index
+ // CHECK: util.global.store %[[OFFSET]], @mutableGlobal__offset : index
+ // CHECK: util.global.store %[[LENGTH]], @mutableGlobal__length : index
+ util.global.store %resource, @mutableGlobal : !util.buffer
+ return
+}
+
+// -----
+
+// Tests that function arguments are expanded into an explicit subrange of
+// (resource, size, offset, length).
+//
+// This rotates subranges from callers into callees.
+
+// CHECK-LABEL: @funcArgs
+// CHECK-SAME: (%[[RESOURCE0:.+]]: !util.buffer, %[[STORAGE_SIZE0:.+]]: index, %[[OFFSET0:.+]]: index, %[[LENGTH0:.+]]: index, %[[RESOURCE1:.+]]: !util.buffer, %[[STORAGE_SIZE1:.+]]: index, %[[OFFSET1:.+]]: index, %[[LENGTH1:.+]]: index)
+func.func @funcArgs(%resource0: !util.buffer, %resource1: !util.buffer) {
+ // CHECK-NEXT: %[[SUBRANGE0:.+]] = util.buffer.subspan %[[RESOURCE0]][%[[OFFSET0]]] : !util.buffer{%[[STORAGE_SIZE0]]} -> !util.buffer{%[[LENGTH0]]}
+ // CHECK-NEXT: %[[SUBRANGE1:.+]] = util.buffer.subspan %[[RESOURCE1]][%[[OFFSET1]]] : !util.buffer{%[[STORAGE_SIZE1]]} -> !util.buffer{%[[LENGTH1]]}
+
+ // CHECK-NEXT: util.do_not_optimize(%[[SUBRANGE0]])
+ util.do_not_optimize(%resource0) : !util.buffer
+ // CHECK-NEXT: util.do_not_optimize(%[[SUBRANGE1]])
+ util.do_not_optimize(%resource1) : !util.buffer
+ return
+}
+
+// -----
+
+// Tests that function results are expanded into an explicit subrange of
+// (resource, size, offset, length).
+//
+// This rotates subranges from callees into callers.
+
+// CHECK-LABEL: @funcResults
+// CHECK-SAME: (%[[RESOURCE0:.+]]: !util.buffer, %[[STORAGE_SIZE0:.+]]: index, %[[OFFSET0:.+]]: index, %[[LENGTH0:.+]]: index, %[[RESOURCE1:.+]]: !util.buffer, %[[STORAGE_SIZE1:.+]]: index, %[[OFFSET1:.+]]: index, %[[LENGTH1:.+]]: index)
+// CHECK-SAME: -> (!util.buffer, index, index, index, !util.buffer, index, index, index)
+func.func @funcResults(%resource0: !util.buffer, %resource1: !util.buffer) -> (!util.buffer, !util.buffer) {
+ // NOTE: there will be extra stuff here from the arg insertion. Since the
+ // return should consume the subrange that was inserted we expect to directly
+ // use the function arguments.
+
+ // CHECK: return %[[RESOURCE0]], %[[STORAGE_SIZE0]], %[[OFFSET0]], %[[LENGTH0]], %[[RESOURCE1]], %[[STORAGE_SIZE1]], %[[OFFSET1]], %[[LENGTH1]]
+ return %resource0, %resource1 : !util.buffer, !util.buffer
+}
+
+// -----
+
+// Tests that function calls have their args and results expanded into
+// (resource, size, offset, length).
+//
+// This rotates subranges on args from callers to callees and subranges on results
+// from callees to callers.
+
+// CHECK-LABEL: @caller
+// CHECK-SAME: (%[[RESOURCE0:.+]]: !util.buffer, %[[STORAGE_SIZE0:.+]]: index, %[[OFFSET0:.+]]: index, %[[LENGTH0:.+]]: index, %[[RESOURCE1:.+]]: !util.buffer, %[[STORAGE_SIZE1:.+]]: index, %[[OFFSET1:.+]]: index, %[[LENGTH1:.+]]: index)
+func.func @caller(%resource0: !util.buffer, %resource1: !util.buffer) {
+ // NOTE: there will be extra stuff here from the arg insertion. The call
+ // consumes the subranges and we expect the args to be passed directly.
+
+ // CHECK: %[[RET:.+]]:8 = call @callee(%[[RESOURCE0]], %[[STORAGE_SIZE0]], %[[OFFSET0]], %[[LENGTH0]],
+ // CHECK-SAME: %[[RESOURCE1]], %[[STORAGE_SIZE1]], %[[OFFSET1]], %[[LENGTH1]])
+ // CHECK-SAME: : (!util.buffer, index, index, index, !util.buffer, index, index, index)
+ // CHECK-SAME: -> (!util.buffer, index, index, index, !util.buffer, index, index, index)
+ %0:2 = call @callee(%resource0, %resource1) : (!util.buffer, !util.buffer) -> (!util.buffer, !util.buffer)
+ // CHECK-NEXT: %[[RET_SUBRANGE0:.+]] = util.buffer.subspan %[[RET]]#0[%[[RET]]#2] : !util.buffer{%[[RET]]#1} -> !util.buffer{%[[RET]]#3}
+ // CHECK-NEXT: %[[RET_SUBRANGE1:.+]] = util.buffer.subspan %[[RET]]#4[%[[RET]]#6] : !util.buffer{%[[RET]]#5} -> !util.buffer{%[[RET]]#7}
+
+ // CHECK-NEXT: util.do_not_optimize(%[[RET_SUBRANGE0]]) : !util.buffer
+ util.do_not_optimize(%0#0) : !util.buffer
+ // CHECK-NEXT: util.do_not_optimize(%[[RET_SUBRANGE1]]) : !util.buffer
+ util.do_not_optimize(%0#1) : !util.buffer
+
+ return
+}
+
+func.func private @callee(%arg0: !util.buffer, %arg1: !util.buffer) -> (!util.buffer, !util.buffer)
+
+// -----
+
+// Tests that branch arguments are expanded into an explicit subrange of
+// (resource, size, offset, length).
+//
+// This rotates subranges on branch operands into successors.
+
+// CHECK-LABEL: @br
+// CHECK-SAME: (%[[RESOURCE0:.+]]: !util.buffer, %[[STORAGE_SIZE0:.+]]: index, %[[OFFSET0:.+]]: index, %[[LENGTH0:.+]]: index, %[[RESOURCE1:.+]]: !util.buffer, %[[STORAGE_SIZE1:.+]]: index, %[[OFFSET1:.+]]: index, %[[LENGTH1:.+]]: index)
+func.func @br(%resource0: !util.buffer, %resource1: !util.buffer) {
+ // NOTE: there will be extra stuff here from the arg insertion. The branch
+ // consumes the unready resources and we expect the args to be passed directly
+ // to the cf.br.
+
+ // CHECK: cf.br ^bb1(%[[RESOURCE0]], %[[STORAGE_SIZE0]], %[[OFFSET0]], %[[LENGTH0]],
+ // CHECK-SAME: %[[RESOURCE1]], %[[STORAGE_SIZE1]], %[[OFFSET1]], %[[LENGTH1]] :
+ cf.br ^bb1(%resource0, %resource1 : !util.buffer, !util.buffer)
+
+// CHECK-NEXT: ^bb1(%[[BB1_RESOURCE0:.+]]: !util.buffer, %[[BB1_STORAGE_SIZE0:.+]]: index, %[[BB1_OFFSET0:.+]]: index, %[[BB1_LENGTH0:.+]]: index, %[[BB1_RESOURCE1:.+]]: !util.buffer, %[[BB1_STORAGE_SIZE1:.+]]: index, %[[BB1_OFFSET1:.+]]: index, %[[BB1_LENGTH1:.+]]: index):
+^bb1(%bb1_resource0: !util.buffer, %bb1_resource1: !util.buffer):
+ // CHECK-NEXT: %[[BB1_SUBRANGE0:.+]] = util.buffer.subspan %[[BB1_RESOURCE0]][%[[BB1_OFFSET0]]] : !util.buffer{%[[BB1_STORAGE_SIZE0]]} -> !util.buffer{%[[BB1_LENGTH0]]}
+ // CHECK-NEXT: %[[BB1_SUBRANGE1:.+]] = util.buffer.subspan %[[BB1_RESOURCE1]][%[[BB1_OFFSET1]]] : !util.buffer{%[[BB1_STORAGE_SIZE1]]} -> !util.buffer{%[[BB1_LENGTH1]]}
+
+ // CHECK-NEXT: util.do_not_optimize(%[[BB1_SUBRANGE0]])
+ util.do_not_optimize(%bb1_resource0) : !util.buffer
+ // CHECK-NEXT: util.do_not_optimize(%[[BB1_SUBRANGE1]])
+ util.do_not_optimize(%bb1_resource1) : !util.buffer
+
+ return
+}
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/BUILD b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/BUILD
index fbc2f6d..38b8028 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/BUILD
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/BUILD
@@ -16,6 +16,7 @@
name = "UtilToVM",
srcs = [
"ConvertAlignmentOps.cpp",
+ "ConvertBufferOps.cpp",
"ConvertGlobalOps.cpp",
"ConvertListOps.cpp",
"ConvertStatusOps.cpp",
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/CMakeLists.txt
index 646faaf..0b580c4 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/CMakeLists.txt
@@ -17,6 +17,7 @@
"ConvertUtilToVM.h"
SRCS
"ConvertAlignmentOps.cpp"
+ "ConvertBufferOps.cpp"
"ConvertGlobalOps.cpp"
"ConvertListOps.cpp"
"ConvertStatusOps.cpp"
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertBufferOps.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertBufferOps.cpp
new file mode 100644
index 0000000..3c50f36
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertBufferOps.cpp
@@ -0,0 +1,313 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
+#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
+#include "iree/compiler/Dialect/VM/Conversion/TypeConverter.h"
+#include "iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertUtilToVM.h"
+#include "iree/compiler/Dialect/VM/IR/VMOps.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace {
+
+static Value castToI64(Value value, OpBuilder &builder) {
+ if (value.getType().isInteger(64)) return value;
+ return builder.createOrFold<IREE::VM::ExtI32I64UOp>(
+ value.getLoc(), builder.getI64Type(), value);
+}
+
+struct BufferConstantOpConversion
+ : public OpConversionPattern<IREE::Util::BufferConstantOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ IREE::Util::BufferConstantOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto alignmentAttr = op.getAlignmentAttr();
+ if (alignmentAttr) {
+ alignmentAttr = rewriter.getI64IntegerAttr(alignmentAttr.getInt());
+ }
+ rewriter.replaceOpWithNewOp<IREE::VM::RodataInlineOp>(
+ op,
+ IREE::VM::RefType::get(
+ IREE::VM::BufferType::get(rewriter.getContext())),
+ /*name=*/nullptr, op.getValue(), alignmentAttr);
+ return success();
+ }
+};
+
+struct BufferAllocOpConversion
+ : public OpConversionPattern<IREE::Util::BufferAllocOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ IREE::Util::BufferAllocOp allocOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // TODO(#9165): support alignment for vm.buffer.alloc. So far we ignore the
+ // alignment attribute when lowering the op to VM dialect.
+ (void)adaptor.getAlignment();
+ auto resultType =
+ getTypeConverter()->convertType(allocOp.getResult().getType());
+ rewriter.replaceOpWithNewOp<IREE::VM::BufferAllocOp>(
+ allocOp, resultType, castToI64(adaptor.getStorageSize(), rewriter));
+ return success();
+ }
+};
+
+struct BufferDeallocOpConversion
+ : public OpConversionPattern<IREE::Util::BufferDeallocOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ IREE::Util::BufferDeallocOp deallocOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // No-op today. We could make this force a dealloc of the underlying storage
+ // or have a vm.hint.reset or something to force a drop of the reference.
+ rewriter.eraseOp(deallocOp);
+ return success();
+ }
+};
+
+// Expands util.buffer.slice -> vm.buffer.alloc + vm.buffer.copy.
+// We could have a vm.buffer.slice op if we wanted; today there's nothing we'd
+// do in the runtime besides this.
+struct BufferSliceOpConversion
+ : public OpConversionPattern<IREE::Util::BufferSliceOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ IREE::Util::BufferSliceOp sliceOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // TODO(#9165): support alignment for vm.buffer.alloc. So far we ignore the
+ // alignment attribute when lowering the op to VM dialect.
+ (void)adaptor.getAlignment();
+ auto resultType =
+ getTypeConverter()->convertType(sliceOp.getResult().getType());
+ auto sliceLength = castToI64(adaptor.getResultSize(), rewriter);
+ Value newBuffer = rewriter.create<IREE::VM::BufferAllocOp>(
+ sliceOp.getLoc(), resultType, sliceLength);
+ Value zero = rewriter.create<IREE::VM::ConstI64ZeroOp>(sliceOp.getLoc());
+ rewriter.create<IREE::VM::BufferCopyOp>(
+ sliceOp.getLoc(), adaptor.getSource(),
+ castToI64(adaptor.getSourceOffset(), rewriter), newBuffer, zero,
+ sliceLength);
+ rewriter.replaceOp(sliceOp, newBuffer);
+ return success();
+ }
+};
+
+struct BufferSizeOpConversion
+ : public OpConversionPattern<IREE::Util::BufferSizeOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ IREE::Util::BufferSizeOp sizeOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Value size = rewriter.create<IREE::VM::BufferLengthOp>(
+ sizeOp.getLoc(), rewriter.getI64Type(), adaptor.getOperand());
+ rewriter.replaceOp(sizeOp, size);
+ return success();
+ }
+};
+
+struct BufferCopyOpConversion
+ : public OpConversionPattern<IREE::Util::BufferCopyOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ IREE::Util::BufferCopyOp copyOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<IREE::VM::BufferCopyOp>(
+ copyOp, adaptor.getSource(),
+ castToI64(adaptor.getSourceOffset(), rewriter), adaptor.getTarget(),
+ castToI64(adaptor.getTargetOffset(), rewriter),
+ castToI64(adaptor.getLength(), rewriter));
+ return success();
+ }
+};
+
+struct BufferCompareOpConversion
+ : public OpConversionPattern<IREE::Util::BufferCompareOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ IREE::Util::BufferCompareOp compareOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto resultType =
+ getTypeConverter()->convertType(compareOp.getResult().getType());
+ rewriter.replaceOpWithNewOp<IREE::VM::BufferCompareOp>(
+ compareOp, resultType, adaptor.getLhs(),
+ castToI64(adaptor.getLhsOffset(), rewriter), adaptor.getRhs(),
+ castToI64(adaptor.getRhsOffset(), rewriter),
+ castToI64(adaptor.getLength(), rewriter));
+ return success();
+ }
+};
+
+struct BufferFillOpConversion
+ : public OpConversionPattern<IREE::Util::BufferFillOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ IREE::Util::BufferFillOp fillOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto oldType = fillOp.getPattern().getType();
+ auto byteOffset = castToI64(adaptor.getTargetOffset(), rewriter);
+ auto byteLength = castToI64(adaptor.getLength(), rewriter);
+ auto pattern = adaptor.getPattern();
+ if (auto integerType = oldType.dyn_cast<IntegerType>()) {
+ if (integerType.isInteger(1) || integerType.isInteger(8)) {
+ rewriter.replaceOpWithNewOp<IREE::VM::BufferFillI8Op>(
+ fillOp, adaptor.getTarget(), byteOffset, byteLength, pattern);
+ } else if (integerType.isInteger(16)) {
+ rewriter.replaceOpWithNewOp<IREE::VM::BufferFillI16Op>(
+ fillOp, adaptor.getTarget(), byteOffset, byteLength, pattern);
+ } else if (integerType.isInteger(32)) {
+ rewriter.replaceOpWithNewOp<IREE::VM::BufferFillI32Op>(
+ fillOp, adaptor.getTarget(), byteOffset, byteLength, pattern);
+ } else if (integerType.isInteger(64)) {
+ rewriter.replaceOpWithNewOp<IREE::VM::BufferFillI64Op>(
+ fillOp, adaptor.getTarget(), byteOffset, byteLength, pattern);
+ } else {
+ return rewriter.notifyMatchFailure(
+ fillOp, "invalid integer buffer element type");
+ }
+ } else if (oldType.isF32()) {
+ rewriter.replaceOpWithNewOp<IREE::VM::BufferFillF32Op>(
+ fillOp, adaptor.getTarget(), byteOffset, byteLength, pattern);
+ } else if (oldType.isF64()) {
+ rewriter.replaceOpWithNewOp<IREE::VM::BufferFillF64Op>(
+ fillOp, adaptor.getTarget(), byteOffset, byteLength, pattern);
+ } else {
+ return rewriter.notifyMatchFailure(fillOp,
+ "invalid float buffer element type");
+ }
+ return success();
+ }
+};
+
+struct BufferLoadOpConversion
+ : public OpConversionPattern<IREE::Util::BufferLoadOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ IREE::Util::BufferLoadOp loadOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto oldType = loadOp.getResult().getType();
+ auto newType = getTypeConverter()->convertType(oldType);
+ auto byteOffset = castToI64(adaptor.getSourceOffset(), rewriter);
+ if (auto integerType = oldType.dyn_cast<IntegerType>()) {
+ if (integerType.isInteger(1) || integerType.isInteger(8)) {
+ if (integerType.isSigned() || integerType.isSignless()) {
+ rewriter.replaceOpWithNewOp<IREE::VM::BufferLoadI8SOp>(
+ loadOp, newType, adaptor.getSource(), byteOffset);
+ } else {
+ rewriter.replaceOpWithNewOp<IREE::VM::BufferLoadI8UOp>(
+ loadOp, newType, adaptor.getSource(), byteOffset);
+ }
+ } else if (integerType.isInteger(16)) {
+ if (integerType.isSigned() || integerType.isSignless()) {
+ rewriter.replaceOpWithNewOp<IREE::VM::BufferLoadI16SOp>(
+ loadOp, newType, adaptor.getSource(), byteOffset);
+ } else {
+ rewriter.replaceOpWithNewOp<IREE::VM::BufferLoadI16UOp>(
+ loadOp, newType, adaptor.getSource(), byteOffset);
+ }
+ } else if (integerType.isInteger(32)) {
+ rewriter.replaceOpWithNewOp<IREE::VM::BufferLoadI32Op>(
+ loadOp, newType, adaptor.getSource(), byteOffset);
+ } else if (integerType.isInteger(64)) {
+ rewriter.replaceOpWithNewOp<IREE::VM::BufferLoadI64Op>(
+ loadOp, newType, adaptor.getSource(), byteOffset);
+ } else {
+ return rewriter.notifyMatchFailure(
+ loadOp, "invalid integer buffer element type");
+ }
+ } else if (oldType.isF32()) {
+ rewriter.replaceOpWithNewOp<IREE::VM::BufferLoadF32Op>(
+ loadOp, newType, adaptor.getSource(), byteOffset);
+ } else if (oldType.isF64()) {
+ rewriter.replaceOpWithNewOp<IREE::VM::BufferLoadF64Op>(
+ loadOp, newType, adaptor.getSource(), byteOffset);
+ } else {
+ return rewriter.notifyMatchFailure(loadOp,
+ "invalid float buffer element type");
+ }
+ return success();
+ }
+};
+
+struct BufferStoreOpConversion
+ : public OpConversionPattern<IREE::Util::BufferStoreOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ IREE::Util::BufferStoreOp storeOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto oldType = storeOp.getSource().getType();
+ auto byteOffset = castToI64(adaptor.getTargetOffset(), rewriter);
+ if (oldType.isInteger(1) || oldType.isInteger(8)) {
+ rewriter.replaceOpWithNewOp<IREE::VM::BufferStoreI8Op>(
+ storeOp, adaptor.getTarget(), byteOffset, adaptor.getSource());
+ } else if (oldType.isInteger(16)) {
+ rewriter.replaceOpWithNewOp<IREE::VM::BufferStoreI16Op>(
+ storeOp, adaptor.getTarget(), byteOffset, adaptor.getSource());
+ } else if (oldType.isInteger(32)) {
+ rewriter.replaceOpWithNewOp<IREE::VM::BufferStoreI32Op>(
+ storeOp, adaptor.getTarget(), byteOffset, adaptor.getSource());
+ } else if (oldType.isInteger(64)) {
+ rewriter.replaceOpWithNewOp<IREE::VM::BufferStoreI64Op>(
+ storeOp, adaptor.getTarget(), byteOffset, adaptor.getSource());
+ } else if (oldType.isF32()) {
+ rewriter.replaceOpWithNewOp<IREE::VM::BufferStoreF32Op>(
+ storeOp, adaptor.getTarget(), byteOffset, adaptor.getSource());
+ } else if (oldType.isF64()) {
+ rewriter.replaceOpWithNewOp<IREE::VM::BufferStoreF64Op>(
+ storeOp, adaptor.getTarget(), byteOffset, adaptor.getSource());
+ } else {
+ return rewriter.notifyMatchFailure(storeOp,
+ "invalid buffer element type");
+ }
+ return success();
+ }
+};
+
+} // namespace
+
+void populateUtilBufferToVMPatterns(MLIRContext *context,
+ ConversionTarget &conversionTarget,
+ TypeConverter &typeConverter,
+ RewritePatternSet &patterns) {
+ typeConverter.addConversion(
+ [](IREE::Util::BufferType type) -> Optional<Type> {
+ return IREE::VM::RefType::get(
+ IREE::VM::BufferType::get(type.getContext()));
+ });
+
+ // TODO(benvanik): some way to handle subspans if they survive. For today we
+ // require they are all removed by propagation. This won't be the case if
+ // buffer subspans are returned across the ABI boundary.
+ conversionTarget.addIllegalOp<IREE::Util::BufferStorageOp>();
+ conversionTarget.addIllegalOp<IREE::Util::BufferSubspanOp>();
+
+ conversionTarget
+ .addIllegalOp<IREE::Util::BufferConstantOp, IREE::Util::BufferAllocOp,
+ IREE::Util::BufferDeallocOp, IREE::Util::BufferSliceOp,
+ IREE::Util::BufferSizeOp, IREE::Util::BufferCopyOp,
+ IREE::Util::BufferCompareOp, IREE::Util::BufferFillOp,
+ IREE::Util::BufferLoadOp, IREE::Util::BufferStoreOp>();
+
+ patterns.insert<BufferConstantOpConversion>(typeConverter, context);
+ patterns.insert<BufferAllocOpConversion>(typeConverter, context);
+ patterns.insert<BufferDeallocOpConversion>(typeConverter, context);
+ patterns.insert<BufferSliceOpConversion>(typeConverter, context);
+ patterns.insert<BufferSizeOpConversion>(typeConverter, context);
+ patterns.insert<BufferCopyOpConversion>(typeConverter, context);
+ patterns.insert<BufferCompareOpConversion>(typeConverter, context);
+ patterns.insert<BufferFillOpConversion>(typeConverter, context);
+ patterns.insert<BufferLoadOpConversion>(typeConverter, context);
+ patterns.insert<BufferStoreOpConversion>(typeConverter, context);
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertUtilToVM.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertUtilToVM.cpp
index 98d5ddf..c944af0 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertUtilToVM.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertUtilToVM.cpp
@@ -20,6 +20,14 @@
namespace mlir {
namespace iree_compiler {
+void populateUtilAlignmentToVMPatterns(MLIRContext *context,
+ ConversionTarget &conversionTarget,
+ TypeConverter &typeConverter,
+ RewritePatternSet &patterns);
+void populateUtilBufferToVMPatterns(MLIRContext *context,
+ ConversionTarget &conversionTarget,
+ TypeConverter &typeConverter,
+ RewritePatternSet &patterns);
void populateUtilGlobalToVMPatterns(MLIRContext *context,
ConversionTarget &conversionTarget,
TypeConverter &typeConverter,
@@ -32,10 +40,6 @@
ConversionTarget &conversionTarget,
TypeConverter &typeConverter,
RewritePatternSet &patterns);
-void populateUtilAlignmentToVMPatterns(MLIRContext *context,
- ConversionTarget &conversionTarget,
- TypeConverter &typeConverter,
- RewritePatternSet &patterns);
namespace {
@@ -74,25 +78,6 @@
};
//===----------------------------------------------------------------------===//
-// util.byte_buffer.*
-//===----------------------------------------------------------------------===//
-
-struct BufferConstantOpConversion
- : public OpConversionPattern<IREE::Util::BufferConstantOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult matchAndRewrite(
- IREE::Util::BufferConstantOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<IREE::VM::RodataInlineOp>(
- op,
- IREE::VM::RefType::get(
- IREE::VM::BufferType::get(rewriter.getContext())),
- /*name=*/nullptr, op.getValue(), op.getAlignmentAttr());
- return success();
- }
-};
-
-//===----------------------------------------------------------------------===//
// Compiler hints
//===----------------------------------------------------------------------===//
@@ -120,23 +105,18 @@
RewritePatternSet &patterns) {
patterns.insert<NullOpConversion>(typeConverter, context);
patterns.insert<CmpEQOpConversion>(typeConverter, context);
- patterns.insert<BufferConstantOpConversion>(typeConverter, context);
patterns.insert<UnreachableOpConversion>(typeConverter, context);
- typeConverter.addConversion(
- [](IREE::Util::BufferType type) -> Optional<Type> {
- return IREE::VM::RefType::get(
- IREE::VM::BufferType::get(type.getContext()));
- });
-
+ populateUtilAlignmentToVMPatterns(context, conversionTarget, typeConverter,
+ patterns);
+ populateUtilBufferToVMPatterns(context, conversionTarget, typeConverter,
+ patterns);
populateUtilGlobalToVMPatterns(context, conversionTarget, typeConverter,
patterns);
populateUtilListToVMPatterns(context, conversionTarget, typeConverter,
patterns);
populateUtilStatusToVMPatterns(context, conversionTarget, typeConverter,
patterns);
- populateUtilAlignmentToVMPatterns(context, conversionTarget, typeConverter,
- patterns);
}
} // namespace iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/test/buffer_ops.mlir b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/test/buffer_ops.mlir
index d3fced7..bfd2936 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/test/buffer_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/test/buffer_ops.mlir
@@ -1,13 +1,224 @@
-// RUN: iree-opt --split-input-file --iree-vm-conversion %s | FileCheck %s
+// RUN: iree-opt --split-input-file --iree-vm-conversion --iree-vm-target-index-bits=32 %s | FileCheck %s --check-prefix=CHECK-32
+// RUN: iree-opt --split-input-file --iree-vm-conversion --iree-vm-target-index-bits=64 %s | FileCheck %s --check-prefix=CHECK-64
// CHECK-LABEL: @buffer_constant
-module @buffer_constant {
-module {
- // CHECK: vm.func private @my_fn
- func.func @my_fn() {
- // CHECK-NEXT: = vm.rodata.inline : !vm.buffer = dense<[1, 2, 3]> : tensor<3xi32>
- %0 = util.buffer.constant : !util.buffer = dense<[1, 2, 3]> : tensor<3xi32>
- return
- }
+func.func @buffer_constant() -> !util.buffer {
+ // CHECK-64: %[[BUFFER:.+]] = vm.rodata.inline {alignment = 64 : i64} : !vm.buffer = dense<[1, 2, 3]> : tensor<3xi32>
+ %0 = util.buffer.constant {alignment = 64 : index} : !util.buffer = dense<[1, 2, 3]> : tensor<3xi32>
+ // CHECK-64: return %[[BUFFER]]
+ return %0 : !util.buffer
}
+
+// -----
+
+// CHECK-LABEL: @buffer_alloc
+func.func @buffer_alloc(%arg0: index) -> !util.buffer {
+ // CHECK-32: %[[SIZE_64:.+]] = vm.ext.i32.i64.u %arg0 : i32 -> i64
+ // CHECK-32: %[[BUFFER:.+]] = vm.buffer.alloc %[[SIZE_64]] : !vm.buffer
+ // CHECK-64: %[[BUFFER:.+]] = vm.buffer.alloc %arg0 : !vm.buffer
+ %0 = util.buffer.alloc uninitialized {alignment = 16 : index} : !util.buffer{%arg0}
+ // CHECK-32: return %[[BUFFER]]
+ return %0 : !util.buffer
+}
+
+// -----
+
+// NOTE: currently not used.
+
+// CHECK-LABEL: @buffer_dealloc
+func.func @buffer_dealloc(%arg0: !util.buffer, %arg1: index) {
+ // CHECK-32-NOT: util.buffer.dealloc
+ util.buffer.dealloc %arg0 : !util.buffer{%arg1}
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @buffer_slice
+func.func @buffer_slice(%arg0: !util.buffer, %arg1: index, %arg2: index, %arg3: index) -> !util.buffer {
+ // CHECK-32: %[[SIZE_64:.+]] = vm.ext.i32.i64.u %arg3 : i32 -> i64
+ // CHECK-32: %[[BUFFER:.+]] = vm.buffer.alloc %[[SIZE_64]] : !vm.buffer
+ // CHECK-32-DAG: %[[ZERO:.+]] = vm.const.i64.zero
+ // CHECK-32-DAG: %[[OFFSET_64:.+]] = vm.ext.i32.i64.u %arg1 : i32 -> i64
+ // CHECK-32: vm.buffer.copy %arg0, %[[OFFSET_64]], %[[BUFFER]], %[[ZERO]], %[[SIZE_64]] : !vm.buffer -> !vm.buffer
+ // CHECK-64-DAG: %[[BUFFER:.+]] = vm.buffer.alloc %arg3 : !vm.buffer
+ // CHECK-64-DAG: %[[ZERO:.+]] = vm.const.i64.zero
+ // CHECK-64: vm.buffer.copy %arg0, %arg1, %[[BUFFER]], %[[ZERO]], %arg3 : !vm.buffer -> !vm.buffer
+ %0 = util.buffer.slice %arg0[%arg1] {alignment = 16 : index} : !util.buffer{%arg2} -> !util.buffer{%arg3}
+ // CHECK-32: return %[[BUFFER]]
+ return %0 : !util.buffer
+}
+
+// -----
+
+// CHECK-LABEL: @buffer_size
+func.func @buffer_size(%arg0: !util.buffer) -> index {
+ // CHECK-32: %[[SIZE:.+]] = vm.buffer.length %arg0 : !vm.buffer -> i64
+ %0 = util.buffer.size %arg0 : !util.buffer
+ // CHECK-32: return %[[SIZE]]
+ return %0 : index
+}
+
+// -----
+
+// CHECK-LABEL: @buffer_copy
+func.func @buffer_copy(%arg0: !util.buffer, %arg1: index) {
+ %c3 = arith.constant 3 : index
+ %c100 = arith.constant 100 : index
+ %c200 = arith.constant 200 : index
+ // CHECK-32-DAG: %[[C3:.+]] = vm.const.i64 3
+ // CHECK-32-DAG: %[[C100:.+]] = vm.const.i64 100
+ // CHECK-32-DAG: %[[C200:.+]] = vm.const.i64 200
+ // CHECK-32: vm.buffer.copy %arg0, %[[C100]], %arg0, %[[C200]], %[[C3]] : !vm.buffer -> !vm.buffer
+ // CHECK-64: vm.buffer.copy %arg0, %c100, %arg0, %c200, %c3 : !vm.buffer -> !vm.buffer
+ util.buffer.copy %arg0[%c100], %arg0[%c200], %c3 : !util.buffer{%arg1} -> !util.buffer{%arg1}
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @buffer_compare
+func.func @buffer_compare(%arg0: !util.buffer, %arg1: index) -> i1 {
+ %c3 = arith.constant 3 : index
+ %c100 = arith.constant 100 : index
+ %c200 = arith.constant 200 : index
+ // CHECK-32-DAG: %[[C3:.+]] = vm.const.i64 3
+ // CHECK-32-DAG: %[[C100:.+]] = vm.const.i64 100
+ // CHECK-32-DAG: %[[C200:.+]] = vm.const.i64 200
+ // CHECK-32: %[[RESULT:.+]] = vm.buffer.compare %arg0, %[[C100]], %arg0, %[[C200]], %[[C3]] : !vm.buffer, !vm.buffer
+ // CHECK-64: %[[RESULT:.+]] = vm.buffer.compare %arg0, %c100, %arg0, %c200, %c3 : !vm.buffer, !vm.buffer
+ %0 = util.buffer.compare %arg0[%c100], %arg0[%c200], %c3 : !util.buffer{%arg1}, !util.buffer{%arg1}
+ // CHECK: return %[[RESULT]]
+ return %0 : i1
+}
+
+// -----
+
+// CHECK-LABEL: @buffer_fill_i1
+func.func @buffer_fill_i1(%arg0: !util.buffer, %arg1: index, %arg2: i1) {
+ %c100 = arith.constant 100 : index
+ %c200 = arith.constant 200 : index
+ // CHECK-32-DAG: %[[C100:.+]] = vm.const.i64 100
+ // CHECK-32-DAG: %[[C200:.+]] = vm.const.i64 200
+ // CHECK-32: vm.buffer.fill.i8 %arg0, %[[C100]], %[[C200]], %arg2 : i32 -> !vm.buffer
+ // CHECK-64: vm.buffer.fill.i8 %arg0, %c100, %c200, %arg2 : i32 -> !vm.buffer
+ util.buffer.fill %arg2, %arg0[%c100 for %c200] : i1 -> !util.buffer{%arg1}
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @buffer_fill_i8
+func.func @buffer_fill_i8(%arg0: !util.buffer, %arg1: index, %arg2: i8) {
+ %c100 = arith.constant 100 : index
+ %c200 = arith.constant 200 : index
+ // CHECK-32-DAG: %[[C100:.+]] = vm.const.i64 100
+ // CHECK-32-DAG: %[[C200:.+]] = vm.const.i64 200
+ // CHECK-32: vm.buffer.fill.i8 %arg0, %[[C100]], %[[C200]], %arg2 : i32 -> !vm.buffer
+ // CHECK-64: vm.buffer.fill.i8 %arg0, %c100, %c200, %arg2 : i32 -> !vm.buffer
+ util.buffer.fill %arg2, %arg0[%c100 for %c200] : i8 -> !util.buffer{%arg1}
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @buffer_fill_i32
+func.func @buffer_fill_i32(%arg0: !util.buffer, %arg1: index, %arg2: i32) {
+ %c100 = arith.constant 100 : index
+ %c200 = arith.constant 200 : index
+ // CHECK-32-DAG: %[[C100:.+]] = vm.const.i64 100
+ // CHECK-32-DAG: %[[C200:.+]] = vm.const.i64 200
+ // CHECK-32: vm.buffer.fill.i32 %arg0, %[[C100]], %[[C200]], %arg2 : i32 -> !vm.buffer
+ // CHECK-64: vm.buffer.fill.i32 %arg0, %c100, %c200, %arg2 : i32 -> !vm.buffer
+ util.buffer.fill %arg2, %arg0[%c100 for %c200] : i32 -> !util.buffer{%arg1}
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @buffer_fill_i64
+func.func @buffer_fill_i64(%arg0: !util.buffer, %arg1: index, %arg2: i64) {
+ %c100 = arith.constant 100 : index
+ %c200 = arith.constant 200 : index
+ // CHECK-32-DAG: %[[C100:.+]] = vm.const.i64 100
+ // CHECK-32-DAG: %[[C200:.+]] = vm.const.i64 200
+ // CHECK-32: vm.buffer.fill.i64 %arg0, %[[C100]], %[[C200]], %arg2 : i64 -> !vm.buffer
+ // CHECK-64: vm.buffer.fill.i64 %arg0, %c100, %c200, %arg2 : i64 -> !vm.buffer
+ util.buffer.fill %arg2, %arg0[%c100 for %c200] : i64 -> !util.buffer{%arg1}
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @buffer_load_i1
+func.func @buffer_load_i32(%arg0: !util.buffer, %arg1: index) -> i1 {
+ %c100 = arith.constant 100 : index
+ // CHECK-32-DAG: %[[C100:.+]] = vm.const.i64 100
+ // CHECK-32: %[[VALUE:.+]] = vm.buffer.load.i8.s %arg0[%[[C100]]] : !vm.buffer -> i32
+ // CHECK-64: %[[VALUE:.+]] = vm.buffer.load.i8.s %arg0[%c100] : !vm.buffer -> i32
+ %0 = util.buffer.load %arg0[%c100] : !util.buffer{%arg1} -> i1
+ // CHECK: return %[[VALUE]]
+ return %0 : i1
+}
+
+// -----
+
+// CHECK-LABEL: @buffer_load_i32
+func.func @buffer_load_i32(%arg0: !util.buffer, %arg1: index) -> i32 {
+ %c100 = arith.constant 100 : index
+ // CHECK-32-DAG: %[[C100:.+]] = vm.const.i64 100
+ // CHECK-32: %[[VALUE:.+]] = vm.buffer.load.i32 %arg0[%[[C100]]] : !vm.buffer -> i32
+ // CHECK-64: %[[VALUE:.+]] = vm.buffer.load.i32 %arg0[%c100] : !vm.buffer -> i32
+ %0 = util.buffer.load %arg0[%c100] : !util.buffer{%arg1} -> i32
+ // CHECK: return %[[VALUE]]
+ return %0 : i32
+}
+
+// -----
+
+// CHECK-LABEL: @buffer_load_i64
+func.func @buffer_load_i64(%arg0: !util.buffer, %arg1: index) -> i64 {
+ %c100 = arith.constant 100 : index
+ // CHECK-32-DAG: %[[C100:.+]] = vm.const.i64 100
+ // CHECK-32: %[[VALUE:.+]] = vm.buffer.load.i64 %arg0[%[[C100]]] : !vm.buffer -> i64
+ // CHECK-64: %[[VALUE:.+]] = vm.buffer.load.i64 %arg0[%c100] : !vm.buffer -> i64
+ %0 = util.buffer.load %arg0[%c100] : !util.buffer{%arg1} -> i64
+ // CHECK: return %[[VALUE]]
+ return %0 : i64
+}
+
+// -----
+
+// CHECK-LABEL: @buffer_store_i1
+func.func @buffer_store_i1(%arg0: !util.buffer, %arg1: index, %arg2: i1) {
+ %c100 = arith.constant 100 : index
+ // CHECK-32-DAG: %[[C100:.+]] = vm.const.i64 100
+ // CHECK-32: vm.buffer.store.i8 %arg2, %arg0[%[[C100]]] : i32 -> !vm.buffer
+ // CHECK-64: vm.buffer.store.i8 %arg2, %arg0[%c100] : i32 -> !vm.buffer
+ util.buffer.store %arg2, %arg0[%c100] : i1 -> !util.buffer{%arg1}
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @buffer_store_i32
+func.func @buffer_store_i32(%arg0: !util.buffer, %arg1: index, %arg2: i32) {
+ %c100 = arith.constant 100 : index
+ // CHECK-32-DAG: %[[C100:.+]] = vm.const.i64 100
+ // CHECK-32: vm.buffer.store.i32 %arg2, %arg0[%[[C100]]] : i32 -> !vm.buffer
+ // CHECK-64: vm.buffer.store.i32 %arg2, %arg0[%c100] : i32 -> !vm.buffer
+ util.buffer.store %arg2, %arg0[%c100] : i32 -> !util.buffer{%arg1}
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @buffer_store_i64
+func.func @buffer_store_i64(%arg0: !util.buffer, %arg1: index, %arg2: i64) {
+ %c100 = arith.constant 100 : index
+ // CHECK-32-DAG: %[[C100:.+]] = vm.const.i64 100
+ // CHECK-32: vm.buffer.store.i64 %arg2, %arg0[%[[C100]]] : i64 -> !vm.buffer
+ // CHECK-64: vm.buffer.store.i64 %arg2, %arg0[%c100] : i64 -> !vm.buffer
+ util.buffer.store %arg2, %arg0[%c100] : i64 -> !util.buffer{%arg1}
+ return
}
diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/BUILD b/compiler/src/iree/compiler/Dialect/VM/Transforms/BUILD
index d3eb105..ba3cced 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Transforms/BUILD
+++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/BUILD
@@ -30,6 +30,7 @@
deps = [
"//compiler/src/iree/compiler/Dialect/Util/Conversion",
"//compiler/src/iree/compiler/Dialect/Util/IR",
+ "//compiler/src/iree/compiler/Dialect/Util/Transforms",
"//compiler/src/iree/compiler/Dialect/VM/Conversion",
"//compiler/src/iree/compiler/Dialect/VM/Conversion/MathToVM",
"//compiler/src/iree/compiler/Dialect/VM/Conversion/MemRefToVM",
diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/VM/Transforms/CMakeLists.txt
index 96b1ad8..c51c863 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/CMakeLists.txt
@@ -44,6 +44,7 @@
MLIRTransforms
iree::compiler::Dialect::Util::Conversion
iree::compiler::Dialect::Util::IR
+ iree::compiler::Dialect::Util::Transforms
iree::compiler::Dialect::VM::Conversion
iree::compiler::Dialect::VM::Conversion::MathToVM
iree::compiler::Dialect::VM::Conversion::MemRefToVM
diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/Conversion.cpp b/compiler/src/iree/compiler/Dialect/VM/Transforms/Conversion.cpp
index 7e2d4f8..4df294d 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Transforms/Conversion.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/Conversion.cpp
@@ -69,18 +69,18 @@
explicit ConversionPass(TargetOptions targetOptions)
: targetOptions_(targetOptions) {}
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<IREE::Util::UtilDialect, IREE::VM::VMDialect,
- func::FuncDialect, mlir::arith::ArithmeticDialect,
- math::MathDialect, AffineDialect, memref::MemRefDialect>();
- }
-
StringRef getArgument() const override { return "iree-vm-conversion"; }
StringRef getDescription() const override {
return "Converts from various dialects to the VM dialect";
}
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<IREE::Util::UtilDialect, IREE::VM::VMDialect,
+ func::FuncDialect, mlir::arith::ArithmeticDialect,
+ math::MathDialect, AffineDialect, memref::MemRefDialect>();
+ }
+
void runOnOperation() override {
if (getOperation().getBody()->empty()) return;
@@ -116,17 +116,17 @@
}
}
- RewritePatternSet conversionPatterns(&getContext());
+ RewritePatternSet patterns(&getContext());
populateUtilConversionPatterns(context, conversionTarget, typeConverter,
- conversionPatterns);
+ patterns);
populateUtilToVMPatterns(context, conversionTarget, typeConverter,
- conversionPatterns);
- arith::populateArithmeticExpandOpsPatterns(conversionPatterns);
- populateStandardToVMPatterns(context, typeConverter, conversionPatterns);
- populateMathToVMPatterns(context, typeConverter, conversionPatterns);
+ patterns);
+ arith::populateArithmeticExpandOpsPatterns(patterns);
+ populateStandardToVMPatterns(context, typeConverter, patterns);
+ populateMathToVMPatterns(context, typeConverter, patterns);
populateMemRefToVMPatterns(context, conversionTarget, typeConverter,
- conversionPatterns);
- populateAffineToStdConversionPatterns(conversionPatterns);
+ patterns);
+ populateAffineToStdConversionPatterns(patterns);
conversionTarget
.addIllegalDialect<func::FuncDialect, mlir::arith::ArithmeticDialect>();
@@ -138,11 +138,11 @@
SymbolTable importSymbols(innerModuleOp);
for (auto *dialectInterface : usedDialects) {
dialectInterface->populateVMConversionPatterns(
- importSymbols, conversionPatterns, conversionTarget, typeConverter);
+ importSymbols, patterns, conversionTarget, typeConverter);
}
if (failed(applyPartialConversion(outerModuleOp, conversionTarget,
- std::move(conversionPatterns)))) {
+ std::move(patterns)))) {
outerModuleOp.emitError() << "conversion to vm.module failed";
return signalPassFailure();
}
diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.cpp
index 6d8e8d7..db45bdf 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.cpp
@@ -9,6 +9,7 @@
#include <memory>
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
+#include "iree/compiler/Dialect/Util/Transforms/Passes.h"
#include "iree/compiler/Dialect/VM/IR/VMOps.h"
#include "iree/compiler/Utils/PassUtils.h"
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
@@ -26,6 +27,32 @@
using FunctionLikeNest = MultiOpNest<func::FuncOp, IREE::Util::InitializerOp>;
+//===----------------------------------------------------------------------===//
+// Utilities
+//===----------------------------------------------------------------------===//
+
+static void addCleanupPatterns(OpPassManager &passManager) {
+ // TODO(benvanik): run in a fixed-point iteration pipeline.
+
+ // Standard MLIR cleanup.
+ passManager.addPass(mlir::createCanonicalizerPass());
+ passManager.addPass(mlir::createCSEPass());
+
+ // Simplify util.global accesses; this can help with data flow tracking as
+ // redundant store-loads are removed.
+ FunctionLikeNest(passManager)
+ .addPass(IREE::Util::createSimplifyGlobalAccessesPass);
+
+ // Cleanup and canonicalization of util.global (and other util ops).
+ passManager.addPass(IREE::Util::createApplyPatternsPass());
+ passManager.addPass(IREE::Util::createFoldGlobalsPass());
+ passManager.addPass(IREE::Util::createFuseGlobalsPass());
+}
+
+//===----------------------------------------------------------------------===//
+// -iree-vm-transformation-pipeline
+//===----------------------------------------------------------------------===//
+
void buildVMTransformPassPipeline(OpPassManager &passManager,
TargetOptions targetOptions) {
passManager.addNestedPass<mlir::func::FuncOp>(createLoopCoalescingPass());
@@ -35,9 +62,13 @@
.addPass(createLoopInvariantCodeMotionPass)
.addPass(createConvertSCFToCFPass);
- passManager.addPass(createCanonicalizerPass());
- passManager.addPass(createCSEPass());
+ // Propagate buffer subranges throughout the program - this should remove any
+ // remaining subspans and give us a smaller surface area during conversion.
+ passManager.addPass(IREE::Util::createPropagateSubrangesPass());
+ addCleanupPatterns(passManager);
+ // Convert std/util/etc -> VM, along with any other dialects implementing the
+ // VM conversion dialect interface.
passManager.addPass(createConversionPass(targetOptions));
passManager.addNestedPass<IREE::VM::ModuleOp>(createHoistInlinedRodataPass());