Improving util.align folding. (#14805)
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Util/IR/BUILD.bazel
index 213481e..9a8c582 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/BUILD.bazel
@@ -95,6 +95,7 @@
":UtilOpsGen",
":UtilTypesGen",
"@llvm-project//llvm:Support",
+ "@llvm-project//mlir:AffineDialect",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:CastInterfaces",
"@llvm-project//mlir:ControlFlowDialect",
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Util/IR/CMakeLists.txt
index c4ca0ce..5abbea5 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/CMakeLists.txt
@@ -63,6 +63,7 @@
::UtilOpsGen
::UtilTypesGen
LLVMSupport
+ MLIRAffineDialect
MLIRArithDialect
MLIRCastInterfaces
MLIRControlFlowDialect
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilAttrs.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilAttrs.cpp
index 2b664cd..0e26efc 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilAttrs.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilAttrs.cpp
@@ -395,7 +395,7 @@
case 64:
return serializeGenericIntegerElements<uint64_t>(attr, endian, os);
default:
- if (bitWidth != 1 && bitWidth < 64) {
+ if (bitWidth < 64) {
// Special case for bit-packing of sub-byte aligned types.
// This could be extended to handle larger widths (i33, etc) but they
// are rare today.
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp
index fe3954a..ea55cd4 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp
@@ -4,10 +4,13 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+#include <numeric>
+
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/EquivalenceClasses.h"
#include "llvm/ADT/SmallPtrSet.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Attributes.h"
@@ -364,9 +367,11 @@
// a large majority of the cases we generate ourselves from packing/allocation.
static bool isAlignedTo(Value value, Value alignment) {
APInt staticValue;
+ bool hasStaticValue = matchPattern(value, m_ConstantInt(&staticValue));
APInt staticAlignment;
- if (matchPattern(value, m_ConstantInt(&staticValue)) &&
- matchPattern(alignment, m_ConstantInt(&staticAlignment))) {
+ bool hasStaticAlignment =
+ matchPattern(alignment, m_ConstantInt(&staticAlignment));
+ if (hasStaticValue && hasStaticAlignment) {
// If this value is itself a multiple of the alignment then we can fold.
if (staticValue.urem(staticAlignment).isZero()) {
return true; // value % alignment == 0
@@ -381,11 +386,12 @@
// If the alignments are constant we can compare them inline.
APInt sourceAlignment;
- APInt selfAlignment;
- if (matchPattern(sourceAlignOp.getAlignment(),
- m_ConstantInt(&sourceAlignment)) &&
- matchPattern(alignment, m_ConstantInt(&selfAlignment))) {
- if (sourceAlignment.uge(selfAlignment)) {
+ if (hasStaticAlignment && matchPattern(sourceAlignOp.getAlignment(),
+ m_ConstantInt(&sourceAlignment))) {
+ if (sourceAlignment.uge(staticAlignment) &&
+ std::gcd(sourceAlignment.getZExtValue(),
+ staticAlignment.getZExtValue()) ==
+ staticAlignment.getZExtValue()) {
return true; // source alignment is >= our alignment
}
}
@@ -395,6 +401,15 @@
return isAlignedTo(sourceAlignOp.getValue(), alignment);
}
+ // Affine apply ops producing the value to be aligned usually include
+ // alignment already.
+ if (auto affineOp = value.getDefiningOp<affine::AffineApplyOp>()) {
+ if (hasStaticAlignment) {
+ return (affineOp.getAffineMap().getLargestKnownDivisorOfMapExprs() %
+ staticAlignment.getZExtValue()) == 0;
+ }
+ }
+
// If we are sourced from add/mul we peephole check to see if what is being
// added is also aligned. This should be part of a larger pass doing IPO but
// as the common case is that we align+add+align this is worth having in a
@@ -414,7 +429,7 @@
}
} else if (auto sourceMulOp = value.getDefiningOp<arith::MulIOp>()) {
// Two aligned values multiplied together are still aligned.
- if (isAlignedTo(sourceMulOp.getLhs(), alignment) &&
+ if (isAlignedTo(sourceMulOp.getLhs(), alignment) ||
isAlignedTo(sourceMulOp.getRhs(), alignment)) {
return true;
}
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/test/alignment_folding.mlir b/compiler/src/iree/compiler/Dialect/Util/IR/test/alignment_folding.mlir
index 750746e..3477dbe 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/test/alignment_folding.mlir
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/test/alignment_folding.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --split-input-file --canonicalize %s | iree-opt --split-input-file | FileCheck %s
+// RUN: iree-opt --split-input-file --canonicalize --mlir-print-local-scope %s | iree-opt --split-input-file --mlir-print-local-scope | FileCheck %s
// CHECK-LABEL: @foldSameAlignment
// CHECK-SAME: (%[[VALUE:.+]]: index, %[[ALIGNMENT:.+]]: index)
@@ -43,6 +43,21 @@
// -----
+// CHECK-LABEL: @dontFoldMixedAlignment
+// CHECK-SAME: (%[[VALUE:.+]]: index)
+func.func @dontFoldMixedAlignment(%value: index) -> index {
+ %c9 = arith.constant 9 : index
+ %c16 = arith.constant 16 : index
+ // CHECK: %[[ALIGN16:.+]] = util.align %[[VALUE]], %c16
+ %0 = util.align %value, %c16 : index
+ // CHECK: %[[ALIGN9:.+]] = util.align %[[ALIGN16]], %c9
+ %1 = util.align %0, %c9 : index
+ // CHECK: return %[[ALIGN9]]
+ return %1 : index
+}
+
+// -----
+
// CHECK-LABEL: @foldAlignmentRecursively
// CHECK-SAME: (%[[VALUE:.+]]: index, %[[ALIGNMENT:.+]]: index)
func.func @foldAlignmentRecursively(%value: index, %alignment: index) -> index {
@@ -94,6 +109,21 @@
// -----
+// CHECK-LABEL: @foldMulAlignmentConstant
+// CHECK-SAME: (%[[LHS:.+]]: index)
+func.func @foldMulAlignmentConstant(%lhs: index) -> index {
+ %c64 = arith.constant 64 : index
+ %c2048 = arith.constant 2048 : index
+ // CHECK: %[[RESULT:.+]] = arith.muli %[[LHS]], %c2048
+ %lhs_mul = arith.muli %lhs, %c2048 : index
+ // CHECK-NOT: util.align
+ %result = util.align %lhs_mul, %c64 : index
+ // CHECK: return %[[RESULT]]
+ return %result : index
+}
+
+// -----
+
// CHECK-LABEL: @foldConstantAlign
func.func @foldConstantAlign() -> (index, index, index) {
%c0 = arith.constant 0 : index
@@ -110,6 +140,22 @@
// -----
+// CHECK-LABEL: @foldAffineAlign
+func.func @foldAffineAlign(%arg0: index) -> (index, index) {
+ // CHECK: %[[A0:.+]] = affine.apply affine_map<()[s0] -> (s0 * 16384)>()[%arg0]
+ %a0 = affine.apply affine_map<()[s0] -> (s0 * 16384)>()[%arg0]
+ %c64 = arith.constant 64 : index
+ %a1 = util.align %a0, %c64 : index
+ // CHECK: %[[B0:.+]] = affine.apply affine_map<()[s0] -> ((s0 * s0) * 4)>()[%arg0]
+ %b0 = affine.apply affine_map<()[s0] -> ((s0 * s0) * 4)>()[%arg0]
+ %c4 = arith.constant 4 : index
+ %b1 = util.align %b0, %c4 : index
+ // CHECK: return %[[A0]], %[[B0]]
+ return %a1, %b1 : index, index
+}
+
+// -----
+
// CHECK-LABEL: @sizeofWholeInt
func.func @sizeofWholeInt() -> index {
// CHECK: = arith.constant 4 : index