Revert "[LinalgExt] Add online_attention op" (#17658)
Reverts iree-org/iree#17536
This caused `sdxl-scheduled-unet-3-tank` to hit timeouts when compiling
for cpu:
https://github.com/iree-org/iree/actions/runs/9484305572/job/26134004282
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
index cc07354..9ae470d 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
@@ -890,7 +890,7 @@
/// Splits the tile sizes in `parallelSizes` into `reductionSizes` for the
/// reduction loops.
static void splitParallelAndReductionTiles(
- Operation *op, SmallVectorImpl<int64_t> ¶llelSizes,
+ linalg::LinalgOp op, SmallVectorImpl<int64_t> ¶llelSizes,
SmallVectorImpl<int64_t> &reductionSizes,
SmallVectorImpl<bool> *parallelScalableFlags = nullptr,
SmallVectorImpl<bool> *reductionScalableFlags = nullptr) {
@@ -900,9 +900,8 @@
reductionScalableFlags->assign(parallelScalableFlags->begin(),
parallelScalableFlags->end());
}
- TilingInterface tilingOp = cast<TilingInterface>(op);
for (auto [index, iteratorType] :
- llvm::enumerate(tilingOp.getLoopIteratorTypes())) {
+ llvm::enumerate(op.getIteratorTypesArray())) {
if (iteratorType == utils::IteratorType::parallel) {
reductionSizes[index] = 0;
if (reductionScalableFlags)
@@ -1122,9 +1121,9 @@
SmallVector<int64_t> parallelTileSizes = vecTileSizes;
SmallVector<int64_t> reductionTileSizes;
SmallVector<bool> reductionScalableFlags;
- splitParallelAndReductionTiles(op, parallelTileSizes, reductionTileSizes,
- ¶llelScalableFlags,
- &reductionScalableFlags);
+ splitParallelAndReductionTiles(
+ cast<linalg::LinalgOp>(op.getOperation()), parallelTileSizes,
+ reductionTileSizes, ¶llelScalableFlags, &reductionScalableFlags);
if (vecPreProcStrategy == VectorPreProcStrategy::None) {
setVectorSizesForDynamicShapes(cast<linalg::LinalgOp>(op.getOperation()),
@@ -1752,13 +1751,14 @@
// Batch, M and N (parallel dimensions) are distributed on workgroups.
DistributionHeuristicConfig config;
- SmallVector<int64_t> distTileSizes =
- getDefaultDistributedLevelTileSizes(attnOp, config);
+ SmallVector<int64_t> distTileSizes = getDefaultDistributedLevelTileSizes(
+ attnOp, DistributionHeuristicConfig{});
// Batch, M and N (parallel dimensions) are distributed on workgroups.
SmallVector<int64_t> vecTileSizes(attnOp.getIterationDomainRank(), 1);
- // Mark k1 reduction dimensions not to distribute.
- for (int i : opInfo.getK1Dims()) {
+ // Mark reduction dimensions not to distribute.
+ for (int64_t i :
+ llvm::concat<const int64_t>(opInfo.getK1Dims(), opInfo.getK2Dims())) {
vecTileSizes[i] = 0;
}
int64_t vectorSize = getVectorSize(entryPointFn, attnOp.getOutputType());
@@ -1773,17 +1773,18 @@
/*numElem=*/tileSize, vectorSize, vectorSize);
}
- SmallVector<int64_t> parallelTileSizes = vecTileSizes;
- SmallVector<int64_t> reductionTileSizes;
- splitParallelAndReductionTiles(attnOp, parallelTileSizes, reductionTileSizes);
+ // TODO (17467): Due to a bug in TileAndDecomposeAttention, N dimension
+ // cannot be tiled. Remove this once fixed.
+ for (int64_t i : opInfo.getNDims()) {
+ distTileSizes[i] = 0;
+ vecTileSizes[i] = 0;
+ }
- LLVM_DEBUG(KD_DBGS() << "Vectorization/unrolling tile sizes (parallel): "
- << parallelTileSizes << "\n");
- LLVM_DEBUG(KD_DBGS() << "Vectorization/unrolling tile sizes (reduction): "
- << reductionTileSizes << "\n");
+ TileSizesListType tileSizes = {distTileSizes, vecTileSizes};
- TileSizesListType tileSizes = {distTileSizes, parallelTileSizes,
- reductionTileSizes};
+ // TODO: (Groverkss): Tile K2 here using reduction tiling interface once we
+ // have it. TileAndDecomposeAttention pass only tiles K2. I think it should
+ // be possible to tile K1 also, but need to explore it more.
return setOpConfigAndEntryPointFnTranslation(
entryPointFn, attnOp, tileSizes,
@@ -1842,9 +1843,6 @@
tileSizes.push_back(distTileSizes);
SmallVector<int64_t> vecTileSizes(iterationRank, 1);
tileSizes.push_back(vecTileSizes);
- // Dummy tiling config for reduction level.
- SmallVector<int64_t> reductionTileSizes(iterationRank, 0);
- tileSizes.push_back(reductionTileSizes);
return setOpConfigAndEntryPointFnTranslation(
entryPointFn, winogradOp, tileSizes,
DispatchLoweringPassPipeline::CPULinalgExtTileAndVectorize);
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
index 6a4363f..31fcead 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
@@ -617,13 +617,10 @@
createLLVMCPUTilePass(tilingConfig.getVectorCommonParallelLevel()));
// TODO: Remove the pass once we have PartialReductionOpInterface implemented
// for AttentionOp.
- funcPassManager.addPass(
- IREE::LinalgExt::createConvertAttentionToOnlineAttentionPass());
- funcPassManager.addPass(
- createLLVMCPUTilePass(tilingConfig.getVectorReductionLevel()));
+ funcPassManager.addPass(IREE::LinalgExt::createTileAttentionPass());
+ funcPassManager.addPass(IREE::LinalgExt::createDecomposeAttentionPass());
funcPassManager.addPass(
IREE::LinalgExt::createDecomposeWinogradTransformPass());
- funcPassManager.addPass(IREE::LinalgExt::createDecomposeAttentionPass());
{
GenericVectorizationPassOptions options;
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_x86_64_lowering_strategy.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_x86_64_lowering_strategy.mlir
index 6777e97..4df5605 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_x86_64_lowering_strategy.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_x86_64_lowering_strategy.mlir
@@ -1531,7 +1531,7 @@
return
}
}
-// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[0, 1, 6, 64], [1, 1, 1, 1], [0, 0, 0, 0]]>
+// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[0, 1, 6, 64], [1, 1, 1, 1]]>
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPULinalgExtTileAndVectorize>
// CHECK: func.func @winograd_output_transform()
// CHECK-SAME: translation_info = #[[TRANSLATION]]
@@ -1556,7 +1556,7 @@
return
}
}
-// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[0, 1, 6, 64], [1, 1, 1, 1], [0, 0, 0, 0]]>
+// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[0, 1, 6, 64], [1, 1, 1, 1]]>
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPULinalgExtTileAndVectorize>
// CHECK: func.func @winograd_input_transform()
// CHECK-SAME: translation_info = #[[TRANSLATION]]
@@ -1581,7 +1581,7 @@
return
}
}
-// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[8, 64], [1, 1], [0, 0]]>
+// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[8, 64], [1, 1]]>
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPULinalgExtTileAndVectorize>
// CHECK: func.func @winograd_filter_transform()
// CHECK-SAME: translation_info = #[[TRANSLATION]]
@@ -1613,7 +1613,7 @@
return
}
}
-// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[20, 64, 0, 0, 64], [20, 32, 0, 0, 32], [0, 0, 0, 32, 0]]>
+// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[20, 64, 0, 0, 0], [20, 32, 0, 0, 0]]>
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPULinalgExtTileAndVectorize>
// CHECK: func.func @attention()
// CHECK-SAME: translation_info = #[[TRANSLATION]]
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/BUILD.bazel b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/BUILD.bazel
index a771ee0..8675c43 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/BUILD.bazel
@@ -29,7 +29,6 @@
"@llvm-project//mlir:ControlFlowInterfacesTdFiles",
"@llvm-project//mlir:DestinationStyleOpInterfaceTdFiles",
"@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
- "@llvm-project//mlir:LinalgOpsTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:PDLDialectTdFiles",
"@llvm-project//mlir:SideEffectInterfacesTdFiles",
@@ -160,9 +159,7 @@
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "LinalgExtOps.td",
- deps = [
- ":td_files",
- ],
+ deps = [":td_files"],
)
iree_gentbl_cc_library(
@@ -215,7 +212,5 @@
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "LinalgExtOps.td",
- deps = [
- ":td_files",
- ],
+ deps = [":td_files"],
)
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
index 34ce4f4..c5c42ec 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -8,7 +8,6 @@
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
-#include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
@@ -1316,9 +1315,6 @@
for (auto [i, dimExpr] : llvm::enumerate(indexingMap.getResults())) {
AffineDimExpr dim = cast<AffineDimExpr>(dimExpr);
int64_t pos = dim.getPosition();
- if (ShapedType::isDynamic(valShape[i])) {
- continue;
- }
if (!foundDims[pos]) {
foundDims[pos] = true;
shape[pos] = valShape[i];
@@ -1431,79 +1427,6 @@
return results;
}
-//===----------------------------------------------------------------------===//
-// OnlineAttentionOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult OnlineAttentionOp::verify() {
- OnlineAttentionOp attnOp = *this;
-
- SmallVector<AffineMap> indexingMaps = attnOp.getIndexingMapsArray();
-
- // Check if indexing maps can represent attention.
- FailureOr<AttentionOpDetail> maybeOpInfo =
- AttentionOpDetail::get(indexingMaps);
-
- // Check shape compatibility based on indexing maps.
- SmallVector<int64_t> shape(getIterationDomainRank());
- SmallVector<bool> foundDims(getIterationDomainRank(), false);
- auto checkShape = [&shape, &foundDims,
- &attnOp](StringRef operandName, ArrayRef<int64_t> valShape,
- AffineMap indexingMap) -> LogicalResult {
- if (indexingMap.getNumResults() != valShape.size()) {
- return attnOp->emitError("Rank Mismatch for ")
- << operandName << ". Expected: " << indexingMap.getNumResults()
- << " Got: " << valShape.size();
- }
- for (auto [i, dimExpr] : llvm::enumerate(indexingMap.getResults())) {
- AffineDimExpr dim = cast<AffineDimExpr>(dimExpr);
- int64_t pos = dim.getPosition();
- if (ShapedType::isDynamic(valShape[i])) {
- continue;
- }
- if (!foundDims[pos]) {
- foundDims[pos] = true;
- shape[pos] = valShape[i];
- }
- if (shape[pos] != valShape[i]) {
- return attnOp->emitError("Shape Mismatch for ")
- << operandName << ". Expected: " << shape[pos]
- << " Got: " << valShape[i];
- }
- }
- return success();
- };
-
- if (failed(checkShape("Query", getQuery().getType().getShape(),
- getQueryMap())) ||
- failed(checkShape("Key", getKey().getType().getShape(), getKeyMap())) ||
- failed(checkShape("Value", getValue().getType().getShape(),
- getValueMap())) ||
- failed(checkShape("Output", getOutput().getType().getShape(),
- getOutputMap())) ||
- failed(checkShape("Max", getMax().getType().getShape(), getMaxMap())) ||
- failed(checkShape("Sum", getSum().getType().getShape(), getSumMap()))) {
- return failure();
- }
-
- return success();
-}
-
-MutableOperandRange OnlineAttentionOp::getDpsInitsMutable() {
- return MutableOperandRange(*this, /*numInputs=*/4, /*numInits=*/3);
-}
-
-LogicalResult OnlineAttentionOp::reifyResultShapes(
- OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
- return cast<LinalgExtOp>(getOperation())
- .reifyResultShapes(b, reifiedReturnShapes);
-}
-
-SmallVector<AffineMap> OnlineAttentionOp::getIndexingMapsArray() {
- return SmallVector<AffineMap>(
- getIndexingMaps().getAsValueRange<AffineMapAttr>());
-}
-
#define DEFINE_OP_GET_EFFECTS(OP_NAME) \
void OP_NAME::getEffects( \
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> \
@@ -1523,7 +1446,6 @@
DEFINE_OP_GET_EFFECTS(WinogradFilterTransformOp)
DEFINE_OP_GET_EFFECTS(WinogradOutputTransformOp)
DEFINE_OP_GET_EFFECTS(AttentionOp)
-DEFINE_OP_GET_EFFECTS(OnlineAttentionOp)
} // namespace mlir::iree_compiler::IREE::LinalgExt
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h
index 3d52ae6..97caaab 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h
@@ -7,7 +7,6 @@
#ifndef IREE_COMPILER_DIALECT_LINALGEXT_IR_LINALGEXTOPS_H_
#define IREE_COMPILER_DIALECT_LINALGEXT_IR_LINALGEXTOPS_H_
-#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/IR/Attributes.h"
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
index 0eebd2e..bf9694d 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
@@ -9,7 +9,6 @@
include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtBase.td"
include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.td"
-include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
@@ -679,96 +678,6 @@
}];
}
-//===----------------------------------------------------------------------===//
-// OnlineAttention
-//===----------------------------------------------------------------------===//
-
-def IREELinalgExt_OnlineAttentionOp : IREELinalgExt_PureOp<"online_attention",
- [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
- DestinationStyleOpInterface, LinalgExtInterface,
- DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
- DeclareOpInterfaceMethods<AggregatedOpInterface, ["decomposeOperation"]>,
- DeclareOpInterfaceMethods<TilingInterface,
- ["getIterationDomain",
- "getLoopIteratorTypes",
- "getResultTilePosition",
- "getTiledImplementation"]>]> {
- let summary = "Online Attention operator";
- let description = [{
- Traditional scaled dot product attention computes:
-
- attention(Q, K, V, scale) = softmax(Q @ K.T * scale) @ V
-
- Online Attention on the other hand, uses an online normalizer instead of
- softmax:
-
- online_attention(Q, K, V, scale, running_max, running_sum)
- = online_normalizer(Q @ K.T * scale, running_max, running_sum) @ V
-
- The advantage of this online_normalizer is that it can be tiled along
- it's reduction dimension, making the online_attention operator:
- - Tilable along softmax reduction dimension
- - Associative along softmax reduction dimension
- - Commutative along softmax associative dimension
-
- Note: The results of online_attention need to be combined after computing
- it over the entire softmax reduction dimension by:
- x, _, sum : results
- x = (1 / sum) * x
- }];
-
- let arguments = (ins AnyShaped:$query,
- AnyShaped:$key,
- AnyShaped:$value,
- AnyFloat:$scale,
- AnyShaped:$output,
- AnyShaped:$max,
- AnyShaped:$sum,
- AffineMapArrayAttr:$indexing_maps
- );
-
- let results = (outs Variadic<AnyRankedTensor>:$results);
- let hasVerifier = 1;
- let hasCustomAssemblyFormat = 1;
- let assemblyFormat = [{
- attr-dict
- `ins` `(` $query `,` $key `,` $value `,` $scale `:` type($query) `,` type($key) `,` type($value) `,` type($scale) `)`
- `outs` `(` $output `,` $max `,` $sum `:` type($output) `,` type($max) `,` type($sum) `)`
- (`->` type($results)^)?
- }];
-
- let extraClassDeclaration = [{
- // Method to implement for specifying output range for
- // DestinationStyleOpInterface
- MutableOperandRange getDpsInitsMutable();
-
- SmallVector<AffineMap> getIndexingMapsArray();
-
- AffineMap getQueryMap() {
- return getIndexingMapsArray()[0];
- }
- AffineMap getKeyMap() {
- return getIndexingMapsArray()[1];
- }
- AffineMap getValueMap() {
- return getIndexingMapsArray()[2];
- }
- AffineMap getOutputMap() {
- return getIndexingMapsArray()[3];
- }
- AffineMap getMaxMap() {
- return getIndexingMapsArray()[4];
- }
- AffineMap getSumMap() {
- return getIndexingMapsArray()[5];
- }
-
- int64_t getIterationDomainRank() {
- return getQueryMap().getNumDims();
- }
- }];
-}
-
} // OpGroupNonStructuredOps
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.cpp
index 923b30a..00bb383 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.cpp
@@ -51,17 +51,5 @@
return DiagnosedSilenceableFailure::success();
}
-DiagnosedSilenceableFailure LinalgExt::ConvertToOnlineAttention::applyToOne(
- transform::TransformRewriter &rewriter, LinalgExt::AttentionOp attentionOp,
- transform::ApplyToEachResultList &results,
- transform::TransformState &state) {
- SmallVector<Operation *> ops;
- LinalgExt::convertToOnlineAttention(attentionOp, ops, rewriter);
- for (Operation *op : ops) {
- results.push_back(op);
- }
- return DiagnosedSilenceableFailure::success();
-}
-
#define GET_OP_CLASSES
#include "iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.cpp.inc"
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.td
index 84e588a..c3a6310 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.td
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.td
@@ -86,36 +86,4 @@
}];
}
-def ConvertToOnlineAttention : Op<Transform_Dialect, "iree.convert_to_online_attention",
- [FunctionalStyleTransformOpTrait,
- MemoryEffectsOpInterface,
- TransformOpInterface,
- TransformEachOpTrait,
- ReportTrackingListenerFailuresOpTrait]> {
- let description = [{
- Target iree_linalg_ext.attention ops and decompose them.
- This transform consumes the target handle and produces a result handle.
- }];
-
- let arguments = (
- ins TransformHandleTypeInterface:$target
- );
- let results = (outs Variadic<TransformHandleTypeInterface>:$result);
-
- let assemblyFormat = "attr-dict $target `:` functional-type(operands, results)";
- let cppNamespace = "mlir::iree_compiler::IREE::LinalgExt";
-
- let assemblyFormat = [{
- $target attr-dict `:` functional-type(operands, results)
- }];
-
- let extraClassDeclaration = [{
- ::mlir::DiagnosedSilenceableFailure applyToOne(
- ::mlir::transform::TransformRewriter &rewriter,
- ::mlir::iree_compiler::IREE::LinalgExt::AttentionOp target,
- ::mlir::transform::ApplyToEachResultList &results,
- ::mlir::transform::TransformState &state);
- }];
-}
-
#endif // IREE_DIALECT_LINALGEXT_TRANSFORMOPS
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp
deleted file mode 100644
index b4370e7..0000000
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp
+++ /dev/null
@@ -1,228 +0,0 @@
-// Copyright 2024 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
-#include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SmallVector.h"
-#include "mlir/Dialect/Arith/Utils/Utils.h"
-#include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/Dialect/Linalg/Utils/Utils.h"
-#include "mlir/Dialect/Math/IR/Math.h"
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
-
-namespace mlir::iree_compiler::IREE::LinalgExt {
-
-static Value scaleValueInPlace(OpBuilder &builder, Location loc,
- AffineMap inputMap, AffineMap scaleMap,
- Value value, Value scale) {
- SmallVector<AffineMap> compressedMaps =
- compressUnusedDims(SmallVector<AffineMap>{inputMap, scaleMap});
- inputMap = compressedMaps[0];
- scaleMap = compressedMaps[1];
-
- SmallVector<utils::IteratorType> iteratorTypes(inputMap.getNumDims(),
- utils::IteratorType::parallel);
-
- auto genericOp = builder.create<linalg::GenericOp>(
- loc, value.getType(), scale, value,
- SmallVector<AffineMap>{scaleMap, inputMap}, iteratorTypes,
- [&](OpBuilder &b, Location loc, ValueRange args) {
- // Convert scale to the same datatype as input.
- Value scale = convertScalarToDtype(b, loc, args[0], args[1].getType(),
- /*isUnsignedCast=*/false);
- Value result = b.create<arith::MulFOp>(loc, scale, args[1]);
- b.create<linalg::YieldOp>(loc, result);
- });
- return genericOp.getResult(0);
-}
-
-template <typename T>
-static Value reduce(OpBuilder &builder, Location loc, AffineMap inputMap,
- AffineMap outputMap, Value input, Value output) {
- SmallVector<AffineMap> compressedMaps =
- compressUnusedDims(SmallVector<AffineMap>{inputMap, outputMap});
- inputMap = compressedMaps[0];
- outputMap = compressedMaps[1];
-
- // Dims not present in outputMap are reductionDims.
- SmallVector<utils::IteratorType> iteratorTypes(
- inputMap.getNumDims(), utils::IteratorType::reduction);
- for (AffineExpr dim : outputMap.getResults()) {
- int pos = cast<AffineDimExpr>(dim).getPosition();
- iteratorTypes[pos] = utils::IteratorType::parallel;
- }
-
- auto genericOp = builder.create<linalg::GenericOp>(
- loc, output.getType(), input, output,
- SmallVector<AffineMap>{inputMap, outputMap}, iteratorTypes,
- [&](OpBuilder &b, Location loc, ValueRange args) {
- // Convert input to the same datatype as acc.
- Value in = convertScalarToDtype(b, loc, args[0], args[1].getType(),
- /*isUnsignedCast=*/false);
- Value result = b.create<T>(loc, in, args[1]);
- b.create<linalg::YieldOp>(loc, result);
- });
-
- return genericOp.getResult(0);
-}
-
-static Value computeMatmul(OpBuilder &builder, Location loc, AffineMap lhsMap,
- AffineMap rhsMap, AffineMap accMap, Value lhs,
- Value rhs, Value acc) {
-
- SmallVector<AffineMap> compressedMaps =
- compressUnusedDims(SmallVector<AffineMap>{lhsMap, rhsMap, accMap});
- lhsMap = compressedMaps[0];
- rhsMap = compressedMaps[1];
- accMap = compressedMaps[2];
-
- // Dims not present in accMap are reduction dims.
- SmallVector<utils::IteratorType> iteratorTypes(
- accMap.getNumDims(), utils::IteratorType::reduction);
- for (AffineExpr dim : accMap.getResults()) {
- int pos = cast<AffineDimExpr>(dim).getPosition();
- iteratorTypes[pos] = utils::IteratorType::parallel;
- }
-
- auto genericOp = builder.create<linalg::GenericOp>(
- loc, acc.getType(), SmallVector<Value>{lhs, rhs}, acc,
- SmallVector<AffineMap>{lhsMap, rhsMap, accMap}, iteratorTypes,
- [&](OpBuilder &b, Location loc, ValueRange args) {
- // Cast inputs to match output datatype.
- Value lhs = convertScalarToDtype(b, loc, args[0], args[2].getType(),
- /*isUnsignedCast=*/false);
- Value rhs = convertScalarToDtype(b, loc, args[1], args[2].getType(),
- /*isUnsignedCast=*/false);
- Value mul = b.create<arith::MulFOp>(loc, lhs, rhs);
- Value add = b.create<arith::AddFOp>(loc, mul, args[2]);
- b.create<linalg::YieldOp>(loc, add);
- });
-
- return genericOp.getResult(0);
-}
-
-// Compute output = exp2(output - input)
-static Value computeSubAndExp2(OpBuilder &builder, Location loc,
- AffineMap inputMap, AffineMap outputMap,
- Value input, Value output) {
- SmallVector<AffineMap> compressedMaps =
- compressUnusedDims(SmallVector<AffineMap>{inputMap, outputMap});
- inputMap = compressedMaps[0];
- outputMap = compressedMaps[1];
-
- SmallVector<utils::IteratorType> iteratorTypes(inputMap.getNumDims(),
- utils::IteratorType::parallel);
- auto genericOp = builder.create<linalg::GenericOp>(
- loc, output.getType(), input, output,
- SmallVector<AffineMap>{inputMap, outputMap}, iteratorTypes,
- [&](OpBuilder &b, Location loc, ValueRange args) {
- // Convert input to the same datatype as output.
- Value in = convertScalarToDtype(b, loc, args[0], args[1].getType(),
- /*isUnsignedCast=*/false);
- Value diff = b.create<arith::SubFOp>(loc, args[1], in);
- Value weight = b.create<math::Exp2Op>(loc, diff);
- b.create<linalg::YieldOp>(loc, weight);
- });
- return genericOp.getResult(0);
-}
-
-FailureOr<SmallVector<Value>>
-OnlineAttentionOp::decomposeOperation(OpBuilder &b) {
- Location loc = getLoc();
- Value query = getQuery();
- Value key = getKey();
- Value value = getValue();
- Value oldAcc = getOutput();
- Value oldMax = getMax();
- Value oldSum = getSum();
- Type elementType = getQuery().getType().getElementType();
-
- FailureOr<AttentionOpDetail> maybeOpInfo =
- AttentionOpDetail::get(getIndexingMapsArray());
- assert(succeeded(maybeOpInfo) && "Invalid attention indexing maps");
- AttentionOpDetail opInfo = maybeOpInfo.value();
-
- SmallVector<OpFoldResult> sizes = llvm::map_to_vector(
- getIterationDomain(b), [](Range x) { return x.size; });
-
- // Since we use exp2 for attention instead of the original exp, we have to
- // multiply the scale by log2(e). We use exp2 instead of exp as most platforms
- // have better support for exp2 (we verified that we gain some speedup on
- // some GPUs).
- Value scale = getScale();
- Value log2e =
- b.create<arith::ConstantOp>(loc, b.getFloatAttr(elementType, M_LOG2E));
- scale = b.create<arith::MulFOp>(loc, scale, log2e);
-
- // In the original algorithm, the scaling is done after the softmax:
- // softmax(Q @ K.T * scale) @ V
- //
- // But, it is mathematically equivalent to do it on Q first and then multiply
- // it by K.T. This just allows us to do the scaling once, instead of each
- // iteration of the loop.
- AffineMap qMap = getQueryMap();
- AffineMap scaleMap = AffineMap::get(/*dimCount=*/qMap.getNumInputs(),
- /*symbolCount=*/0, getContext());
- query = scaleValueInPlace(b, loc, qMap, scaleMap, query, scale);
-
- // ---- Matmul 1 ----
-
- // Get sizes for S.
- AffineMap sMap = opInfo.getSMap();
- SmallVector<OpFoldResult> sSizes;
- for (AffineExpr dimExpr : sMap.getResults()) {
- int dim = cast<AffineDimExpr>(dimExpr).getPosition();
- sSizes.push_back(sizes[dim]);
- }
-
- // S = Q @ K
- // SMap = QMap @ KMap
- Value emptyS = b.create<tensor::EmptyOp>(loc, sSizes, elementType);
- Value sZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));
- Value s = b.create<linalg::FillOp>(loc, sZero, emptyS).getResult(0);
- s = computeMatmul(b, loc, getQueryMap(), getKeyMap(), sMap, query, key, s);
-
- // TODO: This decomposition should be in a seperate op called
- // "online softmax".
- // ---- Online Softmax ----
-
- // newMax = max(oldMax, rowMax(S))
- AffineMap maxMap = getMaxMap();
- Value newMax = reduce<arith::MaximumFOp>(b, loc, sMap, maxMap, s, oldMax);
-
- // P = exp2(S - newMax)
- // PMap = SMap
- AffineMap pMap = sMap;
- Value p = computeSubAndExp2(b, loc, maxMap, sMap, newMax, s);
-
- // norm = exp2(oldMax - newMax)
- // normMap = maxMap
- AffineMap normMap = getMaxMap();
- Value norm = computeSubAndExp2(b, loc, maxMap, normMap, newMax, oldMax);
-
- // normSum = norm * oldSum
- AffineMap sumMap = getSumMap();
- Value normSum = scaleValueInPlace(b, loc, sumMap, normMap, oldSum, norm);
-
- // newSum = normSum + rowMax(P)
- Value newSum = reduce<arith::AddFOp>(b, loc, pMap, sumMap, p, normSum);
-
- // newAcc = norm * oldAcc
- AffineMap accMap = getOutputMap();
- Value newAcc = scaleValueInPlace(b, loc, accMap, normMap, oldAcc, norm);
-
- // ---- Matmul 2 ----
-
- // newAcc = P @ V + newAcc
- newAcc = computeMatmul(b, loc, pMap, getValueMap(), accMap, p, value, newAcc);
-
- return SmallVector<Value>{newAcc, newMax, newSum};
-}
-
-} // namespace mlir::iree_compiler::IREE::LinalgExt
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/BUILD.bazel
index 1a20a8f..1d08da9 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/BUILD.bazel
@@ -30,7 +30,6 @@
iree_compiler_cc_library(
name = "Transforms",
srcs = [
- "AggregatedOpInterfaceImpl.cpp",
"ConvertConv2DToWinograd.cpp",
"ConvertToLoops.cpp",
"DecomposeAttention.cpp",
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/CMakeLists.txt
index 19c7522..668d28a 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/CMakeLists.txt
@@ -26,7 +26,6 @@
"Passes.h"
"Passes.h.inc"
SRCS
- "AggregatedOpInterfaceImpl.cpp"
"ConvertConv2DToWinograd.cpp"
"ConvertToLoops.cpp"
"DecomposeAttention.cpp"
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp
index 2cd851d..c70000f 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp
@@ -366,16 +366,6 @@
SmallVector<Operation *> ops;
decomposeTiledAttention(attnOp, ops, rewriter, optionalTileSize);
});
- getOperation().walk([&](OnlineAttentionOp onlineAtt) {
- rewriter.setInsertionPoint(onlineAtt);
- FailureOr<SmallVector<Value>> results =
- onlineAtt.decomposeOperation(rewriter);
- if (failed(results)) {
- onlineAtt->emitOpError("Could not decompose online attention");
- return signalPassFailure();
- }
- rewriter.replaceOp(onlineAtt, results.value());
- });
}
std::unique_ptr<Pass> createDecomposeAttentionPass() {
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h
index 165430a..43d44a3 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h
@@ -56,18 +56,12 @@
RewriterBase &rewriter,
std::optional<uint64_t> tileSize = std::nullopt);
-void convertToOnlineAttention(IREE::LinalgExt::AttentionOp attnOp,
- SmallVectorImpl<Operation *> &ops,
- RewriterBase &rewriter);
-
// Creates a pass to tile the attention op along the reduction dim.
std::unique_ptr<Pass> createTileAttentionPass();
// Creates a pass to convert the attention op into a sequence of linalg ops.
std::unique_ptr<Pass> createDecomposeAttentionPass();
-std::unique_ptr<Pass> createConvertAttentionToOnlineAttentionPass();
-
//===---------------------------------------------------------------------===//
// Codegen Strategy passes that are moved into IREE.
//===---------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td
index ee801b6..77bb105 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td
@@ -98,12 +98,4 @@
];
}
-def ConvertAttentionToOnlineAttention :
- InterfacePass<"iree-linalg-ext-convert-attention-to-online-attention",
- "mlir::FunctionOpInterface"> {
- let summary = "";
- let constructor = "mlir::iree_compiler::IREE::LinalgExt::"
- "createConvertAttentionToOnlineAttentionPass()";
-}
-
#endif // IREE_DIALECT_LINALGEXT_PASSES
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp
index df0d87e..4c862b5 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp
@@ -8,7 +8,6 @@
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
-#include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
@@ -159,17 +158,6 @@
void runOnOperation() override;
};
-struct ConvertAttentionToOnlineAttentionPass final
- : ConvertAttentionToOnlineAttentionBase<
- ConvertAttentionToOnlineAttentionPass> {
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry
- .insert<affine::AffineDialect, IREE::LinalgExt::IREELinalgExtDialect,
- linalg::LinalgDialect, tensor::TensorDialect>();
- }
- void runOnOperation() override;
-};
-
} // namespace
/// Tile iree_linalg_ext.attention.
@@ -304,103 +292,6 @@
return tiledAttentionOp;
}
-void convertToOnlineAttention(IREE::LinalgExt::AttentionOp attnOp,
- SmallVectorImpl<Operation *> &ops,
- RewriterBase &rewriter) {
- rewriter.setInsertionPoint(attnOp);
-
- Location loc = attnOp.getLoc();
- MLIRContext *ctx = attnOp.getContext();
-
- FailureOr<AttentionOpDetail> maybeOpInfo =
- AttentionOpDetail::get(attnOp.getIndexingMapsArray());
- assert(succeeded(maybeOpInfo) && "Invalid attention indexing maps");
- AttentionOpDetail opInfo = maybeOpInfo.value();
-
- // Create standard maps for max and sum: (batch, m)
- int64_t rank = opInfo.getDomainRank();
- AffineMap maxMap = AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, ctx);
- for (auto dim :
- llvm::concat<const int64_t>(opInfo.getBatchDims(), opInfo.getMDims())) {
- maxMap = maxMap.insertResult(rewriter.getAffineDimExpr(dim),
- maxMap.getNumResults());
- }
- AffineMap sumMap = maxMap;
-
- SmallVector<Range> sizes = attnOp.getIterationDomain(rewriter);
-
- // Create fill for acc, max and sum.
- // TODO: Acc should not need a fill. The attention op should get a filled
- // input instead of an empty input.
- Value zeroAcc = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getZeroAttr(attnOp.getOutputType().getElementType()));
- Value accFill =
- rewriter
- .create<linalg::FillOp>(loc, ValueRange{zeroAcc}, attnOp.getOutput())
- .result();
-
- SmallVector<OpFoldResult> rowRedSize =
- llvm::map_to_vector(sizes, [](Range x) { return x.size; });
- rowRedSize = applyPermutationMap<OpFoldResult>(maxMap, rowRedSize);
-
- Type f32Type = rewriter.getF32Type();
- Value rowRedEmpty =
- rewriter.create<tensor::EmptyOp>(loc, rowRedSize, f32Type);
-
- Value maxInit =
- arith::getIdentityValue(arith::AtomicRMWKind::maximumf, f32Type, rewriter,
- loc, /*useOnlyFiniteValue=*/true);
- Value sumInit = arith::getIdentityValue(arith::AtomicRMWKind::addf, f32Type,
- rewriter, loc);
-
- Value maxFill =
- rewriter.create<linalg::FillOp>(loc, ValueRange{maxInit}, rowRedEmpty)
- .getResult(0);
- Value sumFill =
- rewriter.create<linalg::FillOp>(loc, ValueRange{sumInit}, rowRedEmpty)
- .getResult(0);
-
- // Create online attention op.
- SmallVector<AffineMap> indexingMaps = attnOp.getIndexingMapsArray();
- indexingMaps.push_back(maxMap);
- indexingMaps.push_back(sumMap);
- OnlineAttentionOp onlineAttn = rewriter.create<OnlineAttentionOp>(
- loc, TypeRange{accFill.getType(), maxFill.getType(), sumFill.getType()},
- attnOp.getQuery(), attnOp.getKey(), attnOp.getValue(), attnOp.getScale(),
- accFill, maxFill, sumFill, rewriter.getAffineMapArrayAttr(indexingMaps));
- onlineAttn->setDiscardableAttrs(attnOp->getDiscardableAttrDictionary());
- ops.push_back(onlineAttn);
-
- Value x = onlineAttn.getResult(0);
- Value sum = onlineAttn.getResult(2);
-
- // Merge the outputs of online attention:
- // x = (1 / sum) * x
-
- // Compress the indexing maps.
- SmallVector<AffineMap> compressedMaps =
- compressUnusedDims(SmallVector<AffineMap>{sumMap, attnOp.getOutputMap()});
-
- SmallVector<utils::IteratorType> iteratorTypes(compressedMaps[0].getNumDims(),
- utils::IteratorType::parallel);
-
- auto genericOp = rewriter.create<linalg::GenericOp>(
- loc, x.getType(), sum, x, compressedMaps, iteratorTypes,
- [&](OpBuilder &b, Location loc, ValueRange args) {
- Value one = b.create<arith::ConstantOp>(
- loc, b.getFloatAttr(args[0].getType(), 1.0));
- Value reciprocal = b.create<arith::DivFOp>(loc, one, args[0]);
- // Convert sum to the same datatype as x.
- reciprocal = convertScalarToDtype(b, loc, reciprocal, args[1].getType(),
- /*isUnsignedCast=*/false);
- Value result = b.create<arith::MulFOp>(loc, reciprocal, args[1]);
- b.create<linalg::YieldOp>(loc, result);
- });
- ops.push_back(genericOp);
-
- rewriter.replaceOp(attnOp, genericOp);
-}
-
void TileAttentionPass::runOnOperation() {
MLIRContext *context = &getContext();
IRRewriter rewriter(context);
@@ -414,21 +305,8 @@
});
}
-void ConvertAttentionToOnlineAttentionPass::runOnOperation() {
- MLIRContext *context = &getContext();
- IRRewriter rewriter(context);
- getOperation().walk([&](AttentionOp attnOp) {
- SmallVector<Operation *> ops;
- convertToOnlineAttention(attnOp, ops, rewriter);
- });
-}
-
std::unique_ptr<Pass> createTileAttentionPass() {
return std::make_unique<TileAttentionPass>();
}
-std::unique_ptr<Pass> createConvertAttentionToOnlineAttentionPass() {
- return std::make_unique<ConvertAttentionToOnlineAttentionPass>();
-}
-
} // namespace mlir::iree_compiler::IREE::LinalgExt
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp
index 2ecfd80..6b291c7 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp
@@ -4,9 +4,11 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
+#include "llvm/ADT/TypeSwitch.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -14,7 +16,7 @@
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
namespace mlir::iree_compiler::IREE::LinalgExt {
@@ -1603,16 +1605,16 @@
}
//===----------------------------------------------------------------------===//
-// Attention Helpers
+// AttentionOp
//===----------------------------------------------------------------------===//
-static SmallVector<Range>
-getAttentionIterationDomain(Location loc, OpBuilder &b, int64_t domainRank,
- ArrayRef<Value> values,
- ArrayRef<AffineMap> indexingMaps) {
+SmallVector<Range> AttentionOp::getIterationDomain(OpBuilder &builder) {
+ int64_t domainRank = getIterationDomainRank();
+
SmallVector<Range> loopBounds(domainRank);
- OpFoldResult zero = b.getIndexAttr(0);
- OpFoldResult one = b.getIndexAttr(1);
+ Location loc = getLoc();
+ Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+ Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
for (auto dim : llvm::seq<int64_t>(0, domainRank)) {
loopBounds[dim].offset = zero;
@@ -1629,27 +1631,26 @@
continue;
}
dimsFound[pos] = true;
- loopBounds[pos].size = getDimValue(b, loc, val, idx);
+ loopBounds[pos].size = getDimValue(builder, loc, val, idx);
}
};
- for (auto [val, indexingMap] : llvm::zip_equal(values, indexingMaps)) {
- fillSizes(val, indexingMap);
- }
+ // Sizes can be found from Q, K, V alone.
+ fillSizes(getQuery(), getQueryMap());
+ fillSizes(getKey(), getKeyMap());
+ fillSizes(getValue(), getValueMap());
return loopBounds;
}
-static SmallVector<utils::IteratorType>
-getAttentionIteratorTypes(int64_t domainRank,
- ArrayRef<AffineMap> indexingMaps) {
+SmallVector<utils::IteratorType> AttentionOp::getLoopIteratorTypes() {
FailureOr<AttentionOpDetail> maybeOpInfo =
- AttentionOpDetail::get(indexingMaps);
+ AttentionOpDetail::get(getIndexingMapsArray());
assert(succeeded(maybeOpInfo) && "Failed to infer attention op details");
AttentionOpDetail opInfo = maybeOpInfo.value();
// All dimensions other than k1 and k2 are parallel.
- SmallVector<utils::IteratorType> iteratorTypes(domainRank,
+ SmallVector<utils::IteratorType> iteratorTypes(getIterationDomainRank(),
utils::IteratorType::parallel);
for (auto dim :
@@ -1660,42 +1661,6 @@
return iteratorTypes;
}
-static SmallVector<Range> getPermutedSlice(AffineMap permutation,
- ArrayRef<OpFoldResult> offsets,
- ArrayRef<OpFoldResult> sizes) {
- auto one = IntegerAttr::get(IndexType::get(permutation.getContext()), 1);
- assert(permutation.isProjectedPermutation() &&
- "Indexing map should be a projected permutation");
- SmallVector<Range> output;
- for (AffineExpr dimExpr : permutation.getResults()) {
- int dim = cast<AffineDimExpr>(dimExpr).getPosition();
- Range dimRange;
- dimRange.offset = offsets[dim];
- dimRange.size = sizes[dim];
- dimRange.stride = one;
- output.push_back(dimRange);
- }
- return output;
-}
-
-//===----------------------------------------------------------------------===//
-// AttentionOp
-//===----------------------------------------------------------------------===//
-
-SmallVector<Range> AttentionOp::getIterationDomain(OpBuilder &b) {
- // Attention shape can be determined from Q, K, V alone.
- SmallVector<Value> shapedValues = {getQuery(), getKey(), getValue()};
- SmallVector<AffineMap> indexingMaps = {getQueryMap(), getKeyMap(),
- getValueMap()};
- return getAttentionIterationDomain(getLoc(), b, getIterationDomainRank(),
- shapedValues, indexingMaps);
-}
-
-SmallVector<utils::IteratorType> AttentionOp::getLoopIteratorTypes() {
- return getAttentionIteratorTypes(getIterationDomainRank(),
- getIndexingMapsArray());
-}
-
FailureOr<TilingResult>
AttentionOp::getTiledImplementation(OpBuilder &builder,
ArrayRef<OpFoldResult> offsets,
@@ -1704,36 +1669,59 @@
assert(sizes.size() == getIterationDomainRank());
Location loc = getLoc();
+ auto one = builder.getIndexAttr(1);
- SmallVector<Range> querySlice =
- getPermutedSlice(getQueryMap(), offsets, sizes);
- SmallVector<Range> keySlice = getPermutedSlice(getKeyMap(), offsets, sizes);
- SmallVector<Range> valueSlice =
- getPermutedSlice(getValueMap(), offsets, sizes);
- SmallVector<Range> outputSlice =
- getPermutedSlice(getOutputMap(), offsets, sizes);
+ auto tileValue = [&](Value val, AffineMap indexingMap)
+ -> std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
+ SmallVector<OpFoldResult>> {
+ assert(indexingMap.isProjectedPermutation() &&
+ "Indexing map should be a projected permutation");
+ SmallVector<OpFoldResult> outputOffsets;
+ SmallVector<OpFoldResult> outputSizes;
+ SmallVector<OpFoldResult> outputStrides(indexingMap.getNumResults(), one);
+ for (AffineExpr dimExpr : indexingMap.getResults()) {
+ int dim = cast<AffineDimExpr>(dimExpr).getPosition();
+ outputOffsets.push_back(offsets[dim]);
+ outputSizes.push_back(sizes[dim]);
+ }
+ return {outputOffsets, outputSizes, outputStrides};
+ };
+
+ auto [queryOffsets, querySizes, queryStrides] =
+ tileValue(getQuery(), getQueryMap());
+ auto [keyOffsets, keySizes, keyStrides] = tileValue(getKey(), getKeyMap());
+ auto [valueOffsets, valueSizes, valueStrides] =
+ tileValue(getValue(), getValueMap());
+ auto [outputOffsets, outputSizes, outputStrides] =
+ tileValue(getOutput(), getOutputMap());
Value scale = getScale();
SmallVector<Value> tiledOperands;
- tiledOperands.emplace_back(getSlice(builder, loc, getQuery(), querySlice));
- tiledOperands.emplace_back(getSlice(builder, loc, getKey(), keySlice));
- tiledOperands.emplace_back(getSlice(builder, loc, getValue(), valueSlice));
+ tiledOperands.emplace_back(getSlice(builder, loc, getQuery(), queryOffsets,
+ querySizes, queryStrides));
+ tiledOperands.emplace_back(
+ getSlice(builder, loc, getKey(), keyOffsets, keySizes, keyStrides));
+ tiledOperands.emplace_back(getSlice(builder, loc, getValue(), valueOffsets,
+ valueSizes, valueStrides));
tiledOperands.emplace_back(scale);
- tiledOperands.emplace_back(getSlice(builder, loc, getOutput(), outputSlice));
+ tiledOperands.emplace_back(getSlice(builder, loc, getOutput(), outputOffsets,
+ outputSizes, outputStrides));
std::optional<Value> max = getMax();
if (max) {
- SmallVector<Range> maxSlice =
- getPermutedSlice(*getMaxMap(), offsets, sizes);
- tiledOperands.emplace_back(getSlice(builder, loc, max.value(), maxSlice));
+ auto [maxOffsets, maxSizes, maxStrides] =
+ tileValue(max.value(), *getMaxMap());
+ tiledOperands.emplace_back(
+ getSlice(builder, loc, max.value(), maxOffsets, maxSizes, maxStrides));
}
std::optional<Value> sum = getMax();
if (sum) {
- SmallVector<Range> sumSlice =
- getPermutedSlice(*getSumMap(), offsets, sizes);
- tiledOperands.emplace_back(getSlice(builder, loc, sum.value(), sumSlice));
+ auto [sumOffsets, sumSizes, sumStrides] =
+ tileValue(sum.value(), *getSumMap());
+ tiledOperands.emplace_back(
+ getSlice(builder, loc, sum.value(), sumOffsets, sumSizes, sumStrides));
}
SmallVector<Type> resultTypes;
@@ -1783,93 +1771,4 @@
return success();
}
-//===----------------------------------------------------------------------===//
-// OnlineAttentionOp
-//===----------------------------------------------------------------------===//
-
-SmallVector<Range> OnlineAttentionOp::getIterationDomain(OpBuilder &b) {
- // Attention shape can be determined from Q, K, V alone.
- SmallVector<Value> shapedValues = {getQuery(), getKey(), getValue()};
- SmallVector<AffineMap> indexingMaps = {getQueryMap(), getKeyMap(),
- getValueMap()};
- return getAttentionIterationDomain(getLoc(), b, getIterationDomainRank(),
- shapedValues, indexingMaps);
-}
-
-SmallVector<utils::IteratorType> OnlineAttentionOp::getLoopIteratorTypes() {
- return getAttentionIteratorTypes(getIterationDomainRank(),
- getIndexingMapsArray());
-}
-
-FailureOr<TilingResult>
-OnlineAttentionOp::getTiledImplementation(OpBuilder &builder,
- ArrayRef<OpFoldResult> offsets,
- ArrayRef<OpFoldResult> sizes) {
- assert(offsets.size() == getIterationDomainRank());
- assert(sizes.size() == getIterationDomainRank());
-
- Location loc = getLoc();
-
- SmallVector<Range> querySlice =
- getPermutedSlice(getQueryMap(), offsets, sizes);
- SmallVector<Range> keySlice = getPermutedSlice(getKeyMap(), offsets, sizes);
- SmallVector<Range> valueSlice =
- getPermutedSlice(getValueMap(), offsets, sizes);
- SmallVector<Range> outputSlice =
- getPermutedSlice(getOutputMap(), offsets, sizes);
- SmallVector<Range> maxSlice = getPermutedSlice(getMaxMap(), offsets, sizes);
- SmallVector<Range> sumSlice = getPermutedSlice(getSumMap(), offsets, sizes);
-
- Value scale = getScale();
-
- SmallVector<Value> tiledOperands;
- tiledOperands.emplace_back(getSlice(builder, loc, getQuery(), querySlice));
- tiledOperands.emplace_back(getSlice(builder, loc, getKey(), keySlice));
- tiledOperands.emplace_back(getSlice(builder, loc, getValue(), valueSlice));
- tiledOperands.emplace_back(scale);
- tiledOperands.emplace_back(getSlice(builder, loc, getOutput(), outputSlice));
- tiledOperands.emplace_back(getSlice(builder, loc, getMax(), maxSlice));
- tiledOperands.emplace_back(getSlice(builder, loc, getSum(), sumSlice));
-
- SmallVector<Type> resultTypes;
- resultTypes.push_back(tiledOperands[4].getType());
- resultTypes.push_back(tiledOperands[5].getType());
- resultTypes.push_back(tiledOperands[6].getType());
-
- Operation *tiledOp =
- mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
-
- return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
-}
-
-LogicalResult OnlineAttentionOp::getResultTilePosition(
- OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
- ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
- SmallVector<OpFoldResult> &resultSizes) {
- resultOffsets.clear();
- resultSizes.clear();
-
- AffineMap resultIndexingMap;
- switch (resultNumber) {
- case 0:
- resultIndexingMap = getOutputMap();
- break;
- case 1:
- resultIndexingMap = getMaxMap();
- break;
- case 2:
- resultIndexingMap = getSumMap();
- break;
- default:
- return failure();
- }
-
- for (AffineExpr dimExpr : resultIndexingMap.getResults()) {
- int dim = cast<AffineDimExpr>(dimExpr).getPosition();
- resultOffsets.push_back(offsets[dim]);
- resultSizes.push_back(sizes[dim]);
- }
- return success();
-}
-
} // namespace mlir::iree_compiler::IREE::LinalgExt
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel
index fe21d60..1917631 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel
@@ -19,7 +19,6 @@
"conv2d_to_winograd.mlir",
"convert_to_loops.mlir",
"decompose_attention.mlir",
- "decompose_online_attention.mlir",
"decompose_winograd.mlir",
"distribution.mlir",
"pad_contraction_to_block_size.mlir",
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt
index 7ef5fd7..92abdf7 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt
@@ -17,7 +17,6 @@
"conv2d_to_winograd.mlir"
"convert_to_loops.mlir"
"decompose_attention.mlir"
- "decompose_online_attention.mlir"
"decompose_winograd.mlir"
"distribution.mlir"
"pad_contraction_to_block_size.mlir"
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir
deleted file mode 100644
index 945cff8..0000000
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir
+++ /dev/null
@@ -1,64 +0,0 @@
-// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-linalg-ext-decompose-attention),canonicalize,cse)" %s | FileCheck %s
-
-#mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)>
-#mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)>
-#mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)>
-#mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)>
-#mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)>
-
-func.func @attention_f16(%query: tensor<192x1024x64xf16>,
- %key: tensor<192x1024x64xf16>,
- %value: tensor<192x1024x64xf16>,
- %output: tensor<192x1024x64xf32>,
- %max: tensor<192x1024xf32>,
- %sum: tensor<192x1024xf32>)
- -> (tensor<192x1024x64xf32>, tensor<192x1024xf32>) {
- %scale = arith.constant 1.0 : f16
-
- %out:3 = iree_linalg_ext.online_attention
- { indexing_maps = [#mapQ, #mapK, #mapV, #mapO, #mapR, #mapR] }
- ins(%query, %key, %value, %scale : tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, f16)
- outs(%output, %max, %sum : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>)
- -> tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>
-
- return %out#0, %out#2 : tensor<192x1024x64xf32>, tensor<192x1024xf32>
-}
-
-// We just want to check if we are using the correct algorithm.
-// CHECK-LABEL: @attention_f16
-// S = Q @ K
-// CHECK: linalg.generic
-// CHECK: arith.mulf
-// CHECK: arith.addf
-// CHECK: linalg.yield
-// newMax = max(oldMax, rowMax(S))
-// CHECK: linalg.generic
-// CHECK: arith.maximumf
-// CHECK: linalg.yield
-// P = exp2(S - newMax)
-// CHECK: linalg.generic
-// CHECK: arith.subf
-// CHECK: math.exp2
-// CHECK: linalg.yield
-// norm = exp2(oldMax - newMax)
-// CHECK: linalg.generic
-// CHECK: arith.subf
-// CHECK: math.exp2
-// CHECK: linalg.yield
-// normSum = norm * oldSum
-// CHECK: linalg.generic
-// CHECK: arith.mulf
-// CHECK: linalg.yield
-// newSum = normSum + rowMax(P)
-// CHECK: linalg.generic
-// CHECK: arith.addf
-// CHECK: linalg.yield
-// newAcc = norm * oldAcc
-// CHECK: linalg.generic
-// CHECK: arith.mulf
-// CHECK: linalg.yield
-// newAcc = P @ V + newAcc
-// CHECK: linalg.generic
-// CHECK: arith.mulf
-// CHECK: arith.addf
-// CHECK: linalg.yield
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir
index a974f3a..f92bdb8 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir
@@ -1536,66 +1536,3 @@
// CHECK: }
// CHECK: return
// CHECK: }
-
-// -----
-
-#mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)>
-#mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)>
-#mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)>
-#mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)>
-#mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)>
-
-func.func @online_attention(%query: tensor<192x1024x64xf32>, %key: tensor<192x1024x64xf32>, %value: tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32> {
- %scale = arith.constant 1.0 : f32
-
- %output_empty = tensor.empty() : tensor<192x1024x64xf32>
- %row_red_empty = tensor.empty() : tensor<192x1024xf32>
-
- %sum_ident = arith.constant 0.000000e+00 : f32
- %max_ident = arith.constant -3.40282347E+38 : f32
-
- %output_fill = linalg.fill ins(%sum_ident : f32) outs(%output_empty : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
- %acc_fill = linalg.fill ins(%max_ident : f32) outs(%row_red_empty : tensor<192x1024xf32>) -> tensor<192x1024xf32>
- %sum_fill = linalg.fill ins(%sum_ident : f32) outs(%row_red_empty : tensor<192x1024xf32>) -> tensor<192x1024xf32>
-
- %out:3 = iree_linalg_ext.online_attention
- { indexing_maps = [#mapQ, #mapK, #mapV, #mapO, #mapR, #mapR] }
- ins(%query, %key, %value, %scale : tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, f32)
- outs(%output_fill, %acc_fill, %sum_fill : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>)
- -> tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>
-
- return %out#0 : tensor<192x1024x64xf32>
-}
-
-// CHECK-DAG: #[[$IDXMAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
-// CHECK-DAG: #[[$IDXMAP1:.+]] = affine_map<(d0) -> (d0 * 128)>
-// CHECK-DAG: #[[$IDXMAP2:.+]] = affine_map<(d0) -> (d0 * 32)>
-// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
-// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>
-// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
-// CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
-// CHECK-DAG: #[[$MAP4:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>
-// CHECK-LABEL: @online_attention
-// CHECK: scf.forall (%[[IV0:.+]], %[[IV1:.+]], %[[IV2:.+]]) in (48, 8, 2)
-// CHECK-DAG: %[[I0:.+]] = affine.apply #[[$IDXMAP0]](%[[IV0]])
-// CHECK-DAG: %[[I1:.+]] = affine.apply #[[$IDXMAP1]](%[[IV1]])
-// CHECK-DAG: %[[I2:.+]] = affine.apply #[[$IDXMAP2]](%[[IV2]])
-// CHECK-DAG: %[[Q:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], %[[I1]], 0] [4, 128, 64] [1, 1, 1] : tensor<192x1024x64xf32> to tensor<4x128x64xf32>
-// CHECK-DAG: %[[K:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], 0, 0] [4, 1024, 64] [1, 1, 1] : tensor<192x1024x64xf32> to tensor<4x1024x64xf32>
-// CHECK-DAG: %[[V:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], 0, %[[I2]]] [4, 1024, 32] [1, 1, 1] : tensor<192x1024x64xf32> to tensor<4x1024x32xf32>
-// CHECK-DAG: %[[O:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], %[[I1]], %[[I2]]] [4, 128, 32] [1, 1, 1] : tensor<192x1024x64xf32> to tensor<4x128x32xf32>
-// CHECK-DAG: %[[M:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], %[[I1]]] [4, 128] [1, 1] : tensor<192x1024xf32> to tensor<4x128xf32>
-// CHECK-DAG: %[[S:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], %[[I1]]] [4, 128] [1, 1] : tensor<192x1024xf32> to tensor<4x128xf32>
-// CHECK-DAG: iree_linalg_ext.online_attention
-// CHECK-SAME: {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]], #[[$MAP3]], #[[$MAP4]], #[[$MAP4]]]}
-// CHECK-SAME: ins(%[[Q]], %[[K]], %[[V]], %{{.*}} : tensor<4x128x64xf32>, tensor<4x1024x64xf32>, tensor<4x1024x32xf32>, f32)
-// CHECK-SAME: outs(%[[O]], %[[M]], %[[S]] : tensor<4x128x32xf32>, tensor<4x128xf32>, tensor<4x128xf32>)
-// CHECK: scf.forall.in_parallel
-
-module attributes { transform.with_named_sequence } {
- transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["iree_linalg_ext.online_attention"]} in %module_op : (!transform.any_op) -> !transform.any_op
- %tiled_att, %grid = transform.structured.tile_using_forall %0 tile_sizes [4, 128, 0, 0, 32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
- transform.yield
- }
-}
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.cpp
index ac5f5c0..7feee69 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.cpp
@@ -27,7 +27,7 @@
void AttentionOpDetail::inferFromIndexingMaps(
ArrayRef<AffineMap> indexingMaps) {
- assert(indexingMaps.size() >= 4);
+ assert(indexingMaps.size() == 4);
AffineMap qMap = indexingMaps[0];
AffineMap kMap = indexingMaps[1];
AffineMap vMap = indexingMaps[2];
@@ -82,23 +82,7 @@
AttentionOpDetail opInfo;
opInfo.inferFromIndexingMaps(indexingMaps);
- opInfo.maps = SmallVector<AffineMap>(indexingMaps);
return opInfo;
}
-AffineMap AttentionOpDetail::getSMap() const {
- // We need to create an indexing map for the intermediate result of first
- // matmul. There could be other options, but we choose to create a standard
- // indexing map:
- // SMap = (batch, m, k1, k2, n) -> (batch, m, k2)
- AffineMap sMap = AffineMap::get(/*dimCount=*/getDomainRank(),
- /*symbolCount=*/0, getContext());
- for (auto dim :
- llvm::concat<const int64_t>(getBatchDims(), getMDims(), getK2Dims())) {
- AffineExpr dimExpr = getAffineDimExpr(dim, getContext());
- sMap = sMap.insertResult(dimExpr, sMap.getNumResults());
- }
- return sMap;
-}
-
}; // namespace mlir::iree_compiler::IREE::LinalgExt
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h
index e66bc86..cfba4ff 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h
@@ -42,29 +42,20 @@
public:
static FailureOr<AttentionOpDetail> get(ArrayRef<AffineMap> indexingMaps);
- int64_t getDomainRank() const { return maps[0].getNumDims(); }
ArrayRef<int64_t> getBatchDims() const { return batch; }
ArrayRef<int64_t> getMDims() const { return m; }
ArrayRef<int64_t> getK1Dims() const { return k1; }
ArrayRef<int64_t> getK2Dims() const { return k2; }
ArrayRef<int64_t> getNDims() const { return n; }
- ArrayRef<AffineMap> getIndexingMaps() const { return maps; }
-
- AffineMap getSMap() const;
-
private:
void inferFromIndexingMaps(ArrayRef<AffineMap> indexingMaps);
- MLIRContext *getContext() const { return maps[0].getContext(); }
-
SmallVector<int64_t> batch;
SmallVector<int64_t> m;
SmallVector<int64_t> k1;
SmallVector<int64_t> k2;
SmallVector<int64_t> n;
-
- SmallVector<AffineMap> maps;
};
}; // namespace mlir::iree_compiler::IREE::LinalgExt
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp
index e6c3548..d231322 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp
@@ -44,13 +44,6 @@
[&](int64_t dim) { return getDim(builder, loc, shapedTypeValue, dim); });
}
-Value getSlice(OpBuilder &b, Location loc, Value src, ArrayRef<Range> slice) {
- return getSlice(b, loc, src,
- llvm::map_to_vector(slice, [](Range x) { return x.offset; }),
- llvm::map_to_vector(slice, [](Range x) { return x.size; }),
- llvm::map_to_vector(slice, [](Range x) { return x.stride; }));
-}
-
Value getSlice(OpBuilder &b, Location loc, Value src,
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
ArrayRef<OpFoldResult> strides) {
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h
index eec973f..9b40a54 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h
@@ -13,10 +13,6 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
-namespace mlir {
-struct Range;
-}; // namespace mlir
-
namespace mlir::iree_compiler::IREE::LinalgExt {
/// Returns a `memref.dim` or `tensor.dim` operation to get the shape of `v` at
@@ -30,7 +26,6 @@
/// Returns a `memref.subview` or a `tensor.extract_slice` based on the type of
/// `src`.
-Value getSlice(OpBuilder &b, Location loc, Value src, ArrayRef<Range> slice);
Value getSlice(OpBuilder &b, Location loc, Value src,
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
ArrayRef<OpFoldResult> strides);