Generalization for ElementsAttr coverage (#15433)
This PR addresses the issue that there was code written in IREE from a
time when there was only DenseElementsAttr. DenseElementsAttr is a
specific type and ElementsAttr the interface. There are actually
multiple implementations of the interface now and DenseElementsAttr is
being phased out for general usage. This PR deals with generalization,
so we are not boxed into only supporting DenseElementsAttr.
diff --git a/compiler/src/iree/compiler/Codegen/Common/ConvertBf16ArithToF32.cpp b/compiler/src/iree/compiler/Codegen/Common/ConvertBf16ArithToF32.cpp
index dc3e268..d721e05 100644
--- a/compiler/src/iree/compiler/Codegen/Common/ConvertBf16ArithToF32.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/ConvertBf16ArithToF32.cpp
@@ -214,7 +214,7 @@
return rewriter.notifyMatchFailure(op, "Not casting from vector-scalar");
}
- mlir::DenseElementsAttr vectorCst;
+ mlir::ElementsAttr vectorCst;
if (!matchPattern(operand, m_Constant(&vectorCst))) {
return failure();
}
diff --git a/compiler/src/iree/compiler/Codegen/Common/ConvertToDestinationPassingStylePass.cpp b/compiler/src/iree/compiler/Codegen/Common/ConvertToDestinationPassingStylePass.cpp
index b92780a..9daa80f 100644
--- a/compiler/src/iree/compiler/Codegen/Common/ConvertToDestinationPassingStylePass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/ConvertToDestinationPassingStylePass.cpp
@@ -516,7 +516,7 @@
bool modifiedOutput = false;
Location loc = op.getLoc();
for (OpOperand &opOperand : op.getDpsInitsMutable()) {
- DenseElementsAttr attr;
+ ElementsAttr attr;
if (!matchPattern(opOperand.get(), m_Constant(&attr)))
continue;
if (!attr.isSplat())
diff --git a/compiler/src/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp b/compiler/src/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp
index 744b293..0fd4129 100644
--- a/compiler/src/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp
@@ -51,6 +51,7 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
+#include "mlir/IR/DialectResourceBlobManager.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
@@ -178,6 +179,10 @@
return splatAttr.reshape(newType);
} else if (auto denseAttr = llvm::dyn_cast<DenseElementsAttr>(value)) {
return denseAttr.reshape(newType);
+ } else if (auto denseResourceAttr =
+ llvm::dyn_cast<DenseResourceElementsAttr>(value)) {
+ return DenseResourceElementsAttr::get(newType,
+ denseResourceAttr.getRawHandle());
}
return {};
}
diff --git a/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp b/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp
index ef5f14b..f682f5a 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp
@@ -105,7 +105,7 @@
LogicalResult
matchAndRewrite(arith::ConstantOp constantOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
- auto attr = llvm::cast<DenseElementsAttr>(constantOp.getValue());
+ auto attr = llvm::cast<ElementsAttr>(constantOp.getValue());
auto attrType = llvm::dyn_cast<ShapedType>(attr.getType());
if (!attrType) {
return rewriter.notifyMatchFailure(
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPULayoutAnalysisAndDistribution.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPULayoutAnalysisAndDistribution.cpp
index 33dc874..6320d45 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPULayoutAnalysisAndDistribution.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPULayoutAnalysisAndDistribution.cpp
@@ -1060,7 +1060,7 @@
Value constant = constantOp.getResult();
if (!layoutMap.count(constant))
return;
- auto attr = llvm::cast<DenseElementsAttr>(constantOp.getValue());
+ auto attr = llvm::cast<ElementsAttr>(constantOp.getValue());
// Only handle splat values for now
if (!attr.isSplat())
return;
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp
index 3e2ffd2..6deb796 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp
@@ -529,14 +529,13 @@
auto constantType = op->getResult(0).getType();
if (llvm::isa<SplatElementsAttr>(constantValueAttr)) {
return true;
- } else if (auto denseAttr =
- llvm::dyn_cast<DenseElementsAttr>(constantValueAttr)) {
+ } else if (auto attr = llvm::dyn_cast<ElementsAttr>(constantValueAttr)) {
auto shapedType = llvm::cast<ShapedType>(constantType);
uint64_t estimatedByteLength =
(shapedType.getNumElements() *
IREE::Util::getTypeBitWidth(shapedType.getElementType())) /
8;
- return denseAttr.isSplat() ||
+ return attr.isSplat() ||
estimatedByteLength <= clInlineConstantByteLength;
} else if (constantType.isIntOrIndexOrFloat() ||
isa<ComplexType>(constantType)) {
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/ClosureOpUtils.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/ClosureOpUtils.cpp
index bd17e0b..66e7865 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/ClosureOpUtils.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/ClosureOpUtils.cpp
@@ -134,14 +134,12 @@
// know they are a splat - which is why it's so important we inline them
// here so we know when they are used that's the case.
return true;
- } else if (auto denseAttr =
- llvm::dyn_cast<DenseElementsAttr>(constantValueAttr)) {
+ } else if (auto attr = llvm::dyn_cast<ElementsAttr>(constantValueAttr)) {
// Smallish constants are worth moving inside.
auto shapedType = llvm::cast<ShapedType>(constantType);
uint64_t estimatedByteLength =
IREE::Util::getRoundedPhysicalStorageSize(shapedType);
- return denseAttr.isSplat() ||
- estimatedByteLength <= maxInlinedConstantBytes;
+ return attr.isSplat() || estimatedByteLength <= maxInlinedConstantBytes;
} else if (constantType.isIntOrIndexOrFloat()) {
// Primitives can always go in.
return true;
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/OutlineConstants.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/OutlineConstants.cpp
index f0d31e3..3ea0127 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/OutlineConstants.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/OutlineConstants.cpp
@@ -25,7 +25,7 @@
// Returns true if |value| is worth outlining (large, etc).
static bool isOutlinableValue(Attribute value) {
- if (auto elementsAttr = llvm::dyn_cast<DenseElementsAttr>(value)) {
+ if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(value)) {
// Don't outline splats - we want those fused.
return !elementsAttr.isSplat();
}
diff --git a/compiler/src/iree/compiler/GlobalOptimization/DetachElementwiseFromNamedOps.cpp b/compiler/src/iree/compiler/GlobalOptimization/DetachElementwiseFromNamedOps.cpp
index 062cd41..67d06d4 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/DetachElementwiseFromNamedOps.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/DetachElementwiseFromNamedOps.cpp
@@ -154,7 +154,7 @@
if (!resultType || !resultType.getElementType().isIntOrFloat())
continue;
- auto attr = llvm::dyn_cast<DenseElementsAttr>(constOp.getValue());
+ auto attr = llvm::dyn_cast<ElementsAttr>(constOp.getValue());
if (!attr || !attr.isSplat())
continue;
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Canonicalization.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Canonicalization.cpp
index dc00886..9adbf95 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Canonicalization.cpp
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Canonicalization.cpp
@@ -414,7 +414,7 @@
// Simplify when the condition is a constant.
Value pred = op.getPred();
- DenseElementsAttr cond;
+ ElementsAttr cond;
if (!matchPattern(pred, m_Constant(&cond))) {
return failure();
}
@@ -430,11 +430,11 @@
if (cond.getNumElements() > kFoldOpEltLimit)
return failure();
- DenseElementsAttr trueAttr;
+ ElementsAttr trueAttr;
if (!matchPattern(trueVal, m_Constant(&trueAttr)))
return failure();
- DenseElementsAttr falseAttr;
+ ElementsAttr falseAttr;
if (!matchPattern(falseVal, m_Constant(&falseAttr)))
return failure();
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalg.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalg.cpp
index c3b055e..c36ee73 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalg.cpp
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalg.cpp
@@ -1417,16 +1417,26 @@
return rewriter.notifyMatchFailure(constOp, "type conversion failed");
ElementsAttr replacementAttr = constOp.getValue();
- if (replacementType != constOp.getType()) {
- if (auto denseAttr = dyn_cast<DenseElementsAttr>(replacementAttr)) {
- // Signedness conversion.
- replacementAttr = denseAttr.mapValues(replacementType.getElementType(),
- [](const APInt &i) { return i; });
+ if (replacementType == constOp.getType()) {
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(constOp, replacementType,
+ replacementAttr);
+ return success();
+ } else {
+ auto denseAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
+ if (!denseAttr) {
+ return rewriter.notifyMatchFailure(
+ constOp,
+ "DenseElementsAttr cast failed (only DenseElementsAttr supported)");
}
+ // Signedness conversion.
+ // TODO(#15442): Add generic mapping utility, so we aren't limited to
+ // supporting only DenseElementsAttr.
+ replacementAttr = denseAttr.mapValues(replacementType.getElementType(),
+ [](const APInt &i) { return i; });
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(constOp, replacementType,
+ replacementAttr);
+ return success();
}
- rewriter.replaceOpWithNewOp<arith::ConstantOp>(constOp, replacementType,
- replacementAttr);
- return success();
}
};