[VectorDistribution] Add patterns for distributing transfer_read/transfer_write (#16115)
This patch adds distribution patterns for transfer_read/transfer_write
and lowers them to vector.load/vector.store per thread.
These distribution patterns do the lowering for
transfer_read/transfer_write in one-shot, which is different from how
transfer_read/transfer_write are lowered in upstream mlir. The upstream
patterns unroll one dimension at a time and apply the patterns
recursively. We do this lowering in one-shot because we have the layout
attribute which defines the iteration space for the lowering.
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp
index 15583b5..84dce8d 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp
@@ -8,7 +8,9 @@
#include "iree/compiler/Codegen/Common/GPU/GPUPatterns.h"
#include "iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.h"
#include "iree/compiler/Codegen/Common/VectorLayoutAnalysis.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Rewrite/PatternApplicator.h"
@@ -20,6 +22,93 @@
namespace {
+/// Given a LayoutAttr, find the shape of the given layout dimension. It is
+/// expected that the layout has at most one instance of the requested
+/// dimension. Example:
+/// LayoutAttr: <<BATCHX: 4>, <BATCHY: 4, LANEX: 4>>
+/// dim: BATCHX
+/// output: 4
+static std::optional<int64_t> findDimShape(LayoutAttr layout,
+ LayoutDimension dim) {
+ for (PerDimLayoutAttr dimLayout : layout.getLayouts()) {
+ if (std::optional<int64_t> shape = dimLayout.getShape(dim)) {
+ return shape;
+ }
+ }
+ return std::nullopt;
+}
+
+/// Given the state of the iterator, compute the indices of the original vector
+/// that the current iterator state is iterating over. These indices are
+/// parameterized by the thread grid.
+static SmallVector<Value> computeSIMDIndex(const LayoutIterator::State &state,
+ LayoutAttr layout,
+ ArrayRef<Value> threadGrid,
+ RewriterBase &rewriter) {
+ MLIRContext *ctx = layout.getContext();
+ AffineExpr threadX, threadY, threadZ;
+ bindSymbols(ctx, threadX, threadY, threadZ);
+
+ SmallVector<Value> simdIndex;
+ // Calculate the index for each dim separately.
+ for (PerDimLayoutAttr dimLayout : layout.getLayouts()) {
+ AffineExpr offset = getAffineConstantExpr(0, ctx);
+ AffineExpr stride = getAffineConstantExpr(1, ctx);
+ for (auto [label, shape] : llvm::reverse(
+ llvm::zip(dimLayout.getLabels(), dimLayout.getShapes()))) {
+ int64_t position = state.lookup(label.getValue()).getPosition();
+
+ switch (label.getValue()) {
+ case LayoutDimension::LANEX:
+ offset = offset + stride * threadX;
+ break;
+ case LayoutDimension::LANEY:
+ offset = offset + stride * threadY;
+ break;
+ case LayoutDimension::LANEZ:
+ offset = offset + stride * threadZ;
+ break;
+ default:
+ offset = offset + stride * getAffineConstantExpr(position, ctx);
+ break;
+ }
+ stride = stride * getAffineConstantExpr(shape, ctx);
+ }
+
+ // Compute the index for the dim.
+ AffineMap indexMap = AffineMap::get(0, 3, offset);
+ Value index = rewriter.create<affine::AffineApplyOp>(
+ rewriter.getUnknownLoc(), indexMap, threadGrid);
+ simdIndex.push_back(index);
+ }
+
+ return simdIndex;
+}
+
+/// Given the state of the iterator, compute the indices of the distributed
+/// vector that the current iterator state is iterating over. The indices
+/// are not parameterized by thread, and it is expected that the indices for
+/// all threads are same.
+static SmallVector<int64_t> computeSIMTIndex(const LayoutIterator::State &state,
+ LayoutAttr layout) {
+ constexpr LayoutDimension labels[] = {
+ LayoutDimension::BATCHX, LayoutDimension::BATCHY,
+ LayoutDimension::VECTORZ, LayoutDimension::VECTORY,
+ LayoutDimension::VECTORX};
+
+ SmallVector<int64_t> offset;
+ for (LayoutDimension label : labels) {
+ std::optional shape = findDimShape(layout, label);
+ if (!shape) {
+ continue;
+ }
+ // Get current position for the label.
+ int64_t position = state.lookup(label).getPosition();
+ offset.push_back(position);
+ }
+ return offset;
+}
+
struct DistributeConstants final : OpDistributionPattern<arith::ConstantOp> {
using OpDistributionPattern::OpDistributionPattern;
@@ -99,6 +188,208 @@
}
};
+/// Given a projected permutation, get a reduced permutation, i.e. without
+/// the projected dimensions.
+static SmallVector<int64_t> getReducedPermutation(AffineMap permutationMap) {
+ assert(permutationMap.isProjectedPermutation() &&
+ "permutation map should be a projected permutation.");
+ // TODO: The permutation map may also have broadcasting. Currently, we do not
+ // handle it. This can be fixed by adding a "BROADCAST" dimension in the
+ // layout.
+
+ SmallVector<int64_t> permutation;
+ permutation.reserve(permutationMap.getNumResults());
+
+ unsigned leadingUnitDims =
+ permutationMap.getNumDims() - permutationMap.getNumResults();
+ for (AffineExpr dim : permutationMap.getResults()) {
+ // Get this dim's position in the permutation map.
+ auto dimExpr = dyn_cast<AffineDimExpr>(dim);
+ if (!dimExpr) {
+ llvm::report_fatal_error("permutation map is not a projected "
+ "permutation.");
+ }
+
+ unsigned pos = dimExpr.getPosition();
+ assert(pos >= leadingUnitDims && "invalid permutation map");
+ pos -= leadingUnitDims;
+ permutation.push_back(pos);
+ }
+ return permutation;
+}
+
+template <typename OpTy>
+struct DistributeXferLayoutAttr : OpDistributionPattern<OpTy> {
+ static_assert(std::is_same<OpTy, vector::TransferReadOp>::value ||
+ std::is_same<OpTy, vector::TransferWriteOp>::value,
+ "expected vector::TransferReadOp or vector::TransferWriteOp");
+
+ DistributeXferLayoutAttr(MLIRContext *context, ArrayRef<Value> threadGrid,
+ PatternBenefit benefit = 1)
+ : OpDistributionPattern<OpTy>(context, benefit), threadGrid(threadGrid) {}
+
+ VectorValue accessMemory(OpTy xferOp, VectorValue accumulator,
+ LayoutAttr vectorLayout,
+ PatternRewriter &rewriter) const {
+ // We need to take special consideration of the permutation map when
+ // lowering. When accessing memory, we use the memoryLayout, because that
+ // is how the data is accessed in memory. The data is stored in the vector
+ // according to vectorLayout.
+ SmallVector<int64_t> permutation =
+ getReducedPermutation(xferOp.getPermutationMap());
+ LayoutAttr memoryLayout =
+ cast<LayoutAttr>(vectorLayout.permute(permutation));
+
+ int loadWidth = getLoadStoreWidth(memoryLayout);
+ DenseMap<LayoutDimension, int64_t> steps;
+ steps[LayoutDimension::VECTORX] = loadWidth;
+ LayoutIterator iterator(vectorLayout, steps);
+
+ iterator.apply([&](const LayoutIterator::State &state) {
+ SmallVector<Value> memoryIndices =
+ getMemoryIndices(state, memoryLayout, xferOp.getIndices(), rewriter);
+ SmallVector<int64_t> accIndices = computeSIMTIndex(state, vectorLayout);
+ accumulator = accessUnit(xferOp, memoryIndices, accIndices, accumulator,
+ vectorLayout, memoryLayout, rewriter);
+ });
+
+ return accumulator;
+ }
+
+ SmallVector<Value> getMemoryIndices(const LayoutIterator::State &state,
+ LayoutAttr memoryLayout,
+ SmallVector<Value> indices,
+ RewriterBase &rewriter) const {
+ SmallVector<Value> simdIndices =
+ computeSIMDIndex(state, memoryLayout, threadGrid, rewriter);
+ SmallVector<Value> memoryIndices(indices);
+
+ // The memory layout has some projected leading dims that indices doesn't.
+ int leadingProjectedDims = memoryIndices.size() - simdIndices.size();
+ for (int i = leadingProjectedDims, e = memoryIndices.size(); i < e; ++i) {
+ memoryIndices[i] = rewriter.create<arith::AddIOp>(
+ rewriter.getUnknownLoc(), memoryIndices[i],
+ simdIndices[i - leadingProjectedDims]);
+ }
+
+ return memoryIndices;
+ }
+
+ virtual VectorValue accessUnit(OpTy xferOp, SmallVector<Value> &memoryIndices,
+ SmallVector<int64_t> &accIndices,
+ VectorValue accumulator,
+ LayoutAttr vectorLayout,
+ LayoutAttr memoryLayout,
+ PatternRewriter &rewriter) const = 0;
+
+ int getLoadStoreWidth(LayoutAttr layout) const {
+ PerDimLayoutAttr fastestChanging = layout.getLayouts().back();
+ if (std::optional<int64_t> width =
+ fastestChanging.getShape(LayoutDimension::VECTORX)) {
+ return *width;
+ }
+ return 1;
+ }
+
+ SmallVector<Value> threadGrid;
+};
+
+struct DistributeTransferReadLayoutAttr final
+ : DistributeXferLayoutAttr<vector::TransferReadOp> {
+ using DistributeXferLayoutAttr::DistributeXferLayoutAttr;
+
+ LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
+ DistributionSignature &signature,
+ PatternRewriter &rewriter) const override {
+ VectorValue vector = readOp.getVector();
+ LayoutAttr vectorLayout = dyn_cast<LayoutAttr>(signature.results[0]);
+ if (!vectorLayout) {
+ return failure();
+ }
+
+ // TODO: Return failure if we need masking.
+
+ Type elementType = readOp.getSource().getType().getElementType();
+ auto vectorType = VectorType::get(
+ vectorLayout.getDistributedShape(vector.getType()), elementType);
+ Value zero = rewriter.create<arith::ConstantOp>(
+ readOp.getLoc(), vectorType, rewriter.getZeroAttr(vectorType));
+ VectorValue acc = cast<VectorValue>(zero);
+
+ VectorValue readVec = accessMemory(readOp, acc, vectorLayout, rewriter);
+
+ replaceOpWithDistributedValues(rewriter, readOp, readVec);
+ return success();
+ }
+
+ VectorValue accessUnit(vector::TransferReadOp readOp,
+ SmallVector<Value> &memoryIndices,
+ SmallVector<int64_t> &accIndices,
+ VectorValue accumulator, LayoutAttr vectorLayout,
+ LayoutAttr memoryLayout,
+ PatternRewriter &rewriter) const override {
+ auto unitType = VectorType::get({getLoadStoreWidth(memoryLayout)},
+ accumulator.getType().getElementType());
+ VectorValue load = rewriter.create<vector::LoadOp>(
+ readOp.getLoc(), unitType, readOp.getSource(), memoryIndices);
+ return rewriter.create<vector::InsertStridedSliceOp>(
+ readOp.getLoc(), load, accumulator, accIndices,
+ SmallVector<int64_t>{1});
+ }
+};
+
+struct DistributeTransferWriteLayoutAttr final
+ : DistributeXferLayoutAttr<vector::TransferWriteOp> {
+ using DistributeXferLayoutAttr::DistributeXferLayoutAttr;
+
+ LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
+ DistributionSignature &signature,
+ PatternRewriter &rewriter) const override {
+ VectorValue vector = writeOp.getVector();
+ LayoutAttr vectorLayout = dyn_cast<LayoutAttr>(signature.operands[0]);
+ if (!vectorLayout) {
+ return failure();
+ }
+
+ // TODO: Return failure if we need masking.
+
+ Type elementType = writeOp.getSource().getType().getElementType();
+ auto vectorType = VectorType::get(
+ vectorLayout.getDistributedShape(vector.getType()), elementType);
+ Value zero = rewriter.create<arith::ConstantOp>(
+ writeOp.getLoc(), vectorType, rewriter.getZeroAttr(vectorType));
+ VectorValue acc = cast<VectorValue>(zero);
+
+ accessMemory(writeOp, acc, vectorLayout, rewriter);
+
+ rewriter.eraseOp(writeOp);
+ return success();
+ }
+
+ VectorValue accessUnit(vector::TransferWriteOp writeOp,
+ SmallVector<Value> &memoryIndices,
+ SmallVector<int64_t> &accIndices,
+ VectorValue accumulator, LayoutAttr vectorLayout,
+ LayoutAttr memoryLayout,
+ PatternRewriter &rewriter) const override {
+ int width = getLoadStoreWidth(memoryLayout);
+
+ SmallVector<int64_t> strides(accIndices.size(), 1);
+ SmallVector<int64_t> shapes(accIndices.size(), 1);
+ shapes[shapes.size() - 1] = width;
+ Value result = rewriter.create<vector::ExtractStridedSliceOp>(
+ writeOp.getLoc(), getDistributed(rewriter, accumulator, vectorLayout),
+ accIndices, shapes, strides);
+ result = rewriter.create<vector::ExtractOp>(
+ writeOp.getLoc(), result,
+ SmallVector<int64_t>(accIndices.size() - 1, 0));
+ rewriter.create<vector::StoreOp>(writeOp.getLoc(), result,
+ writeOp.getSource(), memoryIndices);
+
+ return accumulator;
+ }
+};
+
}; // namespace
void populateGPUDistributionPatterns(RewritePatternSet &patterns) {
@@ -108,4 +399,11 @@
DistributeElementwise<arith::AddFOp>>(patterns.getContext());
}
+void populateGPUDistributionLayoutAttrPatterns(ArrayRef<Value> threadGrid,
+ RewritePatternSet &patterns) {
+ patterns
+ .add<DistributeTransferReadLayoutAttr, DistributeTransferWriteLayoutAttr>(
+ patterns.getContext(), threadGrid);
+}
+
}; // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPatterns.h b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPatterns.h
index c6da6ca..cdda85a 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPatterns.h
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPatterns.h
@@ -31,6 +31,9 @@
void populateGPUDistributionPatterns(RewritePatternSet &patterns);
+void populateGPUDistributionLayoutAttrPatterns(ArrayRef<Value> threadGrid,
+ RewritePatternSet &patterns);
+
} // namespace mlir::iree_compiler
#endif // IREE_COMPILER_CODEGEN_COMMON_GPUPATTERNS_H_
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir
index 53ddddf..a7ab386 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir
@@ -41,3 +41,175 @@
transform.yield
}
}
+
+// -----
+
+#layout_row_major = #iree_vector_ext.layout<<[BATCHX, LANEY], [2, 8]>, <[BATCHY, LANEX, VECTORX], [2, 1, 8]>>
+#layout_col_major = #iree_vector_ext.layout<<[BATCHX, LANEY, VECTORX], [1, 4, 4]>, <[BATCHY, LANEX], [2, 8]>>
+
+builtin.module attributes { transform.with_named_sequence } {
+ // CHECK-LABEL: @distribute_transfer_read_row_major
+ func.func @distribute_transfer_read_row_major(%alloc: memref<4x4xf16>) -> vector<16x16xf16> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.0 : f16
+ %root = vector.transfer_read %alloc[%c0, %c0], %cst
+ {in_bounds = [false, false],
+ "__vector_layout_test_anchor_result_0" = #layout_row_major}
+ : memref<4x4xf16>, vector<16x16xf16>
+ // CHECK-COUNT-4: vector.load {{.*}}, vector<8xf16>
+ func.return %root : vector<16x16xf16>
+ }
+
+ // CHECK-LABEL: @distribute_transfer_read_col_major
+ func.func @distribute_transfer_read_col_major(%alloc: memref<32x32xf16>) -> vector<16x16xf16> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.0 : f16
+ %root = vector.transfer_read %alloc[%c0, %c0], %cst
+ {in_bounds = [true, true],
+ "__vector_layout_test_anchor_result_0" = #layout_col_major}
+ : memref<32x32xf16>, vector<16x16xf16>
+ // CHECK-COUNT-8: vector.load {{.*}}, vector<1xf16>
+ func.return %root : vector<16x16xf16>
+ }
+
+ // CHECK-LABEL: @distribute_transfer_read_row_major_with_broadcast
+ func.func @distribute_transfer_read_row_major_with_broadcast(%a: index, %b: index, %alloc: memref<32x32x32x32xf16>) -> vector<16x16xf16> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.0 : f16
+ %root = vector.transfer_read %alloc[%c0, %c0, %a, %b], %cst
+ {in_bounds = [true, true],
+ permutation_map = affine_map<(d0, d1, d2, d3) -> (d2, d3)>,
+ "__vector_layout_test_anchor_result_0" = #layout_row_major}
+ : memref<32x32x32x32xf16>, vector<16x16xf16>
+ // CHECK-COUNT-4: vector.load {{.*}}, vector<8xf16>
+ func.return %root : vector<16x16xf16>
+ }
+
+ // CHECK-LABEL: @distribute_transfer_read_col_major_with_broadcast
+ func.func @distribute_transfer_read_col_major_with_broadcast(%a: index, %b: index, %alloc: memref<32x32x32x32xf16>) -> vector<16x16xf16> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.0 : f16
+ %root = vector.transfer_read %alloc[%c0, %c0, %a, %b], %cst
+ {in_bounds = [true, true],
+ permutation_map = affine_map<(d0, d1, d2, d3) -> (d2, d3)>,
+ "__vector_layout_test_anchor_result_0" = #layout_col_major}
+ : memref<32x32x32x32xf16>, vector<16x16xf16>
+ // CHECK-COUNT-8: vector.load {{.*}}, vector<1xf16>
+ func.return %root : vector<16x16xf16>
+ }
+
+ // CHECK-LABEL: @distribute_transfer_read_row_major_transpose
+ func.func @distribute_transfer_read_row_major_transpose(%a: index, %b: index, %alloc: memref<32x32x32x32xf16>) -> vector<16x16xf16> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.0 : f16
+ %root = vector.transfer_read %alloc[%c0, %c0, %a, %b], %cst
+ {in_bounds = [true, true],
+ permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
+ "__vector_layout_test_anchor_result_0" = #layout_row_major}
+ : memref<32x32x32x32xf16>, vector<16x16xf16>
+ // CHECK-COUNT-32: vector.load {{.*}}, vector<1xf16>
+ func.return %root : vector<16x16xf16>
+ }
+
+ // CHECK-LABEL: @distribute_transfer_read_col_major_transpose
+ func.func @distribute_transfer_read_col_major_transpose(%a: index, %b: index, %alloc: memref<32x32x32x32xf16>) -> vector<16x16xf16> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.0 : f16
+ %root = vector.transfer_read %alloc[%c0, %c0, %a, %b], %cst
+ {in_bounds = [true, true],
+ permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
+ "__vector_layout_test_anchor_result_0" = #layout_col_major}
+ : memref<32x32x32x32xf16>, vector<16x16xf16>
+ // CHECK-COUNT-2: vector.load {{.*}}, vector<4xf16>
+ func.return %root : vector<16x16xf16>
+ }
+
+ transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+ %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+ transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+#layout_row_major = #iree_vector_ext.layout<<[BATCHX, LANEY], [2, 8]>, <[BATCHY, LANEX, VECTORX], [2, 1, 8]>>
+#layout_col_major = #iree_vector_ext.layout<<[BATCHX, LANEY, VECTORX], [1, 4, 4]>, <[BATCHY, LANEX], [2, 8]>>
+
+builtin.module attributes { transform.with_named_sequence } {
+ // CHECK-LABEL: @distribute_transfer_write_row_major
+ func.func @distribute_transfer_write_row_major(%root: vector<16x16xf16>, %alloc: memref<64x64xf16>) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %root, %alloc[%c0, %c0]
+ {in_bounds = [true, true],
+ "__vector_layout_test_anchor_operand_0" = #layout_row_major}
+ : vector<16x16xf16>, memref<64x64xf16>
+ // CHECK-COUNT-4: vector.store {{.*}}, vector<8xf16>
+ func.return
+ }
+
+ // CHECK-LABEL: @distribute_transfer_write_col_major
+ func.func @distribute_transfer_write_col_major(%root: vector<16x16xf16>, %alloc: memref<64x64xf16>) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %root, %alloc[%c0, %c0]
+ {in_bounds = [true, true],
+ "__vector_layout_test_anchor_operand_0" = #layout_col_major}
+ : vector<16x16xf16>, memref<64x64xf16>
+ // CHECK-COUNT-8: vector.store {{.*}}, vector<1xf16>
+ func.return
+ }
+
+ // CHECK-LABEL: @distribute_transfer_write_row_major_with_broadcast
+ func.func @distribute_transfer_write_row_major_with_broadcast(%root: vector<16x16xf16>, %a: index, %b: index, %alloc: memref<32x32x32x32xf16>) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %root, %alloc[%c0, %c0, %a, %b]
+ {in_bounds = [true, true],
+ permutation_map = affine_map<(d0, d1, d2, d3) -> (d2, d3)>,
+ "__vector_layout_test_anchor_operand_0" = #layout_row_major}
+ : vector<16x16xf16>, memref<32x32x32x32xf16>
+ // CHECK-COUNT-4: vector.store {{.*}}, vector<8xf16>
+ func.return
+ }
+
+ // CHECK-LABEL: @distribute_transfer_write_col_major_with_broadcast
+ func.func @distribute_transfer_write_col_major_with_broadcast(%root: vector<16x16xf16>, %a: index, %b: index, %alloc: memref<32x32x32x32xf16>) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %root, %alloc[%c0, %c0, %a, %b]
+ {in_bounds = [true, true],
+ permutation_map = affine_map<(d0, d1, d2, d3) -> (d2, d3)>,
+ "__vector_layout_test_anchor_operand_0" = #layout_col_major}
+ : vector<16x16xf16>, memref<32x32x32x32xf16>
+ // CHECK-COUNT-8: vector.store {{.*}}, vector<1xf16>
+ func.return
+ }
+
+ // CHECK-LABEL: @distribute_transfer_write_row_major_transpose
+ func.func @distribute_transfer_write_row_major_transpose(%root: vector<16x16xf16>, %a: index, %b: index, %alloc: memref<32x32x32x32xf16>) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %root, %alloc[%c0, %c0, %a, %b]
+ {in_bounds = [true, true],
+ permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
+ "__vector_layout_test_anchor_operand_0" = #layout_row_major}
+ : vector<16x16xf16>, memref<32x32x32x32xf16>
+ // CHECK-COUNT-32: vector.store {{.*}}, vector<1xf16>
+ func.return
+ }
+
+ // CHECK-LABEL: @distribute_transfer_write_col_major_transpose
+ func.func @distribute_transfer_write_col_major_transpose(%root: vector<16x16xf16>, %a: index, %b: index, %alloc: memref<32x32x32x32xf16>) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %root, %alloc[%c0, %c0, %a, %b]
+ {in_bounds = [true, true],
+ permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
+ "__vector_layout_test_anchor_operand_0" = #layout_col_major}
+ : vector<16x16xf16>, memref<32x32x32x32xf16>
+ // CHECK-COUNT-2: vector.store {{.*}}, vector<4xf16>
+ func.return
+ }
+
+ transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+ %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+ transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
+ transform.yield
+ }
+}
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
index 38af96b..4eec9cb 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
@@ -1018,7 +1018,16 @@
transform::TransformState &state) {
TestVectorLayoutOptions options(target);
RewritePatternSet patterns(target.getContext());
+
+ rewriter.setInsertionPointToStart(&target.getBody().front());
+ SmallVector<Value> threadGrid = {
+ rewriter.create<gpu::ThreadIdOp>(target.getLoc(), gpu::Dimension::x),
+ rewriter.create<gpu::ThreadIdOp>(target.getLoc(), gpu::Dimension::y),
+ rewriter.create<gpu::ThreadIdOp>(target.getLoc(), gpu::Dimension::z),
+ };
+
populateGPUDistributionPatterns(patterns);
+ populateGPUDistributionLayoutAttrPatterns(threadGrid, patterns);
distributeVectorOps(target, patterns, options);
return DiagnosedSilenceableFailure::success();
}
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h
index e1d490b..f13d2f0 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h
@@ -42,6 +42,8 @@
return position < other.position;
}
+ int64_t getPosition() const { return position; }
+
private:
int64_t position, stride;
};