Adding memref->util.buffer conversion patterns.
diff --git a/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/BUILD b/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/BUILD new file mode 100644 index 0000000..070b616 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/BUILD
@@ -0,0 +1,35 @@ +# Copyright 2022 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_library") + +package( + default_visibility = ["//visibility:public"], + features = ["layering_check"], + licenses = ["notice"], # Apache 2.0 +) + +iree_compiler_cc_library( + name = "MemRefToUtil", + srcs = [ + "ConvertMemRefToUtil.cpp", + ], + hdrs = [ + "ConvertMemRefToUtil.h", + ], + deps = [ + "//compiler/src/iree/compiler/Dialect/Util/IR", + "@llvm-project//mlir:AffineDialect", + "@llvm-project//mlir:ArithmeticDialect", + "@llvm-project//mlir:BufferizationDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], +)
diff --git a/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/CMakeLists.txt new file mode 100644 index 0000000..f1db35b --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/CMakeLists.txt
@@ -0,0 +1,34 @@ +################################################################################ +# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from # +# compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/BUILD # +# # +# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # +# CMake-only content. # +# # +# To disable autogeneration for this file entirely, delete this header. # +################################################################################ + +iree_add_all_subdirs() + +iree_cc_library( + NAME + MemRefToUtil + HDRS + "ConvertMemRefToUtil.h" + SRCS + "ConvertMemRefToUtil.cpp" + DEPS + MLIRAffineDialect + MLIRArithmeticDialect + MLIRBufferizationDialect + MLIRFuncDialect + MLIRIR + MLIRMemRefDialect + MLIRPass + MLIRTransformUtils + MLIRTransforms + iree::compiler::Dialect::Util::IR + PUBLIC +) + +### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/ConvertMemRefToUtil.cpp b/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/ConvertMemRefToUtil.cpp new file mode 100644 index 0000000..4de1861 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/ConvertMemRefToUtil.cpp
@@ -0,0 +1,277 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/Util/Conversion/MemRefToUtil/ConvertMemRefToUtil.h" + +#include "iree/compiler/Dialect/Util/IR/UtilOps.h" +#include "iree/compiler/Dialect/Util/IR/UtilTypes.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +namespace iree_compiler { + +namespace { + +/// Returns true if the given `type` is a MemRef of rank 0 or 1. +static bool isRankZeroOrOneMemRef(Type type) { + if (auto memrefType = type.dyn_cast<MemRefType>()) { + return memrefType.hasRank() && memrefType.getRank() <= 1 && + memrefType.getLayout().isIdentity(); + } + return false; +} + +/// Returns the offset, in bytes, of an index within a linearized dense buffer. +/// Expects that the |memrefValue| has been linearized already. +static Value getBufferOffset(Location loc, Value memrefValue, + ValueRange indices, Type elementType, + ConversionPatternRewriter &rewriter) { + auto memrefType = memrefValue.getType().cast<ShapedType>(); + if (memrefType.getRank() == 0) { + // Rank 0 buffers (like memref<i32>) have only a single valid offset at 0. + return rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0); + } + assert(memrefType.getRank() == 1 && "memrefs should have been flattened"); + + // Element type byte length as the base. + auto elementSize = rewriter.createOrFold<arith::ConstantIndexOp>( + loc, IREE::Util::getRoundedElementByteWidth(elementType)); + + // Rank 1 memrefs are just offset by their element width by the offset. + auto elementCount = indices.front(); + return rewriter.create<arith::MulIOp>(loc, elementSize, elementCount); +} + +/// Pattern to lower operations that become a no-ops at this level. +/// Passes through operands to results. +template <typename OpTy> +struct FoldAsNoOp final : public OpConversionPattern<OpTy> { + using OpConversionPattern<OpTy>::OpConversionPattern; + LogicalResult matchAndRewrite( + OpTy op, typename OpTy::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOp(op, adaptor.getOperands()); + return success(); + } +}; + +/// Pattern to lower operations that become a no-ops at this level. +/// Erases the op entirely. +template <typename OpTy> +struct ElideNoOp final : public OpConversionPattern<OpTy> { + using OpConversionPattern<OpTy>::OpConversionPattern; + LogicalResult matchAndRewrite( + OpTy op, typename OpTy::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.eraseOp(op); + return success(); + } +}; + +struct ConvertMemRefGlobalOp : public OpConversionPattern<memref::GlobalOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + memref::GlobalOp globalOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!isRankZeroOrOneMemRef(globalOp.getType())) { + return rewriter.notifyMatchFailure( + globalOp, + "only rank-0 and rank-1 memrefs are supported; flatten first"); + } + + // For mutable values we'd want to either have a RwdataOp or a global + // !vm.buffer that we initialized with rodata. + if (!globalOp.getConstant()) { + return rewriter.notifyMatchFailure( + globalOp, "mutable global memrefs not yet implemented"); + } + + auto newOp = rewriter.replaceOpWithNewOp<IREE::Util::GlobalOp>( + globalOp, globalOp.getSymName(), /*isMutable=*/false, + rewriter.getType<IREE::Util::BufferType>()); + newOp.setPrivate(); + + auto initializerOp = + rewriter.create<IREE::Util::InitializerOp>(globalOp.getLoc()); + auto initializerBuilder = + OpBuilder::atBlockBegin(initializerOp.addEntryBlock()); + auto alignmentAttr = globalOp.getAlignmentAttr() + ? initializerBuilder.getIndexAttr( + globalOp.getAlignmentAttr().getInt()) + : IntegerAttr{}; + auto constantOp = initializerBuilder.create<IREE::Util::BufferConstantOp>( + globalOp.getLoc(), initializerBuilder.getType<IREE::Util::BufferType>(), + globalOp.getInitialValueAttr(), alignmentAttr); + initializerBuilder.create<IREE::Util::GlobalStoreOp>( + globalOp.getLoc(), constantOp.getResult(), newOp.getName()); + initializerBuilder.create<IREE::Util::InitializerReturnOp>( + globalOp.getLoc()); + + return success(); + } +}; + +struct ConvertMemRefGetGlobalOp + : public OpConversionPattern<memref::GetGlobalOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + memref::GetGlobalOp getOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!isRankZeroOrOneMemRef(getOp.getResult().getType())) { + return rewriter.notifyMatchFailure( + getOp, "only rank-0 and rank-1 memrefs are supported; flatten first"); + } + rewriter.replaceOpWithNewOp<IREE::Util::GlobalLoadOp>( + getOp, rewriter.getType<IREE::Util::BufferType>(), getOp.getName()); + return success(); + } +}; + +struct ConvertMemRefAllocaOp : public OpConversionPattern<memref::AllocaOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + memref::AllocaOp allocaOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto type = allocaOp.getType().cast<ShapedType>(); + if (!type.hasStaticShape()) { + return rewriter.notifyMatchFailure( + allocaOp, "unable to create buffers for dynamic shapes"); + } + int64_t memRefLength = + type.getNumElements() * + IREE::Util::getRoundedElementByteWidth(type.getElementType()); + Value allocationSize = rewriter.create<arith::ConstantIndexOp>( + allocaOp.getLoc(), memRefLength); + rewriter.replaceOpWithNewOp<IREE::Util::BufferAllocOp>( + allocaOp, rewriter.getType<IREE::Util::BufferType>(), allocationSize); + return success(); + } +}; + +struct ConvertMemRefDimOp : public OpConversionPattern<memref::DimOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + memref::DimOp dimOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!isRankZeroOrOneMemRef(dimOp.getSource().getType())) { + return rewriter.notifyMatchFailure( + dimOp, "only rank-0 and rank-1 memrefs are supported; flatten first"); + } + auto newElementType = getTypeConverter()->convertType( + dimOp.getSource().getType().cast<MemRefType>().getElementType()); + if (!newElementType) { + return rewriter.notifyMatchFailure(dimOp, "unsupported element type"); + } + Value elementSize = rewriter.create<arith::ConstantIndexOp>( + dimOp.getLoc(), IREE::Util::getRoundedElementByteWidth(newElementType)); + Value bufferSize = rewriter.create<IREE::Util::BufferSizeOp>( + dimOp.getLoc(), rewriter.getIndexType(), adaptor.getSource()); + rewriter.replaceOpWithNewOp<arith::FloorDivSIOp>(dimOp, bufferSize, + elementSize); + return success(); + } +}; + +struct ConvertMemRefLoadOp : public OpConversionPattern<memref::LoadOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + memref::LoadOp loadOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!isRankZeroOrOneMemRef(loadOp.getMemref().getType())) { + return rewriter.notifyMatchFailure( + loadOp, + "only rank-0 and rank-1 memrefs are supported; flatten first"); + } + auto oldType = loadOp.getResult().getType(); + auto newType = getTypeConverter()->convertType(oldType); + auto newElementType = getTypeConverter()->convertType( + loadOp.getMemRef().getType().cast<MemRefType>().getElementType()); + if (!newElementType) { + return rewriter.notifyMatchFailure(loadOp, "unsupported element type"); + } + auto memRefSize = rewriter.createOrFold<IREE::Util::BufferSizeOp>( + loadOp.getLoc(), rewriter.getIndexType(), adaptor.getMemref()); + auto byteOffset = + getBufferOffset(loadOp.getLoc(), loadOp.getMemref(), + loadOp.getIndices(), newElementType, rewriter); + rewriter.replaceOpWithNewOp<IREE::Util::BufferLoadOp>( + loadOp, newType, adaptor.getMemref(), memRefSize, byteOffset); + return success(); + } +}; + +struct ConvertMemRefStoreOp : public OpConversionPattern<memref::StoreOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + memref::StoreOp storeOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!isRankZeroOrOneMemRef(storeOp.getMemref().getType())) { + return rewriter.notifyMatchFailure( + storeOp, + "only rank-0 and rank-1 memrefs are supported; flatten first"); + } + auto newElementType = getTypeConverter()->convertType( + storeOp.getMemRef().getType().cast<MemRefType>().getElementType()); + if (!newElementType) { + return rewriter.notifyMatchFailure(storeOp, "unsupported element type"); + } + auto memRefSize = rewriter.createOrFold<IREE::Util::BufferSizeOp>( + storeOp.getLoc(), rewriter.getIndexType(), adaptor.getMemref()); + auto byteOffset = + getBufferOffset(storeOp.getLoc(), storeOp.getMemref(), + storeOp.getIndices(), newElementType, rewriter); + rewriter.replaceOpWithNewOp<IREE::Util::BufferStoreOp>( + storeOp, adaptor.getValue(), adaptor.getMemref(), memRefSize, + byteOffset); + return success(); + } +}; + +} // namespace + +void populateMemRefToUtilPatterns(MLIRContext *context, + ConversionTarget &conversionTarget, + TypeConverter &typeConverter, + RewritePatternSet &patterns) { + conversionTarget.addIllegalDialect<memref::MemRefDialect>(); + + typeConverter.addConversion([&](MemRefType type) -> llvm::Optional<Type> { + if (isRankZeroOrOneMemRef(type)) { + return IREE::Util::BufferType::get(type.getContext()); + } + return llvm::None; + }); + + // Unranked memrefs are emitted for library call integration when we just + // need void* semantics. An unranked memref is basically just a (pointer, + // memory-space, element-type). + typeConverter.addConversion( + [&](UnrankedMemRefType type) -> llvm::Optional<Type> { + return IREE::Util::BufferType::get(type.getContext()); + }); + + patterns + .insert<FoldAsNoOp<bufferization::ToMemrefOp>, + ElideNoOp<memref::AssumeAlignmentOp>, FoldAsNoOp<memref::CastOp>>( + typeConverter, context); + patterns.insert<ConvertMemRefGlobalOp, ConvertMemRefGetGlobalOp, + ConvertMemRefAllocaOp, ConvertMemRefDimOp, + ConvertMemRefLoadOp, ConvertMemRefStoreOp>(typeConverter, + context); +} + +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/ConvertMemRefToUtil.h b/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/ConvertMemRefToUtil.h new file mode 100644 index 0000000..b7d991b --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/ConvertMemRefToUtil.h
@@ -0,0 +1,25 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_COMPILER_DIALECT_UTIL_CONVERSION_MEMREFTOUTIL_CONVERTMEMREFTOUTIL_H_ +#define IREE_COMPILER_DIALECT_UTIL_CONVERSION_MEMREFTOUTIL_CONVERTMEMREFTOUTIL_H_ + +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +namespace iree_compiler { + +// Appends memref dialect to vm dialect patterns to the given pattern list. +void populateMemRefToUtilPatterns(MLIRContext *context, + ConversionTarget &conversionTarget, + TypeConverter &typeConverter, + RewritePatternSet &patterns); + +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_DIALECT_UTIL_CONVERSION_MEMREFTOUTIL_CONVERTMEMREFTOUTIL_H_
diff --git a/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/test/BUILD b/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/test/BUILD new file mode 100644 index 0000000..d022190 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/test/BUILD
@@ -0,0 +1,29 @@ +# Copyright 2022 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +load("//build_tools/bazel:iree_lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") + +package( + default_visibility = ["//visibility:public"], + features = ["layering_check"], + licenses = ["notice"], # Apache 2.0 +) + +iree_lit_test_suite( + name = "lit", + srcs = enforce_glob( + [ + "memref_ops.mlir", + ], + include = ["*.mlir"], + ), + cfg = "//compiler:lit.cfg.py", + tools = [ + "//tools:iree-opt", + "@llvm-project//llvm:FileCheck", + ], +)
diff --git a/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/test/CMakeLists.txt new file mode 100644 index 0000000..f18b4d2 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/test/CMakeLists.txt
@@ -0,0 +1,23 @@ +################################################################################ +# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from # +# compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/test/BUILD # +# # +# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # +# CMake-only content. # +# # +# To disable autogeneration for this file entirely, delete this header. # +################################################################################ + +iree_add_all_subdirs() + +iree_lit_test_suite( + NAME + lit + SRCS + "memref_ops.mlir" + TOOLS + FileCheck + iree-opt +) + +### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/test/memref_ops.mlir b/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/test/memref_ops.mlir new file mode 100644 index 0000000..0d68e1f --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/test/memref_ops.mlir
@@ -0,0 +1,89 @@ +// RUN: iree-opt --split-input-file --iree-util-test-conversion --cse --canonicalize --verify-diagnostics %s | FileCheck %s + +// Must be rank-0 or rank-1. + +// expected-error @-5 {{conversion to util failed}} +func.func @verify_invalid_rank_2(%buffer: memref<4x2xf32>, %idx: index) { + // expected-error @below {{failed to legalize operation 'memref.load'}} + memref.load %buffer[%idx, %idx] : memref<4x2xf32> + return +} + +// ----- + +// Must have an identity map. + +#map = affine_map<(d0)[s0] -> (d0 * s0)> +// expected-error @-6 {{conversion to util failed}} +func.func @verify_invalid_non_identity_map(%buffer: memref<4xf32, #map>, %idx: index) { + // expected-error @below {{failed to legalize operation 'memref.load'}} + memref.load %buffer[%idx] : memref<4xf32, #map> + return +} + +// ----- + +// CHECK-LABEL: @assume_alignment +func.func @assume_alignment(%buffer: memref<?xf32>) { + // CHECK-NOT: assume_alignment + memref.assume_alignment %buffer, 64 : memref<?xf32> + func.return +} + +// ----- + +// CHECK-LABEL: @cast +func.func @cast(%buffer: memref<?xf32>) -> memref<5xf32> { + // CHECK-NOT: memref.cast + %0 = memref.cast %buffer : memref<?xf32> to memref<5xf32> + // CHECK: return %arg0 : !util.buffer + func.return %0 : memref<5xf32> +} + +// ----- + +// CHECK-LABEL: @alloca() -> !util.buffer +func.func @alloca() -> memref<16xi32> { + // CHECK: %[[ALLOCATION_SIZE:.+]] = arith.constant 64 : index + // CHECK: %[[BUFFER:.+]] = util.buffer.alloc uninitialized : !util.buffer{%[[ALLOCATION_SIZE]]} + %0 = memref.alloca() : memref<16xi32> + // CHECK: return %[[BUFFER]] + return %0 : memref<16xi32> +} + +// ----- + +// CHECK-LABEL: @load_store_f32 +// CHECK-SAME: (%[[BUFFER:.+]]: !util.buffer, %[[IDX0:.+]]: index, %[[IDX1:.+]]: index) -> f32 { +func.func @load_store_f32(%buffer: memref<?xf32>, %idx0: index, %idx1: index) -> f32 { + // CHECK: %[[BUFFER_SIZE:.+]] = util.buffer.size %[[BUFFER]] + // CHECK: %[[IDX0_BYTES:.+]] = arith.muli %[[IDX0]], %c4 + // CHECK: %[[VALUE:.+]] = util.buffer.load %[[BUFFER]][%[[IDX0_BYTES]]] : !util.buffer{%[[BUFFER_SIZE]]} -> f32 + %0 = memref.load %buffer[%idx0] : memref<?xf32> + // CHECK: %[[IDX1_BYTES:.+]] = arith.muli %[[IDX1]], %c4 + // CHECK: util.buffer.store %[[VALUE]], %[[BUFFER]][%[[IDX1_BYTES]]] : f32 -> !util.buffer{%[[BUFFER_SIZE]]} + memref.store %0, %buffer[%idx1] : memref<?xf32> + // CHECK: return %[[VALUE]] : f32 + return %0 : f32 +} + +// ----- + +// CHECK: util.global private @__constant_f32 : !util.buffer +// CHECK: util.initializer { +// CHECK: %[[BUFFER:.+]] = util.buffer.constant : !util.buffer = dense<[0.0287729427, 0.0297581609]> : tensor<2xf32> +// CHECK: util.global.store %[[BUFFER]], @__constant_f32 : !util.buffer +memref.global "private" constant @__constant_f32 : memref<2xf32> = dense<[0.0287729427, 0.0297581609]> + +// CHECK-LABEL: @constant_global_f32 +// CHECK-SAME: (%[[IDX:.+]]: index) -> f32 { +func.func @constant_global_f32(%idx: index) -> f32 { + // CHECK: %[[BUFFER:.+]] = util.global.load @__constant_f32 : !util.buffer + %0 = memref.get_global @__constant_f32 : memref<2xf32> + // CHECK: %[[BUFFER_SIZE:.+]] = util.buffer.size %[[BUFFER]] + // CHECK: %[[IDX_BYTES:.+]] = arith.muli %[[IDX]], %c4 + // CHECK: %[[VALUE:.+]] = util.buffer.load %[[BUFFER]][%[[IDX_BYTES]]] : !util.buffer{%[[BUFFER_SIZE]]} -> f32 + %1 = memref.load %0[%idx] : memref<2xf32> + // CHECK: return %[[VALUE]] : f32 + return %1 : f32 +}
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/BUILD b/compiler/src/iree/compiler/Dialect/Util/Transforms/BUILD index 3d3babc..01954ef 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/BUILD +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/BUILD
@@ -27,6 +27,7 @@ "PropagateSubrange.cpp", "SimplifyGlobalAccesses.cpp", "StripDebugOps.cpp", + "TestConversion.cpp", "TestFloatRangeAnalysis.cpp", ], hdrs = [ @@ -38,14 +39,20 @@ "//compiler/src/iree/compiler/Dialect/Util/Analysis/Attributes", "//compiler/src/iree/compiler/Dialect/Util/Analysis/Constant", "//compiler/src/iree/compiler/Dialect/Util/Analysis/DFX", + "//compiler/src/iree/compiler/Dialect/Util/Conversion", + "//compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil", "//compiler/src/iree/compiler/Dialect/Util/IR", "//compiler/src/iree/compiler/Utils", "@llvm-project//llvm:Support", + "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:Analysis", "@llvm-project//mlir:ArithmeticDialect", + "@llvm-project//mlir:ArithmeticTransforms", "@llvm-project//mlir:ControlFlowDialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms",
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Util/Transforms/CMakeLists.txt index b769dae..8c53010 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/CMakeLists.txt
@@ -29,14 +29,19 @@ "PropagateSubrange.cpp" "SimplifyGlobalAccesses.cpp" "StripDebugOps.cpp" + "TestConversion.cpp" "TestFloatRangeAnalysis.cpp" DEPS LLVMSupport + MLIRAffineDialect MLIRAnalysis MLIRArithmeticDialect + MLIRArithmeticTransforms MLIRControlFlowDialect MLIRFuncDialect MLIRIR + MLIRMathDialect + MLIRMemRefDialect MLIRPass MLIRSupport MLIRTransforms @@ -44,6 +49,8 @@ iree::compiler::Dialect::Util::Analysis::Attributes iree::compiler::Dialect::Util::Analysis::Constant iree::compiler::Dialect::Util::Analysis::DFX + iree::compiler::Dialect::Util::Conversion + iree::compiler::Dialect::Util::Conversion::MemRefToUtil iree::compiler::Dialect::Util::IR iree::compiler::Utils PUBLIC
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.h index 516ada7..7fbf271 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.h +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.h
@@ -34,7 +34,8 @@ std::unique_ptr<OperationPass<mlir::ModuleOp>> createPromoteF16ToF32Pass(); // Test passes. -std::unique_ptr<OperationPass<void>> createTestFloatRangeAnalysis(); +std::unique_ptr<OperationPass<void>> createTestConversionPass(); +std::unique_ptr<OperationPass<void>> createTestFloatRangeAnalysisPass(); // Register all Passes // TODO: Switch this directory to declarative registration. @@ -55,7 +56,8 @@ createDemoteF64ToF32Pass(); createPromoteF16ToF32Pass(); - createTestFloatRangeAnalysis(); + createTestConversionPass(); + createTestFloatRangeAnalysisPass(); } } // namespace Util
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/TestConversion.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/TestConversion.cpp new file mode 100644 index 0000000..30fe43a --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/TestConversion.cpp
@@ -0,0 +1,80 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/Util/Conversion/ConversionPatterns.h" +#include "iree/compiler/Dialect/Util/Conversion/MemRefToUtil/ConvertMemRefToUtil.h" +#include "iree/compiler/Dialect/Util/IR/UtilDialect.h" +#include "iree/compiler/Dialect/Util/IR/UtilOps.h" +#include "iree/compiler/Dialect/Util/Transforms/Passes.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arithmetic/Transforms/Passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace Util { + +namespace { + +class TestConversionPass + : public PassWrapper<TestConversionPass, OperationPass<void>> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestConversionPass) + + StringRef getArgument() const override { return "iree-util-test-conversion"; } + + StringRef getDescription() const override { + return "Tests util dialect conversion patterns"; + } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert<IREE::Util::UtilDialect, func::FuncDialect, + mlir::arith::ArithmeticDialect, math::MathDialect, + mlir::AffineDialect, memref::MemRefDialect>(); + } + + void runOnOperation() override { + auto *context = &getContext(); + + ConversionTarget conversionTarget(*context); + conversionTarget.addLegalDialect<arith::ArithmeticDialect>(); + conversionTarget.addLegalDialect<IREE::Util::UtilDialect>(); + + TypeConverter typeConverter; + typeConverter.addConversion([](Type type) { return type; }); + + RewritePatternSet patterns(&getContext()); + populateUtilConversionPatterns(context, conversionTarget, typeConverter, + patterns); + populateGenericStructuralConversionPatterns(context, conversionTarget, + typeConverter, patterns); + populateMemRefToUtilPatterns(context, conversionTarget, typeConverter, + patterns); + + if (failed(applyPartialConversion(getOperation(), conversionTarget, + std::move(patterns)))) { + getOperation()->emitError() << "conversion to util failed"; + return signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr<OperationPass<void>> createTestConversionPass() { + return std::make_unique<TestConversionPass>(); +} + +static PassRegistration<TestConversionPass> pass; + +} // namespace Util +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/TestFloatRangeAnalysis.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/TestFloatRangeAnalysis.cpp index 1158570..297c28f 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/TestFloatRangeAnalysis.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/TestFloatRangeAnalysis.cpp
@@ -66,7 +66,7 @@ } // namespace -std::unique_ptr<OperationPass<void>> createTestFloatRangeAnalysis() { +std::unique_ptr<OperationPass<void>> createTestFloatRangeAnalysisPass() { return std::make_unique<TestFloatRangeAnalysisPass>(); }