Use the new memref->util->vm path. (#9883)
* Use the new memref->util->vm path.
This was a little more involved than anticipated:
* A number of type conversion issues and A->B->C issues in the patterns were fixed (only show up in type conversion settings).
* The memref->util buffer path introduces initializers for rodata constants, whereas the old one short-circuited these directly to vm rodata. This highlighted missing VM support for optimizing such cases, and it was an issue because VMVX does not support initializers. We may want to fix that last part, but how I have it at least gets the conversion to the same IR as before once it runs through the pipeline.
* The above alternate path highlighted a bad VM folding pattern which was not accounting for initializers and just folding to null (even though it had a value). Deleted (the new pass applies this optimization when it is correct to do so).
* Still more that can be done, but this at least gets us to deleting the original and switching.
* Linearize VMVX buffers to 1D, disallowing 0D. Aside from being one less thing to reason about, this interops with the invariants of the memref flattener and now lowers properly with the memref->util->vm flow.
* Enable last remaining test that was blocked on 0D.
diff --git a/compiler/src/iree/compiler/Codegen/VMVX/LowerLinalgMicrokernels.cpp b/compiler/src/iree/compiler/Codegen/VMVX/LowerLinalgMicrokernels.cpp
index a3bd923..8e8288b 100644
--- a/compiler/src/iree/compiler/Codegen/VMVX/LowerLinalgMicrokernels.cpp
+++ b/compiler/src/iree/compiler/Codegen/VMVX/LowerLinalgMicrokernels.cpp
@@ -101,21 +101,29 @@
/// with element-based addressing.
Value castToLinear(Location loc, OpBuilder &builder) {
BaseMemRefType sourceType = baseBuffer.getType().cast<MemRefType>();
- if (sourceType.getRank() <= 1) return baseBuffer;
+ if (sourceType.getRank() == 1) return baseBuffer;
// Insert the cast just after the original def to keep inner loops tidy.
OpBuilder::InsertionGuard restoreIp(builder);
Operation *def = baseBuffer.getDefiningOp();
if (def) builder.setInsertionPointAfter(def);
- // Collapse to 1D.
- ReassociationIndices reassociation;
- reassociation.resize(sourceType.getRank());
- for (int i = 0; i < sourceType.getRank(); ++i) {
- reassociation[i] = i;
+ if (sourceType.getRank() > 1) {
+ // Collapse to 1D.
+ ReassociationIndices reassociation;
+ reassociation.resize(sourceType.getRank());
+ for (int i = 0; i < sourceType.getRank(); ++i) {
+ reassociation[i] = i;
+ }
+ return builder.create<memref::CollapseShapeOp>(loc, baseBuffer,
+ reassociation);
+ } else {
+ // Expand 0D to 1D.
+ // ReassociationIndices reassociation;
+ return builder.create<memref::ExpandShapeOp>(
+ loc, MemRefType::get({1}, sourceType.getElementType()), baseBuffer,
+ ArrayRef<ReassociationIndices>{});
}
- return builder.create<memref::CollapseShapeOp>(loc, baseBuffer,
- reassociation);
}
};
diff --git a/compiler/src/iree/compiler/Codegen/VMVX/test/lower_linalg_microkernels.mlir b/compiler/src/iree/compiler/Codegen/VMVX/test/lower_linalg_microkernels.mlir
index 40920ba..367df02 100644
--- a/compiler/src/iree/compiler/Codegen/VMVX/test/lower_linalg_microkernels.mlir
+++ b/compiler/src/iree/compiler/Codegen/VMVX/test/lower_linalg_microkernels.mlir
@@ -116,3 +116,21 @@
}
func.return
}
+
+// CHECK-LABEL: @addf0d
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[ARG1_1D:.*]] = memref.expand_shape %arg1 [] : memref<f32> into memref<1xf32>
+// CHECK: vmvx.add lhs(%[[ARG1_1D]] offset %[[C0]] strides[%[[C0]], %[[C0]]] : memref<1xf32>)
+// CHECK-SAME: rhs(%arg0 offset %c0 strides[%[[C0]], %[[C1]]] : memref<2xf32>)
+// CHECK-SAME: out(%arg0 offset %[[C0]] strides[%[[C0]], %[[C1]]] : memref<2xf32>) sizes(%[[C1]], %[[C2]]) : f32
+func.func @addf0d(%arg0 : memref<2xf32>, %arg1 : memref<f32>) {
+ linalg.generic {indexing_maps = [affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]}
+ ins(%arg1 : memref<f32>) outs(%arg0 : memref<2xf32>) {
+ ^bb0(%arg2: f32, %arg3: f32):
+ %12 = arith.addf %arg2, %arg3 : f32
+ linalg.yield %12 : f32
+ }
+ func.return
+}
diff --git a/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/ConvertMemRefToUtil.cpp b/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/ConvertMemRefToUtil.cpp
index 4de1861..e74bf2f 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/ConvertMemRefToUtil.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/ConvertMemRefToUtil.cpp
@@ -37,7 +37,7 @@
/// Returns the offset, in bytes, of an index within a linearized dense buffer.
/// Expects that the |memrefValue| has been linearized already.
static Value getBufferOffset(Location loc, Value memrefValue,
- ValueRange indices, Type elementType,
+ ValueRange indices,
ConversionPatternRewriter &rewriter) {
auto memrefType = memrefValue.getType().cast<ShapedType>();
if (memrefType.getRank() == 0) {
@@ -46,7 +46,12 @@
}
assert(memrefType.getRank() == 1 && "memrefs should have been flattened");
- // Element type byte length as the base.
+ // Element type byte length as the base. Note that this is the unconverted
+ // element type. Since these are storage types within a buffer, they are
+ // not subject to general type conversion (i.e. a general type converter
+ // may elect to represent all i8 registers as i32, but this does not mean
+ // that all memrefs are widened from i8 to i32).
+ auto elementType = memrefType.getElementType();
auto elementSize = rewriter.createOrFold<arith::ConstantIndexOp>(
loc, IREE::Util::getRoundedElementByteWidth(elementType));
@@ -170,13 +175,10 @@
return rewriter.notifyMatchFailure(
dimOp, "only rank-0 and rank-1 memrefs are supported; flatten first");
}
- auto newElementType = getTypeConverter()->convertType(
- dimOp.getSource().getType().cast<MemRefType>().getElementType());
- if (!newElementType) {
- return rewriter.notifyMatchFailure(dimOp, "unsupported element type");
- }
+ auto elementType =
+ dimOp.getSource().getType().cast<MemRefType>().getElementType();
Value elementSize = rewriter.create<arith::ConstantIndexOp>(
- dimOp.getLoc(), IREE::Util::getRoundedElementByteWidth(newElementType));
+ dimOp.getLoc(), IREE::Util::getRoundedElementByteWidth(elementType));
Value bufferSize = rewriter.create<IREE::Util::BufferSizeOp>(
dimOp.getLoc(), rewriter.getIndexType(), adaptor.getSource());
rewriter.replaceOpWithNewOp<arith::FloorDivSIOp>(dimOp, bufferSize,
@@ -197,18 +199,25 @@
}
auto oldType = loadOp.getResult().getType();
auto newType = getTypeConverter()->convertType(oldType);
- auto newElementType = getTypeConverter()->convertType(
- loadOp.getMemRef().getType().cast<MemRefType>().getElementType());
- if (!newElementType) {
- return rewriter.notifyMatchFailure(loadOp, "unsupported element type");
- }
auto memRefSize = rewriter.createOrFold<IREE::Util::BufferSizeOp>(
loadOp.getLoc(), rewriter.getIndexType(), adaptor.getMemref());
- auto byteOffset =
- getBufferOffset(loadOp.getLoc(), loadOp.getMemref(),
- loadOp.getIndices(), newElementType, rewriter);
- rewriter.replaceOpWithNewOp<IREE::Util::BufferLoadOp>(
- loadOp, newType, adaptor.getMemref(), memRefSize, byteOffset);
+ auto byteOffset = getBufferOffset(loadOp.getLoc(), loadOp.getMemref(),
+ loadOp.getIndices(), rewriter);
+ Value loaded = rewriter.create<IREE::Util::BufferLoadOp>(
+ loadOp.getLoc(), oldType, adaptor.getMemref(), memRefSize, byteOffset);
+ if (newType != oldType) {
+ // Since the BufferLoadOp semantics include its result type (i.e. a load
+ // of an i8 is different than a load of an i32), in the presence of type
+ // conversion, we must preserve the original type and emit an unrealized
+ // conversion cast for downstreams. In this case, further legalizations
+ // will be required to resolve it. This comes up in A->B->C lowerings
+ // where the BufferLoad is an intermediate stage.
+ loaded = rewriter
+ .create<UnrealizedConversionCastOp>(loadOp.getLoc(), newType,
+ loaded)
+ .getResult(0);
+ }
+ rewriter.replaceOp(loadOp, loaded);
return success();
}
};
@@ -223,19 +232,26 @@
storeOp,
"only rank-0 and rank-1 memrefs are supported; flatten first");
}
- auto newElementType = getTypeConverter()->convertType(
- storeOp.getMemRef().getType().cast<MemRefType>().getElementType());
- if (!newElementType) {
- return rewriter.notifyMatchFailure(storeOp, "unsupported element type");
- }
auto memRefSize = rewriter.createOrFold<IREE::Util::BufferSizeOp>(
storeOp.getLoc(), rewriter.getIndexType(), adaptor.getMemref());
- auto byteOffset =
- getBufferOffset(storeOp.getLoc(), storeOp.getMemref(),
- storeOp.getIndices(), newElementType, rewriter);
+ auto byteOffset = getBufferOffset(storeOp.getLoc(), storeOp.getMemref(),
+ storeOp.getIndices(), rewriter);
+ Value newValue = adaptor.getValue();
+ if (newValue.getType() != storeOp.getValue().getType()) {
+ // In combination with type conversion, the elemental type may change,
+ // and this is load bearing with respect to buffer_store op semantics
+ // (i.e. storing of an i32 is different from an i8, even if the
+ // conversion target widens). Insert an unrealized conversion cast to
+ // preserve the original semantic. Presumably, something will clear this
+ // with additional lowering.
+ newValue =
+ rewriter
+ .create<UnrealizedConversionCastOp>(
+ storeOp.getLoc(), storeOp.getValue().getType(), newValue)
+ .getResult(0);
+ }
rewriter.replaceOpWithNewOp<IREE::Util::BufferStoreOp>(
- storeOp, adaptor.getValue(), adaptor.getMemref(), memRefSize,
- byteOffset);
+ storeOp, newValue, adaptor.getMemref(), memRefSize, byteOffset);
return success();
}
};
@@ -245,22 +261,20 @@
void populateMemRefToUtilPatterns(MLIRContext *context,
ConversionTarget &conversionTarget,
TypeConverter &typeConverter,
- RewritePatternSet &patterns) {
+ RewritePatternSet &patterns,
+ Type convertedBufferType) {
conversionTarget.addIllegalDialect<memref::MemRefDialect>();
- typeConverter.addConversion([&](MemRefType type) -> llvm::Optional<Type> {
- if (isRankZeroOrOneMemRef(type)) {
- return IREE::Util::BufferType::get(type.getContext());
- }
- return llvm::None;
- });
-
- // Unranked memrefs are emitted for library call integration when we just
- // need void* semantics. An unranked memref is basically just a (pointer,
- // memory-space, element-type).
typeConverter.addConversion(
- [&](UnrankedMemRefType type) -> llvm::Optional<Type> {
- return IREE::Util::BufferType::get(type.getContext());
+ [convertedBufferType](MemRefType type) -> llvm::Optional<Type> {
+ if (isRankZeroOrOneMemRef(type)) {
+ if (convertedBufferType) {
+ return convertedBufferType;
+ } else {
+ return IREE::Util::BufferType::get(type.getContext());
+ }
+ }
+ return llvm::None;
});
patterns
diff --git a/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/ConvertMemRefToUtil.h b/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/ConvertMemRefToUtil.h
index b7d991b..311bd06 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/ConvertMemRefToUtil.h
+++ b/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/ConvertMemRefToUtil.h
@@ -14,10 +14,14 @@
namespace iree_compiler {
// Appends memref dialect to vm dialect patterns to the given pattern list.
+// Because these patterns are often used in A->B->C lowerings, we allow the
+// final buffer type to be specialized (this must be the buffer type that
+// is valid in the 'C' dialect). If null, the a Util::BufferType is used.
void populateMemRefToUtilPatterns(MLIRContext *context,
ConversionTarget &conversionTarget,
TypeConverter &typeConverter,
- RewritePatternSet &patterns);
+ RewritePatternSet &patterns,
+ Type convertedBufferType = {});
} // namespace iree_compiler
} // namespace mlir
diff --git a/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/test/memref_ops.mlir b/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/test/memref_ops.mlir
index 0d68e1f..fa7a099 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/test/memref_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/test/memref_ops.mlir
@@ -1,28 +1,27 @@
-// RUN: iree-opt --split-input-file --iree-util-test-conversion --cse --canonicalize --verify-diagnostics %s | FileCheck %s
+// RUN: iree-opt --split-input-file \
+// RUN: --pass-pipeline='iree-util-test-conversion{widen-integers}, cse, canonicalize' \
+// RUN: --verify-diagnostics %s | FileCheck %s
+// -----
// Must be rank-0 or rank-1.
-
-// expected-error @-5 {{conversion to util failed}}
-func.func @verify_invalid_rank_2(%buffer: memref<4x2xf32>, %idx: index) {
+// expected-error @-3 {{conversion to util failed}}
+func.func @verify_invalid_rank_2(%buffer: memref<4x2xf32>, %idx: index) -> f32{
// expected-error @below {{failed to legalize operation 'memref.load'}}
- memref.load %buffer[%idx, %idx] : memref<4x2xf32>
- return
+ %0 = memref.load %buffer[%idx, %idx] : memref<4x2xf32>
+ return %0 : f32
}
// -----
-
// Must have an identity map.
-
+// expected-error @-3 {{conversion to util failed}}
#map = affine_map<(d0)[s0] -> (d0 * s0)>
-// expected-error @-6 {{conversion to util failed}}
-func.func @verify_invalid_non_identity_map(%buffer: memref<4xf32, #map>, %idx: index) {
+func.func @verify_invalid_non_identity_map(%buffer: memref<4xf32, #map>, %idx: index) -> f32 {
// expected-error @below {{failed to legalize operation 'memref.load'}}
- memref.load %buffer[%idx] : memref<4xf32, #map>
- return
+ %0 = memref.load %buffer[%idx] : memref<4xf32, #map>
+ return %0 : f32
}
// -----
-
// CHECK-LABEL: @assume_alignment
func.func @assume_alignment(%buffer: memref<?xf32>) {
// CHECK-NOT: assume_alignment
@@ -31,7 +30,6 @@
}
// -----
-
// CHECK-LABEL: @cast
func.func @cast(%buffer: memref<?xf32>) -> memref<5xf32> {
// CHECK-NOT: memref.cast
@@ -41,7 +39,6 @@
}
// -----
-
// CHECK-LABEL: @alloca() -> !util.buffer
func.func @alloca() -> memref<16xi32> {
// CHECK: %[[ALLOCATION_SIZE:.+]] = arith.constant 64 : index
@@ -52,7 +49,17 @@
}
// -----
+// CHECK-LABEL: @alloc_i16
+// CHECK-SAME: (%[[IDX0:.+]]: index) -> !util.buffer {
+func.func @alloc_i16(%idx0: index) -> memref<4xi16> {
+ // CHECK: %[[C8:.*]] = arith.constant 8 : index
+ // CHECK: %[[BUFFER:.*]] = util.buffer.alloc uninitialized : !util.buffer{%[[C8]]}
+ %0 = memref.alloca() : memref<4xi16>
+ // CHECK: return %[[BUFFER]]
+ return %0 : memref<4xi16>
+}
+// -----
// CHECK-LABEL: @load_store_f32
// CHECK-SAME: (%[[BUFFER:.+]]: !util.buffer, %[[IDX0:.+]]: index, %[[IDX1:.+]]: index) -> f32 {
func.func @load_store_f32(%buffer: memref<?xf32>, %idx0: index, %idx1: index) -> f32 {
@@ -68,7 +75,6 @@
}
// -----
-
// CHECK: util.global private @__constant_f32 : !util.buffer
// CHECK: util.initializer {
// CHECK: %[[BUFFER:.+]] = util.buffer.constant : !util.buffer = dense<[0.0287729427, 0.0297581609]> : tensor<2xf32>
@@ -87,3 +93,33 @@
// CHECK: return %[[VALUE]] : f32
return %1 : f32
}
+
+// -----
+// CHECK-LABEL: @load_store_i16
+// CHECK-SAME: (%[[BUFFER:.+]]: !util.buffer, %[[IDX0:.+]]: index, %[[IDX1:.+]]: index, %[[VALUE:.+]]: i32) -> i32 {
+func.func @load_store_i16(%buffer: memref<?xi16>, %idx0: index, %idx1: index, %value: i16) -> i16 {
+ // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+ // CHECK-DAG: %[[SZ:.*]] = util.buffer.size %[[BUFFER]]
+ // CHECK-DAG: %[[OFS0:.*]] = arith.muli %[[IDX0]], %[[C2]] : index
+ // CHECK-DAG: %[[UCST0:.*]] = builtin.unrealized_conversion_cast %[[VALUE]] : i32 to i16
+ // CHECK: util.buffer.store %[[UCST0]], %[[BUFFER]][%[[OFS0]]] : i16 -> !util.buffer{%[[SZ]]}
+ memref.store %value, %buffer[%idx0] : memref<?xi16>
+ // CHECK: %[[OFS1:.*]] = arith.muli %[[IDX1]], %[[C2]] : index
+ // CHECK: %[[LD:.*]] = util.buffer.load %[[BUFFER]][%[[OFS1]]] : !util.buffer{%[[SZ]]} -> i16
+ // CHECK: %[[UCST1:.*]] = builtin.unrealized_conversion_cast %[[LD]] : i16 to i32
+ %1 = memref.load %buffer[%idx1] : memref<?xi16>
+ // CHECK: return %[[UCST1]]
+ return %1 : i16
+}
+
+// -----
+// CHECK-LABEL: @dim_i16
+// CHECK-SAME: (%[[BUFFER:.+]]: !util.buffer, %[[IDX0:.+]]: index) -> index {
+func.func @dim_i16(%buffer: memref<?xi16>, %idx0: index) -> index {
+ // CHECK: %[[C2:.*]] = arith.constant 2 : index
+ // CHECK: %[[SZ:.*]] = util.buffer.size %[[BUFFER]] : !util.buffer
+ // CHECK: %[[DV:.*]] = arith.floordivsi %[[SZ]], %[[C2]] : index
+ %0 = memref.dim %buffer, %idx0 : memref<?xi16>
+ // CHECK: return %[[DV]]
+ return %0 : index
+}
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp
index 955be21..ab7721c 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp
@@ -694,13 +694,16 @@
OpFoldResult BufferSizeOp::fold(ArrayRef<Attribute> operands) {
// Try to find the size in the use-def chain.
// If it's out of the local scope we'll need IPO to help out.
+ // During A->B->C dialect conversion, the type may not be legal so be
+ // defensive.
auto operand = getOperand();
- auto sizeAwareType =
- operand.getType().cast<IREE::Util::SizeAwareTypeInterface>();
- Operation *op = this->getOperation();
- if (auto sizeValue = sizeAwareType.findSizeValue(operand, op->getBlock(),
- Block::iterator(op))) {
- return sizeValue;
+ if (auto sizeAwareType =
+ operand.getType().dyn_cast<IREE::Util::SizeAwareTypeInterface>()) {
+ Operation *op = this->getOperation();
+ if (auto sizeValue = sizeAwareType.findSizeValue(operand, op->getBlock(),
+ Block::iterator(op))) {
+ return sizeValue;
+ }
}
// If the source is a constant then we can calculate that immediately.
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.h
index 7fbf271..c3bb1e8 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.h
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.h
@@ -34,7 +34,7 @@
std::unique_ptr<OperationPass<mlir::ModuleOp>> createPromoteF16ToF32Pass();
// Test passes.
-std::unique_ptr<OperationPass<void>> createTestConversionPass();
+std::unique_ptr<OperationPass<mlir::ModuleOp>> createTestConversionPass();
std::unique_ptr<OperationPass<void>> createTestFloatRangeAnalysisPass();
// Register all Passes
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/TestConversion.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/TestConversion.cpp
index 30fe43a..7c2d7dc 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/TestConversion.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/TestConversion.cpp
@@ -24,8 +24,10 @@
namespace {
class TestConversionPass
- : public PassWrapper<TestConversionPass, OperationPass<void>> {
+ : public PassWrapper<TestConversionPass, OperationPass<ModuleOp>> {
public:
+ TestConversionPass() = default;
+ TestConversionPass(const TestConversionPass &) {}
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestConversionPass)
StringRef getArgument() const override { return "iree-util-test-conversion"; }
@@ -42,13 +44,23 @@
void runOnOperation() override {
auto *context = &getContext();
-
ConversionTarget conversionTarget(*context);
conversionTarget.addLegalDialect<arith::ArithmeticDialect>();
conversionTarget.addLegalDialect<IREE::Util::UtilDialect>();
+ conversionTarget.addLegalOp<UnrealizedConversionCastOp>();
TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });
+ if (widenIntegers) {
+ // Promote all integers < 32bit to 32bit to test type conversion on
+ // for tests that are sensitive to that.
+ typeConverter.addConversion([](IntegerType type) {
+ if (type.getWidth() < 32) {
+ return IntegerType::get(type.getContext(), 32);
+ }
+ return type;
+ });
+ }
RewritePatternSet patterns(&getContext());
populateUtilConversionPatterns(context, conversionTarget, typeConverter,
@@ -64,11 +76,15 @@
return signalPassFailure();
}
}
+
+ Option<bool> widenIntegers{
+ *this, "widen-integers",
+ llvm::cl::desc("Tests type conversion by widening integers to i32")};
};
} // namespace
-std::unique_ptr<OperationPass<void>> createTestConversionPass() {
+std::unique_ptr<OperationPass<ModuleOp>> createTestConversionPass() {
return std::make_unique<TestConversionPass>();
}
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/MemRefToVM/BUILD b/compiler/src/iree/compiler/Dialect/VM/Conversion/MemRefToVM/BUILD
deleted file mode 100644
index 0a6094d..0000000
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/MemRefToVM/BUILD
+++ /dev/null
@@ -1,37 +0,0 @@
-# Copyright 2021 The IREE Authors
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_library")
-
-package(
- default_visibility = ["//visibility:public"],
- features = ["layering_check"],
- licenses = ["notice"], # Apache 2.0
-)
-
-iree_compiler_cc_library(
- name = "MemRefToVM",
- srcs = [
- "ConvertMemRefToVM.cpp",
- ],
- hdrs = [
- "ConvertMemRefToVM.h",
- ],
- deps = [
- "//compiler/src/iree/compiler/Dialect/Util/IR",
- "//compiler/src/iree/compiler/Dialect/VM/Conversion",
- "//compiler/src/iree/compiler/Dialect/VM/IR",
- "@llvm-project//mlir:AffineDialect",
- "@llvm-project//mlir:ArithmeticDialect",
- "@llvm-project//mlir:BufferizationDialect",
- "@llvm-project//mlir:FuncDialect",
- "@llvm-project//mlir:IR",
- "@llvm-project//mlir:MemRefDialect",
- "@llvm-project//mlir:Pass",
- "@llvm-project//mlir:TransformUtils",
- "@llvm-project//mlir:Transforms",
- ],
-)
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/MemRefToVM/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/VM/Conversion/MemRefToVM/CMakeLists.txt
deleted file mode 100644
index 8afca1d..0000000
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/MemRefToVM/CMakeLists.txt
+++ /dev/null
@@ -1,36 +0,0 @@
-################################################################################
-# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
-# compiler/src/iree/compiler/Dialect/VM/Conversion/MemRefToVM/BUILD #
-# #
-# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
-# CMake-only content. #
-# #
-# To disable autogeneration for this file entirely, delete this header. #
-################################################################################
-
-iree_add_all_subdirs()
-
-iree_cc_library(
- NAME
- MemRefToVM
- HDRS
- "ConvertMemRefToVM.h"
- SRCS
- "ConvertMemRefToVM.cpp"
- DEPS
- MLIRAffineDialect
- MLIRArithmeticDialect
- MLIRBufferizationDialect
- MLIRFuncDialect
- MLIRIR
- MLIRMemRefDialect
- MLIRPass
- MLIRTransformUtils
- MLIRTransforms
- iree::compiler::Dialect::Util::IR
- iree::compiler::Dialect::VM::Conversion
- iree::compiler::Dialect::VM::IR
- PUBLIC
-)
-
-### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/MemRefToVM/ConvertMemRefToVM.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/MemRefToVM/ConvertMemRefToVM.cpp
deleted file mode 100644
index 3e379cf..0000000
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/MemRefToVM/ConvertMemRefToVM.cpp
+++ /dev/null
@@ -1,288 +0,0 @@
-// Copyright 2021 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree/compiler/Dialect/VM/Conversion/MemRefToVM/ConvertMemRefToVM.h"
-
-#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
-#include "iree/compiler/Dialect/VM/Conversion/TargetOptions.h"
-#include "iree/compiler/Dialect/VM/Conversion/TypeConverter.h"
-#include "iree/compiler/Dialect/VM/IR/VMOps.h"
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
-#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
-#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/Matchers.h"
-#include "mlir/IR/OperationSupport.h"
-#include "mlir/Transforms/DialectConversion.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-namespace {
-
-/// Pattern to lower operations that become a no-ops at this level.
-/// Passes through operands to results.
-template <typename OpTy>
-struct FoldAsNoOp final : public OpConversionPattern<OpTy> {
- using OpConversionPattern<OpTy>::OpConversionPattern;
- LogicalResult matchAndRewrite(
- OpTy op, typename OpTy::Adaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOp(op, adaptor.getOperands());
- return success();
- }
-};
-
-/// Pattern to lower operations that become a no-ops at this level.
-/// Erases the op entirely.
-template <typename OpTy>
-struct ElideNoOp final : public OpConversionPattern<OpTy> {
- using OpConversionPattern<OpTy>::OpConversionPattern;
- LogicalResult matchAndRewrite(
- OpTy op, typename OpTy::Adaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- rewriter.eraseOp(op);
- return success();
- }
-};
-
-/// Returns true if the given `type` is a MemRef of rank 0 or 1.
-static bool isRankZeroOrOneMemRef(Type type) {
- if (auto memrefType = type.dyn_cast<MemRefType>()) {
- return memrefType.hasRank() && memrefType.getRank() <= 1;
- }
- return false;
-}
-
-// Returns the offset, in bytes, of an index within a linearized dense buffer.
-// Expects that the |memrefValue| has been linearized already.
-static Value getBufferOffset(Location loc, Value memrefValue,
- ValueRange indices, Type indexType,
- ConversionPatternRewriter &rewriter) {
- auto memrefType = memrefValue.getType().cast<ShapedType>();
- if (memrefType.getRank() == 0) {
- // Rank 0 buffers (like memref<i32>) have only a single valid offset at 0.
- return rewriter.createOrFold<arith::ConstantIntOp>(loc, 0, indexType);
- }
- assert(memrefType.getRank() == 1 && "memrefs should have been flattened");
-
- // Element type byte length as the base.
- auto elementType = memrefType.getElementType();
- auto scalingExpr = getAffineBinaryOpExpr(
- AffineExprKind::Mul, getAffineSymbolExpr(0, rewriter.getContext()),
- getAffineConstantExpr(IREE::Util::getRoundedElementByteWidth(elementType),
- rewriter.getContext()));
-
- // Rank 1 memrefs are just offset by their element width by the offset.
- Value offset = rewriter.createOrFold<AffineApplyOp>(
- loc, scalingExpr, ArrayRef<Value>{indices.front()});
- return rewriter.create<arith::IndexCastOp>(loc, indexType, offset);
-}
-
-struct ConvertMemRefGlobalOp : public OpConversionPattern<memref::GlobalOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult matchAndRewrite(
- memref::GlobalOp globalOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- if (!isRankZeroOrOneMemRef(globalOp.getType())) {
- return rewriter.notifyMatchFailure(
- globalOp,
- "only rank-0 and rank-1 memrefs are supported; flatten first");
- }
-
- // For mutable values we'd want to either have a RwdataOp or a global
- // !vm.buffer that we initialized with rodata.
- if (!globalOp.getConstant()) {
- return rewriter.notifyMatchFailure(
- globalOp, "mutable global memrefs not yet implemented");
- }
-
- auto rodataOp = rewriter.replaceOpWithNewOp<IREE::VM::RodataOp>(
- globalOp, globalOp.getSymName(),
- globalOp.getInitialValueAttr().cast<ElementsAttr>());
- rodataOp.setPrivate();
- return success();
- }
-};
-
-struct ConvertMemRefGetGlobalOp
- : public OpConversionPattern<memref::GetGlobalOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult matchAndRewrite(
- memref::GetGlobalOp getOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- if (!isRankZeroOrOneMemRef(getOp.getResult().getType())) {
- return rewriter.notifyMatchFailure(
- getOp, "only rank-0 and rank-1 memrefs are supported; flatten first");
- }
- rewriter.replaceOpWithNewOp<IREE::VM::ConstRefRodataOp>(getOp,
- getOp.getName());
- return success();
- }
-};
-
-// TODO(#9165): Support alignment for vm.buffer.alloc. So far we ignore the
-// alignment attribute when lowering the op to VM dialect.
-struct ConvertMemRefAllocaOp : public OpConversionPattern<memref::AllocaOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult matchAndRewrite(
- memref::AllocaOp allocaOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto type = allocaOp.getType().cast<ShapedType>();
- if (!type.hasStaticShape()) {
- return rewriter.notifyMatchFailure(
- allocaOp, "unable to create buffers for dynamic shapes");
- }
-
- // TODO: Support dynamic shapes.
- Type elementType = getTypeConverter()->convertType(type.getElementType());
- if (!elementType) {
- return rewriter.notifyMatchFailure(allocaOp, "unsupported element type");
- }
- assert(elementType.isIntOrFloat() && "must be int or float");
- int64_t length = IREE::Util::getRoundedElementByteWidth(elementType);
- for (auto extent : type.getShape()) {
- length *= extent;
- }
-
- auto oldType = allocaOp.getType();
- auto newType = getTypeConverter()->convertType(oldType);
- Value size =
- rewriter.create<IREE::VM::ConstI64Op>(allocaOp.getLoc(), length);
- rewriter.replaceOpWithNewOp<IREE::VM::BufferAllocOp>(allocaOp, newType,
- size);
- return success();
- }
-};
-
-struct ConvertMemRefLoadOp : public OpConversionPattern<memref::LoadOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult matchAndRewrite(
- memref::LoadOp loadOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- if (!isRankZeroOrOneMemRef(loadOp.getMemref().getType())) {
- return rewriter.notifyMatchFailure(
- loadOp,
- "only rank-0 and rank-1 memrefs are supported; flatten first");
- }
- auto oldType = loadOp.getResult().getType();
- auto newType = getTypeConverter()->convertType(oldType);
- auto byteOffset =
- getBufferOffset(loadOp.getLoc(), loadOp.getMemref(),
- loadOp.getIndices(), rewriter.getI64Type(), rewriter);
- if (auto integerType = oldType.dyn_cast<IntegerType>()) {
- if (integerType.isInteger(1) || integerType.isInteger(8)) {
- if (integerType.isSigned() || integerType.isSignless()) {
- rewriter.replaceOpWithNewOp<IREE::VM::BufferLoadI8SOp>(
- loadOp, newType, adaptor.getMemref(), byteOffset);
- } else {
- rewriter.replaceOpWithNewOp<IREE::VM::BufferLoadI8UOp>(
- loadOp, newType, adaptor.getMemref(), byteOffset);
- }
- } else if (integerType.isInteger(16)) {
- if (integerType.isSigned() || integerType.isSignless()) {
- rewriter.replaceOpWithNewOp<IREE::VM::BufferLoadI16SOp>(
- loadOp, newType, adaptor.getMemref(), byteOffset);
- } else {
- rewriter.replaceOpWithNewOp<IREE::VM::BufferLoadI16UOp>(
- loadOp, newType, adaptor.getMemref(), byteOffset);
- }
- } else if (integerType.isInteger(32)) {
- rewriter.replaceOpWithNewOp<IREE::VM::BufferLoadI32Op>(
- loadOp, newType, adaptor.getMemref(), byteOffset);
- } else if (integerType.isInteger(64)) {
- rewriter.replaceOpWithNewOp<IREE::VM::BufferLoadI64Op>(
- loadOp, newType, adaptor.getMemref(), byteOffset);
- } else {
- return rewriter.notifyMatchFailure(
- loadOp, "invalid integer buffer element type");
- }
- } else if (oldType.isF32()) {
- rewriter.replaceOpWithNewOp<IREE::VM::BufferLoadF32Op>(
- loadOp, newType, adaptor.getMemref(), byteOffset);
- } else if (oldType.isF64()) {
- rewriter.replaceOpWithNewOp<IREE::VM::BufferLoadF64Op>(
- loadOp, newType, adaptor.getMemref(), byteOffset);
- } else {
- return rewriter.notifyMatchFailure(loadOp,
- "invalid float buffer element type");
- }
- return success();
- }
-};
-
-struct ConvertMemRefStoreOp : public OpConversionPattern<memref::StoreOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult matchAndRewrite(
- memref::StoreOp storeOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- if (!isRankZeroOrOneMemRef(storeOp.getMemref().getType())) {
- return rewriter.notifyMatchFailure(
- storeOp,
- "only rank-0 and rank-1 memrefs are supported; flatten first");
- }
- auto oldType = storeOp.getValue().getType();
- auto byteOffset =
- getBufferOffset(storeOp.getLoc(), storeOp.getMemref(),
- storeOp.getIndices(), rewriter.getI64Type(), rewriter);
- if (oldType.isInteger(1) || oldType.isInteger(8)) {
- rewriter.replaceOpWithNewOp<IREE::VM::BufferStoreI8Op>(
- storeOp, adaptor.getMemref(), byteOffset, adaptor.getValue());
- } else if (oldType.isInteger(16)) {
- rewriter.replaceOpWithNewOp<IREE::VM::BufferStoreI16Op>(
- storeOp, adaptor.getMemref(), byteOffset, adaptor.getValue());
- } else if (oldType.isInteger(32)) {
- rewriter.replaceOpWithNewOp<IREE::VM::BufferStoreI32Op>(
- storeOp, adaptor.getMemref(), byteOffset, adaptor.getValue());
- } else if (oldType.isInteger(64)) {
- rewriter.replaceOpWithNewOp<IREE::VM::BufferStoreI64Op>(
- storeOp, adaptor.getMemref(), byteOffset, adaptor.getValue());
- } else if (oldType.isF32()) {
- rewriter.replaceOpWithNewOp<IREE::VM::BufferStoreF32Op>(
- storeOp, adaptor.getMemref(), byteOffset, adaptor.getValue());
- } else if (oldType.isF64()) {
- rewriter.replaceOpWithNewOp<IREE::VM::BufferStoreF64Op>(
- storeOp, adaptor.getMemref(), byteOffset, adaptor.getValue());
- } else {
- return rewriter.notifyMatchFailure(storeOp,
- "invalid buffer element type");
- }
- return success();
- }
-};
-
-} // namespace
-
-void populateMemRefToVMPatterns(MLIRContext *context,
- ConversionTarget &conversionTarget,
- TypeConverter &typeConverter,
- RewritePatternSet &patterns) {
- conversionTarget.addIllegalDialect<memref::MemRefDialect>();
-
- typeConverter.addConversion([&](MemRefType type) -> llvm::Optional<Type> {
- if (isRankZeroOrOneMemRef(type)) {
- return IREE::VM::RefType::get(
- IREE::VM::BufferType::get(type.getContext()));
- }
- return llvm::None;
- });
-
- patterns
- .insert<FoldAsNoOp<bufferization::ToMemrefOp>,
- ElideNoOp<memref::AssumeAlignmentOp>, FoldAsNoOp<memref::CastOp>>(
- typeConverter, context);
- patterns
- .insert<ConvertMemRefGlobalOp, ConvertMemRefGetGlobalOp,
- ConvertMemRefAllocaOp, ConvertMemRefLoadOp, ConvertMemRefStoreOp>(
- typeConverter, context);
-}
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/MemRefToVM/ConvertMemRefToVM.h b/compiler/src/iree/compiler/Dialect/VM/Conversion/MemRefToVM/ConvertMemRefToVM.h
deleted file mode 100644
index 084f7f7..0000000
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/MemRefToVM/ConvertMemRefToVM.h
+++ /dev/null
@@ -1,25 +0,0 @@
-// Copyright 2021 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#ifndef IREE_COMPILER_DIALECT_VM_CONVERSION_MEMREFTOVM_CONVERTMEMREFTOVM_H_
-#define IREE_COMPILER_DIALECT_VM_CONVERSION_MEMREFTOVM_CONVERTMEMREFTOVM_H_
-
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Transforms/DialectConversion.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-// Appends memref dialect to vm dialect patterns to the given pattern list.
-void populateMemRefToVMPatterns(MLIRContext *context,
- ConversionTarget &conversionTarget,
- TypeConverter &typeConverter,
- RewritePatternSet &patterns);
-
-} // namespace iree_compiler
-} // namespace mlir
-
-#endif // IREE_COMPILER_DIALECT_VM_CONVERSION_MEMREFTOVM_CONVERTMEMREFTOVM_H_
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/MemRefToVM/test/BUILD b/compiler/src/iree/compiler/Dialect/VM/Conversion/MemRefToVM/test/BUILD
deleted file mode 100644
index 6bceba1..0000000
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/MemRefToVM/test/BUILD
+++ /dev/null
@@ -1,29 +0,0 @@
-# Copyright 2019 The IREE Authors
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-load("//build_tools/bazel:iree_lit_test.bzl", "iree_lit_test_suite")
-load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob")
-
-package(
- default_visibility = ["//visibility:public"],
- features = ["layering_check"],
- licenses = ["notice"], # Apache 2.0
-)
-
-iree_lit_test_suite(
- name = "lit",
- srcs = enforce_glob(
- [
- "memref_ops.mlir",
- ],
- include = ["*.mlir"],
- ),
- cfg = "//compiler:lit.cfg.py",
- tools = [
- "//tools:iree-opt",
- "@llvm-project//llvm:FileCheck",
- ],
-)
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/MemRefToVM/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/VM/Conversion/MemRefToVM/test/CMakeLists.txt
deleted file mode 100644
index 8ead80d..0000000
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/MemRefToVM/test/CMakeLists.txt
+++ /dev/null
@@ -1,23 +0,0 @@
-################################################################################
-# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
-# compiler/src/iree/compiler/Dialect/VM/Conversion/MemRefToVM/test/BUILD #
-# #
-# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
-# CMake-only content. #
-# #
-# To disable autogeneration for this file entirely, delete this header. #
-################################################################################
-
-iree_add_all_subdirs()
-
-iree_lit_test_suite(
- NAME
- lit
- SRCS
- "memref_ops.mlir"
- TOOLS
- FileCheck
- iree-opt
-)
-
-### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/MemRefToVM/test/memref_ops.mlir b/compiler/src/iree/compiler/Dialect/VM/Conversion/MemRefToVM/test/memref_ops.mlir
deleted file mode 100644
index 5b676e3..0000000
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/MemRefToVM/test/memref_ops.mlir
+++ /dev/null
@@ -1,94 +0,0 @@
-// RUN: iree-opt --split-input-file --iree-vm-conversion %s | FileCheck %s
-
-module {
- // CHECK-LABEL: vm.func private @alloca() -> !vm.buffer
- func.func @alloca() -> memref<16xi32> {
- // CHECK: %[[LEN_64:.+]] = vm.const.i64 64
- // CHECK: %[[BUF:.+]] = vm.buffer.alloc %[[LEN_64]] : !vm.buffer
- // return %[[BUF]]
- %0 = memref.alloca() : memref<16xi32>
- return %0 : memref<16xi32>
- }
-}
-
-// -----
-
-module {
- // CHECK-LABEL: vm.func private @load_store
- // CHECK-SAME: (%[[BUFFER:.+]]: !vm.buffer, %[[IDX0:.+]]: i32, %[[IDX1:.+]]: i32) -> f32 {
- func.func @load_store(%buffer: memref<?xf32>, %idx0: index, %idx1: index) -> f32 {
- // CHECK-NEXT: %[[C4_0:.+]] = vm.const.i32 4
- // CHECK-NEXT: %[[OFFSET0_32:.+]] = vm.mul.i32 %[[IDX0]], %[[C4_0]] : i32
- // CHECK-NEXT: %[[OFFSET0:.+]] = vm.ext.i32.i64.u %[[OFFSET0_32]]
- // CHECK-NEXT: %[[VALUE:.+]] = vm.buffer.load.f32 %[[BUFFER]][%[[OFFSET0]]] : !vm.buffer -> f32
- %0 = memref.load %buffer[%idx0] : memref<?xf32>
- // CHECK-NEXT: %[[C4_1:.+]] = vm.const.i32 4
- // CHECK-NEXT: %[[OFFSET1_32:.+]] = vm.mul.i32 %[[IDX1]], %[[C4_1]] : i32
- // CHECK-NEXT: %[[OFFSET1:.+]] = vm.ext.i32.i64.u %[[OFFSET1_32]]
- // CHECK-NEXT: vm.buffer.store.f32 %[[VALUE]], %[[BUFFER]][%[[OFFSET1]]] : f32 -> !vm.buffer
- memref.store %0, %buffer[%idx1] : memref<?xf32>
- // CHECK-NEXT: vm.return %[[VALUE]] : f32
- return %0 : f32
- }
-}
-
-// -----
-
-module {
- // CHECK: vm.rodata private @__constant dense<[0.0287729427, 0.0297581609]> : tensor<2xf32>
- memref.global "private" constant @__constant : memref<2xf32> = dense<[0.0287729427, 0.0297581609]>
- // CHECK-LABEL: vm.func private @load_global
- // CHECK-SAME: (%[[IDX:.+]]: i32) -> f32 {
- func.func @load_global_1d(%idx: index) -> f32 {
- // CHECK-NEXT: %[[BUFFER:.+]] = vm.const.ref.rodata @__constant : !vm.buffer
- %0 = memref.get_global @__constant : memref<2xf32>
- // CHECK-NEXT: %[[C4:.+]] = vm.const.i32 4
- // CHECK-NEXT: %[[OFFSET_32:.+]] = vm.mul.i32 %[[IDX]], %[[C4]] : i32
- // CHECK-NEXT: %[[OFFSET:.+]] = vm.ext.i32.i64.u %[[OFFSET_32]]
- // CHECK-NEXT: %[[VALUE:.+]] = vm.buffer.load.f32 %[[BUFFER]][%[[OFFSET]]] : !vm.buffer -> f32
- %1 = memref.load %0[%idx] : memref<2xf32>
- // vm.return %[[VALUE]] : f32
- return %1 : f32
- }
-}
-
-// -----
-
-module {
- // CHECK: vm.rodata private @__constant dense<0.0287729427> : tensor<f32>
- memref.global "private" constant @__constant : memref<f32> = dense<0.0287729427>
- // CHECK-LABEL: vm.func private @load_global
- func.func @load_global_0d() -> f32 {
- // CHECK-NEXT: %[[BUFFER:.+]] = vm.const.ref.rodata @__constant : !vm.buffer
- %0 = memref.get_global @__constant : memref<f32>
- // CHECK-NEXT: %[[OFFSET:.+]] = vm.const.i64.zero
- // CHECK-NEXT: %[[VALUE:.+]] = vm.buffer.load.f32 %[[BUFFER]][%[[OFFSET]]] : !vm.buffer -> f32
- %1 = memref.load %0[] : memref<f32>
- // vm.return %[[VALUE]] : f32
- return %1 : f32
- }
-}
-
-// -----
-
-module {
- // CHECK-LABEL: @assume_alignment
- func.func @assume_alignment(%buffer: memref<?xf32>) {
- // CHECK-NOT: assume_alignment
- memref.assume_alignment %buffer, 64 : memref<?xf32>
- // CHECK: return
- func.return
- }
-}
-
-// -----
-
-module {
- // CHECK-LABEL: @cast
- func.func @cast(%buffer: memref<?xf32>) {
- // CHECK-NOT: memref.cast
- memref.cast %buffer : memref<?xf32> to memref<5xf32>
- // CHECK: return
- func.return
- }
-}
diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp b/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
index 09d8bca..f22e35b 100644
--- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
@@ -233,31 +233,6 @@
namespace {
-/// Inlines immutable global constants into their loads.
-struct InlineConstGlobalLoadRefOp : public OpRewritePattern<GlobalLoadRefOp> {
- using OpRewritePattern<GlobalLoadRefOp>::OpRewritePattern;
- LogicalResult matchAndRewrite(GlobalLoadRefOp op,
- PatternRewriter &rewriter) const override {
- auto globalAttr = op->getAttrOfType<FlatSymbolRefAttr>("global");
- auto globalOp =
- op->getParentOfType<VM::ModuleOp>().lookupSymbol<GlobalRefOp>(
- globalAttr.getValue());
- if (!globalOp) return failure();
- if (globalOp.getIsMutable()) return failure();
- rewriter.replaceOpWithNewOp<ConstRefZeroOp>(op, op.getType());
- return success();
- }
-};
-
-} // namespace
-
-void GlobalLoadRefOp::getCanonicalizationPatterns(RewritePatternSet &results,
- MLIRContext *context) {
- results.insert<InlineConstGlobalLoadRefOp>(context);
-}
-
-namespace {
-
template <typename INDIRECT, typename DIRECT>
struct PropagateGlobalLoadAddress : public OpRewritePattern<INDIRECT> {
using OpRewritePattern<INDIRECT>::OpRewritePattern;
diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td
index 51a7440..6e706f0 100644
--- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td
+++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td
@@ -697,8 +697,6 @@
VM_EncTypeOf<"value">,
VM_EncResult<"value">,
];
-
- let hasCanonicalizer = 1;
}
def VM_GlobalStoreRefOp : VM_GlobalStoreOp<VM_AnyRef, "global.store.ref"> {
diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/test/global_folding.mlir b/compiler/src/iree/compiler/Dialect/VM/IR/test/global_folding.mlir
index 833d9db..1b6a80e 100644
--- a/compiler/src/iree/compiler/Dialect/VM/IR/test/global_folding.mlir
+++ b/compiler/src/iree/compiler/Dialect/VM/IR/test/global_folding.mlir
@@ -28,19 +28,6 @@
// -----
-// CHECK-LABEL: @global_ref_folds_null
-vm.module @global_ref_folds_null {
- // CHECK: vm.global.ref public mutable @g0 : !vm.ref<?>
- vm.global.ref mutable @g0 : !vm.ref<?>
- vm.initializer {
- %null = vm.const.ref.zero : !vm.ref<?>
- vm.global.store.ref %null, @g0 : !vm.ref<?>
- vm.return
- }
-}
-
-// -----
-
// CHECK-LABEL: @global_load_i32_folds
vm.module @global_load_i32_folds {
vm.global.i32 @g0 = 123 : i32
@@ -65,30 +52,6 @@
// -----
-// CHECK-LABEL: @global_load_ref_folds
-vm.module @global_load_ref_folds {
- vm.global.ref @g0 : !vm.ref<?>
- // CHECK-LABEL: @inline_const_null
- vm.func @inline_const_null() -> !vm.ref<?> {
- // CHECK-NEXT: %null = vm.const.ref.zero : !vm.ref<?>
- // CHECK-NEXT: vm.return %null : !vm.ref<?>
- %g0 = vm.global.load.ref @g0 : !vm.ref<?>
- vm.return %g0 : !vm.ref<?>
- }
-
- vm.global.ref mutable @g1 : !vm.ref<?>
- // CHECK-LABEL: @ignore_nonconst_value
- vm.func @ignore_nonconst_value() -> !vm.ref<?> {
- // NOTE: ensure we don't inline non-constant values.
- // CHECK-NEXT: %g1 = vm.global.load.ref @g1 : !vm.ref<?>
- // CHECK-NEXT: vm.return %g1 : !vm.ref<?>
- %g1 = vm.global.load.ref @g1 : !vm.ref<?>
- vm.return %g1 : !vm.ref<?>
- }
-}
-
-// -----
-
// CHECK-LABEL: @global_indirect_folds
vm.module @global_indirect_folds {
vm.global.i32 mutable @g0 : i32
diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/BUILD b/compiler/src/iree/compiler/Dialect/VM/Transforms/BUILD
index ba3cced..34c0e2e 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Transforms/BUILD
+++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/BUILD
@@ -23,17 +23,18 @@
"OrdinalAllocation.cpp",
"Passes.cpp",
"SinkDefiningOps.cpp",
+ "SinkGlobalBufferLoads.cpp",
],
hdrs = [
"Passes.h",
],
deps = [
"//compiler/src/iree/compiler/Dialect/Util/Conversion",
+ "//compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil",
"//compiler/src/iree/compiler/Dialect/Util/IR",
"//compiler/src/iree/compiler/Dialect/Util/Transforms",
"//compiler/src/iree/compiler/Dialect/VM/Conversion",
"//compiler/src/iree/compiler/Dialect/VM/Conversion/MathToVM",
- "//compiler/src/iree/compiler/Dialect/VM/Conversion/MemRefToVM",
"//compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM",
"//compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM",
"//compiler/src/iree/compiler/Dialect/VM/IR",
diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/VM/Transforms/CMakeLists.txt
index c51c863..f084125 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/CMakeLists.txt
@@ -24,6 +24,7 @@
"OrdinalAllocation.cpp"
"Passes.cpp"
"SinkDefiningOps.cpp"
+ "SinkGlobalBufferLoads.cpp"
DEPS
LLVMSupport
MLIRAffineDialect
@@ -43,11 +44,11 @@
MLIRTransformUtils
MLIRTransforms
iree::compiler::Dialect::Util::Conversion
+ iree::compiler::Dialect::Util::Conversion::MemRefToUtil
iree::compiler::Dialect::Util::IR
iree::compiler::Dialect::Util::Transforms
iree::compiler::Dialect::VM::Conversion
iree::compiler::Dialect::VM::Conversion::MathToVM
- iree::compiler::Dialect::VM::Conversion::MemRefToVM
iree::compiler::Dialect::VM::Conversion::StandardToVM
iree::compiler::Dialect::VM::Conversion::UtilToVM
iree::compiler::Dialect::VM::IR
diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/Conversion.cpp b/compiler/src/iree/compiler/Dialect/VM/Transforms/Conversion.cpp
index 4df294d..7474985 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Transforms/Conversion.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/Conversion.cpp
@@ -8,13 +8,13 @@
#include <tuple>
#include "iree/compiler/Dialect/Util/Conversion/ConversionPatterns.h"
+#include "iree/compiler/Dialect/Util/Conversion/MemRefToUtil/ConvertMemRefToUtil.h"
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/Dialect/VM/Conversion/ConversionDialectInterface.h"
#include "iree/compiler/Dialect/VM/Conversion/ConversionTarget.h"
#include "iree/compiler/Dialect/VM/Conversion/ImportUtils.h"
#include "iree/compiler/Dialect/VM/Conversion/MathToVM/ConvertMathToVM.h"
-#include "iree/compiler/Dialect/VM/Conversion/MemRefToVM/ConvertMemRefToVM.h"
#include "iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.h"
#include "iree/compiler/Dialect/VM/Conversion/TypeConverter.h"
#include "iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertUtilToVM.h"
@@ -124,10 +124,16 @@
arith::populateArithmeticExpandOpsPatterns(patterns);
populateStandardToVMPatterns(context, typeConverter, patterns);
populateMathToVMPatterns(context, typeConverter, patterns);
- populateMemRefToVMPatterns(context, conversionTarget, typeConverter,
- patterns);
populateAffineToStdConversionPatterns(patterns);
+ // MemRef to Util (to VM) is an A->B->C lowering. We must instruct it
+ // specifically on what the correct C buffer type is.
+ auto utilBufferType =
+ typeConverter.convertType(IREE::Util::BufferType::get(&getContext()));
+ assert(utilBufferType);
+ populateMemRefToUtilPatterns(context, conversionTarget, typeConverter,
+ patterns, utilBufferType);
+
conversionTarget
.addIllegalDialect<func::FuncDialect, mlir::arith::ArithmeticDialect>();
conversionTarget.addIllegalDialect<AffineDialect>();
diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.cpp
index db45bdf..a44e063 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.cpp
@@ -74,6 +74,8 @@
passManager.addNestedPass<IREE::VM::ModuleOp>(createHoistInlinedRodataPass());
passManager.addNestedPass<IREE::VM::ModuleOp>(createDeduplicateRodataPass());
passManager.addNestedPass<IREE::VM::ModuleOp>(
+ createSinkGlobalBufferLoadsPass());
+ passManager.addNestedPass<IREE::VM::ModuleOp>(
createGlobalInitializationPass());
passManager.addPass(createInlinerPass());
diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.h
index 4ba3a27..8673354 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.h
+++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.h
@@ -59,6 +59,12 @@
std::unique_ptr<OperationPass<IREE::VM::ModuleOp>>
createDeduplicateRodataPass();
+// Sinks global buffer references into loads. This should result in the
+// elimination of initializers that are trivially initializing a global buffer
+// from rodata.
+std::unique_ptr<OperationPass<IREE::VM::ModuleOp>>
+createSinkGlobalBufferLoadsPass();
+
//===----------------------------------------------------------------------===//
// Module analysis and ordinal assignment
//===----------------------------------------------------------------------===//
@@ -106,6 +112,7 @@
createGlobalInitializationPass();
createOrdinalAllocationPass();
createSinkDefiningOpsPass();
+ createSinkGlobalBufferLoadsPass();
}
inline void registerVMTestPasses() {
diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/SinkGlobalBufferLoads.cpp b/compiler/src/iree/compiler/Dialect/VM/Transforms/SinkGlobalBufferLoads.cpp
new file mode 100644
index 0000000..b571112
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/SinkGlobalBufferLoads.cpp
@@ -0,0 +1,129 @@
+// Copyright 2020 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include <utility>
+
+#include "iree/compiler/Dialect/VM/IR/VMDialect.h"
+#include "iree/compiler/Dialect/VM/IR/VMOps.h"
+#include "iree/compiler/Dialect/VM/IR/VMTypes.h"
+#include "iree/compiler/Dialect/VM/Transforms/Passes.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace VM {
+
+class SinkGlobalBufferLoadsPass
+ : public PassWrapper<SinkGlobalBufferLoadsPass,
+ OperationPass<IREE::VM::ModuleOp>> {
+ public:
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<IREE::VM::VMDialect>();
+ }
+
+ StringRef getArgument() const override {
+ return "iree-vm-sink-global-buffer-loads";
+ }
+
+ StringRef getDescription() const override {
+ return "Sinks global buffer references into loads.";
+ }
+
+ void runOnOperation() override {
+ auto moduleOp = getOperation();
+ SymbolTableCollection symbolTable;
+
+ // Find all vm.global.store.ref ops in the module and note the ones that
+ // meet our requirements:
+ // - Is in an initializer
+ // - Storing to an immutable global
+ // - Sourcing from a vm.const.ref.rodata or vm.const.ref.zero op
+ //
+ // We will set an info with a null initializerOp if there is a store but
+ // an unrecognized source.
+ struct GlobalInitInfo {
+ IREE::VM::GlobalStoreRefOp storeOp;
+ Operation *initializerOp;
+ };
+ DenseMap<IREE::VM::GlobalRefOp, GlobalInitInfo> globalInitInfos;
+ moduleOp.walk([&](IREE::VM::GlobalStoreRefOp storeOp) {
+ if (!storeOp->getParentOfType<IREE::VM::InitializerOp>()) {
+ return;
+ }
+ // Only consider it a constant for a couple of cases.
+ Operation *initializerOp = storeOp.getValue().getDefiningOp();
+ if (initializerOp &&
+ !llvm::isa<IREE::VM::ConstRefRodataOp, IREE::VM::ConstRefZeroOp>(
+ initializerOp)) {
+ initializerOp = nullptr;
+ }
+ auto globalOp =
+ symbolTable.lookupNearestSymbolFrom<IREE::VM::GlobalRefOp>(
+ storeOp->getParentOp(), storeOp.getGlobalAttr());
+ if (globalOp) {
+ if (!globalOp.isMutable()) {
+ globalInitInfos[globalOp] = GlobalInitInfo{storeOp, initializerOp};
+ }
+ }
+ });
+
+ // Walk over all loads and update.
+ moduleOp.walk([&](IREE::VM::GlobalLoadRefOp loadOp) {
+ auto globalOp =
+ symbolTable.lookupNearestSymbolFrom<IREE::VM::GlobalRefOp>(
+ loadOp->getParentOp(), loadOp.getGlobalAttr());
+ auto it = globalInitInfos.find(globalOp);
+ if (it != globalInitInfos.end()) {
+ auto &info = it->second;
+ if (info.initializerOp) {
+ // We are sourced from a constant. Clone/replace.
+ OpBuilder builder(loadOp);
+ Operation *newOp = builder.clone(*info.initializerOp);
+ loadOp.replaceAllUsesWith(newOp);
+ loadOp.erase();
+ }
+ return;
+ }
+
+ // Still here? If the global is immutable, we can replace with null.
+ // (i.e. there is no initializing store)
+ if (!globalOp.isMutable()) {
+ OpBuilder builder(loadOp);
+ Value zero = builder.create<IREE::VM::ConstRefZeroOp>(
+ loadOp.getLoc(), loadOp.getResult().getType());
+ loadOp.replaceAllUsesWith(zero);
+ loadOp.erase();
+ }
+ });
+
+ // Erase initializers no longer needed.
+ for (auto it : globalInitInfos) {
+ auto global = it.first;
+ auto &info = it.second;
+ if (info.initializerOp) {
+ info.storeOp.erase();
+ global.erase();
+ }
+ }
+ }
+};
+
+std::unique_ptr<OperationPass<IREE::VM::ModuleOp>>
+createSinkGlobalBufferLoadsPass() {
+ return std::make_unique<SinkGlobalBufferLoadsPass>();
+}
+
+static PassRegistration<SinkGlobalBufferLoadsPass> pass;
+
+} // namespace VM
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/test/BUILD b/compiler/src/iree/compiler/Dialect/VM/Transforms/test/BUILD
index 7d81db8..5a0c23b 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Transforms/test/BUILD
+++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/test/BUILD
@@ -23,6 +23,7 @@
"hoist_inlined_rodata.mlir",
"ordinal_allocation.mlir",
"sink_defining_ops.mlir",
+ "sink_global_buffer_loads.mlir",
],
include = ["*.mlir"],
),
diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/VM/Transforms/test/CMakeLists.txt
index f0cb3bf..3030c30 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Transforms/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/test/CMakeLists.txt
@@ -20,6 +20,7 @@
"hoist_inlined_rodata.mlir"
"ordinal_allocation.mlir"
"sink_defining_ops.mlir"
+ "sink_global_buffer_loads.mlir"
TOOLS
FileCheck
iree-opt
diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/test/sink_global_buffer_loads.mlir b/compiler/src/iree/compiler/Dialect/VM/Transforms/test/sink_global_buffer_loads.mlir
new file mode 100644
index 0000000..56c1e2a
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/test/sink_global_buffer_loads.mlir
@@ -0,0 +1,103 @@
+// RUN: iree-opt --split-input-file --pass-pipeline="vm.module(iree-vm-sink-global-buffer-loads)" --allow-unregistered-dialect %s | FileCheck %s
+
+// CHECK-LABEL: @rodata_ref
+module {
+ vm.module public @rodata_ref {
+ vm.rodata private @_const_0 dense<[1, 0]> : tensor<2xi32>
+ // CHECK: vm.func public @f
+ vm.func public @f() -> !vm.buffer {
+ // CHECK: %[[V:.*]] = vm.const.ref.rodata @_const_0 : !vm.buffer
+ %0 = vm.global.load.ref @__constant_1x2xi32 : !vm.buffer
+ // CHECK: vm.return %[[V]]
+ vm.return %0 : !vm.buffer
+ }
+ // CHECK-NOT: vm.global.ref
+ // CHECK-NOT: vm.global.store.ref
+ vm.global.ref private @__constant_1x2xi32 : !vm.buffer
+ vm.initializer {
+ %_const_0 = vm.const.ref.rodata @_const_0 : !vm.buffer
+ vm.global.store.ref %_const_0, @__constant_1x2xi32 : !vm.buffer
+ vm.return
+ }
+ }
+}
+
+// -----
+// CHECK-LABEL: @undefined
+module {
+ vm.module public @undefined {
+ // CHECK: vm.func public @f
+ vm.func public @f() -> !vm.buffer {
+ // CHECK: %[[V:.*]] = vm.const.ref.zero : !vm.buffer
+ %0 = vm.global.load.ref @__constant_1x2xi32 : !vm.buffer
+ // CHECK: vm.return %[[V]]
+ vm.return %0 : !vm.buffer
+ }
+ vm.global.ref private @__constant_1x2xi32 : !vm.buffer
+ }
+}
+
+// -----
+// CHECK-LABEL: @unknown_initializer
+module {
+ vm.module public @unknown_initializer {
+ // CHECK: vm.func public @f
+ vm.func public @f() -> !vm.buffer {
+ // CHECK: %[[V:.*]] = vm.global.load.ref @__constant_1x2xi32
+ %0 = vm.global.load.ref @__constant_1x2xi32 : !vm.buffer
+ // CHECK: vm.return %[[V]]
+ vm.return %0 : !vm.buffer
+ }
+ vm.global.ref private @__constant_1x2xi32 : !vm.buffer
+ // CHECK: vm.initializer
+ vm.initializer {
+ %_const_0 = "undefined.custom_initializer"() : () -> (!vm.buffer)
+ // CHECK: vm.global.store.ref
+ vm.global.store.ref %_const_0, @__constant_1x2xi32 : !vm.buffer
+ vm.return
+ }
+ }
+}
+
+// -----
+// CHECK-LABEL: @const.ref.zero
+module {
+ vm.module public @const.ref.zero {
+ // CHECK: vm.func public @f
+ vm.func public @f() -> !vm.buffer {
+ // CHECK: %[[NULL:.*]] = vm.const.ref.zero : !vm.buffer
+ %0 = vm.global.load.ref @__constant_1x2xi32 : !vm.buffer
+ // CHECK: vm.return %[[NULL]]
+ vm.return %0 : !vm.buffer
+ }
+ // CHECK-NOT: vm.global.ref
+ // CHECK-NOT: vm.global.store.ref
+ vm.global.ref private @__constant_1x2xi32 : !vm.buffer
+ vm.initializer {
+ %_const_0 = vm.const.ref.zero : !vm.buffer
+ vm.global.store.ref %_const_0, @__constant_1x2xi32 : !vm.buffer
+ vm.return
+ }
+ }
+}
+
+// -----
+// CHECK-LABEL: @mutable_ref
+module {
+ vm.module public @mutable_ref {
+ vm.rodata private @_const_0 dense<[1, 0]> : tensor<2xi32>
+ // CHECK: vm.func public @f
+ vm.func public @f() -> !vm.buffer {
+ // CHECK: %[[V:.*]] = vm.global.load.ref
+ %0 = vm.global.load.ref @__constant_1x2xi32 : !vm.buffer
+ // CHECK: vm.return %[[V]]
+ vm.return %0 : !vm.buffer
+ }
+ vm.global.ref private mutable @__constant_1x2xi32 : !vm.buffer
+ vm.initializer {
+ %_const_0 = vm.const.ref.rodata @_const_0 : !vm.buffer
+ vm.global.store.ref %_const_0, @__constant_1x2xi32 : !vm.buffer
+ vm.return
+ }
+ }
+}
diff --git a/tests/e2e/tosa_ops/BUILD b/tests/e2e/tosa_ops/BUILD
index 4b31afa..93fc1a3 100644
--- a/tests/e2e/tosa_ops/BUILD
+++ b/tests/e2e/tosa_ops/BUILD
@@ -143,13 +143,18 @@
"clamp.mlir",
"clz.mlir",
"const.mlir",
+ "equal.mlir",
"exp.mlir",
"floor.mlir",
+ "fully_connected.mlir",
+ "greater.mlir",
+ "greater_equal.mlir",
"if.mlir",
"log.mlir",
"logical_left_shift.mlir",
"logical_right_shift.mlir",
"matmul.mlir",
+ "max_pool.mlir",
"maximum.mlir",
"minimum.mlir",
"mul.mlir",
@@ -160,28 +165,18 @@
"reluN.mlir",
"reshape.mlir",
"rsqrt.mlir",
+ "select.mlir",
"sigmoid.mlir",
"sub.mlir",
"tanh.mlir",
"transpose.mlir",
+ "while.mlir",
],
include = ["*.mlir"],
exclude = [
- # Gather and table are failing on complicated buffer conversions that
- # will shortly be fixed separately.
+ # Decompositions produce tensor<index> which is not handled properly.
"gather.mlir",
"table.mlir",
- # equal, greater_equal, select have (suspected) i1 problems which
- # manifest at runtime.
- "equal.mlir",
- "greater.mlir",
- "greater_equal.mlir",
- "select.mlir",
- # fully_connected and while are failing on unrealizable 0d memref
- # conversions. The first is incidental to how the test is constructed.
- "fully_connected.mlir",
- "max_pool.mlir",
- "while.mlir",
],
)
diff --git a/tests/e2e/tosa_ops/CMakeLists.txt b/tests/e2e/tosa_ops/CMakeLists.txt
index 3c8bca3..853fd13 100644
--- a/tests/e2e/tosa_ops/CMakeLists.txt
+++ b/tests/e2e/tosa_ops/CMakeLists.txt
@@ -127,13 +127,18 @@
"clamp.mlir"
"clz.mlir"
"const.mlir"
+ "equal.mlir"
"exp.mlir"
"floor.mlir"
+ "fully_connected.mlir"
+ "greater.mlir"
+ "greater_equal.mlir"
"if.mlir"
"log.mlir"
"logical_left_shift.mlir"
"logical_right_shift.mlir"
"matmul.mlir"
+ "max_pool.mlir"
"maximum.mlir"
"minimum.mlir"
"mul.mlir"
@@ -144,10 +149,12 @@
"reluN.mlir"
"reshape.mlir"
"rsqrt.mlir"
+ "select.mlir"
"sigmoid.mlir"
"sub.mlir"
"tanh.mlir"
"transpose.mlir"
+ "while.mlir"
TARGET_BACKEND
"vmvx"
DRIVER