[Codegen] Add transfer read distribution pattern for nested layout (#16393)
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel
index 808469c..fe92b60 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel
@@ -56,6 +56,7 @@
"GPUGeneralizeNamedOps.cpp",
"GPULowerToUKernels.cpp",
"GPUMultiBuffering.cpp",
+ "GPUNestedLayoutDistributionPatterns.cpp",
"GPUPatterns.cpp",
"GPUPipelining.cpp",
"GPUReduceBankConflicts.cpp",
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt
index 7214b6c..cfd05aa 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt
@@ -54,6 +54,7 @@
"GPUGeneralizeNamedOps.cpp"
"GPULowerToUKernels.cpp"
"GPUMultiBuffering.cpp"
+ "GPUNestedLayoutDistributionPatterns.cpp"
"GPUPatterns.cpp"
"GPUPipelining.cpp"
"GPUReduceBankConflicts.cpp"
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp
new file mode 100644
index 0000000..a516148
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp
@@ -0,0 +1,215 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h"
+#include "iree/compiler/Codegen/Common/GPU/GPUPatterns.h"
+#include "iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.h"
+#include "iree/compiler/Codegen/Common/VectorLayoutAnalysis.h"
+#include "iree/compiler/Codegen/Utils/GPUUtils.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/SmallVector.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/Utils.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/Rewrite/PatternApplicator.h"
+
+namespace mlir::iree_compiler {
+
+using namespace mlir::iree_compiler::IREE::VectorExt;
+using VectorValue = TypedValue<VectorType>;
+
+/// Helper to linearize the given |ids| with maximum values given as |sizes|.
+/// Gets the element ID in terms of |elementCount| and adds the element
+/// |offset|. For example,
+///
+/// IDs = [d0, d1, d2, d3]
+/// sizes = [s0, s1, s2, s3]
+/// linear_index = d0 * (s1 * s2 * s3)
+/// + d1 * (s2 * s3)
+/// + d2 * (s3)
+/// + d3
+/// return element_index = linear_index * |elementCount| + |offset|;
+static Value linearizeIndex(OpBuilder &builder, Value offset,
+ ArrayRef<OpFoldResult> ids, ArrayRef<int64_t> sizes,
+ int64_t elementCount) {
+ SmallVector<AffineExpr> exprs(ids.size() + 1);
+ bindSymbolsList(builder.getContext(), MutableArrayRef{exprs});
+ AffineExpr idExpr = builder.getAffineConstantExpr(0);
+
+ for (int i = 0, e = ids.size(); i < e; ++i) {
+ if (sizes[i] > 1) {
+ // Multiply by the residual threads along this dimension (which must be
+ // faster changing than all previous dimensions) and add the id for this
+ // dimension.
+ idExpr = idExpr * builder.getAffineConstantExpr(sizes[i]) + exprs[i];
+ }
+ }
+ idExpr = idExpr * builder.getAffineConstantExpr(elementCount);
+ idExpr = idExpr + exprs.back();
+ SmallVector<OpFoldResult> mapArgs(ids);
+ mapArgs.push_back(offset);
+ return affine::makeComposedAffineApply(
+ builder, offset.getLoc(),
+ AffineMap::get(0, mapArgs.size(), idExpr), mapArgs)
+ .getResult();
+}
+
+namespace {
+
+/// Pattern to distribute `vector.transfer_read` ops with nested layouts.
+struct DistributeTransferReadNestedLayoutAttr final
+ : OpDistributionPattern<vector::TransferReadOp> {
+ using OpDistributionPattern::OpDistributionPattern;
+
+ DistributeTransferReadNestedLayoutAttr(MLIRContext *context, Value threadId)
+ : OpDistributionPattern(context), threadId(threadId) {}
+
+ LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
+ DistributionSignature &signature,
+ PatternRewriter &rewriter) const override {
+ // TODO: Support masking.
+ if (readOp.getMask()) {
+ return failure();
+ }
+ NestedLayoutAttr vectorLayout =
+ dyn_cast<NestedLayoutAttr>(signature[readOp.getResult()]);
+ if (!vectorLayout) {
+ return failure();
+ }
+
+ // Guard on memrefs for distribution. In isolation this pattern is agnostic
+ // to tensors or memrefs.
+ if (!isa<MemRefType>(readOp.getSource().getType())) {
+ return failure();
+ }
+
+ // The delinearized thread IDs are returned from outer most to inner most,
+ // i.e. before applying the layout described dimensions ordering.
+ ValueRange threadIds = vectorLayout.computeThreadIds(threadId, rewriter);
+
+ SmallVector<int64_t> distShape = vectorLayout.getDistributedShape();
+ SmallVector<int64_t> tileShape = vectorLayout.getDistributedShape();
+ int64_t rank = vectorLayout.getBatchOrder().size();
+
+ // Let the unroll order first unroll the batch dimensions, then the
+ // outer vector dimensions. We unroll in the order specified by the
+ // layout.
+ SmallVector<int64_t> loopOrder;
+ int64_t base = 0;
+ for (int64_t b : vectorLayout.getBatchOrder()) {
+ loopOrder.push_back(base + b);
+ tileShape[base + b] = 1;
+ }
+ base += rank;
+ // We must unroll along the outer dimensions as well to match the rank
+ // requirements of vector transfer ops (<= memref rank up to broadcasts).
+ for (int64_t o : vectorLayout.getOuterOrder()) {
+ loopOrder.push_back(base + o);
+ tileShape[base + o] = 1;
+ }
+ base += rank;
+ for (int i = 0, e = rank; i < e; ++i) {
+ loopOrder.push_back(base + i);
+ }
+
+ Type elementType = readOp.getSource().getType().getElementType();
+ auto vectorType = VectorType::get(distShape, elementType);
+ // The shape of the vector we read is pre-permutation. The permutation is
+ // a transpose on the resulting read vector.
+ auto innerVectorType =
+ VectorType::get(vectorLayout.getElementsPerThread(), elementType);
+
+ // Initialize the full distributed vector for unrolling the batch/outer
+ // vector dimensions.
+ Value zero = rewriter.create<arith::ConstantOp>(
+ readOp.getLoc(), vectorType, rewriter.getZeroAttr(vectorType));
+ VectorValue acc = cast<VectorValue>(zero);
+
+ // Subgroup and thread (lane) indices normalized to the order in which
+ // they are used by each dimension.
+ SmallVector<Value> warpIndices = llvm::to_vector(
+ llvm::map_range(vectorLayout.getSubgroupOrder(),
+ [&](int64_t i) { return threadIds[i]; }));
+ SmallVector<Value> threadIndices = llvm::to_vector(
+ llvm::map_range(vectorLayout.getThreadOrder(),
+ [&](int64_t i) { return threadIds[i + rank]; }));
+
+ auto isBroadcast = [](AffineExpr expr) {
+ if (auto constExpr = dyn_cast<AffineConstantExpr>(expr))
+ return constExpr.getValue() == 0;
+ return false;
+ };
+
+ ValueRange indices = readOp.getIndices();
+ SmallVector<int64_t> strides(rank, 1);
+ for (SmallVector<int64_t> offsets :
+ StaticTileOffsetRange(distShape, tileShape, loopOrder)) {
+ // Permute the batch and outer vector offsets to match the order of
+ // the vector dimensions using the inverse of the batch/offset order.
+ SmallVector<int64_t> batchOffsets = applyPermutation(
+ ArrayRef<int64_t>(offsets.begin(), rank),
+ invertPermutationVector(vectorLayout.getBatchOrder()));
+ SmallVector<int64_t> outerVectorOffsets = applyPermutation(
+ ArrayRef<int64_t>(offsets.begin() + rank, rank),
+ invertPermutationVector(vectorLayout.getOuterOrder()));
+
+ SmallVector<Value> slicedIndices(indices.begin(), indices.end());
+ for (const auto &[i, dim] :
+ llvm::enumerate(readOp.getPermutationMap().getResults())) {
+ if (isBroadcast(dim))
+ continue;
+ unsigned pos = cast<AffineDimExpr>(dim).getPosition();
+ SmallVector<OpFoldResult> ids = {
+ warpIndices[i], rewriter.getIndexAttr(batchOffsets[i]),
+ rewriter.getIndexAttr(outerVectorOffsets[i]), threadIndices[i]};
+ // The order in which a vector dimension is "tiled" is
+ // subgroups -> batches -> outer vectors -> threads -> elements
+ SmallVector<int64_t> sizes = {
+ vectorLayout.getSubgroupsPerWorkgroup()[i],
+ vectorLayout.getBatchesPerSubgroup()[i],
+ vectorLayout.getOutersPerBatch()[i],
+ vectorLayout.getThreadsPerOuter()[i]};
+ slicedIndices[pos] =
+ linearizeIndex(rewriter, indices[pos], ids, sizes,
+ vectorLayout.getElementsPerThread()[i]);
+ }
+ Value slicedRead = rewriter.create<vector::TransferReadOp>(
+ readOp.getLoc(), innerVectorType, readOp.getSource(), slicedIndices,
+ readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
+ readOp.getInBoundsAttr());
+ // Transpose to the element order.
+ if (!isIdentityPermutation(vectorLayout.getElementOrder())) {
+ slicedRead = rewriter.create<vector::TransposeOp>(
+ slicedRead.getLoc(), slicedRead, vectorLayout.getElementOrder());
+ }
+
+ acc = rewriter.create<vector::InsertStridedSliceOp>(
+ readOp.getLoc(), slicedRead, acc, offsets, strides);
+ }
+
+ replaceOpWithDistributedValues(rewriter, readOp, acc);
+ return success();
+ }
+
+ Value threadId;
+};
+
+} // namespace
+
+void populateGPUDistributeNestedLayoutAttrPatterns(
+ Value threadId, RewritePatternSet &patterns) {
+ patterns.add<DistributeTransferReadNestedLayoutAttr>(patterns.getContext(),
+ threadId);
+}
+
+}; // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPatterns.h b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPatterns.h
index 727eaaa..5475f03 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPatterns.h
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPatterns.h
@@ -37,6 +37,9 @@
void populateGPUReductionDistributionPatterns(RewritePatternSet &patterns,
int64_t maxBitsPerShuffle = 32);
+void populateGPUDistributeNestedLayoutAttrPatterns(Value threadId,
+ RewritePatternSet &patterns);
+
} // namespace mlir::iree_compiler
#endif // IREE_COMPILER_CODEGEN_COMMON_GPUPATTERNS_H_
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel
index 42a3c0a..d21e6a7 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel
@@ -23,6 +23,7 @@
"gpu_distribute_shared_memory.mlir",
"gpu_generalize_named_ops.mlir",
"gpu_lower_to_ukernels.mlir",
+ "gpu_nested_layout_vector_distribution.mlir",
"gpu_pipeline.mlir",
"gpu_tensor_alloc.mlir",
"gpu_tensor_tile.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt
index ddc8f47..2891b8a 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt
@@ -19,6 +19,7 @@
"gpu_distribute_shared_memory.mlir"
"gpu_generalize_named_ops.mlir"
"gpu_lower_to_ukernels.mlir"
+ "gpu_nested_layout_vector_distribution.mlir"
"gpu_pipeline.mlir"
"gpu_tensor_alloc.mlir"
"gpu_tensor_tile.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir
new file mode 100644
index 0000000..d11526b
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir
@@ -0,0 +1,327 @@
+// RUN: iree-opt --iree-transform-dialect-interpreter --split-input-file --canonicalize --cse %s | FileCheck %s
+
+#layout_row_major = #iree_vector_ext.nested_layout<
+ subgroups_per_workgroup = [1, 1],
+ batches_per_subgroup = [2, 2],
+ outers_per_batch = [1, 1],
+ threads_per_outer = [8, 1],
+ elements_per_thread = [1, 8],
+
+ subgroup_order = [0, 1],
+ batch_order = [1, 0],
+ outer_order = [0, 1],
+ thread_order = [0, 1],
+ element_order = [0, 1]
+>
+
+// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 + 8)>
+// CHECK-LABEL: @distribute_transfer_read_row_major
+func.func @distribute_transfer_read_row_major(%arg0: memref<4x4xf16>) -> vector<16x16xf16> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.0 : f16
+ %root = vector.transfer_read %arg0[%c0, %c0], %cst
+ {in_bounds = [false, false],
+ "__vector_layout_test_anchor_result_0" = #layout_row_major}
+ : memref<4x4xf16>, vector<16x16xf16>
+ func.return %root : vector<16x16xf16>
+}
+
+builtin.module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+ %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+ transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
+ transform.yield
+ }
+}
+
+// CHECK: %[[ACC:.+]] = arith.constant dense<0.000000e+00> : vector<2x2x1x1x1x8xf16>
+// CHECK: %[[IDX:.+]] = gpu.thread_id x
+// CHECK: %[[IDS:.+]]:4 = affine.delinearize_index %[[IDX]] into (%c1, %c1, %c8, %c1) : index, index, index, index
+// CHECK: vector.transfer_read %arg0[%[[IDS]]#2, %c0], {{.*}} : memref<4x4xf16>, vector<1x8xf16>
+// CHECK: vector.insert_strided_slice %{{.*}}, %[[ACC]] {offsets = [0, 0, 0, 0, 0, 0], strides = [1, 1]} : vector<1x8xf16> into vector<2x2x1x1x1x8xf16>
+// CHECK: vector.transfer_read %arg0[%[[IDS]]#2, %c8]
+// CHECK: vector.insert_strided_slice {{.*}} {offsets = [1, 0, 0, 0, 0, 0]
+// CHECK: %[[ID_PLUS_BATCH1:.+]] = affine.apply #[[$MAP]]()[%[[IDS]]#2]
+// CHECK: vector.transfer_read %arg0[%[[ID_PLUS_BATCH1]], %c0]
+// CHECK: vector.insert_strided_slice {{.*}} {offsets = [0, 1, 0, 0, 0, 0]
+// CHECK: vector.transfer_read %arg0[%[[ID_PLUS_BATCH1]], %c8]
+// CHECK: vector.insert_strided_slice {{.*}} {offsets = [1, 1, 0, 0, 0, 0]
+// CHECK: iree_vector_ext.to_simd %10 : vector<2x2x1x1x1x8xf16> -> vector<16x16xf16>
+
+// -----
+
+#layout_col_major = #iree_vector_ext.nested_layout<
+ subgroups_per_workgroup = [1, 1],
+ batches_per_subgroup = [1, 2],
+ outers_per_batch = [1, 1],
+ threads_per_outer = [4, 8],
+ elements_per_thread = [4, 1],
+
+ subgroup_order = [0, 1],
+ batch_order = [1, 0],
+ outer_order = [0, 1],
+ thread_order = [0, 1],
+ element_order = [1, 0]
+>
+
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 4)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<()[s0] -> (s0 + 8)>
+
+// CHECK-LABEL: @distribute_transfer_read_col_major
+func.func @distribute_transfer_read_col_major(%arg0: memref<32x32xf16>) -> vector<16x16xf16> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.0 : f16
+ %root = vector.transfer_read %arg0[%c0, %c0], %cst
+ {in_bounds = [true, true],
+ "__vector_layout_test_anchor_result_0" = #layout_col_major}
+ : memref<32x32xf16>, vector<16x16xf16>
+ func.return %root : vector<16x16xf16>
+}
+
+builtin.module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+ %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+ transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
+ transform.yield
+ }
+}
+
+// CHECK: %[[IDS:.+]]:4 = affine.delinearize_index %{{.*}} into (%c1, %c1, %c4, %c8) : index, index, index, index
+// CHECK: %[[LANEY:.+]] = affine.apply #[[$MAP]]()[%[[IDS]]#2]
+// CHECK: %[[RD00:.+]] = vector.transfer_read %arg0[%[[LANEY:.+]], %[[IDS]]#3], {{.*}} : memref<32x32xf16>, vector<4x1xf16>
+// CHECK: %[[ELEM_ORDER:.+]] = vector.transpose %[[RD00]], [1, 0] : vector<4x1xf16> to vector<1x4xf16>
+// CHECK: vector.insert_strided_slice %[[ELEM_ORDER]], %{{.*}} {offsets = [0, 0, 0, 0, 0, 0], strides = [1, 1]} : vector<1x4xf16> into vector<2x1x1x1x1x4xf16>
+// CHECK: %[[LANEX_PLUS_BATCH:.+]] = affine.apply #[[$MAP1]]()[%[[IDS]]#3]
+// CHECK: vector.transfer_read %arg0[%[[LANEY]], %[[LANEX_PLUS_BATCH]]], %{{.*}} {in_bounds = [true, true]} : memref<32x32xf16>, vector<4x1xf16>
+// CHECK: vector.transpose %{{.*}}, [1, 0] : vector<4x1xf16> to vector<1x4xf16>
+// CHECK: vector.insert_strided_slice {{.*}} {offsets = [1, 0, 0, 0, 0, 0]
+// CHECK: iree_vector_ext.to_simd %{{.*}} : vector<2x1x1x1x1x4xf16> -> vector<16x16xf16>
+
+// -----
+
+#layout_row_major = #iree_vector_ext.nested_layout<
+ subgroups_per_workgroup = [1, 1],
+ batches_per_subgroup = [2, 2],
+ outers_per_batch = [1, 1],
+ threads_per_outer = [8, 1],
+ elements_per_thread = [1, 8],
+
+ subgroup_order = [0, 1],
+ batch_order = [1, 0],
+ outer_order = [0, 1],
+ thread_order = [0, 1],
+ element_order = [0, 1]
+>
+
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<()[s0] -> (s0 + 8)>
+// CHECK-DAG: #[[$MAP2:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 8)>
+
+func.func @distribute_transfer_read_row_major_with_nontrivial_index(%a: index, %b: index, %arg0: memref<32x32x32x32xf16>) -> vector<16x16xf16> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.0 : f16
+ %root = vector.transfer_read %arg0[%c0, %c0, %a, %b], %cst
+ {in_bounds = [true, true],
+ permutation_map = affine_map<(d0, d1, d2, d3) -> (d2, d3)>,
+ "__vector_layout_test_anchor_result_0" = #layout_row_major}
+ : memref<32x32x32x32xf16>, vector<16x16xf16>
+ func.return %root : vector<16x16xf16>
+}
+
+builtin.module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+ %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+ transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
+ transform.yield
+ }
+}
+// CHECK-LABEL: @distribute_transfer_read_row_major_with_nontrivial_index
+// CHECK-SAME: %[[I0:.+]]: index, %[[I1:.+]]: index
+
+// CHECK: %[[IDS:.+]]:4 = affine.delinearize_index %{{.*}} into (%c1, %c1, %c8, %c1) : index, index, index, index
+// CHECK: %[[OFF0:.+]] = affine.apply #[[$MAP]]()[%[[IDS]]#2, %[[I0]]]
+// CHECK: vector.transfer_read %{{.*}}[%c0, %c0, %[[OFF0]], %[[I1]]]
+// CHECK: %[[OFF1:.+]] = affine.apply #[[$MAP1]]()[%arg1]
+// CHECK: vector.transfer_read %{{.*}}[%c0, %c0, %[[OFF0]], %[[OFF1]]]
+// CHECK: %[[OFF2:.+]] = affine.apply #[[$MAP2]]()[%[[IDS]]#2, %[[I0]]]
+// CHECK: vector.transfer_read %{{.*}}[%c0, %c0, %[[OFF2]], %[[I1]]]
+// CHECK: vector.transfer_read %{{.*}}[%c0, %c0, %[[OFF2]], %[[OFF1]]]
+
+// -----
+
+#layout_col_major = #iree_vector_ext.nested_layout<
+ subgroups_per_workgroup = [1, 1],
+ batches_per_subgroup = [1, 2],
+ outers_per_batch = [1, 1],
+ threads_per_outer = [4, 8],
+ elements_per_thread = [4, 1],
+
+ subgroup_order = [0, 1],
+ batch_order = [1, 0],
+ outer_order = [0, 1],
+ thread_order = [0, 1],
+ element_order = [1, 0]
+>
+
+func.func @distribute_transfer_read_col_major_with_broadcast(%a: index, %b: index, %arg0: memref<32x32x32x32xf16>) -> vector<16x16xf16> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.0 : f16
+ %root = vector.transfer_read %arg0[%c0, %c0, %a, %b], %cst
+ {in_bounds = [true, true],
+ permutation_map = affine_map<(d0, d1, d2, d3) -> (0, 0)>,
+ "__vector_layout_test_anchor_result_0" = #layout_col_major}
+ : memref<32x32x32x32xf16>, vector<16x16xf16>
+ func.return %root : vector<16x16xf16>
+}
+
+builtin.module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+ %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+ transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
+ transform.yield
+ }
+}
+
+// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (0, 0)>
+
+// CHECK-LABEL: @distribute_transfer_read_col_major_with_broadcast
+// CHECK-SAME: %[[I0:.+]]: index, %[[I1:.+]]: index
+
+// CHECK: %[[BROADCAST_READ:.+]] = vector.transfer_read %{{.*}}[%c0, %c0, %[[I0]], %[[I1]]], %{{.*}} permutation_map = #[[$MAP]]
+// CHECK: %[[UNIT:.+]] = vector.transpose %[[BROADCAST_READ]], [1, 0] : vector<4x1xf16> to vector<1x4xf16>
+// CHECK: vector.insert_strided_slice %[[UNIT]], %{{.*}} {offsets = [0, 0, 0, 0, 0, 0]
+// CHECK: vector.insert_strided_slice %[[UNIT]], %{{.*}} {offsets = [1, 0, 0, 0, 0, 0]
+
+// -----
+
+#layout_row_major = #iree_vector_ext.nested_layout<
+ subgroups_per_workgroup = [1, 1],
+ batches_per_subgroup = [2, 2],
+ outers_per_batch = [1, 1],
+ threads_per_outer = [8, 1],
+ elements_per_thread = [1, 8],
+
+ subgroup_order = [0, 1],
+ batch_order = [1, 0],
+ outer_order = [0, 1],
+ thread_order = [0, 1],
+ element_order = [0, 1]
+>
+
+// CHECK-DAG: #[[$MAP0:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+// CHECK-DAG: #[[$MAP2:.+]] = affine_map<()[s0] -> (s0 + 8)>
+// CHECK-DAG: #[[$MAP3:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 8)>
+
+func.func @distribute_transfer_read_row_major_transpose(%a: index, %b: index, %arg0: memref<32x32x32x32xf16>) -> vector<16x16xf16> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.0 : f16
+ %root = vector.transfer_read %arg0[%c0, %c0, %a, %b], %cst
+ {in_bounds = [true, true],
+ permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
+ "__vector_layout_test_anchor_result_0" = #layout_row_major}
+ : memref<32x32x32x32xf16>, vector<16x16xf16>
+ func.return %root : vector<16x16xf16>
+}
+
+builtin.module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+ %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+ transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
+ transform.yield
+ }
+}
+
+// CHECK-LABEL: @distribute_transfer_read_row_major_transpose
+// CHECK-SAME: %[[I0:.+]]: index, %[[I1:.+]]: index
+
+// CHECK: %[[IDS:.+]]:4 = affine.delinearize_index %{{.*}} into (%c1, %c1, %c8, %c1) : index, index, index, index
+// CHECK: %[[LIN_ID0:.+]] = affine.apply #[[$MAP:.+]]()[%[[IDS]]#2, %[[I1]]]
+// CHECK: vector.transfer_read %{{.*}}[%c0, %c0, %[[I0]], %[[LIN_ID0]]], {{.*}} permutation_map = #[[$MAP1]]
+// CHECK: %[[I0_PLUS_8:.+]] = affine.apply #[[$MAP2]]()[%[[I0]]]
+// CHECK: vector.transfer_read %arg2[%c0, %c0, %[[I0_PLUS_8]], %[[LIN_ID0]]], {{.*}} permutation_map = #[[$MAP1]]
+// CHECK: %[[LIN_ID1:.+]] = affine.apply #[[$MAP3]]()[%[[IDS]]#2, %[[I1]]]
+// CHECK: vector.transfer_read %arg2[%c0, %c0, %[[I0]], %[[LIN_ID1]]], {{.*}} permutation_map = #[[$MAP1]]
+// CHECK: vector.transfer_read %arg2[%c0, %c0, %[[I0_PLUS_8]], %[[LIN_ID1]]], %cst_0 {in_bounds = [true, true], permutation_map = #map1} : memref<32x32x32x32xf16>, vector<1x8xf16>
+
+// -----
+
+#layout_col_major = #iree_vector_ext.nested_layout<
+ subgroups_per_workgroup = [1, 1],
+ batches_per_subgroup = [1, 2],
+ outers_per_batch = [1, 1],
+ threads_per_outer = [4, 8],
+ elements_per_thread = [4, 1],
+
+ subgroup_order = [0, 1],
+ batch_order = [0, 1],
+ outer_order = [0, 1],
+ thread_order = [0, 1],
+ element_order = [0, 1]
+>
+
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+
+// CHECK-LABEL: @distribute_transfer_read_col_major_transpose
+func.func @distribute_transfer_read_col_major_transpose(%a: index, %b: index, %arg0: memref<32x32x32x32xf16>) -> vector<16x16xf16> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.0 : f16
+ %root = vector.transfer_read %arg0[%c0, %c0, %a, %b], %cst
+ {in_bounds = [true, true],
+ permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
+ "__vector_layout_test_anchor_result_0" = #layout_col_major}
+ : memref<32x32x32x32xf16>, vector<16x16xf16>
+ func.return %root : vector<16x16xf16>
+}
+
+builtin.module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+ %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+ transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
+ transform.yield
+ }
+}
+
+// CHECK: vector.transfer_read {{.*}} permutation_map = #[[$MAP2]]
+// CHECK: vector.transfer_read {{.*}} permutation_map = #[[$MAP2]]
+
+// -----
+
+#layout = #iree_vector_ext.nested_layout<
+ subgroups_per_workgroup = [7, 3, 1, 1],
+ batches_per_subgroup = [3, 5, 2, 1],
+ outers_per_batch = [1, 1, 2, 4],
+ threads_per_outer = [1, 1, 2, 2],
+ elements_per_thread = [1, 1, 1, 2],
+
+ subgroup_order = [1, 0, 2, 3],
+ batch_order = [1, 2, 3, 0],
+ outer_order = [0, 3, 1, 2],
+ thread_order = [0, 1, 3, 2],
+ element_order = [0, 1, 2, 3]
+>
+
+func.func @distribute_transfer_read_row_major_with_permutations(%a: index, %b: index, %arg0: memref<32x32x32x32xf16>) -> vector<21x15x8x16xf16> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.0 : f16
+ %root = vector.transfer_read %arg0[%c0, %c0, %a, %b], %cst
+ {in_bounds = [true, true, true, true],
+ permutation_map = affine_map<(d0, d1, d2, d3) -> (d0, d3, 0, d1)>,
+ "__vector_layout_test_anchor_result_0" = #layout}
+ : memref<32x32x32x32xf16>, vector<21x15x8x16xf16>
+ func.return %root : vector<21x15x8x16xf16>
+}
+
+builtin.module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+ %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+ transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
+ transform.yield
+ }
+}
+// CHECK-LABEL: @distribute_transfer_read_row_major_with_permutations
+
+// Verify that there are (batch0: 3) * (batch1: 5) * (outer3: 4) = 60 total
+// unique transfer read ops. The broadcasted dimension (2) CSEs the duplicate
+// reads.
+// CHECK-COUNT-60: vector.transfer_read
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
index c801ad2..1ca1dd2 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
@@ -966,6 +966,7 @@
populateGPUDistributionPatterns(patterns);
populateGPUDistributionLayoutAttrPatterns(laneId, patterns);
populateGPUReductionDistributionPatterns(patterns);
+ populateGPUDistributeNestedLayoutAttrPatterns(laneId, patterns);
distributeVectorOps(target, patterns, options);
return DiagnosedSilenceableFailure::success();
}
diff --git a/llvm-external-projects/iree-dialects/BUILD.bazel b/llvm-external-projects/iree-dialects/BUILD.bazel
index 82ea1fb..31044bc 100644
--- a/llvm-external-projects/iree-dialects/BUILD.bazel
+++ b/llvm-external-projects/iree-dialects/BUILD.bazel
@@ -733,6 +733,7 @@
":IREEVectorExtIncGen",
":IREEVectorExtInterfacesIncGen",
"@llvm-project//llvm:Support",
+ "@llvm-project//mlir:AffineDialect",
"@llvm-project//mlir:DialectUtils",
"@llvm-project//mlir:IR",
],
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.td
index 05e331c..c12f9eb 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.td
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.td
@@ -293,6 +293,12 @@
`>`
}];
+ let extraClassDeclaration = [{
+ // Returns the subgroup/lane ids delinearized from a single linearized
+ // thread ID.
+ ValueRange computeThreadIds(Value threadId, RewriterBase &rewriter) const;
+ }];
+
let genVerifyDecl = 1;
}
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtInterfaces.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtInterfaces.h
index 71e9f0c..59bcd1a 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtInterfaces.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtInterfaces.h
@@ -9,6 +9,8 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
+#include <mlir/IR/PatternMatch.h>
+#include <mlir/IR/Value.h>
/// Include the generated interface declarations.
#include "iree-dialects/Dialect/VectorExt/IR/VectorExtAttrInterfaces.h.inc"
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/CMakeLists.txt
index 6e965d1..82d810b 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/CMakeLists.txt
@@ -12,6 +12,7 @@
LINK_LIBS PUBLIC
MLIRIR
+ MLIRAffineDialect
)
iree_dialects_target_includes(IREEVectorExtDialect)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp
index 6b80529..033e0ff 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp
@@ -4,12 +4,16 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+#include <numeric>
+
#include "iree-dialects/Dialect/VectorExt/IR/VectorExtDialect.h"
#include "iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/TypeSwitch.h"
-#include <numeric>
using namespace mlir;
@@ -264,6 +268,23 @@
return success();
}
+/// Given a single flat thread ID, compute the indices of the distributed
+/// dimensions (subgroup and thread ids).
+ValueRange NestedLayoutAttr::computeThreadIds(Value threadId,
+ RewriterBase &rewriter) const {
+ SmallVector<OpFoldResult> basis;
+ for (auto warpTy : getSubgroupOrder()) {
+ basis.push_back(rewriter.getIndexAttr(getSubgroupsPerWorkgroup()[warpTy]));
+ }
+ for (auto threadTy : getThreadOrder()) {
+ basis.push_back(rewriter.getIndexAttr(getThreadsPerOuter()[threadTy]));
+ }
+
+ auto delinearized = rewriter.create<mlir::affine::AffineDelinearizeIndexOp>(
+ threadId.getLoc(), threadId, basis);
+ return delinearized->getResults();
+}
+
} // namespace mlir::iree_compiler::IREE::VectorExt
using namespace mlir::iree_compiler::IREE::VectorExt;