[gpu] Distribute fused producer elementwise ops in SIMT pipeline (#13908)
The fused producer ops have different dimensions than the consumer
matmul op. Tiling and distribution the matmul op's configuration may not
the the best.
This commit postpones the producer op tiling distribution to where copy
ops are handled. There we use all threads in the workgroup to distribute
in a flat manner.
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorTile.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorTile.cpp
index fbbfc53..5e5cc37 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorTile.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorTile.cpp
@@ -92,7 +92,7 @@
private:
FailureOr<scf::SCFTilingResult>
- tileConsumerAndFuseInputProducer(RewriterBase &rewriter,
+ tileConsumerAndFuseInputProducer(PatternRewriter &rewriter,
TilingInterface consumer,
ArrayRef<int64_t> tileSizes) const {
// First tile the current op as the consumer op.
@@ -141,7 +141,18 @@
while (!candidates.empty()) {
tensor::ExtractSliceOp sliceOp = candidates.back();
candidates.pop_back();
- tileAndFuseProducerOfSlice(rewriter, sliceOp, tilingResult->loops);
+ std::optional<scf::SCFFuseProducerOfSliceResult> result =
+ tileAndFuseProducerOfSlice(rewriter, sliceOp, tilingResult->loops);
+ if (result) {
+ // Mark the fused input producer for distribution when writing to shared
+ // memory. We cannot use the current matmul op's tiling scheme here
+ // given dimensions are different.
+ IREE::LinalgExt::LinalgTransformationFilter f(
+ ArrayRef<StringAttr>(),
+ rewriter.getStringAttr(getCopyToWorkgroupMemoryMarker()));
+ f.replaceLinalgTransformationFilter(
+ rewriter, result->tiledAndFusedProducer.getDefiningOp());
+ }
}
return tilingResult;
}
@@ -207,7 +218,15 @@
SmallVector<TilingInterface> computeOps;
funcOp.walk([&](TilingInterface op) { computeOps.push_back(op); });
+ auto marker =
+ StringAttr::get(funcOp.getContext(), getCopyToWorkgroupMemoryMarker());
+
for (TilingInterface tilingOp : computeOps) {
+ auto attr = tilingOp->getAttr(
+ IREE::LinalgExt::LinalgTransforms::kLinalgTransformMarker);
+ if (attr == marker)
+ continue;
+
size_t numLoops = 0;
for (auto type : tilingOp.getLoopIteratorTypes()) {
if (type == utils::IteratorType::parallel)
@@ -317,10 +336,6 @@
if (!isEntryPoint(funcOp))
return;
- funcOp->walk([&](linalg::LinalgOp op) {
- op->removeAttr(IREE::LinalgExt::LinalgTransforms::kLinalgTransformMarker);
- });
-
auto workgroupSize = llvm::map_to_vector(
getEntryPoint(funcOp)->getWorkgroupSize().value(),
[&](Attribute attr) { return llvm::cast<IntegerAttr>(attr).getInt(); });
@@ -329,7 +344,7 @@
}
LLVM_DEBUG({
- llvm::dbgs() << "--- After second level of tiling";
+ llvm::dbgs() << "// --- After second level of tiling:\n";
funcOp.dump();
});
@@ -339,7 +354,7 @@
return signalPassFailure();
LLVM_DEBUG({
- llvm::dbgs() << "--- After tile reductions:";
+ llvm::dbgs() << "// --- After tile reductions:\n";
funcOp.dump();
});
@@ -348,7 +363,7 @@
}
LLVM_DEBUG({
- llvm::dbgs() << "--- After conv unrolling:";
+ llvm::dbgs() << "// --- After conv unrolling:\n";
funcOp.dump();
});
}
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_matmul_fusion.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_matmul_fusion.mlir
index 09d8f0f..39efad3 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_matmul_fusion.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_matmul_fusion.mlir
@@ -103,15 +103,38 @@
// CHECK: %[[ZP:.+]] = vector.transfer_read %[[ZP_BINDING]]
// CHECK: %[[ZP_EXT:.+]] = arith.extsi %[[ZP]] : vector<4xi4> to vector<4xi32>
// CHECK: scf.for %arg5 = %c0 to %c96 step %c32 iter_args({{.+}}) -> (vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>)
-// CHECK-COUNT-4: vector.transfer_read %[[WEIGHT_BINDING]]
-// CHECK-COUNT-4: arith.extsi %{{.+}} : vector<4xi4> to vector<4xi32>
-// CHECK-COUNT-4: arith.subi %{{.+}}, %[[ZP_EXT]] : vector<4xi32>
-// CHECK-COUNT-4: arith.sitofp %{{.+}} : vector<4xi32> to vector<4xf32>
-// CHECK-COUNT-4: arith.mulf %{{.+}}, %[[SCALE]] : vector<4xf32>
-// CHECK-COUNT-4: vector.transfer_write %{{.+}}, %[[B_ALLOC]]
-// CHECK: gpu.barrier
+
+// CHECK: vector.transfer_read %[[WEIGHT_BINDING]]
+// CHECK: arith.extsi %{{.+}} : vector<4xi4> to vector<4xi32>
+// CHECK: arith.subi %{{.+}}, %[[ZP_EXT]] : vector<4xi32>
+// CHECK: arith.sitofp %{{.+}} : vector<4xi32> to vector<4xf32>
+// CHECK: arith.mulf %{{.+}}, %[[SCALE]] : vector<4xf32>
+// CHECK: vector.transfer_write %{{.+}}, %[[B_ALLOC]]
+
+// CHECK: vector.transfer_read %[[WEIGHT_BINDING]]
+// CHECK: arith.extsi %{{.+}} : vector<4xi4> to vector<4xi32>
+// CHECK: arith.subi %{{.+}}, %[[ZP_EXT]] : vector<4xi32>
+// CHECK: arith.sitofp %{{.+}} : vector<4xi32> to vector<4xf32>
+// CHECK: arith.mulf %{{.+}}, %[[SCALE]] : vector<4xf32>
+// CHECK: vector.transfer_write %{{.+}}, %[[B_ALLOC]]
+
+// CHECK: vector.transfer_read %[[WEIGHT_BINDING]]
+// CHECK: arith.extsi %{{.+}} : vector<4xi4> to vector<4xi32>
+// CHECK: arith.subi %{{.+}}, %[[ZP_EXT]] : vector<4xi32>
+// CHECK: arith.sitofp %{{.+}} : vector<4xi32> to vector<4xf32>
+// CHECK: arith.mulf %{{.+}}, %[[SCALE]] : vector<4xf32>
+// CHECK: vector.transfer_write %{{.+}}, %[[B_ALLOC]]
+
+// CHECK: vector.transfer_read %[[WEIGHT_BINDING]]
+// CHECK: arith.extsi %{{.+}} : vector<4xi4> to vector<4xi32>
+// CHECK: arith.subi %{{.+}}, %[[ZP_EXT]] : vector<4xi32>
+// CHECK: arith.sitofp %{{.+}} : vector<4xi32> to vector<4xf32>
+// CHECK: arith.mulf %{{.+}}, %[[SCALE]] : vector<4xf32>
+// CHECK: vector.transfer_write %{{.+}}, %[[B_ALLOC]]
+
// CHECK: vector.transfer_write %{{.+}}, %[[A_ALLOC]]
// CHECK: gpu.barrier
+
// CHECK-COUNT-32: vector.transfer_read %[[A_ALLOC]]
// CHECK-COUNT-32: vector.transfer_read %[[B_ALLOC]]
// CHECK-COUNT-128: vector.fma %{{.+}}, %{{.+}}, %{{.+}} : vector<4xf32>