[LinalgExt] Add IndexingMapOpInterface to ArgCompareOp (#24173)
This PR adds IndexingMapOpInterface to ArgCompareOp so that we can
refactor the `setReductionConfig` to work for both `linalg::LinalgOp`
and `LinalgExt::ArgCompareOp`.
Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
index 0c333e2..04d6c57 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -1617,6 +1617,20 @@
return llvm::to_vector(getInputType().getShape());
}
+ArrayAttr IREE::LinalgExt::ArgCompareOp::getIndexingMaps() {
+ SmallVector<AffineMap> maps = getIndexingMapsArray();
+ return Builder(getContext()).getAffineMapArrayAttr(maps);
+}
+
+AffineMap
+IREE::LinalgExt::ArgCompareOp::getMatchingIndexingMap(OpOperand *operand) {
+ SmallVector<AffineMap> maps = getIndexingMapsArray();
+ unsigned idx = operand->getOperandNumber();
+ assert(idx < maps.size() &&
+ "operand does not have an indexing map (e.g. index_base)");
+ return maps[idx];
+}
+
MutableOperandRange ArgCompareOp::getDpsInitsMutable() {
return MutableOperandRange(*this, /*numInputs=*/getInputIndex() ? 2 : 1,
/*numInits=*/2);
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
index 0daaf22..82d588c 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
@@ -785,6 +785,7 @@
DeclareOpInterfaceMethods<LinalgFusionInterface,
["getIndexingMapsForResults", "getIndexingMapsForOperands",
"getStaticLoopRanges"]>,
+ DeclareOpInterfaceMethods<IndexingMapOpInterface, ["getMatchingIndexingMap"]>,
DeclareOpInterfaceMethods<LinalgExtInterface>,
DeclareOpInterfaceMethods<TilingInterface,
["generateScalarImplementation",