Integrate llvm-project and bump dependencies 20230322 (#12730)
* llvm-project: 411b1d8f0795
* mlir-hlo: fa4f6f47d7a7c0123cbbc85c8929a29a51143b3b
* tensorflow: 6bedca8e818b152ea594a911faf6a6add9f7d795
mlir-hlo patch
* move LLVMSupport to LINK_COMPONENTS
tensorflow patch:
* revert
https://github.com/tensorflow/tensorflow/commit/c5920f0727124a374e72146a70a2b32153cdcfab
---------
Co-authored-by: Mahesh Ravishankar <ravishankarm@google.com>
Co-authored-by: Nicolas Vasilache <nicolas.vasilache@gmail.com>
Co-authored-by: Lei Zhang <antiagainst@google.com>
Co-authored-by: Hanhan Wang <hanchung@google.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPUPipelining.cpp b/compiler/src/iree/compiler/Codegen/Common/GPUPipelining.cpp
index 5d66cc9..c7675de 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPUPipelining.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPUPipelining.cpp
@@ -647,9 +647,9 @@
} // namespace
FailureOr<scf::ForOp> pipelineSharedMemoryCopy(
- scf::ForOp forOp, PipeliningSchedulingStrategy startegy, bool peelEpilogue,
- int64_t depth, RewriterBase& rewriter) {
- return applyPipelining(forOp, depth, peelEpilogue, startegy);
+ RewriterBase& rewriter, scf::ForOp forOp,
+ PipeliningSchedulingStrategy strategy, bool peelEpilogue, int64_t depth) {
+ return applyPipelining(forOp, depth, peelEpilogue, strategy);
}
/// Pass options
diff --git a/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp b/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp
index df6bb45..c79e5be 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp
@@ -425,8 +425,8 @@
IRRewriter rewriter(context);
if (failed(tileAndFuseDispatchUsingSCFForOp(
- cast<TilingInterface>(computeOps.back()), linalgTilingOptions,
- rewriter))) {
+ rewriter, cast<TilingInterface>(computeOps.back()),
+ linalgTilingOptions))) {
funcOp.emitOpError("Tile+Distribute failed");
return signalPassFailure();
}
diff --git a/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingInterface.cpp b/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingInterface.cpp
index 978cf70..d03f8d4 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingInterface.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingInterface.cpp
@@ -180,9 +180,12 @@
/// `flow.dispatch.tensor.store` that writes only a tile of the result at
/// offsets given by `tiledOffsets` and sizes given by `tiledSizes`, using
/// `tiledValue` as the source.
-static LogicalResult replaceStoreWithTiledVersion(
- RewriterBase &rewriter, OpResult untiledValue, OpResult tiledValue,
- ArrayRef<OpFoldResult> tileOffsets, ArrayRef<OpFoldResult> tileSizes) {
+static LogicalResult replaceStoresWithTiledVersion(
+ RewriterBase &rewriter, OpResult untiledValue, Value tiledValue,
+ ArrayRef<OpFoldResult> tileOffsets, ArrayRef<OpFoldResult> tileSizes,
+ Block *innerLoopBody) {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(innerLoopBody->getTerminator());
SmallVector<IREE::Flow::DispatchTensorStoreOp> storeOps;
for (OpOperand &use : untiledValue.getUses()) {
auto storeOp = dyn_cast<IREE::Flow::DispatchTensorStoreOp>(use.getOwner());
@@ -226,7 +229,9 @@
static LogicalResult replaceAllStoresWithTiledVersion(
RewriterBase &rewriter, TilingInterface untiledOp,
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
- Operation *tiledOp) {
+ ValueRange tiledValues, Block *innerLoopBody) {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(innerLoopBody->getTerminator());
for (auto [index, result] : llvm::enumerate(untiledOp->getResults())) {
SmallVector<OpFoldResult> resultOffsets, resultSizes;
if (failed(untiledOp.getResultTilePosition(rewriter, index, offsets, sizes,
@@ -234,12 +239,9 @@
return rewriter.notifyMatchFailure(
untiledOp, "failed to rewrite destructive update");
}
- OpBuilder::InsertionGuard g(rewriter);
- rewriter.setInsertionPoint(tiledOp->getBlock()->getTerminator());
- if (failed(replaceStoreWithTiledVersion(
- rewriter, result.cast<OpResult>(),
- tiledOp->getResult(index).cast<OpResult>(), resultOffsets,
- resultSizes))) {
+ if (failed(replaceStoresWithTiledVersion(rewriter, result.cast<OpResult>(),
+ tiledValues[index], resultOffsets,
+ resultSizes, innerLoopBody))) {
return failure();
}
}
@@ -248,8 +250,9 @@
namespace {
/// Result of the tiled operation.
-struct TilingResult {
- Operation *tiledOp = nullptr;
+struct IREETilingResult {
+ SmallVector<Operation *> tiledOps;
+ SmallVector<Value> tiledValues;
SmallVector<scf::ForOp> loops;
llvm::SmallBitVector tiledLoops;
SmallVector<OpFoldResult> tileOffsets;
@@ -257,9 +260,9 @@
};
} // namespace
-static FailureOr<TilingResult> tileDispatchUsingSCFFopOp(
- TilingInterface op, linalg::LinalgTilingOptions options,
- RewriterBase &rewriter) {
+static FailureOr<IREETilingResult> tileDispatchUsingSCFFopOp(
+ RewriterBase &rewriter, TilingInterface op,
+ linalg::LinalgTilingOptions options) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointAfter(op);
@@ -273,7 +276,7 @@
Location loc = op.getLoc();
size_t numLoops = iterationDomainOfr.size();
if (numLoops == 0) {
- return TilingResult();
+ return IREETilingResult();
}
auto iterationDomain =
llvm::to_vector(llvm::map_range(iterationDomainOfr, [&](Range r) {
@@ -295,7 +298,7 @@
}
tileSizeVector.resize(numLoops);
- TilingResult tilingResult;
+ IREETilingResult tilingResult;
tilingResult.tiledLoops.resize(numLoops, false);
for (auto [index, tileSize] : llvm::enumerate(tileSizeVector)) {
if (!isZero(tileSize)) {
@@ -304,10 +307,9 @@
}
if (!tilingResult.tiledLoops.any()) {
- return TilingResult();
+ return IREETilingResult();
}
- SmallVector<Operation *> tiledImplementation;
{
SmallVector<OpFoldResult> offsets, sizes;
// If there is an interchange specified, permute the iteration domain and
@@ -383,8 +385,14 @@
if (!tilingResult.loops.empty())
rewriter.setInsertionPoint(
tilingResult.loops.back().getBody()->getTerminator());
- tiledImplementation = op.getTiledImplementation(rewriter, offsets, sizes);
- tilingResult.tiledOp = tiledImplementation.back();
+ FailureOr<TilingResult> tiledImplementation =
+ op.getTiledImplementation(rewriter, offsets, sizes);
+ if (failed(tiledImplementation)) {
+ return rewriter.notifyMatchFailure(
+ op, "failed to generate tiled implementation");
+ }
+ std::swap(tilingResult.tiledOps, tiledImplementation->tiledOps);
+ std::swap(tilingResult.tiledValues, tiledImplementation->tiledValues);
LLVM_DEBUG({
if (!tilingResult.loops.empty()) {
@@ -406,7 +414,7 @@
// of the store. Its valid to this for all stores of the root untiled op.
if (failed(replaceAllStoresWithTiledVersion(
rewriter, op, tilingResult.tileOffsets, tilingResult.tileSizes,
- tilingResult.tiledOp))) {
+ tilingResult.tiledValues, tilingResult.loops.back().getBody()))) {
return failure();
}
return tilingResult;
@@ -458,19 +466,19 @@
return sliceOps;
}
-FailureOr<TileAndFuseResult> tileAndFuseDispatchUsingSCFForOp(
- TilingInterface op, linalg::LinalgTilingOptions tilingOptions,
- RewriterBase &rewriter) {
- TileAndFuseResult tileAndFuseResult;
+FailureOr<IREETileAndFuseResult> tileAndFuseDispatchUsingSCFForOp(
+ RewriterBase &rewriter, TilingInterface op,
+ linalg::LinalgTilingOptions tilingOptions) {
+ IREETileAndFuseResult tileAndFuseResult;
auto fusableProducers = getAllFusableProducers(op);
// Apply the tiling pattern.
- FailureOr<TilingResult> tilingResult =
- tileDispatchUsingSCFFopOp(op, tilingOptions, rewriter);
+ FailureOr<IREETilingResult> tilingResult =
+ tileDispatchUsingSCFFopOp(rewriter, op, tilingOptions);
if (failed(tilingResult)) {
return failure();
}
- tileAndFuseResult.tiledAndFusedOps.push_back(tilingResult->tiledOp);
- tileAndFuseResult.loops.append(tilingResult->loops);
+ tileAndFuseResult.tiledAndFusedOps = tilingResult->tiledOps;
+ tileAndFuseResult.loops = tilingResult->loops;
// If there is no tiling then there is nothing to do for fusion.
if (!tilingResult->tiledLoops.any()) {
@@ -488,67 +496,34 @@
// could use the first slice as the insertion point.
auto sliceOps = getAllFusableProducerUses(
fusableProducer, tileAndFuseResult.tiledAndFusedOps);
- if (sliceOps.empty()) continue;
- tensor::ExtractSliceOp sliceOp = sliceOps.front();
- OpBuilder::InsertionGuard g(rewriter);
- rewriter.setInsertionPoint(sliceOp);
- // Generate the tiled implementation of the producer.
- FailureOr<Value> tiledProducerVal =
- tensor::replaceExtractSliceWithTiledProducer(
- rewriter, sliceOp, sliceOp.getSource().cast<OpResult>());
- if (failed(tiledProducerVal)) {
- return rewriter.notifyMatchFailure(sliceOp,
- "fusion along slice op failed");
- }
- auto tiledProducer = tiledProducerVal->getDefiningOp();
- if (tiledProducer->getNumResults() != fusableProducer->getNumResults()) {
- return rewriter.notifyMatchFailure(fusableProducer,
- "fused operation expected to produce "
- "an op with same number of results");
- }
-
- // 2b. Assume that the tile sizes used are such that all tiled loops are
- // "common parallel loops" for the consumer and all pulled in
- // producers. So using the tile size of the tiled consumer op, and the
- // information about which loops are tiled and which arent, compute
- // the tile sizes to use for the producer as well.
- SmallVector<OpFoldResult> producerOffset, producerSizes;
- SmallVector<Range> producerIterationDomain =
- fusableProducer.getIterationDomain(rewriter);
- for (auto [index, range] : llvm::enumerate(producerIterationDomain)) {
- if (index < tilingResult->tiledLoops.size() &&
- index < sliceOp.getMixedOffsets().size() &&
- tilingResult->tiledLoops.test(index)) {
- // It is not true that the tiling sizes for produces are always as same
- // as the tiling sizes for consumers. The tensor.extract_slice op
- // carries the information, so we can get the tiling sizes and offsets
- // from it.
- producerOffset.push_back(sliceOp.getMixedOffsets()[index]);
- producerSizes.push_back(sliceOp.getMixedSizes()[index]);
- } else {
- producerOffset.push_back(range.offset);
- producerSizes.push_back(range.size);
- }
- }
-
- // 2c. Finally replace any `flow.dispatch.tensor.store` operation with
- // tiled version of the operation. It is only valid to do this under the
- // above assumption that the producer and consumer share the loops
- // that can be tiled.
- if (failed(replaceAllStoresWithTiledVersion(rewriter, fusableProducer,
- producerOffset, producerSizes,
- tiledProducer))) {
- return failure();
- }
- // Replace all uses of the slices processed in this step with values from
- // the producer.
for (auto sliceOp : sliceOps) {
- unsigned resultNumber =
- sliceOp.getSource().cast<OpResult>().getResultNumber();
- rewriter.replaceOp(sliceOp, tiledProducer->getResult(resultNumber));
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(sliceOp);
+
+ // Generate the tiled implementation of the producer.
+ OpResult untiledValue = sliceOp.getSource().cast<OpResult>();
+ FailureOr<TilingResult> swapSliceResult =
+ tensor::replaceExtractSliceWithTiledProducer(rewriter, sliceOp,
+ untiledValue);
+ if (failed(swapSliceResult) || swapSliceResult->tiledValues.size() != 1) {
+ return rewriter.notifyMatchFailure(sliceOp,
+ "fusion along slice op failed");
+ }
+
+ // 2c. Finally replace any `flow.dispatch.tensor.store` operation with
+ // tiled version of the operation. It is only valid to do this under
+ // the above assumption that the producer and consumer share the loops
+ // that can be tiled.
+ if (failed(replaceStoresWithTiledVersion(
+ rewriter, untiledValue, swapSliceResult->tiledValues[0],
+ sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(),
+ tileAndFuseResult.loops.back().getBody()))) {
+ return failure();
+ }
+ rewriter.replaceOp(sliceOp, swapSliceResult->tiledValues[0]);
+ tileAndFuseResult.tiledAndFusedOps.append(swapSliceResult->tiledOps);
}
- tileAndFuseResult.tiledAndFusedOps.push_back(tiledProducer);
}
return tileAndFuseResult;
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD.bazel
index e966c9d..31cabb3 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD.bazel
@@ -80,6 +80,7 @@
"@llvm-project//mlir:ArithUtils",
"@llvm-project//mlir:BufferizationDialect",
"@llvm-project//mlir:BufferizationTransforms",
+ "@llvm-project//mlir:DialectUtils",
"@llvm-project//mlir:GPUDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgDialect",
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
index 0a186a0..de9c6c7 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
@@ -38,6 +38,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/Transforms/Passes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/Pass/PassManager.h"
diff --git a/compiler/src/iree/compiler/Codegen/Common/Transforms.h b/compiler/src/iree/compiler/Codegen/Common/Transforms.h
index 083fb5b..ca99cb5 100644
--- a/compiler/src/iree/compiler/Codegen/Common/Transforms.h
+++ b/compiler/src/iree/compiler/Codegen/Common/Transforms.h
@@ -30,17 +30,18 @@
/// For a given operation within a dispatch, tile and distribute the operation
/// to workgroups as well as tile + fuse its producers. Returns the
/// generated tiled and fused ops, as well as the loops used for distribution.
-struct TileAndFuseResult {
+struct IREETileAndFuseResult {
SmallVector<Operation *> tiledAndFusedOps;
SmallVector<scf::ForOp> loops;
};
-FailureOr<TileAndFuseResult> tileAndFuseDispatchUsingSCFForOp(
- TilingInterface op, linalg::LinalgTilingOptions tilingOptions,
- RewriterBase &rewriter);
+
+FailureOr<IREETileAndFuseResult> tileAndFuseDispatchUsingSCFForOp(
+ RewriterBase &rewriter, TilingInterface op,
+ linalg::LinalgTilingOptions tilingOptions);
FailureOr<scf::ForOp> pipelineSharedMemoryCopy(
- scf::ForOp forOp, PipeliningSchedulingStrategy startegy, bool peelEpilogue,
- int64_t depth, RewriterBase &rewriter);
+ RewriterBase &rewriter, scf::ForOp forOp,
+ PipeliningSchedulingStrategy startegy, bool peelEpilogue, int64_t depth);
/// Populate patterns related to clean up the IR after tile and distribute to
/// workgroups.
diff --git a/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp b/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp
index 68982b2..4ff32ab 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp
@@ -378,13 +378,13 @@
// Move all the regions from the old op to the new op and legalize its
// signature.
- for (auto &[index, region] : llvm::enumerate(op->getRegions())) {
+ for (const auto &[index, region] : llvm::enumerate(op->getRegions())) {
Region &newOpRegion = newOp->getRegion(index);
rewriter.inlineRegionBefore(region, newOpRegion, newOpRegion.begin());
TypeConverter::SignatureConversion signatureConverter(
newOpRegion.getNumArguments());
bool doSignatureConversion = false;
- for (auto &[argIndex, arg] :
+ for (const auto &[argIndex, arg] :
llvm::enumerate(newOpRegion.getArguments())) {
Type argType = arg.getType();
Type legalizedType = this->typeConverter->convertType(argType);
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/flatten_memref_subspan.mlir b/compiler/src/iree/compiler/Codegen/Common/test/flatten_memref_subspan.mlir
index 8936261..59dd467 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/flatten_memref_subspan.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/flatten_memref_subspan.mlir
@@ -81,6 +81,7 @@
// -----
+
func.func @store_subspan_with_all_dynamic_dim(%value: f32, %offset : index, %i0: index, %i1: index, %i2: index, %i3: index) {
%dim0 = hal.interface.constant.load[0] : index
%dim1 = hal.interface.constant.load[1] : index
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_to_workgroups.mlir b/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_to_workgroups.mlir
index 299fcdb..2a45439 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_to_workgroups.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_to_workgroups.mlir
@@ -1855,7 +1855,7 @@
// CHECK-DAG: flow.dispatch.tensor.store %[[GENERIC0]], %[[RESULT_BINDING2]], offsets = [%[[IV0]], %[[IV1]]]
// -----
-
+
hal.executable private @no_tile {
hal.executable.variant public @embedded_elf_x86_64, target = <"llvm-cpu", "embedded-elf-x86_64", {}> {
hal.executable.export public @no_tile ordinal(0) layout(#hal.pipeline.layout<
@@ -2279,7 +2279,6 @@
%18 = tensor.pack %16#0 inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %17 {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[8, 64]]>} : tensor<384x512xf32> -> tensor<48x512x8x1xf32>
flow.dispatch.tensor.store %18, %6, offsets = [0, 0, 0, 0], sizes = [48, 512, 8, 1], strides = [1, 1, 1, 1] : tensor<48x512x8x1xf32> -> !flow.dispatch.tensor<writeonly:tensor<48x512x8x1xf32>>
flow.dispatch.tensor.store %16#0, %7, offsets = [0, 0], sizes = [384, 512], strides = [1, 1] : tensor<384x512xf32> -> !flow.dispatch.tensor<writeonly:tensor<384x512xf32>>
- flow.dispatch.tensor.store %16#1, %8, offsets = [0, 0], sizes = [384, 512], strides = [1, 1] : tensor<384x512xf32> -> !flow.dispatch.tensor<writeonly:tensor<384x512xf32>>
return
}
}
@@ -2292,7 +2291,6 @@
// CHECK: %[[PACK:.+]] = tensor.pack
// CHECK-DAG: flow.dispatch.tensor.store %[[PACK]], {{.*}} sizes = [8, 64, 8, 1]
// CHECK-DAG: flow.dispatch.tensor.store %[[ELEM]]#0, {{.*}} sizes = [64, 64]
-// CHECK-DAG: flow.dispatch.tensor.store %[[ELEM]]#1, {{.*}} sizes = [64, 64]
// -----
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileAndFuse.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileAndFuse.cpp
index 0788de8..d7afcd6 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileAndFuse.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileAndFuse.cpp
@@ -179,10 +179,9 @@
}
// Add more fusion candidates to the worklist.
- if (auto fusedProducerOp =
- fusedProducer->tiledAndFusedProducer.getDefiningOp()) {
- addCandidateSlices(fusedProducerOp, candidates);
- tiledOps.push_back(fusedProducerOp);
+ for (auto tiledOp : fusedProducer->tiledOps) {
+ addCandidateSlices(tiledOp, candidates);
+ tiledOps.push_back(tiledOp);
}
}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUDistribute.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUDistribute.cpp
index 735ee94..f915ee7 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUDistribute.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUDistribute.cpp
@@ -42,29 +42,11 @@
[&](Attribute attr) { return attr.cast<IntegerAttr>().getInt(); }));
IRRewriter rewriter(funcOp->getContext());
- rewriter.setInsertionPoint(funcOp);
- MLIRContext* ctx = funcOp->getContext();
- SmallVector<DeviceMappingAttrInterface> threadMappingAttributes = {
- gpu::GPUThreadMappingAttr::get(ctx, gpu::Threads::DimX),
- gpu::GPUThreadMappingAttr::get(ctx, gpu::Threads::DimY),
- gpu::GPUThreadMappingAttr::get(ctx, gpu::Threads::DimZ)};
-
- auto threadIdGenerator = [](RewriterBase& rewriter, scf::ForallOp forallOp,
- SmallVectorImpl<Value>& threadIds) {
- Location loc = forallOp.getLoc();
- IndexType indexType = rewriter.getIndexType();
- threadIds.assign(
- {rewriter.create<gpu::ThreadIdOp>(loc, indexType, gpu::Dimension::x),
- rewriter.create<gpu::ThreadIdOp>(loc, indexType, gpu::Dimension::y),
- rewriter.create<gpu::ThreadIdOp>(loc, indexType,
- gpu::Dimension::z)});
- };
-
- DiagnosedSilenceableFailure const result =
+ rewriter.setInsertionPointToStart(&funcOp.getBody().front());
+ DiagnosedSilenceableFailure result =
mlir::transform::gpu::mapNestedForallToThreadsImpl(
- rewriter, funcOp, workgroupSize, threadIdGenerator, false,
- std::nullopt, threadMappingAttributes);
-
+ rewriter, std::nullopt, funcOp, workgroupSize, /*warpDims=*/{},
+ false);
if (!result.succeeded()) return signalPassFailure();
}
};
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/BUILD.bazel
index bf73d9e..1199289 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/BUILD.bazel
@@ -63,8 +63,11 @@
"//compiler/src/iree/compiler/Dialect/HAL/IR",
"//llvm-external-projects/iree-dialects:IREEDialectsTransforms",
"//llvm-external-projects/iree-dialects:IREELinalgTransformDialect",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:AffineDialect",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:BufferizationDialect",
+ "@llvm-project//mlir:DialectUtils",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:GPUDialect",
"@llvm-project//mlir:GPUTransformOps",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/CMakeLists.txt
index 521b5ae..00deb7a 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/CMakeLists.txt
@@ -33,6 +33,8 @@
::LLVMGPUExtensionsOpGen
IREEDialectsTransforms
IREELinalgTransformDialect
+ LLVMSupport
+ MLIRAffineDialect
MLIRArithDialect
MLIRBufferizationDialect
MLIRFuncDialect
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp
index 9090e3c..f52c5a6a 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp
@@ -11,7 +11,10 @@
#include "iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/Debug.h"
#include "mlir/Conversion/VectorToGPU/VectorToGPU.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -22,13 +25,25 @@
#include "mlir/Dialect/NVGPU/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Region.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+using llvm::dbgs;
+
+#define DEBUG_TYPE "transform-llvmgpu-extensions"
+
+#define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ")
+#define LDBG(X) LLVM_DEBUG(dbgs() << '[' << DEBUG_TYPE << "] " << X)
+
using namespace mlir;
using namespace mlir::iree_compiler::IREE;
@@ -61,8 +76,10 @@
transform::TransformState &state) {
if (!isa<HAL::ExecutableOp, HAL::ExecutableVariantOp>(state.getTopLevel())) {
state.getTopLevel()->emitOpError(
- "requires HAL::ExecutableOp or HAL::ExecutableVariantOp toplevel to "
- "attach the workgroup size information to a nested ExecutableExportOp");
+ "requires HAL::ExecutableOp or HAL::ExecutableVariantOp "
+ "toplevel to "
+ "attach the workgroup size information to a nested "
+ "ExecutableExportOp");
return emitDefaultDefiniteFailure(target);
}
@@ -75,95 +92,22 @@
return emitDefaultDefiniteFailure(target);
}
- SmallVector<int64_t> workgroupSize{getWorkgroupDims()};
- // TODO: no magic constant but IREE uses this extensively.
- workgroupSize.resize(/*size=*/3, /*value=*/1);
-
auto transformOp = cast<transform::TransformOpInterface>(getOperation());
- SimplePatternRewriter rewriter(target);
- MLIRContext *ctx = target->getContext();
- SmallVector<DeviceMappingAttrInterface> threadMappingAttributes = {
- gpu::GPUThreadMappingAttr::get(ctx, gpu::Threads::DimX),
- gpu::GPUThreadMappingAttr::get(ctx, gpu::Threads::DimY),
- gpu::GPUThreadMappingAttr::get(ctx, gpu::Threads::DimZ)};
-
- auto threadIdGenerator = [](RewriterBase &rewriter, scf::ForallOp forallOp,
- SmallVectorImpl<Value> &threadIds) {
- Location loc = forallOp.getLoc();
- IndexType indexType = rewriter.getIndexType();
- threadIds.assign(
- {rewriter.create<gpu::ThreadIdOp>(loc, indexType, gpu::Dimension::x),
- rewriter.create<gpu::ThreadIdOp>(loc, indexType, gpu::Dimension::y),
- rewriter.create<gpu::ThreadIdOp>(loc, indexType, gpu::Dimension::z)});
- };
-
+ IRRewriter rewriter(target->getContext());
+ rewriter.setInsertionPointToStart(&target.getBody().front());
DiagnosedSilenceableFailure diag =
mlir::transform::gpu::mapNestedForallToThreadsImpl(
- rewriter, target, workgroupSize, threadIdGenerator, true, transformOp,
- threadMappingAttributes);
+ rewriter, transformOp, target, getWorkgroupDims(), getWarpDims(),
+ true);
if (diag.succeeded()) {
- auto newAttr = rewriter.getIndexArrayAttr(workgroupSize);
- // TODO: should really be: exportOp.setWorkgroupSizeAttr(newAttr);
+ auto newAttr = rewriter.getIndexArrayAttr(getWorkgroupDims());
+ rewriter.startRootUpdate(exportOp);
exportOp->setAttr(exportOp.getWorkgroupSizeAttrName(), newAttr);
+ rewriter.finalizeRootUpdate(exportOp);
}
- // Map warpIds, only if threadIdx.x is a multiple of the warp size.
- // TODO: more advanced mechanism to linearize/delinearize the threadIds to
- // warps.
- SmallVector<DeviceMappingAttrInterface> warpMappingAttributes = {
- gpu::GPUWarpMappingAttr::get(ctx, gpu::Warps::DimX),
- gpu::GPUWarpMappingAttr::get(ctx, gpu::Warps::DimY),
- gpu::GPUWarpMappingAttr::get(ctx, gpu::Warps::DimZ)};
- if (diag.succeeded() && (workgroupSize[0] % kWarpSize == 0)) {
- auto warpIdGenerator = [](RewriterBase &rewriter, scf::ForallOp forallOp,
- SmallVectorImpl<Value> &warpIds) {
- Location loc = forallOp.getLoc();
- IndexType indexType = rewriter.getIndexType();
- Value threadIdX =
- rewriter.create<gpu::ThreadIdOp>(loc, indexType, gpu::Dimension::x);
- Value cstWarpSize =
- rewriter.create<arith::ConstantIndexOp>(loc, kWarpSize);
- Value warpIdX =
- rewriter.create<arith::DivUIOp>(loc, threadIdX, cstWarpSize);
- warpIds.assign(
- {warpIdX,
- rewriter.create<gpu::ThreadIdOp>(loc, indexType, gpu::Dimension::y),
- rewriter.create<gpu::ThreadIdOp>(loc, indexType,
- gpu::Dimension::z)});
- };
- SmallVector<int64_t> numWarps = {workgroupSize[0] / kWarpSize,
- workgroupSize[1], workgroupSize[2]};
- diag = mlir::transform::gpu::mapNestedForallToThreadsImpl(
- rewriter, target, numWarps, warpIdGenerator, true, transformOp,
- warpMappingAttributes);
- }
-
- auto walkResult = target->walk([&warpMappingAttributes](
- scf::ForallOp forallOp) -> WalkResult {
- auto maybeMapping = forallOp.getMapping();
- if (!maybeMapping) return WalkResult::advance();
- for (Attribute attr : *maybeMapping) {
- for (const auto &warpAttr : warpMappingAttributes) {
- if (attr == warpAttr) {
- forallOp->emitOpError(
- "Mapping failed: is threadIdx.x a multiple of the warp size?");
- return WalkResult::interrupt();
- }
- }
- }
- return WalkResult::advance();
- });
- if (walkResult.wasInterrupted()) {
- return emitDefaultDefiniteFailure(target);
- }
-
- auto newAttr = rewriter.getIndexArrayAttr(workgroupSize);
- // TODO: should really be: exportOp.setWorkgroupSizeAttr(newAttr);
- rewriter.startRootUpdate(exportOp);
- exportOp->setAttr(exportOp.getWorkgroupSizeAttrName(), newAttr);
- rewriter.finalizeRootUpdate(exportOp);
return diag;
}
@@ -273,19 +217,21 @@
isThreadIdxxZeroPredicate(ifOp);
if (failed(maybeThreadIdxxOp)) return failure();
- // All the code below will be executed on a single warp given a fixed
- // (threadIdxy, threadIdxz).
- // Note, we reuse `maybeThreadIdxxOp` here because we later want to replace
- // this op instance by 0 without relying on CSE or canonicalizations.
+ // All the code below will be executed on a single warp given a
+ // fixed (threadIdxy, threadIdxz). Note, we reuse
+ // `maybeThreadIdxxOp` here because we later want to replace this
+ // op instance by 0 without relying on CSE or canonicalizations.
Value threadIdxx = *maybeThreadIdxxOp;
assert(workgroupSizeX % warpSize == 0);
if (workgroupSizeX != warpSize) {
- // Add a guard for `threadIdxx < warp size` around the WarpExecuteOnLane0Op.
+ // Add a guard for `threadIdxx < warp size` around the
+ // WarpExecuteOnLane0Op.
Value predicate = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::ult, threadIdxx,
rewriter.create<arith::ConstantIndexOp>(loc, warpSize));
- // Note: return-less IfOp is built with a terminator, no need to add one.
+ // Note: return-less IfOp is built with a terminator, no need to
+ // add one.
auto newIfOp =
rewriter.create<scf::IfOp>(loc, predicate, /*withElseRegion=*/false);
rewriter.setInsertionPointToStart(&newIfOp.getThenRegion().front());
@@ -293,7 +239,8 @@
auto warpOp = rewriter.create<vector::WarpExecuteOnLane0Op>(
loc, TypeRange(), threadIdxx, warpSize);
- // Move the code from the previous ifOp to the WarpExecuteOnLane0Op.
+ // Move the code from the previous ifOp to the
+ // WarpExecuteOnLane0Op.
Block &sourceBlock = ifOp.getThenRegion().front();
Block &targetBlock = warpOp.getWarpRegion().front();
Block::iterator insertionPoint = targetBlock.begin();
@@ -310,8 +257,8 @@
// This simple rewrite propagates zero in lieu of laneId within the
// warp_execute_on_lane_0 op.
// Atm, this **must** occur before any hoisting of code.
- // TODO: Replace this by a more robust scoped SCCP that will make it more
- // robust re. hoisting.
+ // TODO: Replace this by a more robust scoped SCCP that will make
+ // it more robust re. hoisting.
(void)replaceAllUsesOfLaneWithin(rewriter, warpOp);
// Hoist the scalar code outside of the warp region.
@@ -341,9 +288,12 @@
if (!isa<HAL::ExecutableOp, HAL::ExecutableVariantOp>(state.getTopLevel())) {
results.assign(1, nullptr);
return emitDefaultSilenceableFailure(state.getTopLevel())
- << "requires HAL::ExecutableOp or HAL::ExecutableVariantOp toplevel "
- "so that IR is properly isolated. This is required so we can "
- "safely inspect the HAL::ExecutableExportOp under multi-threaded "
+ << "requires HAL::ExecutableOp or "
+ "HAL::ExecutableVariantOp toplevel "
+ "so that IR is properly isolated. This is required so "
+ "we can "
+ "safely inspect the HAL::ExecutableExportOp under "
+ "multi-threaded "
"pass assumptions.";
}
@@ -353,29 +303,35 @@
HAL::ExecutableExportOp exportOp =
getExecutableExportOpForFunc(halExecutableVariantOp, funcOp);
if (!halExecutableVariantOp || !funcOp || !exportOp) {
- // Return a silenceable failure and set the expected 1 result to nullptr.
+ // Return a silenceable failure and set the expected 1 result to
+ // nullptr.
results.assign(1, nullptr);
return emitDefaultSilenceableFailure(target)
- << "export op is missing --- the transform is not applied";
+ << "export op is missing --- the transform is not "
+ "applied";
}
std::optional<ArrayAttr> maybeAttr = exportOp.getWorkgroupSize();
// TODO: Pervasive 3 constant in IREE.
if (!maybeAttr || maybeAttr->size() != 3) {
- // Return a silenceable failure and set the expected 1 result to nullptr.
+ // Return a silenceable failure and set the expected 1 result to
+ // nullptr.
results.assign(1, nullptr);
return emitDefaultSilenceableFailure(target)
- << "export op must have workgroup_size attribute set with 3 entries "
+ << "export op must have workgroup_size attribute set "
+ "with 3 entries "
"--- the transform is not applied";
}
int64_t workgroupSizeX = (*maybeAttr)[0].cast<IntegerAttr>().getInt();
int64_t warpSize = getWarpSize();
if (workgroupSizeX % warpSize != 0) {
- // Return a silenceable failure and set the expected 1 result to nullptr.
+ // Return a silenceable failure and set the expected 1 result to
+ // nullptr.
results.assign(1, nullptr);
return emitDefaultSilenceableFailure(target)
- << "vector distribution requires workgroup size for x to be a "
+ << "vector distribution requires workgroup size for x to "
+ "be a "
<< "multiple of the warp size: " << workgroupSizeX << " vs "
<< warpSize << " --- the transform is not applied";
}
@@ -385,10 +341,12 @@
rewriteScfIfAsWarpExecuteOnLane0(rewriter, target->getLoc(), target,
workgroupSizeX, warpSize);
if (failed(vectorDistributionResult)) {
- // Return a silenceable failure and set the expected 1 result to nullptr.
+ // Return a silenceable failure and set the expected 1 result to
+ // nullptr.
results.assign(1, nullptr);
return emitDefaultSilenceableFailure(target)
- << "scf::ifOp needs to be predicated on threadIdx.x == 0 --- the "
+ << "scf::ifOp needs to be predicated on threadIdx.x == 0 "
+ "--- the "
"transform is not applied";
}
results.push_back(vectorDistributionResult->warpOp);
@@ -499,15 +457,15 @@
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointAfter(warpOp);
// TODO: generalize this.
- // options.warpSyncronizationFn currently must take a WarpExecuteOnLane0Op
- // which we don't have here.
+ // options.warpSyncronizationFn currently must take a
+ // WarpExecuteOnLane0Op which we don't have here.
rewriter.create<gpu::BarrierOp>(load.getLoc());
Value newRead = rewriter.create<memref::LoadOp>(
load.getLoc(), distributedVal.getType(), load.getMemref(), indices);
- // The result type of WarpExecuteOnLane0Op may or may not match the yielded
- // type depending on whether the op has "broadcast" behavior (see the doc
- // of WarpExecuteOnLane0Op).
+ // The result type of WarpExecuteOnLane0Op may or may not match
+ // the yielded type depending on whether the op has "broadcast"
+ // behavior (see the doc of WarpExecuteOnLane0Op).
for (OpOperand &use : distributedVal.getUses()) {
rewriter.startRootUpdate(use.getOwner());
Value replacement = newRead;
@@ -539,8 +497,8 @@
auto warpParent = alloc->getParentOfType<vector::WarpExecuteOnLane0Op>();
if (!warpParent) return failure();
alloc->moveBefore(warpParent);
- // Conservatively move the dealloc after the warpOp. This may extend the
- // liverange of the allocation but is always correct.
+ // Conservatively move the dealloc after the warpOp. This may
+ // extend the liverange of the allocation but is always correct.
for (Operation *user : alloc->getUsers()) {
if (isa<memref::DeallocOp>(user)) user->moveAfter(warpParent);
}
@@ -633,7 +591,8 @@
transform::TransformState &state) {
if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
target->emitOpError(
- "applies only to isolated-from-above targets because it needs to apply "
+ "applies only to isolated-from-above targets because it "
+ "needs to apply "
"patterns greedily");
return emitDefaultDefiniteFailure(target);
}
@@ -642,8 +601,8 @@
// automatically get listening capabilities.
MLIRContext *ctx = target->getContext();
- // MultiReduction lowering is necessary until we have explicit support for
- // distributing that op.
+ // MultiReduction lowering is necessary until we have explicit
+ // support for distributing that op.
RewritePatternSet preProcessingPatterns(ctx);
populateMultiReductionLoweringPatterns(target, preProcessingPatterns,
/*benefit=*/1);
@@ -657,8 +616,10 @@
}
RewritePatternSet patterns(ctx);
- populateVectorTransferWriteDistribution(target, patterns, /*benefit=*/2);
- populatePropagateVectorDistribution(target, patterns, /*benefit=*/1);
+ populateVectorTransferWriteDistribution(target, patterns,
+ /*benefit=*/2);
+ populatePropagateVectorDistribution(target, patterns,
+ /*benefit=*/1);
if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) {
return mlir::emitDefiniteFailure(
target, "warp distribution patterns failed to apply");
@@ -668,7 +629,8 @@
vector::WarpExecuteOnLane0LoweringOptions options;
options.warpAllocationFn = allocateGlobalSharedMemory;
options.warpSyncronizationFn = warpSyncronizationFn;
- populateWarpExecuteOnLane0ToScf(target, endPatterns, options, /*benefit=*/0);
+ populateWarpExecuteOnLane0ToScf(target, endPatterns, options,
+ /*benefit=*/0);
if (failed(applyPatternsAndFoldGreedily(target, std::move(endPatterns)))) {
return mlir::emitDefiniteFailure(
target, "warp execute on lane 0 to scf patterns failed to apply");
@@ -689,7 +651,8 @@
transform::TransformState &state) {
if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
target->emitOpError(
- "applies only to isolated-from-above targets because it needs to apply "
+ "applies only to isolated-from-above targets because it "
+ "needs to apply "
"patterns greedily");
return emitDefaultDefiniteFailure(target);
}
@@ -709,8 +672,8 @@
MLIRContext *ctx = target->getContext();
// Unrolling to native vector size must have previously occurred.
- // TODO: Add pattern to propagate the extract through the scf.for ops.
- // Convert slice of contract operations to mma_sync/wmma ops.
+ // TODO: Add pattern to propagate the extract through the scf.for
+ // ops. Convert slice of contract operations to mma_sync/wmma ops.
RewritePatternSet patterns(ctx);
mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
populatePrepareVectorToMMAPatterns(patterns, getUseMmaSync());
@@ -788,8 +751,8 @@
IRRewriter rewriter(getContext());
int64_t depth(getDepth());
FailureOr<scf::ForOp> pipelinedFor = iree_compiler::pipelineSharedMemoryCopy(
- forOp, PipeliningSchedulingStrategy::loadGlobalStage0, false, depth,
- rewriter);
+ rewriter, forOp, PipeliningSchedulingStrategy::loadGlobalStage0, false,
+ depth);
if (failed(pipelinedFor)) return emitDefaultSilenceableFailure(forOp);
results.push_back(pipelinedFor.value());
return DiagnosedSilenceableFailure::success();
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td
index 34e4d38..768041d 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td
@@ -88,12 +88,14 @@
}];
let arguments = (ins TransformHandleTypeInterface:$target,
- DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$workgroup_dims);
+ DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$workgroup_dims,
+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$warp_dims);
let results = (outs);
let assemblyFormat = [{
$target
`workgroup_dims` `=` $workgroup_dims
+ (`warp_dims` `=` $warp_dims^)?
attr-dict
`:` functional-type($target, results)
}];
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_foreach_to_gpu_spec.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_foreach_to_gpu_spec.mlir
index 167798c..f53464f 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_foreach_to_gpu_spec.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_foreach_to_gpu_spec.mlir
@@ -19,7 +19,7 @@
%memref_func = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!pdl.operation) -> !pdl.operation
transform.iree.erase_hal_descriptor_type_from_memref %memref_func : (!pdl.operation) -> ()
transform.iree.map_nested_forall_to_gpu_threads %memref_func
- workgroup_dims = [10, 11] : (!pdl.operation) -> ()
+ workgroup_dims = [10, 11, 1] : (!pdl.operation) -> ()
// Late canonicalizations to cleanup and pass the checks
transform.iree.apply_patterns %memref_func
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_distribute_forall.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_distribute_forall.mlir
index 4173942..3267723 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_distribute_forall.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_distribute_forall.mlir
@@ -13,6 +13,8 @@
hal.return %arg1, %c1, %c1 : index, index, index
}
builtin.module {
+
+// CHECK: #[[$DIV32MOD8:.*]] = affine_map<()[s0] -> ((s0 floordiv 32) mod 8)>
// CHECK-LABEL: func.func @distribute
func.func @distribute() {
%cst_0 = arith.constant dense<0.000000e+00> : vector<1xf16>
@@ -23,7 +25,7 @@
memref.assume_alignment %1, 64 : memref<2xf16>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%subview = memref.subview %1[%workgroup_id_x] [1] [1] : memref<2xf16> to memref<1xf16, strided<[1], offset: ?>>
-// CHECK: %[[C32:.+]] = arith.constant 32 : index
+
// CHECK: %[[TX:.+]] = gpu.thread_id x
// CHECK: %[[COND:.*]] = arith.cmpi ult
// CHECK: scf.if %[[COND]] {
@@ -32,7 +34,8 @@
vector.transfer_write %cst_0, %subview[%arg0]
{in_bounds = [true]} : vector<1xf16>, memref<1xf16, strided<[1], offset: ?>>
} {mapping = [#gpu.thread<x>]}
-// CHECK: %[[WX:.+]] = arith.divui %[[TX]], %[[C32]] : index
+
+// CHECK: %[[WX:.+]] = affine.apply #[[$DIV32MOD8]]()[%[[TX]]]
// CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[WX]]] {in_bounds = [true]} : vector<1xf16>, memref<1xf16, strided<[1], offset: ?>>
scf.forall (%arg0) in (%c8) {
vector.transfer_write %cst_0, %subview[%arg0]
@@ -46,7 +49,7 @@
%17 = transform.structured.match ops{["func.func"]} in %variant_op
: (!pdl.operation) -> !pdl.operation
transform.iree.map_nested_forall_to_gpu_threads %17
- workgroup_dims = [256, 1, 1] : (!pdl.operation) -> ()
+ workgroup_dims = [256, 1, 1] warp_dims = [8, 1, 1] : (!pdl.operation) -> ()
// Late canonicalizations to cleanup and pass the checks.
// Needs to occur on the whole variant to perform cse on the workgroup_count region
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transpose_pipeline_test.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transpose_pipeline_test.mlir
index 7e24df8..24ffbf1 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transpose_pipeline_test.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transpose_pipeline_test.mlir
@@ -200,8 +200,8 @@
}
// CHECK-LABEL: hal.executable public @transpose_3d_yes_dispatch_0_generic_10x768x2048 {
-// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[D0:.*]] = gpu.thread_id x
// CHECK: %[[D1:.*]] = gpu.thread_id y
// CHECK: %[[D2:.*]] = gpu.thread_id z
@@ -267,8 +267,8 @@
}
// CHECK-LABEL: hal.executable public @transpose_3d_trans_out_dispatch_0_generic_10x2048x768 {
-// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[D0:.*]] = gpu.thread_id x
// CHECK: %[[D1:.*]] = gpu.thread_id y
// CHECK: %[[D2:.*]] = gpu.thread_id z
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_matmul_promotion.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_matmul_promotion.mlir
index 9919567..babde9d 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_matmul_promotion.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_matmul_promotion.mlir
@@ -182,15 +182,15 @@
// CHECK-COUNT-32: vector.transfer_read %{{.+}}, %[[CST0]] {in_bounds = [true]} : memref<3x64x20xf32, #gpu.address_space<workgroup>>, vector<4xf32>
// CHECK-COUNT-16: vector.transfer_read %{{.+}}, %[[CST0]] {in_bounds = [true]} : memref<3x16x68xf32, #gpu.address_space<workgroup>>, vector<4xf32>
// CHECK-COUNT-128: vector.fma %{{.+}}, %{{.+}}, %{{.+}} : vector<4xf32>
-// CHECK: affine.apply #[[MAP]]
+// CHECK-DAG: %[[APPLY:.+]] = affine.apply #[[MAP]]
+// CHECK-DAG: vector.transfer_read %{{.+}}, %[[CST0]] {__pipelining_first_stage__, in_bounds = [true]} : memref<128x512xf32, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
+// CHECK: vector.transfer_write %{{.+}}, %{{.+}}[%[[APPLY]], {{.+}}] {__pipelining_first_stage__, in_bounds = [true]} : vector<4xf32>, memref<3x64x20xf32, #gpu.address_space<workgroup>>
// CHECK: vector.transfer_read %{{.+}}, %[[CST0]] {__pipelining_first_stage__, in_bounds = [true]} : memref<128x512xf32, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
-// CHECK: vector.transfer_write %{{.+}}, %{{.+}} {__pipelining_first_stage__, in_bounds = [true]} : vector<4xf32>, memref<3x64x20xf32, #gpu.address_space<workgroup>>
-// CHECK: vector.transfer_read %{{.+}}, %[[CST0]] {__pipelining_first_stage__, in_bounds = [true]} : memref<128x512xf32, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
-// CHECK: vector.transfer_write %{{.+}}, %{{.+}} {__pipelining_first_stage__, in_bounds = [true]} : vector<4xf32>, memref<3x64x20xf32, #gpu.address_space<workgroup>>
+// CHECK: vector.transfer_write %{{.+}}, %{{.+}}[%[[APPLY]], {{.+}}] {__pipelining_first_stage__, in_bounds = [true]} : vector<4xf32>, memref<3x64x20xf32, #gpu.address_space<workgroup>>
// CHECK: vector.transfer_read %{{.+}}, %[[CST0]] {__pipelining_first_stage__, in_bounds = [true]} : memref<512x256xf32, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
-// CHECK: vector.transfer_write %{{.+}}, %{{.+}} {__pipelining_first_stage__, in_bounds = [true]} : vector<4xf32>, memref<3x16x68xf32, #gpu.address_space<workgroup>>
+// CHECK: vector.transfer_write %{{.+}}, %{{.+}}[%[[APPLY]], {{.+}}] {__pipelining_first_stage__, in_bounds = [true]} : vector<4xf32>, memref<3x16x68xf32, #gpu.address_space<workgroup>>
// CHECK: vector.transfer_read %{{.+}}, %[[CST0]] {__pipelining_first_stage__, in_bounds = [true]} : memref<512x256xf32, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
-// CHECK: vector.transfer_write %{{.+}}, %{{.+}} {__pipelining_first_stage__, in_bounds = [true]} : vector<4xf32>, memref<3x16x68xf32, #gpu.address_space<workgroup>>
+// CHECK: vector.transfer_write %{{.+}}, %{{.+}}[%[[APPLY]], {{.+}}] {__pipelining_first_stage__, in_bounds = [true]} : vector<4xf32>, memref<3x16x68xf32, #gpu.address_space<workgroup>>
// CHECK: gpu.barrier {__pipelining_first_stage__}
// CHECK: scf.yield
// CHECK-COUNT-32: vector.transfer_read %{{.+}}, %[[CST0]] {in_bounds = [true]} : memref<3x64x20xf32, #gpu.address_space<workgroup>>, vector<4xf32>
diff --git a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp
index 5a35273..01c2510 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp
@@ -96,7 +96,7 @@
// For zero dim tensor, consider it's too small to access using all threads.
if (shape.size() == 0) return false;
int64_t threadsAvailable = threadCount;
- for (auto &[index, dim] : llvm::enumerate(llvm::reverse(shape))) {
+ for (const auto &[index, dim] : llvm::enumerate(llvm::reverse(shape))) {
int64_t numElementPerThread = index == 0 ? vectorSize : 1;
int64_t numThreads = dim / numElementPerThread;
if (numThreads == 0) return false;
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/CUDA/CUDATarget.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/CUDA/CUDATarget.cpp
index 0fcda13..3a37335 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Target/CUDA/CUDATarget.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/CUDA/CUDATarget.cpp
@@ -139,7 +139,7 @@
llvm::PassInstrumentationCallbacks pic;
llvm::StandardInstrumentations si(module.getContext(), false);
- si.registerCallbacks(pic, &fam);
+ si.registerCallbacks(pic, &mam);
llvm::PassBuilder pb(&targetMachine, pto, std::nullopt, &pic);
llvm::ModulePassManager mpm;
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.h b/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.h
index 3394389..0f3ef14 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.h
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.h
@@ -93,7 +93,7 @@
auto newOperands =
llvm::to_vector<4>(adaptor.getODSOperands(inputSetIndex));
++inputSetIndex;
- if (auto inputTupleType = inputType.dyn_cast<TupleType>()) {
+ if (auto inputTupleType = inputType.template dyn_cast<TupleType>()) {
// Unpack a tuple<...> from the variadic.
// This only supports a single level of unpacking.
if (inputTupleType.size() != newOperands.size()) {
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp
index af14977..56d0957 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp
@@ -3032,7 +3032,7 @@
funcOp.getNumArguments() - op.getOperands().size();
IRMapping refMapping;
- for (auto &pair : llvm::enumerate(op.getOperands())) {
+ for (const auto &pair : llvm::enumerate(op.getOperands())) {
Value operand = pair.value();
size_t index = pair.index();
@@ -3492,7 +3492,7 @@
this->template getTypeConverter<IREE::VM::EmitCTypeConverter>();
SmallVector<Value> unwrappedOperands;
- for (auto &operand : llvm::enumerate(adaptor.getOperands())) {
+ for (const auto &operand : llvm::enumerate(adaptor.getOperands())) {
if (refArgumentIndices.contains(operand.index())) {
Type originalType =
op.getOperation()->getOperand(operand.index()).getType();
diff --git a/integrations/tensorflow/WORKSPACE b/integrations/tensorflow/WORKSPACE
index 4cf13ed..7b66de5 100644
--- a/integrations/tensorflow/WORKSPACE
+++ b/integrations/tensorflow/WORKSPACE
@@ -7,7 +7,7 @@
load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository")
-TENSORFLOW_COMMIT = "49ffedb82a7e6597b5e5dae9ec6f948fab858842"
+TENSORFLOW_COMMIT = "b4c19400d3a5be519a487cb8461fb7b57e2150ca"
git_repository(
name = "org_tensorflow",
diff --git a/integrations/tensorflow/test/iree_tf_tests/layers/vulkan__dynamic_dims_Concatenate.run b/integrations/tensorflow/test/iree_tf_tests/layers/vulkan__dynamic_dims_Concatenate.run
index c376280..594b7a6 100644
--- a/integrations/tensorflow/test/iree_tf_tests/layers/vulkan__dynamic_dims_Concatenate.run
+++ b/integrations/tensorflow/test/iree_tf_tests/layers/vulkan__dynamic_dims_Concatenate.run
@@ -1,3 +1,2 @@
# REQUIRES: vulkan
# RUN: %PYTHON -m iree_tf_tests.layers.layers_test --target_backends=iree_vulkan --dynamic_dims=true --training=false --test_default_kwargs_only=true --layer=Concatenate --artifacts_dir=%t
-# XFAIL: *
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h
index 398a356..c57bb75 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h
@@ -30,7 +30,7 @@
struct SwapTilingInterfaceOp : public OpRewritePattern<tensor::ExtractSliceOp> {
using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
- FailureOr<Operation *>
+ FailureOr<TilingResult>
returningMatchAndRewrite(tensor::ExtractSliceOp sliceOp,
PatternRewriter &rewriter) const;
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
index d00a15e..bf2cd6b 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -267,7 +267,7 @@
return ranges;
}
-SmallVector<Operation *>
+FailureOr<TilingResult>
ScatterOp::getTiledImplementation(OpBuilder &builder,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) {
@@ -316,7 +316,8 @@
Operation *tiledScatterOp =
mlir::clone(builder, getOperation(), resultTypes,
ValueRange{tiledUpdate, tiledIndices, tiledOriginal});
- return {tiledScatterOp};
+ return TilingResult{{tiledScatterOp},
+ SmallVector<Value>(tiledScatterOp->getResults())};
}
LogicalResult ScatterOp::getResultTilePosition(
@@ -484,7 +485,7 @@
return loopBounds;
}
-SmallVector<Operation *>
+FailureOr<TilingResult>
SortOp::getTiledImplementation(OpBuilder &builder,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) {
@@ -506,7 +507,8 @@
}
Operation *tiledSortOp =
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
- return {tiledSortOp};
+ return TilingResult{{tiledSortOp},
+ SmallVector<Value>{tiledSortOp->getResults()}};
}
LogicalResult SortOp::getResultTilePosition(
@@ -807,7 +809,7 @@
return success();
}
-SmallVector<Operation *>
+FailureOr<TilingResult>
FftOp::getTiledImplementation(OpBuilder &builder,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) {
@@ -828,7 +830,8 @@
}
Operation *tiledFftOp =
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
- return {tiledFftOp};
+ return TilingResult{{tiledFftOp},
+ SmallVector<Value>(tiledFftOp->getResults())};
}
LogicalResult FftOp::getResultTilePosition(
@@ -1006,7 +1009,7 @@
return success();
}
-SmallVector<Operation *>
+FailureOr<TilingResult>
ScanOp::getTiledImplementation(OpBuilder &builder,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) {
@@ -1042,7 +1045,8 @@
Operation *tiledScanOp =
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
- return {tiledScanOp};
+ return TilingResult{{tiledScanOp},
+ SmallVector<Value>(tiledScanOp->getResults())};
}
LogicalResult ScanOp::getResultTilePosition(
@@ -1161,7 +1165,7 @@
return success();
}
-SmallVector<Operation *>
+FailureOr<TilingResult>
ReverseOp::getTiledImplementation(OpBuilder &builder,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) {
@@ -1191,7 +1195,8 @@
Operation *tiledRevOp =
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
- return {tiledRevOp};
+ return TilingResult{{tiledRevOp},
+ SmallVector<Value>(tiledRevOp->getResults())};
}
LogicalResult ReverseOp::getResultTilePosition(
@@ -1424,7 +1429,7 @@
return success();
}
-SmallVector<Operation *>
+FailureOr<TilingResult>
TopkOp::getTiledImplementation(OpBuilder &builder,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) {
@@ -1465,7 +1470,8 @@
Operation *tiledTopkOp =
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
- return {tiledTopkOp};
+ return TilingResult{{tiledTopkOp},
+ SmallVector<Value>(tiledTopkOp->getResults())};
}
LogicalResult TopkOp::getResultTilePosition(
@@ -2170,7 +2176,7 @@
return iteratorTypes;
}
-SmallVector<Operation *>
+FailureOr<TilingResult>
WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) {
@@ -2213,7 +2219,7 @@
Operation *tiledOp =
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
- return {tiledOp};
+ return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
}
LogicalResult WinogradInputTransformOp::getResultTilePosition(
@@ -2332,7 +2338,7 @@
return iteratorTypes;
}
-SmallVector<Operation *> WinogradOutputTransformOp::getTiledImplementation(
+FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) {
@@ -2374,7 +2380,7 @@
Operation *tiledOp =
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
- return {tiledOp};
+ return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
}
LogicalResult WinogradOutputTransformOp::getResultTilePosition(
@@ -2450,7 +2456,7 @@
return iteratorTypes;
}
-SmallVector<Operation *>
+FailureOr<TilingResult>
SoftmaxOp::getTiledImplementation(OpBuilder &builder,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) {
@@ -2470,7 +2476,7 @@
Operation *tiledOp =
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
- return {tiledOp};
+ return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
}
LogicalResult SoftmaxOp::getResultTilePosition(
@@ -2546,7 +2552,7 @@
return iteratorTypes;
}
-SmallVector<Operation *>
+FailureOr<TilingResult>
AttentionOp::getTiledImplementation(OpBuilder &builder,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) {
@@ -2594,7 +2600,7 @@
Operation *tiledOp =
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
- return {tiledOp};
+ return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
}
LogicalResult AttentionOp::getResultTilePosition(
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/Tiling.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/Tiling.cpp
index ae77fb3..bb020cb 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/Tiling.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/Tiling.cpp
@@ -89,30 +89,27 @@
// the op by invoking the TiledOpInterface methods.
if (loopDepth == tileSizes.size()) {
TiledOp ret;
- SmallVector<Operation *> tiledOps =
+ FailureOr<TilingResult> tiledOps =
tilableOp.getTiledImplementation(builder, offsets, tileSizes);
- if (tiledOps.empty()) {
+ if (failed(tiledOps)) {
return static_cast<LogicalResult>(
tilableOp.emitOpError("failed to get tiled implementation"));
}
- assert(
- (tiledOps.size() == 1) &&
- "expected only a single operation returned from tiling implementation");
- ret.op.assign(tiledOps);
- for (auto result : llvm::enumerate(ret.op.back()->getResults())) {
- if (!result.value().getType().isa<RankedTensorType>()) {
- ret.results.push_back(result.value());
+ ret.op.append(tiledOps->tiledOps);
+ for (auto [index, result] : llvm::enumerate(tilableOp->getResults())) {
+ if (!result.getType().isa<RankedTensorType>()) {
+ ret.results.push_back(result);
continue;
}
SmallVector<OpFoldResult> resultOffsets, resultSizes;
- if (succeeded(tilableOp.getResultTilePosition(
- builder, result.index(), offsets, tileSizes, resultOffsets,
- resultSizes))) {
+ if (succeeded(tilableOp.getResultTilePosition(builder, index, offsets,
+ tileSizes, resultOffsets,
+ resultSizes))) {
SmallVector<OpFoldResult> resultStrides(resultOffsets.size(),
builder.getIndexAttr(1));
Value insertSlice = builder.create<tensor::InsertSliceOp>(
- loc, ret.op.back()->getResult(result.index()),
- outputs[result.index()], resultOffsets, resultSizes, resultStrides);
+ loc, tiledOps->tiledValues[index], outputs[index], resultOffsets,
+ resultSizes, resultStrides);
ret.results.push_back(insertSlice);
}
}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Fusion.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Fusion.cpp
index 7069bed..a73a90b 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Fusion.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Fusion.cpp
@@ -38,12 +38,18 @@
return failure();
// Tile the producer.
- FailureOr<Value> tiledProducer = producerOp.generateResultTileValue(
- rewriter, /*resultNumber=*/0, sliceOp.getMixedOffsets(),
- sliceOp.getMixedSizes());
- if (failed(tiledProducer))
+ FailureOr<TilingResult> tileAndFuseResult =
+ producerOp.generateResultTileValue(rewriter, /*resultNumber=*/0,
+ sliceOp.getMixedOffsets(),
+ sliceOp.getMixedSizes());
+ if (failed(tileAndFuseResult))
return failure();
- fusedOps.push_back(cast<TilingInterface>(tiledProducer->getDefiningOp()));
+ for (auto tileAndFusedOp : tileAndFuseResult->tiledOps) {
+ auto interfaceOp = dyn_cast<TilingInterface>(tileAndFusedOp);
+ if (!interfaceOp)
+ continue;
+ fusedOps.push_back(interfaceOp);
+ }
}
// Update the consumer in-place using the tiled producer results.
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Tiling.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Tiling.cpp
index de4bff6..7a6b068 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Tiling.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Tiling.cpp
@@ -195,15 +195,16 @@
/// Second pattern to implement the switch of `TilingInterface ->
/// tensor.extract_slice` to `tensor.extract_slice -> `TilingInterface`.
-FailureOr<Operation *> SwapTilingInterfaceOp::returningMatchAndRewrite(
+FailureOr<TilingResult> SwapTilingInterfaceOp::returningMatchAndRewrite(
tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const {
auto sourceOp = sliceOp.getSource().getDefiningOp<TilingInterface>();
if (!sourceOp)
return failure();
- SmallVector<Operation *> tiledOps = sourceOp.getTiledImplementation(
+ FailureOr<TilingResult> tilingResult = sourceOp.getTiledImplementation(
rewriter, sliceOp.getMixedOffsets(), sliceOp.getMixedSizes());
- assert(tiledOps.size() && "expected single tiled op");
- Operation *tiledOp = tiledOps.front();
- rewriter.replaceOp(sliceOp, tiledOp->getResults());
- return tiledOp;
+ if (failed(tilingResult)) {
+ return failure();
+ }
+ rewriter.replaceOp(sliceOp, tilingResult->tiledValues);
+ return tilingResult.value();
}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp
index ddbca39..474f4ab 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp
@@ -295,10 +295,9 @@
}
// Add more fusion candidates to the worklist.
- if (auto fusedProducerOp =
- fusedProducer->tiledAndFusedProducer.getDefiningOp()) {
- addCandidateSlices(fusedProducerOp, candidates);
- tiledOps.push_back(fusedProducerOp);
+ for (auto tiledOp : fusedProducer->tiledOps) {
+ addCandidateSlices(tiledOp, candidates);
+ tiledOps.push_back(tiledOp);
}
}
diff --git a/third_party/llvm-project b/third_party/llvm-project
index 75e86f4..411b1d8 160000
--- a/third_party/llvm-project
+++ b/third_party/llvm-project
@@ -1 +1 @@
-Subproject commit 75e86f472075c3c34a5b2466fa42e3248a432f78
+Subproject commit 411b1d8f079533860a990ee615abae3b0e6dbd8b
diff --git a/third_party/mlir-hlo b/third_party/mlir-hlo
index 34b9f70..1f096a7 160000
--- a/third_party/mlir-hlo
+++ b/third_party/mlir-hlo
@@ -1 +1 @@
-Subproject commit 34b9f70f784ff67797fffddf6aad94a44321a09f
+Subproject commit 1f096a793ab7f73ae8f62deb8b6502c543763ca1