[VectorDistribution] Remove iteration on non-distributed levels (#17096)
This patch removes iteration order in the layout for non-distributed
levels. The layout now only represents how the register is tiled, and
does not define an iteration order over it. The patterns themselves
define an iteration order over the tiled layout.
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPUDistributeContract.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPUDistributeContract.cpp
index efe43d3..4088008 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPUDistributeContract.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPUDistributeContract.cpp
@@ -4,6 +4,7 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+#include <numeric>
#include "iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
#include "iree/compiler/Codegen/Utils/VectorOpUtils.h"
@@ -97,12 +98,12 @@
LLVM_DEBUG(llvm::dbgs() << "init tile: " << finalTile << "\n");
// Offsets into the LHS/RHS batches.
- SmallVector<int64_t, 2> lhsBatchOffsets(lhsLayout.getRank(), 0);
- SmallVector<int64_t, 2> rhsBatchOffsets(rhsLayout.getRank(), 0);
+ SmallVector<int64_t> lhsBatchOffsets(rank, 0);
+ SmallVector<int64_t> rhsBatchOffsets(rank, 0);
// Offsets into the result batches.
ArrayRef<int64_t> resultBatches = resultLayout.getBatchesPerSubgroup();
- SmallVector<int64_t, 2> resultBatchTileSizes(rank, 1);
+ SmallVector<int64_t> resultBatchTileSizes(rank, 1);
LLVM_DEBUG({
llvm::dbgs() << "result batches: [";
llvm::interleaveComma(resultBatches, llvm::dbgs());
@@ -114,18 +115,22 @@
Value lhs = getDistributed(rewriter, contractOp.getLhs(), lhsLayout);
Value rhs = getDistributed(rewriter, contractOp.getRhs(), rhsLayout);
+ SmallVector<AffineMap> indexingMaps = contractOp.getIndexingMapsArray();
+ AffineMap lhsMap = compressUnusedDims(indexingMaps[0]);
+ AffineMap rhsMap = compressUnusedDims(indexingMaps[1]);
+ AffineMap resMap = compressUnusedDims(indexingMaps[2]);
+
+ SmallVector<int64_t> resBatchOrder(resMap.getNumResults());
+ std::iota(resBatchOrder.begin(), resBatchOrder.end(), 0);
+ resBatchOrder = applyPermutationMap(resMap, ArrayRef(resBatchOrder));
+
// Iterate over all result batches and unroll computation to direct MFMA
// intrinsic ops.
Location loc = contractOp.getLoc();
auto resultTiles = StaticTileOffsetRange(
- resultBatches, resultBatchTileSizes, resultLayout.getBatchOrder());
+ resultBatches, resultBatchTileSizes, resBatchOrder);
SmallVector<int64_t, 2> resultBatchOffsets;
- for (SmallVector<int64_t, 2> originalResultBatchOffsets : resultTiles) {
- // Permute the result batch offsets first to match the distributed shape
- // dim order for indexing.
- resultBatchOffsets = originalResultBatchOffsets;
- applyPermutationToVector(resultBatchOffsets,
- resultLayout.getBatchOrder());
+ for (SmallVector<int64_t, 2> resultBatchOffsets : resultTiles) {
LLVM_DEBUG({
llvm::dbgs() << "current result batch offsets: [";
llvm::interleaveComma(resultBatchOffsets, llvm::dbgs());
@@ -151,9 +156,9 @@
// Fills the batch offsets for LHS and RHS. For the K dimension it's the
// induction variable; for the M/N dimension we need to extract from the
// result batch offsets.
- fillOperandBatchOffsets(opDetail, k, originalResultBatchOffsets,
- resultLayout, lhsBatchOffsets, rhsBatchOffsets,
- lhsLayout, rhsLayout);
+ fillOperandBatchOffsets(opDetail, k, resultBatchOffsets,
+ lhsBatchOffsets, rhsBatchOffsets, lhsMap,
+ rhsMap);
LLVM_DEBUG({
llvm::dbgs() << "current lhs batch offsets: [";
llvm::interleaveComma(lhsBatchOffsets, llvm::dbgs());
@@ -196,12 +201,10 @@
// both LHS and RHS.
void fillOperandBatchOffsets(const VectorContractOpInfo &opDetail,
int64_t kOffset, ArrayRef<int64_t> resultOffsets,
- NestedLayoutAttr resultLayout,
- SmallVector<int64_t, 2> &lhsOffsets,
- SmallVector<int64_t, 2> &rhsOffsets,
- NestedLayoutAttr lhsLayout,
- NestedLayoutAttr rhsLayout) const {
-
+ SmallVector<int64_t> &lhsOffsets,
+ SmallVector<int64_t> &rhsOffsets,
+ AffineMap lhsMap, AffineMap rhsMap) const {
+ auto [lhsK, rhsK] = opDetail.getOperandKIndex();
// resultOffsets contains batch indices into the C/D vector. It is a 2-D
// index for both M and N. We need to split out for M and N, and add index
// for K.
@@ -214,13 +217,8 @@
rhsOffsets[rhsN] = resultOffsets[resultN];
}
- auto [lhsK, rhsK] = opDetail.getOperandKIndex();
lhsOffsets[lhsK] = kOffset;
rhsOffsets[rhsK] = kOffset;
-
- // Now apply permutation on LHS/RHS according to their batch order.
- applyPermutationToVector(lhsOffsets, lhsLayout.getBatchOrder());
- applyPermutationToVector(rhsOffsets, rhsLayout.getBatchOrder());
}
struct AMDMMAParameters {
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp
index 1887c28..94a046a 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp
@@ -616,17 +616,18 @@
}
};
-struct DistributeTranspose final : OpDistributionPattern<vector::TransposeOp> {
+struct DistributeTransposeLayoutAttr final
+ : OpDistributionPattern<vector::TransposeOp> {
using OpDistributionPattern::OpDistributionPattern;
LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
DistributionSignature &signature,
PatternRewriter &rewriter) const override {
VectorValue value = transposeOp.getVector();
- VectorLayoutInterface layout =
- dyn_cast<VectorLayoutInterface>(signature[value]);
+ VectorLayoutInterface layout = dyn_cast<LayoutAttr>(signature[value]);
if (!layout) {
- return failure();
+ return rewriter.notifyMatchFailure(transposeOp,
+ "layout must be LayoutAttr");
}
/// Transpose only changes the notion of where the data carried by each
@@ -849,7 +850,7 @@
patterns
.add<DistributeTransferReadLayoutAttr, DistributeTransferWriteLayoutAttr>(
patterns.getContext(), laneId);
- patterns.add<DistributeBroadcastLayoutAttr, DistributeTranspose>(
+ patterns.add<DistributeBroadcastLayoutAttr, DistributeTransposeLayoutAttr>(
patterns.getContext());
}
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp
index f0859e8..f8122da 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp
@@ -81,12 +81,8 @@
int64_t rank = vectorLayout.getRank();
// Permute the batch and outer vector offsets to match the order of
// the vector dimensions using the inverse of the batch/offset order.
- SmallVector<int64_t> batchOffsets =
- applyPermutation(ArrayRef<int64_t>(offsets.begin(), rank),
- invertPermutationVector(vectorLayout.getBatchOrder()));
- SmallVector<int64_t> outerVectorOffsets =
- applyPermutation(ArrayRef<int64_t>(offsets.begin() + rank, rank),
- invertPermutationVector(vectorLayout.getOuterOrder()));
+ ArrayRef<int64_t> batchOffsets(offsets.begin(), rank);
+ ArrayRef<int64_t> outerVectorOffsets(offsets.begin() + rank, rank);
SmallVector<Value> slicedIndices(indices.begin(), indices.end());
for (const auto &[i, dim] : llvm::enumerate(permutationMap.getResults())) {
@@ -111,29 +107,6 @@
return slicedIndices;
}
-static SmallVector<int64_t> getLoopOrder(NestedLayoutAttr vectorLayout) {
- int64_t rank = vectorLayout.getRank();
- // Let the unroll order first unroll the batch dimensions, then the
- // outer vector dimensions. We unroll in the order specified by the
- // layout.
- SmallVector<int64_t> loopOrder;
- int64_t base = 0;
- for (auto b : vectorLayout.getBatchOrder()) {
- loopOrder.push_back(base + b);
- }
- base += rank;
- // We must unroll along the outer dimensions as well to match the rank
- // requirements of vector transfer ops (<= memref rank up to broadcasts).
- for (auto o : vectorLayout.getOuterOrder()) {
- loopOrder.push_back(base + o);
- }
- base += rank;
- for (int i = 0, e = rank; i < e; ++i) {
- loopOrder.push_back(base + i);
- }
- return loopOrder;
-}
-
static SmallVector<int64_t>
getElementVectorTileShape(NestedLayoutAttr vectorLayout) {
int64_t rank = vectorLayout.getRank();
@@ -215,7 +188,6 @@
SmallVector<int64_t> distShape = vectorLayout.getDistributedShape();
SmallVector<int64_t> tileShape = getElementVectorTileShape(vectorLayout);
- SmallVector<int64_t> loopOrder = getLoopOrder(vectorLayout);
int64_t rank = vectorLayout.getRank();
Type elementType = readOp.getSource().getType().getElementType();
@@ -238,7 +210,7 @@
ValueRange indices = readOp.getIndices();
SmallVector<int64_t> strides(rank, 1);
for (SmallVector<int64_t> offsets :
- StaticTileOffsetRange(distShape, tileShape, loopOrder)) {
+ StaticTileOffsetRange(distShape, tileShape)) {
SmallVector<Value> slicedIndices = getTransferIndicesFromNestedLayout(
rewriter, indices, offsets, vectorLayout, readOp.getPermutationMap(),
warpIndices, threadIndices);
@@ -247,19 +219,6 @@
readOp.getLoc(), innerVectorType, readOp.getSource(), slicedIndices,
readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
readOp.getInBoundsAttr());
- // Transpose to the element order.
- //
- // A = transfer_read
- // B = transpose A
- //
- // P(A) = I
- //
- // P(A) * perm = P(B)
- // perm = P(B)
- if (!isIdentityPermutation(vectorLayout.getElementOrder())) {
- slicedRead = rewriter.create<vector::TransposeOp>(
- slicedRead.getLoc(), slicedRead, vectorLayout.getElementOrder());
- }
acc = rewriter.create<vector::InsertStridedSliceOp>(
readOp.getLoc(), slicedRead, acc, offsets, strides);
@@ -302,7 +261,6 @@
SmallVector<int64_t> distShape = vectorLayout.getDistributedShape();
SmallVector<int64_t> tileShape = getElementVectorTileShape(vectorLayout);
- SmallVector<int64_t> loopOrder = getLoopOrder(vectorLayout);
int64_t rank = vectorLayout.getRank();
SmallVector<Value> warpIndices, threadIndices;
@@ -314,7 +272,7 @@
ValueRange indices = writeOp.getIndices();
for (SmallVector<int64_t> offsets :
- StaticTileOffsetRange(distShape, tileShape, loopOrder)) {
+ StaticTileOffsetRange(distShape, tileShape)) {
SmallVector<Value> slicedIndices = getTransferIndicesFromNestedLayout(
rewriter, indices, offsets, vectorLayout, writeOp.getPermutationMap(),
warpIndices, threadIndices);
@@ -326,20 +284,6 @@
Value slicedVector = rewriter.create<vector::ExtractOp>(
writeOp.getLoc(), distributedVector,
offsetArray.take_front(rank * 2));
- // Transpose to the native dimension order.
- // B = transpose(A)
- // transfer_write B
- //
- // P(B) = I
- //
- // P(A) * perm = P(B)
- // P(A) * perm = I
- // perm = P(A) ^ -1
- if (!isIdentityPermutation(vectorLayout.getElementOrder())) {
- slicedVector = rewriter.create<vector::TransposeOp>(
- slicedVector.getLoc(), slicedVector,
- invertPermutationVector(vectorLayout.getElementOrder()));
- }
rewriter.create<vector::TransferWriteOp>(
writeOp.getLoc(), slicedVector, writeOp.getSource(), slicedIndices,
writeOp.getPermutationMapAttr(), writeOp.getMask(),
@@ -388,53 +332,39 @@
loc, vectorType, rewriter.getZeroAttr(vectorType));
int64_t rank = vectorLayout.getRank();
- int64_t sourceRank = sourceLayout.getRank();
// We unroll along both the batch and outer dimensions for a similar reason
// to the transfer ops. `vector.broadcast` can only broadcast along outer
// dims, so mixing broadcasted and un-broadcasted element/outer dims can't
// be represented with a single `vector.broadcast`.
SmallVector<int64_t> resultVectorUnrollShape =
getElementVectorTileShape(vectorLayout);
- SmallVector<int64_t> loopOrder = getLoopOrder(vectorLayout);
Value distributedSource = getDistributed(rewriter, srcVector, sourceLayout);
VectorType broadcastTargetType =
- VectorType::get(applyPermutation(vectorLayout.getElementsPerThread(),
- vectorLayout.getElementOrder()),
- elementType);
+ VectorType::get(vectorLayout.getElementsPerThread(), elementType);
+
+ int64_t sourceRank = sourceLayout.getRank();
+
for (SmallVector<int64_t> offsets :
- StaticTileOffsetRange(distShape, resultVectorUnrollShape, loopOrder)) {
+ StaticTileOffsetRange(distShape, resultVectorUnrollShape)) {
ArrayRef<int64_t> offsetsRef(offsets);
- // Invert the permutations on the batch/outer offsets to get the offsets
- // in the order of the vector dimensions. We are iterating over each
- // (batch x outer) tile, and the offsets for those tiles are already
- // permuted by the layout batch/outer orders. Hence why we apply the
- // inverse permutation here.
- SmallVector<int64_t> permutedBatchOffsets = applyPermutation(
- offsetsRef.slice(0, rank),
- invertPermutationVector(vectorLayout.getBatchOrder()));
- SmallVector<int64_t> permutedOuterOffsets = applyPermutation(
- offsetsRef.slice(rank, rank),
- invertPermutationVector(vectorLayout.getOuterOrder()));
// Slice out the last |sourceRank| dimensions which is the inner
// broadcasted shape.
ArrayRef<int64_t> batchSourceOffsets =
- ArrayRef<int64_t>(permutedBatchOffsets)
- .slice(rank - sourceRank, sourceRank);
+ offsetsRef.slice(rank - sourceRank, sourceRank);
ArrayRef<int64_t> outerSourceOffsets =
- ArrayRef<int64_t>(permutedOuterOffsets)
- .slice(rank - sourceRank, sourceRank);
+ offsetsRef.slice(2 * rank - sourceRank, sourceRank);
// Construct the list of source offsets based on the batch/outer order of
// the broadcasted vector. This is because we need to compute the offsets
// into the distributed source vector with the distributed permutation.
SmallVector<int64_t> sourceOffsets;
- sourceOffsets.append(
- applyPermutation(batchSourceOffsets, sourceLayout.getBatchOrder()));
- sourceOffsets.append(
- applyPermutation(outerSourceOffsets, sourceLayout.getOuterOrder()));
+ sourceOffsets.append(batchSourceOffsets.begin(),
+ batchSourceOffsets.end());
+ sourceOffsets.append(outerSourceOffsets.begin(),
+ outerSourceOffsets.end());
// Extract a slice of the input to be broadcasted.
Value slice = rewriter.create<vector::ExtractOp>(loc, distributedSource,
@@ -538,14 +468,13 @@
// Do thread local reduce.
+ // The distributed reduction mask is simply the same mask appended
+ // thrice.
SmallVector<bool> distributedReductionMask;
distributedReductionMask.reserve(3 * rank);
- distributedReductionMask.append(
- applyPermutation(reducedDims, srcLayout.getBatchOrder()));
- distributedReductionMask.append(
- applyPermutation(reducedDims, srcLayout.getOuterOrder()));
- distributedReductionMask.append(
- applyPermutation(reducedDims, srcLayout.getElementOrder()));
+ for (int i = 0; i < 3; ++i) {
+ distributedReductionMask.append(reducedDims.begin(), reducedDims.end());
+ }
auto localReduction = rewriter.create<vector::MultiDimReductionOp>(
loc, disSrc, disAcc, distributedReductionMask, multiReduceOp.getKind());
@@ -627,6 +556,60 @@
int64_t maxBitsPerShuffle;
};
+struct DistributeTranspose final : OpDistributionPattern<vector::TransposeOp> {
+ using OpDistributionPattern::OpDistributionPattern;
+
+ LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
+ DistributionSignature &signature,
+ PatternRewriter &rewriter) const override {
+ VectorValue value = transposeOp.getVector();
+ VectorLayoutInterface layout = dyn_cast<NestedLayoutAttr>(signature[value]);
+ if (!layout) {
+ return rewriter.notifyMatchFailure(transposeOp,
+ "layout must be NestedLayoutAttr");
+ }
+
+ /// Transpose only changes the notion of where the data carried by each
+ /// thread comes from in the transposed SIMD vector. The data carried by
+ /// each thread is still the same, transposed as requested by the operation.
+ /// So, for distributed dimensions (thread and subgroup) transpose is a
+ /// no-op.
+ ///
+ /// Example (indices [0-3] represent ids of the threads carrying the data):
+ ///
+ /// input: vector<2x4xf16>
+ ///
+ /// 0 0 1 1
+ /// 2 2 3 3
+ ///
+ /// after transpose,
+ ///
+ /// transp: vector<4x2xf16>
+ ///
+ /// 0 2
+ /// 0 2
+ /// 1 3
+ /// 1 3
+ ///
+ /// As it can be seen, each thread is still carrying the same data but
+ /// just holds a transposed version of it.
+
+ VectorValue input = getDistributed(rewriter, value, layout);
+ // Permute batch, outer and element based on the given permutation.
+ int64_t rank = value.getType().getRank();
+ SmallVector<int64_t> permutation;
+ for (int i = 0; i < 3; ++i) {
+ for (auto it : transposeOp.getPermutation()) {
+ permutation.push_back(it + (i * rank));
+ }
+ }
+ VectorValue transposed = rewriter.create<vector::TransposeOp>(
+ transposeOp.getLoc(), input, permutation);
+ replaceOpWithDistributedValues(rewriter, transposeOp, transposed);
+ return success();
+ }
+};
+
} // namespace
void populateGPUDistributeNestedLayoutAttrPatterns(RewritePatternSet &patterns,
@@ -635,7 +618,7 @@
int64_t maxBitsPerShuffle) {
patterns.add<DistributeTransferRead, DistributeTransferWrite>(
patterns.getContext(), threadId);
- patterns.add<DistributeBroadcast>(patterns.getContext());
+ patterns.add<DistributeBroadcast, DistributeTranspose>(patterns.getContext());
patterns.add<DistributeMultiReduction>(patterns.getContext(), subgroupSize,
maxBitsPerShuffle);
}
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_contract_amdgpu.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_contract_amdgpu.mlir
index 9141916..3c26d7f 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_contract_amdgpu.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_contract_amdgpu.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --split-input-file --iree-transform-dialect-interpreter --cse %s | FileCheck %s
+// RUN: iree-opt --split-input-file --iree-transform-dialect-interpreter --canonicalize --cse %s | FileCheck %s
// CDNA3 V_MFMA_F32_32x32x8_F16
@@ -28,8 +28,6 @@
threads_per_outer = [2, 32],
elements_per_thread = [4, 1],
- element_order = [1, 0],
-
subgroup_basis = [1, 1],
thread_basis = [2, 32]
>
@@ -42,8 +40,6 @@
threads_per_outer = [2, 32],
elements_per_thread = [4, 1],
- element_order = [1, 0],
-
subgroup_basis = [1, 1],
thread_basis = [2, 32]
>
@@ -71,24 +67,23 @@
}
// CHECK-LABEL: func @contract_to_mfma_32x32x8_mm
-// CHECK-SAME: (%[[A:.+]]: vector<32x8xf16>, %[[B:.+]]: vector<8x32xf16>, %[[C:.+]]: vector<32x32xf32>)
-// CHECK: %[[INIT:.+]] = arith.constant dense<0.000000e+00> : vector<1x1x4x1x1x4xf32>
-// CHECK: %[[C_SIMT:.+]] = iree_vector_ext.to_simt %[[C]] : vector<32x32xf32> -> vector<1x1x4x1x1x4xf32>
-// CHECK: %[[A_SIMT:.+]] = iree_vector_ext.to_simt %[[A]] : vector<32x8xf16> -> vector<1x1x1x1x1x4xf16>
-// CHECK: %[[B_SIMT:.+]] = iree_vector_ext.to_simt %[[B]] : vector<8x32xf16> -> vector<1x1x1x1x1x4xf16>
-// CHECK: %[[C_VEC:.+]] = vector.extract %[[C_SIMT]][0, 0] : vector<4x1x1x4xf32> from vector<1x1x4x1x1x4xf32>
-// CHECK: %[[A_VEC:.+]] = vector.extract %[[A_SIMT]][0, 0] : vector<1x1x1x4xf16> from vector<1x1x1x1x1x4xf16>
-// CHECK: %[[B_VEC:.+]] = vector.extract %[[B_SIMT]][0, 0] : vector<1x1x1x4xf16> from vector<1x1x1x1x1x4xf16>
-// CHECK: %[[A_CAST:.+]] = vector.shape_cast %[[A_VEC]] : vector<1x1x1x4xf16> to vector<4xf16>
-// CHECK: %[[B_CAST:.+]] = vector.shape_cast %[[B_VEC]] : vector<1x1x1x4xf16> to vector<4xf16>
-// CHECK: %[[C_CAST:.+]] = vector.shape_cast %[[C_VEC]] : vector<4x1x1x4xf32> to vector<16xf32>
-// CHECK: %[[MFMA:.+]] = amdgpu.mfma %[[A_CAST]] * %[[B_CAST]] + %[[C_CAST]]
-// CHECK-SAME: {blocks = 1 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32} blgp = none
-// CHECK-SAME: : vector<4xf16>, vector<4xf16>, vector<16xf32>
-// CHECK: %[[R_CAST:.+]] = vector.shape_cast %[[MFMA]] : vector<16xf32> to vector<4x1x1x4xf32>
-// CHECK: %[[INSERT:.+]] = vector.insert %[[R_CAST]], %[[INIT]] [0, 0] : vector<4x1x1x4xf32> into vector<1x1x4x1x1x4xf32>
-// CHECK: %[[R_SIMD:.+]] = iree_vector_ext.to_simd %[[INSERT]] : vector<1x1x4x1x1x4xf32> -> vector<32x32xf32>
-// CHECK: return {{.*}} %[[R_SIMD]]
+// CHECK-SAME: (%[[A:.+]]: vector<32x8xf16>, %[[B:.+]]: vector<8x32xf16>, %[[C:.+]]: vector<32x32xf32>)
+// CHECK: %[[C_SIMT:.+]] = iree_vector_ext.to_simt %[[C]] : vector<32x32xf32> -> vector<1x1x4x1x4x1xf32
+// CHECK: %[[A_SIMT:.+]] = iree_vector_ext.to_simt %[[A]] : vector<32x8xf16> -> vector<1x1x1x1x1x4xf16>
+// CHECK: %[[B_SIMT:.+]] = iree_vector_ext.to_simt %[[B]] : vector<8x32xf16> -> vector<1x1x1x1x4x1xf16>
+// CHECK: %[[C_VEC:.+]] = vector.extract %[[C_SIMT]][0, 0] : vector<4x1x4x1xf32> from vector<1x1x4x1x4x1xf32>
+// CHECK: %[[A_VEC:.+]] = vector.extract %[[A_SIMT]][0, 0] : vector<1x1x1x4xf16> from vector<1x1x1x1x1x4xf16>
+// CHECK: %[[B_VEC:.+]] = vector.extract %[[B_SIMT]][0, 0] : vector<1x1x4x1xf16> from vector<1x1x1x1x4x1xf16>
+// CHECK: %[[A_CAST:.+]] = vector.shape_cast %[[A_VEC]] : vector<1x1x1x4xf16> to vector<4xf16>
+// CHECK: %[[B_CAST:.+]] = vector.shape_cast %[[B_VEC]] : vector<1x1x4x1xf16> to vector<4xf16>
+// CHECK: %[[C_CAST:.+]] = vector.shape_cast %[[C_VEC]] : vector<4x1x4x1xf32> to vector<16xf32>
+// CHECK: %[[MFMA:.+]] = amdgpu.mfma %[[A_CAST]] * %[[B_CAST]] + %[[C_CAST]]
+// CHECK-SAME: {blocks = 1 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32} blgp = none
+// CHECK-SAME: : vector<4xf16>, vector<4xf16>, vector<16xf32>
+// CHECK: %[[R_CAST:.+]] = vector.shape_cast %[[MFMA]] : vector<16xf32> to vector<4x1x4x1xf32>
+// CHECK: %[[B_OUT:.*]] = vector.broadcast %[[R_CAST]] : vector<4x1x4x1xf32> to vector<1x1x4x1x4x1xf32>
+// CHECK: %[[R_SIMD:.+]] = iree_vector_ext.to_simd %[[B_OUT]] : vector<1x1x4x1x4x1xf32> -> vector<32x32xf32>
+// CHECK: return {{.*}} %[[R_SIMD]]
// -----
@@ -120,8 +115,6 @@
threads_per_outer = [4, 16],
elements_per_thread = [4, 1],
- element_order = [1, 0],
-
subgroup_basis = [1, 1],
thread_basis = [4, 16]
>
@@ -152,22 +145,22 @@
// CHECK-LABEL: func @contract_to_mfma_16x16x16_mm
// CHECK-SAME: (%[[A:.+]]: vector<16x16xf16>, %[[B:.+]]: vector<16x16xf16>, %[[C:.+]]: vector<16x16xf32>)
-// CHECK: %[[INIT:.+]] = arith.constant dense<0.000000e+00> : vector<1x1x1x1x1x4xf32>
-// CHECK: %[[C_SIMT:.+]] = iree_vector_ext.to_simt %[[C]] : vector<16x16xf32> -> vector<1x1x1x1x1x4xf32>
+// CHECK: %[[C_SIMT:.+]] = iree_vector_ext.to_simt %[[C]] : vector<16x16xf32> -> vector<1x1x1x1x4x1xf32>
// CHECK: %[[A_SIMT:.+]] = iree_vector_ext.to_simt %[[A]] : vector<16x16xf16> -> vector<1x1x1x1x1x4xf16>
-// CHECK: %[[B_SIMT:.+]] = iree_vector_ext.to_simt %[[B]] : vector<16x16xf16> -> vector<1x1x1x1x1x4xf16>
-// CHECK: %[[C_VEC:.+]] = vector.extract %[[C_SIMT]][0, 0] : vector<1x1x1x4xf32> from vector<1x1x1x1x1x4xf32>
+// CHECK: %[[B_SIMT:.+]] = iree_vector_ext.to_simt %[[B]] : vector<16x16xf16> -> vector<1x1x1x1x4x1xf16>
+// CHECK: %[[C_VEC:.+]] = vector.extract %[[C_SIMT]][0, 0] : vector<1x1x4x1xf32> from vector<1x1x1x1x4x1xf32>
// CHECK: %[[A_VEC:.+]] = vector.extract %[[A_SIMT]][0, 0] : vector<1x1x1x4xf16> from vector<1x1x1x1x1x4xf16>
-// CHECK: %[[B_VEC:.+]] = vector.extract %[[B_SIMT]][0, 0] : vector<1x1x1x4xf16> from vector<1x1x1x1x1x4xf16>
+// CHECK: %[[B_VEC:.+]] = vector.extract %[[B_SIMT]][0, 0] : vector<1x1x4x1xf16> from vector<1x1x1x1x4x1xf16>
// CHECK: %[[A_CAST:.+]] = vector.shape_cast %[[A_VEC]] : vector<1x1x1x4xf16> to vector<4xf16>
-// CHECK: %[[B_CAST:.+]] = vector.shape_cast %[[B_VEC]] : vector<1x1x1x4xf16> to vector<4xf16>
-// CHECK: %[[C_CAST:.+]] = vector.shape_cast %[[C_VEC]] : vector<1x1x1x4xf32> to vector<4xf32>
+// CHECK: %[[B_CAST:.+]] = vector.shape_cast %[[B_VEC]] : vector<1x1x4x1xf16> to vector<4xf16>
+// CHECK: %[[C_CAST:.+]] = vector.shape_cast %[[C_VEC]] : vector<1x1x4x1xf32> to vector<4xf32>
// CHECK: %[[MFMA:.+]] = amdgpu.mfma %[[A_CAST]] * %[[B_CAST]] + %[[C_CAST]]
// CHECK-SAME: {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp = none
// CHECK-SAME: : vector<4xf16>, vector<4xf16>, vector<4xf32>
-// CHECK: %[[R_CAST:.+]] = vector.shape_cast %[[MFMA]] : vector<4xf32> to vector<1x1x1x4xf32>
-// CHECK: %[[INSERT:.+]] = vector.insert %[[R_CAST]], %[[INIT]] [0, 0] : vector<1x1x1x4xf32> into vector<1x1x1x1x1x4xf32>
-// CHECK: %[[R_SIMD:.+]] = iree_vector_ext.to_simd %[[INSERT]] : vector<1x1x1x1x1x4xf32> -> vector<16x16xf32>
+
+// CHECK: %[[R_CAST:.+]] = vector.shape_cast %[[MFMA]] : vector<4xf32> to vector<1x1x4x1xf32>
+// CHECK: %[[B_OUT:.*]] = vector.broadcast %[[R_CAST]] : vector<1x1x4x1xf32> to vector<1x1x1x1x4x1xf32>
+// CHECK: %[[R_SIMD:.+]] = iree_vector_ext.to_simd %[[B_OUT]] : vector<1x1x1x1x4x1xf32> -> vector<16x16xf32>
// CHECK: return {{.*}} %[[R_SIMD]]
// -----
@@ -200,8 +193,6 @@
threads_per_outer = [2, 32],
elements_per_thread = [4, 1],
- element_order = [1, 0],
-
subgroup_basis = [1, 1],
thread_basis = [2, 32]
>
@@ -214,8 +205,6 @@
threads_per_outer = [2, 32],
elements_per_thread = [4, 1],
- element_order = [1, 0],
-
subgroup_basis = [1, 1],
thread_basis = [2, 32]
>
@@ -244,23 +233,23 @@
// CHECK-LABEL: func.func @contract_to_mfma_32x32x8_mm_mnbatch
// CHECK: %[[INIT:.+]] = arith.constant dense<0.000000e+00>
-// CHECK: %[[C_SIMT:.+]] = iree_vector_ext.to_simt %{{.+}} : vector<64x32xf32> -> vector<2x1x4x1x1x4xf32>
+// CHECK: %[[C_SIMT:.+]] = iree_vector_ext.to_simt %{{.+}} : vector<64x32xf32> -> vector<2x1x4x1x4x1xf32>
// CHECK: %[[A_SIMT:.+]] = iree_vector_ext.to_simt %{{.+}} : vector<64x8xf16> -> vector<2x1x1x1x1x4xf16>
-// CHECK: %[[C_SLICE0:.+]] = vector.extract %[[C_SIMT]][0, 0] : vector<4x1x1x4xf32> from vector<2x1x4x1x1x4xf32>
+// CHECK: %[[C_SLICE0:.+]] = vector.extract %[[C_SIMT]][0, 0] : vector<4x1x4x1xf32> from vector<2x1x4x1x4x1xf32
// CHECK: %[[A_SLICE0:.+]] = vector.extract %[[A_SIMT]][0, 0] : vector<1x1x1x4xf16> from vector<2x1x1x1x1x4xf16>
// CHECK: %[[A0_CAST:.+]] = vector.shape_cast %[[A_SLICE0]] : vector<1x1x1x4xf16> to vector<4xf16>
-// CHECK: %[[C0_CAST:.+]] = vector.shape_cast %[[C_SLICE0]] : vector<4x1x1x4xf32> to vector<16xf32>
+// CHECK: %[[C0_CAST:.+]] = vector.shape_cast %[[C_SLICE0]] : vector<4x1x4x1xf32> to vector<16xf32>
// CHECK: %[[MFMA0:.+]] = amdgpu.mfma %[[A0_CAST]] * %{{.+}} + %[[C0_CAST]]
-// CHECK: %[[R0_CAST:.+]] = vector.shape_cast %[[MFMA0]] : vector<16xf32> to vector<4x1x1x4xf32>
-// CHECK: %[[C0_INS:.+]] = vector.insert %[[R0_CAST]], %[[INIT]] [0, 0] : vector<4x1x1x4xf32> into vector<2x1x4x1x1x4xf32>
-// CHECK: %[[C_SLICE1:.+]] = vector.extract %[[C_SIMT]][1, 0] : vector<4x1x1x4xf32> from vector<2x1x4x1x1x4xf32>
+// CHECK: %[[R0_CAST:.+]] = vector.shape_cast %[[MFMA0]] : vector<16xf32> to vector<4x1x4x1xf32>
+// CHECK: %[[C0_INS:.+]] = vector.insert %[[R0_CAST]], %[[INIT]] [0, 0] : vector<4x1x4x1xf32> into vector<2x1x4x1x4x1xf32>
+// CHECK: %[[C_SLICE1:.+]] = vector.extract %[[C_SIMT]][1, 0] : vector<4x1x4x1xf32> from vector<2x1x4x1x4x1xf32>
// CHECK: %[[A_SLICE1:.+]] = vector.extract %[[A_SIMT]][1, 0] : vector<1x1x1x4xf16> from vector<2x1x1x1x1x4xf16>
// CHECK: %[[A1_CAST:.+]] = vector.shape_cast %[[A_SLICE1]] : vector<1x1x1x4xf16> to vector<4xf16>
-// CHECK: %[[C1_CAST:.+]] = vector.shape_cast %[[C_SLICE1]] : vector<4x1x1x4xf32> to vector<16xf32>
+// CHECK: %[[C1_CAST:.+]] = vector.shape_cast %[[C_SLICE1]] : vector<4x1x4x1xf32> to vector<16xf32>
// CHECK: %[[MFMA1:.+]] = amdgpu.mfma %[[A1_CAST]] * %{{.+}} + %[[C1_CAST]]
-// CHECK: %[[R1_CAST:.+]] = vector.shape_cast %[[MFMA1]] : vector<16xf32> to vector<4x1x1x4xf32>
-// CHECK: %[[C1_INS:.+]] = vector.insert %[[R1_CAST]], %[[C0_INS]] [1, 0] : vector<4x1x1x4xf32> into vector<2x1x4x1x1x4xf32>
-// CHECK: %[[R:.+]] = iree_vector_ext.to_simd %[[C1_INS]] : vector<2x1x4x1x1x4xf32> -> vector<64x32xf32>
+// CHECK: %[[R1_CAST:.+]] = vector.shape_cast %[[MFMA1]] : vector<16xf32> to vector<4x1x4x1xf32>
+// CHECK: %[[C1_INS:.+]] = vector.insert %[[R1_CAST]], %[[C0_INS]] [1, 0] : vector<4x1x4x1xf32> into vector<2x1x4x1x4x1xf32>
+// CHECK: %[[R:.+]] = iree_vector_ext.to_simd %[[C1_INS]] : vector<2x1x4x1x4x1xf32> -> vector<64x32xf32>
// CHECK: return {{.*}}} %[[R]]
// -----
@@ -293,8 +282,6 @@
threads_per_outer = [2, 32],
elements_per_thread = [4, 1],
- element_order = [1, 0],
-
subgroup_basis = [1, 1],
thread_basis = [2, 32]
>
@@ -307,8 +294,6 @@
threads_per_outer = [2, 32],
elements_per_thread = [4, 1],
- element_order = [1, 0],
-
subgroup_basis = [1, 1],
thread_basis = [2, 32]
>
@@ -336,22 +321,19 @@
}
// CHECK-LABEL: func.func @contract_to_mfma_32x32x8_mm_kbatch(%arg0: vector<32x16xf16>, %arg1: vector<16x32xf16>, %arg2: vector<32x32xf32>) -> vector<32x32xf32> {
-// CHECK: %[[A_SIMT:.+]] = iree_vector_ext.to_simt %{{.+}} : vector<32x16xf16> -> vector<1x2x1x1x1x4xf16>
-// CHECK: %[[B_SIMT:.+]] = iree_vector_ext.to_simt %{{.+}} : vector<16x32xf16> -> vector<2x1x1x1x1x4xf16>
-// CHECK: %[[A_SLICE0:.+]] = vector.extract %[[A_SIMT]][0, 0] : vector<1x1x1x4xf16> from vector<1x2x1x1x1x4xf16>
-// CHECK: %[[B_SLICE0:.+]] = vector.extract %[[B_SIMT]][0, 0] : vector<1x1x1x4xf16> from vector<2x1x1x1x1x4xf16>
-// CHECK: %[[A0_CAST:.+]] = vector.shape_cast %[[A_SLICE0]] : vector<1x1x1x4xf16> to vector<4xf16>
-// CHECK: %[[B0_CAST:.+]] = vector.shape_cast %[[B_SLICE0]] : vector<1x1x1x4xf16> to vector<4xf16>
+// CHECK: %[[A_SIMT:.+]] = iree_vector_ext.to_simt %{{.+}} : vector<32x16xf16>
+// CHECK: %[[B_SIMT:.+]] = iree_vector_ext.to_simt %{{.+}} : vector<16x32xf16>
+// CHECK: %[[A_SLICE0:.+]] = vector.extract %[[A_SIMT]][0, 0]
+// CHECK: %[[B_SLICE0:.+]] = vector.extract %[[B_SIMT]][0, 0]
+// CHECK: %[[A0_CAST:.+]] = vector.shape_cast %[[A_SLICE0]]
+// CHECK: %[[B0_CAST:.+]] = vector.shape_cast %[[B_SLICE0]]
// CHECK: %[[MFMA0:.+]] = amdgpu.mfma %[[A0_CAST]] * %[[B0_CAST]] + %{{.+}}
-// CHECK: %[[A_SLICE1:.+]] = vector.extract %[[A_SIMT]][0, 1] : vector<1x1x1x4xf16> from vector<1x2x1x1x1x4xf16>
-// CHECK: %[[B_SLICE1:.+]] = vector.extract %[[B_SIMT]][1, 0] : vector<1x1x1x4xf16> from vector<2x1x1x1x1x4xf16>
-// CHECK: %[[A1_CAST:.+]] = vector.shape_cast %[[A_SLICE1]] : vector<1x1x1x4xf16> to vector<4xf16>
-// CHECK: %[[B1_CAST:.+]] = vector.shape_cast %[[B_SLICE1]] : vector<1x1x1x4xf16> to vector<4xf16>
+// CHECK: %[[A_SLICE1:.+]] = vector.extract %[[A_SIMT]][0, 1]
+// CHECK: %[[B_SLICE1:.+]] = vector.extract %[[B_SIMT]][1, 0]
+// CHECK: %[[A1_CAST:.+]] = vector.shape_cast %[[A_SLICE1]]
+// CHECK: %[[B1_CAST:.+]] = vector.shape_cast %[[B_SLICE1]]
// CHECK: %[[MFMA1:.+]] = amdgpu.mfma %[[A1_CAST]] * %[[B1_CAST]] + %[[MFMA0]]
-// CHECK: %[[R_CAST:.+]] = vector.shape_cast %[[MFMA1]] : vector<16xf32> to vector<4x1x1x4xf32>
-// CHECK: %[[INSERT:.+]] = vector.insert %[[R_CAST]], %{{.+}} [0, 0] : vector<4x1x1x4xf32> into vector<1x1x4x1x1x4xf32>
-// CHECK: %[[R:.+]] = iree_vector_ext.to_simd %[[INSERT]] : vector<1x1x4x1x1x4xf32> -> vector<32x32xf32>
-// CHECK: return {{.*}} %[[R]] : vector<32x32xf32>
+// CHECK: %[[R_CAST:.+]] = vector.shape_cast %[[MFMA1]]
// -----
@@ -383,8 +365,6 @@
threads_per_outer = [2, 32],
elements_per_thread = [4, 1],
- element_order = [1, 0],
-
subgroup_basis = [1, 1],
thread_basis = [2, 32]
>
@@ -397,9 +377,6 @@
threads_per_outer = [2, 32],
elements_per_thread = [4, 1],
- batch_order = [1, 0],
- element_order = [1, 0],
-
subgroup_basis = [1, 1],
thread_basis = [2, 32]
>
@@ -427,27 +404,27 @@
}
// CHECK-LABEL: func.func @contract_to_mfma_32x32x8_mm_mnbatch_order
-// CHECK: %[[INIT:.+]] = arith.constant dense<0.000000e+00> : vector<3x2x4x1x1x4xf32>
-// CHECK: %[[C_SIMT:.+]] = iree_vector_ext.to_simt %{{.+}} : vector<64x96xf32> -> vector<3x2x4x1x1x4xf32>
-// CHECK: vector.extract %[[C_SIMT]][0, 0] : vector<4x1x1x4xf32> from vector<3x2x4x1x1x4xf32>
+// CHECK: %[[INIT:.+]] = arith.constant dense<0.000000e+00> : vector<2x3x4x1x4x1xf32>
+// CHECK: %[[C_SIMT:.+]] = iree_vector_ext.to_simt %{{.+}} : vector<64x96xf32> -> vector<2x3x4x1x4x1xf32>
+// CHECK: vector.extract %[[C_SIMT]][0, 0]
// CHECK: amdgpu.mfma
-// CHECK: %[[INS0:.+]] = vector.insert %{{.+}}, %[[INIT]] [0, 0] : vector<4x1x1x4xf32> into vector<3x2x4x1x1x4xf32>
-// CHECK: vector.extract %[[C_SIMT]][0, 1] : vector<4x1x1x4xf32> from vector<3x2x4x1x1x4xf32>
+// CHECK: %[[INS0:.+]] = vector.insert %{{.+}}, %[[INIT]] [0, 0]
+// CHECK: vector.extract %[[C_SIMT]][0, 1]
// CHECK: amdgpu.mfma
-// CHECK: %[[INS1:.+]] = vector.insert %{{.+}}, %[[INS0]] [0, 1] : vector<4x1x1x4xf32> into vector<3x2x4x1x1x4xf32>
-// CHECK: vector.extract %[[C_SIMT]][1, 0] : vector<4x1x1x4xf32> from vector<3x2x4x1x1x4xf32>
+// CHECK: %[[INS1:.+]] = vector.insert %{{.+}}, %[[INS0]] [0, 1]
+// CHECK: vector.extract %[[C_SIMT]][0, 2]
// CHECK: amdgpu.mfma
-// CHECK: %[[INS2:.+]] = vector.insert %{{.+}}, %[[INS1]] [1, 0] : vector<4x1x1x4xf32> into vector<3x2x4x1x1x4xf32>
-// CHECK: vector.extract %[[C_SIMT]][1, 1] : vector<4x1x1x4xf32> from vector<3x2x4x1x1x4xf32>
+// CHECK: %[[INS2:.+]] = vector.insert %{{.+}}, %[[INS1]] [0, 2]
+// CHECK: vector.extract %[[C_SIMT]][1, 0]
// CHECK: amdgpu.mfma
-// CHECK: %[[INS3:.+]] = vector.insert %{{.+}}, %[[INS2]] [1, 1] : vector<4x1x1x4xf32> into vector<3x2x4x1x1x4xf32>
-// CHECK: vector.extract %[[C_SIMT]][2, 0] : vector<4x1x1x4xf32> from vector<3x2x4x1x1x4xf32>
+// CHECK: %[[INS3:.+]] = vector.insert %{{.+}}, %[[INS2]] [1, 0]
+// CHECK: vector.extract %[[C_SIMT]][1, 1]
// CHECK: amdgpu.mfma
-// CHECK: %[[INS4:.+]] = vector.insert %{{.+}}, %[[INS3]] [2, 0] : vector<4x1x1x4xf32> into vector<3x2x4x1x1x4xf32>
-// CHECK: vector.extract %[[C_SIMT]][2, 1] : vector<4x1x1x4xf32> from vector<3x2x4x1x1x4xf32>
+// CHECK: %[[INS4:.+]] = vector.insert %{{.+}}, %[[INS3]] [1, 1]
+// CHECK: vector.extract %[[C_SIMT]][1, 2]
// CHECK: amdgpu.mfma
-// CHECK: %[[INS5:.+]] = vector.insert %{{.+}}, %[[INS4]] [2, 1] : vector<4x1x1x4xf32> into vector<3x2x4x1x1x4xf32>
-// CHECK: iree_vector_ext.to_simd %[[INS5]] : vector<3x2x4x1x1x4xf32> -> vector<64x96xf32>
+// CHECK: %[[INS5:.+]] = vector.insert %{{.+}}, %[[INS4]] [1, 2]
+// CHECK: iree_vector_ext.to_simd %[[INS5]]
// -----
@@ -493,8 +470,6 @@
threads_per_outer = [2, 32],
elements_per_thread = [4, 1],
- element_order = [1, 0],
-
subgroup_basis = [1, 1],
thread_basis = [2, 32]
>
@@ -522,15 +497,15 @@
}
// CHECK-LABEL: func.func @contract_to_mfma_32x32x8_mmt
-// CHECK: %[[INIT:.+]] = arith.constant dense<0.000000e+00> : vector<1x2x4x1x1x4xf32>
+// CHECK: %[[INIT:.+]] = arith.constant dense<0.000000e+00> : vector<1x2x4x1x4x1xf32>
// CHECK: %[[B_SIMT:.+]] = iree_vector_ext.to_simt %{{.+}} : vector<64x8xf16> -> vector<2x1x1x1x1x4xf16>
-// CHECK: vector.extract %[[B_SIMT]][0, 0] : vector<1x1x1x4xf16> from vector<2x1x1x1x1x4xf16>
+// CHECK: vector.extract %[[B_SIMT]][0, 0]
// CHECK: amdgpu.mfma
-// CHECK: %[[INS0:.+]] = vector.insert %{{.+}}, %[[INIT]] [0, 0] : vector<4x1x1x4xf32> into vector<1x2x4x1x1x4xf32>
-// CHECK: vector.extract %[[B_SIMT]][1, 0] : vector<1x1x1x4xf16> from vector<2x1x1x1x1x4xf16>
+// CHECK: %[[INS0:.+]] = vector.insert %{{.+}}, %[[INIT]] [0, 0]
+// CHECK: vector.extract %[[B_SIMT]][1, 0]
// CHECK: amdgpu.mfma
-// CHECK: %[[INS1:.+]] = vector.insert %17, %[[INS0]] [0, 1] : vector<4x1x1x4xf32> into vector<1x2x4x1x1x4xf32>
-// CHECK: iree_vector_ext.to_simd %[[INS1]] : vector<1x2x4x1x1x4xf32> -> vector<32x64xf32>
+// CHECK: %[[INS1:.+]] = vector.insert %17, %[[INS0]] [0, 1]
+// CHECK: iree_vector_ext.to_simd %[[INS1]] : vector<1x2x4x1x4x1xf32> -> vector<32x64xf32>
// -----
@@ -562,8 +537,6 @@
threads_per_outer = [1, 16],
elements_per_thread = [16, 1],
- element_order = [1, 0],
-
subgroup_basis = [1, 1],
thread_basis = [1, 32]
>
@@ -576,8 +549,6 @@
threads_per_outer = [2, 16],
elements_per_thread = [1, 1],
- element_order = [1, 0],
-
subgroup_basis = [1, 1],
thread_basis = [2, 16]
>
@@ -607,19 +578,17 @@
// CHECK-LABEL: func.func @contract_to_wmma_16x16x16_mm
// CHECK-SAME: (%[[A:.+]]: vector<16x16xf16>, %[[B:.+]]: vector<16x16xf16>, %[[C:.+]]: vector<16x16xf32>)
-// CHECK: %[[INIT:.+]] = arith.constant dense<0.000000e+00> : vector<1x1x8x1x1x1xf32>
// CHECK: %[[C_SIMT:.+]] = iree_vector_ext.to_simt %[[C]] : vector<16x16xf32> -> vector<1x1x8x1x1x1xf32>
// CHECK: %[[A_SIMT:.+]] = iree_vector_ext.to_simt %[[A]] : vector<16x16xf16> -> vector<1x1x1x1x1x16xf16>
-// CHECK: %[[B_SIMT:.+]] = iree_vector_ext.to_simt %[[B]] : vector<16x16xf16> -> vector<1x1x1x1x1x16xf16>
+// CHECK: %[[B_SIMT:.+]] = iree_vector_ext.to_simt %[[B]] : vector<16x16xf16> -> vector<1x1x1x1x16x1xf16>
// CHECK: %[[C_VEC:.+]] = vector.extract %[[C_SIMT]][0, 0] : vector<8x1x1x1xf32> from vector<1x1x8x1x1x1xf32>
// CHECK: %[[A_VEC:.+]] = vector.extract %[[A_SIMT]][0, 0] : vector<1x1x1x16xf16> from vector<1x1x1x1x1x16xf16>
-// CHECK: %[[B_VEC:.+]] = vector.extract %[[B_SIMT]][0, 0] : vector<1x1x1x16xf16> from vector<1x1x1x1x1x16xf16>
-// CHECK: %[[A_CAST:.+]] = vector.shape_cast %[[A_VEC]] : vector<1x1x1x16xf16> to vector<16xf16>
-// CHECK: %[[B_CAST:.+]] = vector.shape_cast %[[B_VEC]] : vector<1x1x1x16xf16> to vector<16xf16>
+// CHECK: %[[B_VEC:.+]] = vector.extract %[[B_SIMT]][0, 0] : vector<1x1x16x1xf16> from vector<1x1x1x1x16x1xf16>
+// CHECK: %[[A_CAST:.+]] = vector.shape_cast %[[A_VEC]] : vector<1x1x1x16xf16> to vector<16xf1
+// CHECK: %[[B_CAST:.+]] = vector.shape_cast %[[B_VEC]] : vector<1x1x16x1xf16> to vector<16xf1
// CHECK: %[[C_CAST:.+]] = vector.shape_cast %[[C_VEC]] : vector<8x1x1x1xf32> to vector<8xf32>
// CHECK: %[[WMMA:.+]] = amdgpu.wmma %[[A_CAST]] * %[[B_CAST]] + %[[C_CAST]]
// CHECK: %[[R_CAST:.+]] = vector.shape_cast %[[WMMA]] : vector<8xf32> to vector<8x1x1x1xf32>
-// CHECK: %[[INSERT:.+]] = vector.insert %[[R_CAST]], %[[INIT]] [0, 0] : vector<8x1x1x1xf32> into vector<1x1x8x1x1x1xf32>
-// CHECK: %[[R_SIMD:.+]] = iree_vector_ext.to_simd %[[INSERT]] : vector<1x1x8x1x1x1xf32> -> vector<16x16xf32>
+// CHECK: %[[B_OUT:.*]] = vector.broadcast %[[R_CAST]] : vector<8x1x1x1xf32> to vector<1x1x8x1x1x1xf32>
+// CHECK: %[[R_SIMD:.+]] = iree_vector_ext.to_simd %[[B_OUT]] : vector<1x1x8x1x1x1xf32> -> vector<16x16xf32>
// CHECK: return {{.*}} %[[R_SIMD]]
-
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir
index d93b3c2..41002f8 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir
@@ -1,54 +1,5 @@
// RUN: iree-opt --iree-transform-dialect-interpreter --split-input-file --canonicalize --cse %s | FileCheck %s
-#layout_row_major = #iree_vector_ext.nested_layout<
- subgroups_per_workgroup = [1, 1],
- batches_per_subgroup = [2, 2],
- outers_per_batch = [1, 1],
- threads_per_outer = [8, 1],
- elements_per_thread = [1, 8],
-
- batch_order = [1, 0],
-
- subgroup_basis = [1, 1],
- thread_basis = [8, 1]
->
-
-// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 + 8)>
-// CHECK-LABEL: @distribute_transfer_read_row_major
-func.func @distribute_transfer_read_row_major(%arg0: memref<4x4xf16>) -> vector<16x16xf16> {
- %c0 = arith.constant 0 : index
- %cst = arith.constant 0.0 : f16
- %root = vector.transfer_read %arg0[%c0, %c0], %cst
- {in_bounds = [false, false],
- "__vector_layout_test_anchor_result_0" = #layout_row_major}
- : memref<4x4xf16>, vector<16x16xf16>
- func.return %root : vector<16x16xf16>
-}
-
-builtin.module attributes { transform.with_named_sequence } {
- transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
- %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
- transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
- transform.yield
- }
-}
-
-// CHECK: %[[ACC:.+]] = arith.constant dense<0.000000e+00> : vector<2x2x1x1x1x8xf16>
-// CHECK: %[[IDX:.+]] = gpu.thread_id x
-// CHECK: %[[IDS:.+]]:4 = affine.delinearize_index %[[IDX]] into (%c1, %c1, %c8, %c1) : index, index, index, index
-// CHECK: vector.transfer_read %arg0[%[[IDS]]#2, %c0], {{.*}} : memref<4x4xf16>, vector<1x8xf16>
-// CHECK: vector.insert_strided_slice %{{.*}}, %[[ACC]] {offsets = [0, 0, 0, 0, 0, 0], strides = [1, 1]} : vector<1x8xf16> into vector<2x2x1x1x1x8xf16>
-// CHECK: vector.transfer_read %arg0[%[[IDS]]#2, %c8]
-// CHECK: vector.insert_strided_slice {{.*}} {offsets = [1, 0, 0, 0, 0, 0]
-// CHECK: %[[ID_PLUS_BATCH1:.+]] = affine.apply #[[$MAP]]()[%[[IDS]]#2]
-// CHECK: vector.transfer_read %arg0[%[[ID_PLUS_BATCH1]], %c0]
-// CHECK: vector.insert_strided_slice {{.*}} {offsets = [0, 1, 0, 0, 0, 0]
-// CHECK: vector.transfer_read %arg0[%[[ID_PLUS_BATCH1]], %c8]
-// CHECK: vector.insert_strided_slice {{.*}} {offsets = [1, 1, 0, 0, 0, 0]
-// CHECK: iree_vector_ext.to_simd %{{.*}} : vector<2x2x1x1x1x8xf16> -> vector<16x16xf16>
-
-// -----
-
#layout_col_major = #iree_vector_ext.nested_layout<
subgroups_per_workgroup = [1, 1],
batches_per_subgroup = [1, 2],
@@ -56,9 +7,6 @@
threads_per_outer = [4, 8],
elements_per_thread = [4, 1],
- batch_order = [1, 0],
- element_order = [1, 0],
-
subgroup_basis = [1, 1],
thread_basis = [4, 8]
>
@@ -88,13 +36,11 @@
// CHECK: %[[IDS:.+]]:4 = affine.delinearize_index %{{.*}} into (%c1, %c1, %c4, %c8) : index, index, index, index
// CHECK: %[[LANEY:.+]] = affine.apply #[[$MAP]]()[%[[IDS]]#2]
// CHECK: %[[RD00:.+]] = vector.transfer_read %arg0[%[[LANEY:.+]], %[[IDS]]#3], {{.*}} : memref<32x32xf16>, vector<4x1xf16>
-// CHECK: %[[ELEM_ORDER:.+]] = vector.transpose %[[RD00]], [1, 0] : vector<4x1xf16> to vector<1x4xf16>
-// CHECK: vector.insert_strided_slice %[[ELEM_ORDER]], %{{.*}} {offsets = [0, 0, 0, 0, 0, 0], strides = [1, 1]} : vector<1x4xf16> into vector<2x1x1x1x1x4xf16>
+// CHECK: vector.insert_strided_slice %[[RD00]], %{{.*}} {offsets = [0, 0, 0, 0, 0, 0], strides = [1, 1]} : vector<4x1xf16> into vector<1x2x1x1x4x1xf16>
// CHECK: %[[LANEX_PLUS_BATCH:.+]] = affine.apply #[[$MAP1]]()[%[[IDS]]#3]
// CHECK: vector.transfer_read %arg0[%[[LANEY]], %[[LANEX_PLUS_BATCH]]], %{{.*}} {in_bounds = [true, true]} : memref<32x32xf16>, vector<4x1xf16>
-// CHECK: vector.transpose %{{.*}}, [1, 0] : vector<4x1xf16> to vector<1x4xf16>
-// CHECK: vector.insert_strided_slice {{.*}} {offsets = [1, 0, 0, 0, 0, 0]
-// CHECK: iree_vector_ext.to_simd %{{.*}} : vector<2x1x1x1x1x4xf16> -> vector<16x16xf16>
+// CHECK: vector.insert_strided_slice {{.*}} {offsets = [0, 1, 0, 0, 0, 0]
+// CHECK: iree_vector_ext.to_simd %{{.*}} : vector<1x2x1x1x4x1xf16> -> vector<16x16xf16>
// -----
@@ -105,8 +51,6 @@
threads_per_outer = [8, 1],
elements_per_thread = [1, 8],
- batch_order = [1, 0],
-
subgroup_basis = [1, 1],
thread_basis = [8, 1]
>
@@ -154,9 +98,6 @@
threads_per_outer = [4, 8],
elements_per_thread = [4, 1],
- batch_order = [1, 0],
- element_order = [1, 0],
-
subgroup_basis = [1, 1],
thread_basis = [4, 8]
>
@@ -186,9 +127,8 @@
// CHECK-SAME: %[[I0:.+]]: index, %[[I1:.+]]: index
// CHECK: %[[BROADCAST_READ:.+]] = vector.transfer_read %{{.*}}[%c0, %c0, %[[I0]], %[[I1]]], %{{.*}} permutation_map = #[[$MAP]]
-// CHECK: %[[UNIT:.+]] = vector.transpose %[[BROADCAST_READ]], [1, 0] : vector<4x1xf16> to vector<1x4xf16>
-// CHECK: vector.insert_strided_slice %[[UNIT]], %{{.*}} {offsets = [0, 0, 0, 0, 0, 0]
-// CHECK: vector.insert_strided_slice %[[UNIT]], %{{.*}} {offsets = [1, 0, 0, 0, 0, 0]
+// CHECK: vector.insert_strided_slice %[[BROADCAST_READ]], %{{.*}} {offsets = [0, 0, 0, 0, 0, 0]
+// CHECK: vector.insert_strided_slice %[[BROADCAST_READ]], %{{.*}} {offsets = [0, 1, 0, 0, 0, 0]
// -----
@@ -199,8 +139,6 @@
threads_per_outer = [8, 1],
elements_per_thread = [1, 8],
- batch_order = [1, 0],
-
subgroup_basis = [1, 1],
thread_basis = [8, 1]
>
@@ -289,8 +227,6 @@
elements_per_thread = [1, 1, 1, 2],
subgroup_order = [1, 0, 2, 3],
- batch_order = [1, 2, 3, 0],
- outer_order = [0, 3, 1, 2],
thread_order = [0, 1, 3, 2],
subgroup_basis = [7, 3, 1, 1],
@@ -410,8 +346,6 @@
threads_per_outer = [8, 1],
elements_per_thread = [1, 8],
- batch_order = [1, 0],
-
subgroup_basis = [1, 1],
thread_basis = [8, 1]
>
@@ -439,10 +373,10 @@
// CHECK: %[[IDS:.+]]:4 = affine.delinearize_index %{{.*}} into (%c1, %c1, %c8, %c1) : index, index, index, index
// CHECK: %[[SLICE:.+]] = vector.extract %{{.*}}[0, 0, 0, 0] : vector<1x8xf16> from vector<2x2x1x1x1x8xf16>
// CHECK: vector.transfer_write %[[SLICE]], %{{.*}}[%[[IDS]]#2, %c0] {in_bounds = [true, true]} : vector<1x8xf16>, memref<64x64xf16>
-// CHECK: vector.extract %{{.*}}[1, 0, 0, 0]
+// CHECK: vector.extract %{{.*}}[0, 1, 0, 0]
// CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[IDS]]#2, %c8]
// CHECK: %[[LANEX_PLUS_VECDIMX:.+]] = affine.apply #[[$MAP]]()[%[[IDS]]#2]
-// CHECK: vector.extract %{{.*}}[0, 1, 0, 0]
+// CHECK: vector.extract %{{.*}}[1, 0, 0, 0]
// CHECK: vector.transfer_write %{{.*}}[%[[LANEX_PLUS_VECDIMX]], %c0]
// CHECK: vector.extract %{{.*}}[1, 1, 0, 0]
// CHECK: vector.transfer_write %{{.*}}[%[[LANEX_PLUS_VECDIMX]], %c8]
@@ -456,9 +390,6 @@
threads_per_outer = [4, 8],
elements_per_thread = [4, 1],
- batch_order = [1, 0],
- element_order = [1, 0],
-
subgroup_basis = [1, 1],
thread_basis = [4, 8]
>
@@ -488,11 +419,9 @@
// CHECK: %[[IDS:.+]]:4 = affine.delinearize_index %[[TIDX]] into (%c1, %c1, %c4, %c8) : index, index, index, index
// CHECK: %[[LANEY:.+]] = affine.apply #map()[%[[IDS]]#2]
// CHECK: vector.extract %{{.*}}[0, 0, 0, 0]
-// CHECK: vector.transpose %{{.*}}, [1, 0] : vector<1x4xf16> to vector<4x1xf16>
// CHECK: vector.transfer_write %{{.*}}[%[[LANEY]], %[[IDS]]#3]
// CHECK: %[[LANEX:.+]] = affine.apply #[[$MAP1]]()[%[[IDS]]#3]
-// CHECK: vector.extract %{{.*}}[1, 0, 0, 0]
-// CHECK: vector.transpose %{{.*}}, [1, 0] : vector<1x4xf16> to vector<4x1xf16>
+// CHECK: vector.extract %{{.*}}[0, 1, 0, 0]
// CHECK: vector.transfer_write {{.*}}[%[[LANEY]], %[[LANEX]]]
// -----
@@ -504,8 +433,6 @@
threads_per_outer = [8, 1],
elements_per_thread = [1, 8],
- batch_order = [1, 0],
-
subgroup_basis = [1, 1],
thread_basis = [8, 1]
>
@@ -541,10 +468,10 @@
// CHECK: vector.extract %{{.*}}[0, 0, 0, 0]
// CHECK: vector.transfer_write %{{.*}}[%c0, %c0, %[[I0]], %[[LIN_ID0]]] {{.*}} permutation_map = #[[$MAP1]]
// CHECK: %[[LIN_ID1:.+]] = affine.apply #[[$MAP2]]()[%[[I0]]]
-// CHECK: vector.extract %{{.*}}[1, 0, 0, 0]
+// CHECK: vector.extract %{{.*}}[0, 1, 0, 0]
// CHECK: vector.transfer_write %{{.*}}[%c0, %c0, %[[LIN_ID1]], %[[LIN_ID0]]] {{.*}} permutation_map = #[[$MAP1]]
// CHECK: %[[LIN_ID2:.+]] = affine.apply #[[$MAP3]]()[%[[IDS]]#2, %[[I1]]]
-// CHECK: vector.extract %{{.*}}[0, 1, 0, 0]
+// CHECK: vector.extract %{{.*}}[1, 0, 0, 0]
// CHECK: vector.transfer_write %{{.*}}[%c0, %c0, %[[I0]], %[[LIN_ID2]]] {{.*}} permutation_map = #[[$MAP1]]
// CHECK: vector.extract %{{.*}}[1, 1, 0, 0]
// CHECK: vector.transfer_write %{{.*}}[%c0, %c0, %[[LIN_ID1]], %[[LIN_ID2]]] {{.*}} permutation_map = #[[$MAP1]]
@@ -558,8 +485,6 @@
threads_per_outer = [8, 1],
elements_per_thread = [1, 8],
- batch_order = [1, 0],
-
subgroup_basis = [1, 1],
thread_basis = [8, 1]
>
@@ -623,8 +548,6 @@
threads_per_outer = [2, 32],
elements_per_thread = [4, 1],
- element_order = [1, 0],
-
subgroup_basis = [4, 2],
thread_basis = [2, 32]
>
@@ -637,8 +560,6 @@
threads_per_outer = [2, 32],
elements_per_thread = [4, 1],
- element_order = [1, 0],
-
subgroup_basis = [4, 2],
thread_basis = [2, 32]
>
@@ -715,10 +636,7 @@
elements_per_thread = [1, 4],
subgroup_order = [1, 0],
- batch_order = [1, 0],
- outer_order = [1, 0],
thread_order = [1, 0],
- element_order = [0, 1],
subgroup_basis = [4, 2],
thread_basis = [2, 32]
@@ -767,7 +685,6 @@
outers_per_batch = [1, 1],
threads_per_outer = [4, 16],
elements_per_thread = [4, 1],
- element_order = [1, 0],
subgroup_basis = [2, 2],
thread_basis = [4, 16]
>
@@ -787,21 +704,21 @@
}
// CHECK: vector.extract {{.*}}[0, 0] : vector<1xf16> from vector<4x1x1xf16>
-// CHECK: vector.broadcast {{.*}} : vector<1xf16> to vector<1x4xf16>
-// CHECK: vector.insert {{.*}} [0, 0, 0, 0] : vector<1x4xf16> into vector<2x4x1x1x1x4xf16>
+// CHECK: vector.broadcast {{.*}} : vector<1xf16> to vector<4x1xf16>
+// CHECK: vector.insert {{.*}} [0, 0, 0, 0] : vector<4x1xf16> into vector<2x4x1x1x4x1xf16>
// CHECK: vector.extract {{.*}}[1, 0] : vector<1xf16> from vector<4x1x1xf16>
-// CHECK: vector.broadcast {{.*}} : vector<1xf16> to vector<1x4xf16>
-// CHECK: vector.insert {{.*}} [0, 1, 0, 0] : vector<1x4xf16> into vector<2x4x1x1x1x4xf16>
+// CHECK: vector.broadcast {{.*}} : vector<1xf16> to vector<4x1xf16>
+// CHECK: vector.insert {{.*}} [0, 1, 0, 0] : vector<4x1xf16> into vector<2x4x1x1x4x1xf16>
// CHECK: vector.extract {{.*}}[2, 0] : vector<1xf16> from vector<4x1x1xf16>
-// CHECK: vector.broadcast {{.*}} : vector<1xf16> to vector<1x4xf16>
-// CHECK: vector.insert {{.*}} [0, 2, 0, 0] : vector<1x4xf16> into vector<2x4x1x1x1x4xf16>
+// CHECK: vector.broadcast {{.*}} : vector<1xf16> to vector<4x1xf16>
+// CHECK: vector.insert {{.*}} [0, 2, 0, 0] : vector<4x1xf16> into vector<2x4x1x1x4x1xf16>
// CHECK: vector.extract {{.*}}[3, 0] : vector<1xf16> from vector<4x1x1xf16>
-// CHECK: vector.broadcast {{.*}} : vector<1xf16> to vector<1x4xf16>
+// CHECK: vector.broadcast {{.*}} : vector<1xf16> to vector<4x1xf16>
-// CHECK: vector.insert {{.*}} [1, 0, 0, 0] : vector<1x4xf16> into vector<2x4x1x1x1x4xf16>
-// CHECK: vector.insert {{.*}} [1, 1, 0, 0] : vector<1x4xf16> into vector<2x4x1x1x1x4xf16>
-// CHECK: vector.insert {{.*}} [1, 2, 0, 0] : vector<1x4xf16> into vector<2x4x1x1x1x4xf16>
-// CHECK: vector.insert {{.*}} [1, 3, 0, 0] : vector<1x4xf16> into vector<2x4x1x1x1x4xf16>
+// CHECK: vector.insert {{.*}} [1, 0, 0, 0] : vector<4x1xf16> into vector<2x4x1x1x4x1xf16>
+// CHECK: vector.insert {{.*}} [1, 1, 0, 0] : vector<4x1xf16> into vector<2x4x1x1x4x1xf16>
+// CHECK: vector.insert {{.*}} [1, 2, 0, 0] : vector<4x1xf16> into vector<2x4x1x1x4x1xf16>
+// CHECK: vector.insert {{.*}} [1, 3, 0, 0] : vector<4x1xf16> into vector<2x4x1x1x4x1xf16>
// -----
@@ -811,7 +728,6 @@
outers_per_batch = [2, 1, 1],
threads_per_outer = [4, 16, 8],
elements_per_thread = [1, 4, 4],
- batch_order = [2, 1, 0],
subgroup_basis = [2, 2, 2],
thread_basis = [4, 16, 8]
>
@@ -832,14 +748,14 @@
// CHECK: %[[EXTRACT:.+]] = vector.extract %{{.*}}[0, 0] : vector<4xf16> from vector<1x1x4xf16>
// CHECK: %[[BCAST:.+]] = vector.broadcast %[[EXTRACT]] : vector<4xf16> to vector<1x4x4xf16>
-// CHECK: vector.insert %[[BCAST]], %{{.*}} [0, 0, 0, 0, 0, 0] : vector<1x4x4xf16> into vector<1x2x2x2x1x1x1x4x4xf16>
-// CHECK: vector.insert %[[BCAST]], %{{.*}} [0, 0, 0, 1, 0, 0] : vector<1x4x4xf16> into vector<1x2x2x2x1x1x1x4x4xf16>
-// CHECK: vector.insert %[[BCAST]], %{{.*}} [0, 1, 0, 0, 0, 0] : vector<1x4x4xf16> into vector<1x2x2x2x1x1x1x4x4xf16>
-// CHECK: vector.insert %[[BCAST]], %{{.*}} [0, 1, 0, 1, 0, 0] : vector<1x4x4xf16> into vector<1x2x2x2x1x1x1x4x4xf16>
-// CHECK: vector.insert %[[BCAST]], %{{.*}} [0, 0, 1, 0, 0, 0] : vector<1x4x4xf16> into vector<1x2x2x2x1x1x1x4x4xf16>
-// CHECK: vector.insert %[[BCAST]], %{{.*}} [0, 0, 1, 1, 0, 0] : vector<1x4x4xf16> into vector<1x2x2x2x1x1x1x4x4xf16>
-// CHECK: vector.insert %[[BCAST]], %{{.*}} [0, 1, 1, 0, 0, 0] : vector<1x4x4xf16> into vector<1x2x2x2x1x1x1x4x4xf16>
-// CHECK: vector.insert %[[BCAST]], %{{.*}} [0, 1, 1, 1, 0, 0] : vector<1x4x4xf16> into vector<1x2x2x2x1x1x1x4x4xf16>
+// CHECK: vector.insert %[[BCAST]], %{{.*}} [0, 0, 0, 0, 0, 0] : vector<1x4x4xf16> into vector<2x2x1x2x1x1x1x4x4xf16>
+// CHECK: vector.insert %[[BCAST]], %{{.*}} [0, 0, 0, 1, 0, 0] : vector<1x4x4xf16> into vector<2x2x1x2x1x1x1x4x4xf16>
+// CHECK: vector.insert %[[BCAST]], %{{.*}} [0, 1, 0, 0, 0, 0] : vector<1x4x4xf16> into vector<2x2x1x2x1x1x1x4x4xf16>
+// CHECK: vector.insert %[[BCAST]], %{{.*}} [0, 1, 0, 1, 0, 0] : vector<1x4x4xf16> into vector<2x2x1x2x1x1x1x4x4xf16>
+// CHECK: vector.insert %[[BCAST]], %{{.*}} [1, 0, 0, 0, 0, 0] : vector<1x4x4xf16> into vector<2x2x1x2x1x1x1x4x4xf16>
+// CHECK: vector.insert %[[BCAST]], %{{.*}} [1, 0, 0, 1, 0, 0] : vector<1x4x4xf16> into vector<2x2x1x2x1x1x1x4x4xf16>
+// CHECK: vector.insert %[[BCAST]], %{{.*}} [1, 1, 0, 0, 0, 0] : vector<1x4x4xf16> into vector<2x2x1x2x1x1x1x4x4xf16>
+// CHECK: vector.insert %[[BCAST]], %{{.*}} [1, 1, 0, 1, 0, 0] : vector<1x4x4xf16> into vector<2x2x1x2x1x1x1x4x4xf16>
// -----
@@ -869,7 +785,8 @@
}
// CHECK-LABEL: func @transpose
-// CHECK: iree_vector_ext.to_simt %{{.*}} : vector<256x64xf16> -> vector<2x4x2x1x2x2xf16>
+// CHECK: iree_vector_ext.to_simt %{{.*}} : vector<256x64xf16> -> vector<4x2x1x2x2x2xf16>
+// CHECK: vector.transpose %{{.*}}, [1, 0, 3, 2, 5, 4] : vector<4x2x1x2x2x2xf16> to vector<2x4x2x1x2x2xf16>
// CHECK: math.sqrt %{{.*}} : vector<2x4x2x1x2x2xf16>
// CHECK: iree_vector_ext.to_simd %{{.*}} : vector<2x4x2x1x2x2xf16> -> vector<64x256xf16>
@@ -907,17 +824,15 @@
// CHECK-SAME: threads_per_outer = [16, 4]
// CHECK-SAME: elements_per_thread = [2, 2]
// CHECK-SAME: subgroup_order = [1, 0]
-// CHECK-SAME: batch_order = [1, 0],
-// CHECK-SAME: outer_order = [1, 0]
// CHECK-SAME: thread_order = [1, 0]
-// CHECK-SAME: element_order = [1, 0]
// CHECK-SAME: subgroup_basis = [2, 2]
// CHECK-SAME: thread_basis = [4, 16]
// CHECK-LABEL: func @transpose
// CHECK: iree_vector_ext.to_simt %{{.*}} : vector<64x256xf16> -> vector<2x4x2x1x2x2xf16>
-// CHECK: math.sqrt %{{.*}} : vector<2x4x2x1x2x2xf16>
-// CHECK: iree_vector_ext.to_simd %{{.*}} : vector<2x4x2x1x2x2xf16> -> vector<256x64xf16>
+// CHECK: vector.transpose %{{.*}}, [1, 0, 3, 2, 5, 4] : vector<2x4x2x1x2x2xf16> to vector<4x2x1x2x2x2xf16>
+// CHECK: math.sqrt %{{.*}} : vector<4x2x1x2x2x2xf16>
+// CHECK: iree_vector_ext.to_simd %{{.*}} : vector<4x2x1x2x2x2xf16> -> vector<256x64xf16>
// CHECK: return {{.*}}#[[$LAYOUT]]
// -----
@@ -964,69 +879,54 @@
// CHECK-LABEL: func @transpose_3d
// CHECK: %[[IDS:.+]]:6 = affine.delinearize_index %{{.*}} into (%c2, %c1, %c1, %c4, %c8, %c2)
-// CHECK: %[[DIM0_ID:.+]] = affine.apply #[[$MAP]]()[%[[IDS]]#0, %[[IDS]]#3]
-// CHECK: %[[DIM2_ID0:.+]] = affine.apply #[[$MAP1]]()[%[[IDS]]#5]
-// CHECK: %[[RD0:.+]] = vector.transfer_read %arg0[%[[DIM0_ID]], %[[IDS]]#4, %[[DIM2_ID0]]], {{.*}} : memref<32x32x32xf16>, vector<4x1x2xf16>
-// CHECK: %[[DIM2_ID1:.+]] = affine.apply #[[$MAP2]]()[%[[IDS]]#5]
-// CHECK: %[[RD1:.+]] = vector.transfer_read %arg0[%[[DIM0_ID]], %[[IDS]]#4, %[[DIM2_ID1]]]
-// CHECK: %[[DIM2_ID2:.+]] = affine.apply #[[$MAP3]]()[%[[IDS]]#5]
-// CHECK: %[[RD2:.+]] = vector.transfer_read %arg0[%[[DIM0_ID]], %[[IDS]]#4, %[[DIM2_ID2]]]
-// CHECK: %[[DIM2_ID3:.+]] = affine.apply #[[$MAP4]]()[%[[IDS]]#5]
-// CHECK: %[[RD3:.+]] = vector.transfer_read %arg0[%[[DIM0_ID]], %[[IDS]]#4, %[[DIM2_ID3]]]
-// CHECK: %[[DIM2_ID4:.+]] = affine.apply #[[$MAP5]]()[%[[IDS]]#4]
-// CHECK: %[[RD4:.+]] = vector.transfer_read %arg0[%[[DIM0_ID]], %[[DIM2_ID4]], %[[DIM2_ID0]]]
-// CHECK: %[[RD5:.+]] = vector.transfer_read %arg0[%[[DIM0_ID]], %[[DIM2_ID4]], %[[DIM2_ID1]]]
-// CHECK: %[[RD6:.+]] = vector.transfer_read %arg0[%[[DIM0_ID]], %[[DIM2_ID4]], %[[DIM2_ID2]]]
-// CHECK: %[[RD7:.+]] = vector.transfer_read %arg0[%[[DIM0_ID]], %[[DIM2_ID4]], %[[DIM2_ID3]]]
+// CHECK-DAG: %[[DIM0_ID:.+]] = affine.apply #[[$MAP]]()[%[[IDS]]#0, %[[IDS]]#3]
+// CHECK-DAG: %[[DIM2_ID0:.+]] = affine.apply #[[$MAP1]]()[%[[IDS]]#5]
+// CHECK-DAG: %[[RD0:.+]] = vector.transfer_read %arg0[%[[DIM0_ID]], %[[IDS]]#4, %[[DIM2_ID0]]], {{.*}} : memref<32x32x32xf16>, vector<4x1x2xf16>
+// CHECK-DAG: %[[DIM2_ID1:.+]] = affine.apply #[[$MAP2]]()[%[[IDS]]#5]
+// CHECK-DAG: %[[RD1:.+]] = vector.transfer_read %arg0[%[[DIM0_ID]], %[[IDS]]#4, %[[DIM2_ID1]]]
+// CHECK-DAG: %[[DIM2_ID2:.+]] = affine.apply #[[$MAP3]]()[%[[IDS]]#5]
+// CHECK-DAG: %[[RD2:.+]] = vector.transfer_read %arg0[%[[DIM0_ID]], %[[IDS]]#4, %[[DIM2_ID2]]]
+// CHECK-DAG: %[[DIM2_ID3:.+]] = affine.apply #[[$MAP4]]()[%[[IDS]]#5]
+// CHECK-DAG: %[[RD3:.+]] = vector.transfer_read %arg0[%[[DIM0_ID]], %[[IDS]]#4, %[[DIM2_ID3]]]
+// CHECK-DAG: %[[DIM2_ID4:.+]] = affine.apply #[[$MAP5]]()[%[[IDS]]#4]
+// CHECK-DAG: %[[RD4:.+]] = vector.transfer_read %arg0[%[[DIM0_ID]], %[[DIM2_ID4]], %[[DIM2_ID0]]]
+// CHECK-DAG: %[[RD5:.+]] = vector.transfer_read %arg0[%[[DIM0_ID]], %[[DIM2_ID4]], %[[DIM2_ID1]]]
+// CHECK-DAG: %[[RD6:.+]] = vector.transfer_read %arg0[%[[DIM0_ID]], %[[DIM2_ID4]], %[[DIM2_ID2]]]
+// CHECK-DAG: %[[RD7:.+]] = vector.transfer_read %arg0[%[[DIM0_ID]], %[[DIM2_ID4]], %[[DIM2_ID3]]]
-// CHECK: %[[T0:.+]] = vector.transpose %[[RD0]], [1, 2, 0] : vector<4x1x2xf16> to vector<1x2x4xf16>
-// CHECK: vector.transfer_write %[[T0]], %arg0[%[[IDS]]#4, %[[DIM2_ID0]], %[[DIM0_ID]]] {{.*}} : vector<1x2x4xf16>, memref<32x32x32xf16>
+// CHECK: vector.transpose %{{.*}}, [1, 2, 0, 4, 5, 3, 7, 8, 6] : vector<1x2x4x1x1x1x4x1x2xf16> to vector<2x4x1x1x1x1x1x2x4xf16>
-// CHECK: %[[T1:.+]] = vector.transpose %[[RD4]], [1, 2, 0] : vector<4x1x2xf16> to vector<1x2x4xf16>
-// CHECK: vector.transfer_write %[[T1]], %arg0[%[[DIM2_ID4]], %[[DIM2_ID0]], %[[DIM0_ID]]]
-
-// CHECK: %[[T2:.+]] = vector.transpose %[[RD1]], [1, 2, 0] : vector<4x1x2xf16> to vector<1x2x4xf16>
-// CHECK: vector.transfer_write %[[T2]], %arg0[%[[IDS]]#4, %[[DIM2_ID1]], %[[DIM0_ID]]]
-
-// CHECK: %[[T3:.+]] = vector.transpose %[[RD5]], [1, 2, 0] : vector<4x1x2xf16> to vector<1x2x4xf16>
-// CHECK: vector.transfer_write %[[T3]], %arg0[%[[DIM2_ID4]], %[[DIM2_ID1]], %[[DIM0_ID]]]
-
-// CHECK: %[[T4:.+]] = vector.transpose %[[RD2]], [1, 2, 0] : vector<4x1x2xf16> to vector<1x2x4xf16>
-// CHECK: vector.transfer_write %[[T4]], %arg0[%[[IDS]]#4, %[[DIM2_ID2]], %[[DIM0_ID]]]
-
-// CHECK: %[[T5:.+]] = vector.transpose %[[RD6]], [1, 2, 0] : vector<4x1x2xf16> to vector<1x2x4xf16>
-// CHECK: vector.transfer_write %[[T5]], %arg0[%[[DIM2_ID4]], %[[DIM2_ID2]], %[[DIM0_ID]]]
-
-// CHECK: %[[T6:.+]] = vector.transpose %[[RD3]], [1, 2, 0] : vector<4x1x2xf16> to vector<1x2x4xf16>
-// CHECK: vector.transfer_write %[[T6]], %arg0[%[[IDS]]#4, %[[DIM2_ID3]], %[[DIM0_ID]]]
-
-// CHECK: %[[T7:.+]] = vector.transpose %[[RD7]], [1, 2, 0] : vector<4x1x2xf16> to vector<1x2x4xf16>
-// CHECK: vector.transfer_write %[[T7]], %arg0[%[[DIM2_ID4]], %[[DIM2_ID3]], %[[DIM0_ID]]]
+// CHECK-DAG: vector.transfer_write %{{.*}}, %arg0[%[[IDS]]#4, %[[DIM2_ID0]], %[[DIM0_ID]]] {{.*}} : vector<1x2x4xf16>, memref<32x32x32xf16>
+// CHECK-DAG: vector.transfer_write %{{.*}}, %arg0[%[[DIM2_ID4]], %[[DIM2_ID0]], %[[DIM0_ID]]]
+// CHECK-DAG: vector.transfer_write %{{.*}}, %arg0[%[[IDS]]#4, %[[DIM2_ID1]], %[[DIM0_ID]]]
+// CHECK-DAG: vector.transfer_write %{{.*}}, %arg0[%[[DIM2_ID4]], %[[DIM2_ID1]], %[[DIM0_ID]]]
+// CHECK-DAG: vector.transfer_write %{{.*}}, %arg0[%[[IDS]]#4, %[[DIM2_ID2]], %[[DIM0_ID]]]
+// CHECK-DAG: vector.transfer_write %{{.*}}, %arg0[%[[DIM2_ID4]], %[[DIM2_ID2]], %[[DIM0_ID]]]
+// CHECK-DAG: vector.transfer_write %{{.*}}, %arg0[%[[IDS]]#4, %[[DIM2_ID3]], %[[DIM0_ID]]]
+// CHECK-DAG: vector.transfer_write %{{.*}}, %arg0[%[[DIM2_ID4]], %[[DIM2_ID3]], %[[DIM0_ID]]]
// -----
#nested = #iree_vector_ext.nested_layout<
- subgroups_per_workgroup = [1, 1],
- // We are reducing along dim=1, so each thread will reduce
+ subgroups_per_workgroup = [1, 1],
+ // We are reducing along dim=1, so each thread will reduce
// 2 batches x 4 elements = 8 elements.
- batches_per_subgroup = [2, 2],
- outers_per_batch = [1, 1],
+ batches_per_subgroup = [2, 2],
+ outers_per_batch = [1, 1],
// We are reducing on dim=1, which is distributed over 4 threads. Based
// on the subgroup basis and thread order, the shuffle offset is 16.
- threads_per_outer = [16, 4],
- elements_per_thread = [1, 4],
+ threads_per_outer = [16, 4],
+ elements_per_thread = [1, 4],
- subgroup_order = [1, 0],
- batch_order = [1, 0],
- outer_order = [1, 0],
- thread_order = [1, 0],
+ subgroup_order = [1, 0],
+ thread_order = [1, 0],
- subgroup_basis = [1, 1],
+ subgroup_basis = [1, 1],
thread_basis = [4, 16]
>
func.func @mfma_16x16x16_out_reduced_dim1(%arg0: vector<32x32xf32>, %arg1: vector<32xf32>) -> vector<32xf32> {
- %0 = vector.multi_reduction <maximumf>, %arg0, %arg1
+ %0 = vector.multi_reduction <maximumf>, %arg0, %arg1
{
__vector_layout_test_anchor_operand_0 = #nested
} [1] : vector<32x32xf32> to vector<32xf32>
@@ -1048,7 +948,7 @@
// CHECK-DAG: %[[DARG0:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<32x32xf32> -> vector<2x2x1x1x1x4xf32>
// CHECK-DAG: %[[DARG1:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<32xf32> -> vector<2x1x1xf32>
// Local reduction
-// CHECK: vector.multi_reduction <maximumf>, %[[DARG0]], %[[DARG1]] [0, 2, 5] : vector<2x2x1x1x1x4xf32> to vector<2x1x1xf32>
+// CHECK: vector.multi_reduction <maximumf>, %[[DARG0]], %[[DARG1]] [1, 3, 5] : vector<2x2x1x1x1x4xf32> to vector<2x1x1xf32>
// Global reduction
// CHECK: gpu.shuffle xor %{{.*}}, %[[C16]], %[[C64]] : f32
// CHECK: gpu.shuffle xor %{{.*}}, %[[C32]], %[[C64]] : f32
@@ -1060,7 +960,7 @@
#nested = #iree_vector_ext.nested_layout<
subgroups_per_workgroup = [1, 1],
- // We are reducing along dim=1, so each thread will reduce
+ // We are reducing along dim=1, so each thread will reduce
// 4 batches x 4 elements = 16 elements.
batches_per_subgroup = [1, 4],
outers_per_batch = [1, 1],
@@ -1076,7 +976,7 @@
>
func.func @mfma_32x32x8_out_reduced_dim1(%arg0: vector<32x32xf32>, %arg1: vector<32xf32>) -> vector<32xf32> {
- %0 = vector.multi_reduction <maximumf>, %arg0, %arg1
+ %0 = vector.multi_reduction <maximumf>, %arg0, %arg1
{
__vector_layout_test_anchor_operand_0 = #nested
} [1] : vector<32x32xf32> to vector<32xf32>
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir
index 0a03966..f36e8bd 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir
@@ -45,10 +45,7 @@
elements_per_thread = [1, 8, 2],
subgroup_order = [0, 1, 2],
- batch_order = [0, 1, 2],
- outer_order = [0, 1, 2],
thread_order = [0, 1, 2],
- element_order = [0, 2, 1],
subgroup_basis = [2, 1, 1],
thread_basis = [8, 2, 4]
@@ -58,15 +55,15 @@
func.func @distribute_elementwise_nested_layout_f16(%a: vector<128x128x128xf16>, %b: vector<128x128x128xf16>) -> vector<128x128x128xf16> {
%c0 = arith.constant 0 : index
%cst_0 = arith.constant 0.0 : f16
- // CHECK: %[[ROOT:.*]] = arith.constant dense<0.000000e+00> : vector<8x2x4x1x4x4x1x2x8xf16>
+ // CHECK: %[[ROOT:.*]] = arith.constant dense<0.000000e+00> : vector<8x2x4x1x4x4x1x8x2xf16>
%root = arith.constant {"__vector_layout_test_anchor_result_0" = #nested} dense<0.0> : vector<128x128x128xf16>
- // CHECK-DAG: %[[B:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<128x128x128xf16> -> vector<8x2x4x1x4x4x1x2x8xf16>
- // CHECK-DAG: %[[C:.*]] = arith.mulf %[[B]], %[[ROOT]] : vector<8x2x4x1x4x4x1x2x8xf16>
+ // CHECK-DAG: %[[B:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<128x128x128xf16> -> vector<8x2x4x1x4x4x1x8x2xf16>
+ // CHECK-DAG: %[[C:.*]] = arith.mulf %[[B]], %[[ROOT]] : vector<8x2x4x1x4x4x1x8x2xf16>
%c = arith.mulf %root, %b : vector<128x128x128xf16>
- // CHECK-DAG: %[[A:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<128x128x128xf16> -> vector<8x2x4x1x4x4x1x2x8xf16>
- // CHECK-DAG: %[[D:.*]] = arith.addf %[[C]], %[[A]] fastmath<reassoc,nnan> : vector<8x2x4x1x4x4x1x2x8xf16>
+ // CHECK-DAG: %[[A:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<128x128x128xf16> -> vector<8x2x4x1x4x4x1x8x2xf16>
+ // CHECK-DAG: %[[D:.*]] = arith.addf %[[C]], %[[A]] fastmath<reassoc,nnan> : vector<8x2x4x1x4x4x1x8x2xf16>
%d = arith.addf %c, %a fastmath<reassoc,nnan> : vector<128x128x128xf16>
- // CHECK: iree_vector_ext.to_simd %[[D]] : vector<8x2x4x1x4x4x1x2x8xf16> -> vector<128x128x128xf16>
+ // CHECK: iree_vector_ext.to_simd %[[D]] : vector<8x2x4x1x4x4x1x8x2xf16> -> vector<128x128x128xf16>
return %d : vector<128x128x128xf16>
}
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
index f27f67d..417f8ee 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
@@ -602,10 +602,9 @@
NestedLayoutAttr permuteAndCreateNestedLayout(
MLIRContext *context, int64_t rank, int64_t outerDim, int64_t innerDim,
SmallVector<int64_t> subgroupCount, SmallVector<int64_t> subgroupOrder,
- SmallVector<int64_t> batchCount, SmallVector<int64_t> batchOrder,
- MMAAttr::SingleSubgroupLayout counts, MMAAttr::SingleSubgroupLayout orders,
- ArrayRef<int64_t> dataDuplicate, ArrayRef<int64_t> subgroupBasis,
- ArrayRef<bool> subgroupActiveIds) {
+ SmallVector<int64_t> batchCount, MMAAttr::SingleSubgroupLayout counts,
+ MMAAttr::SingleSubgroupLayout orders, ArrayRef<int64_t> dataDuplicate,
+ ArrayRef<int64_t> subgroupBasis, ArrayRef<bool> subgroupActiveIds) {
LLVM_DEBUG({
llvm::errs() << "Given:";
@@ -617,20 +616,14 @@
llvm::interleaveComma(subgroupOrder, llvm::errs());
llvm::errs() << "\n batchCount: ";
llvm::interleaveComma(batchCount, llvm::errs());
- llvm::errs() << "\n batchOrder: ";
- llvm::interleaveComma(batchOrder, llvm::errs());
llvm::errs() << "\n counts.outer: ";
llvm::interleaveComma(counts.outer, llvm::errs());
- llvm::errs() << "\n orders.outer: ";
- llvm::interleaveComma(orders.outer, llvm::errs());
llvm::errs() << "\n counts.thread: ";
llvm::interleaveComma(counts.thread, llvm::errs());
llvm::errs() << "\n orders.thread: ";
llvm::interleaveComma(orders.thread, llvm::errs());
llvm::errs() << "\n counts.element: ";
llvm::interleaveComma(counts.element, llvm::errs());
- llvm::errs() << "\n orders.element: ";
- llvm::interleaveComma(orders.element, llvm::errs());
llvm::errs() << "\n subgroupBasis: ";
llvm::interleaveComma(subgroupBasis, llvm::errs());
llvm::errs() << "\n subgroupActiveIds: ";
@@ -638,12 +631,8 @@
llvm::errs() << "\n";
});
- SmallVector<int64_t> outerOrder =
- getIdentityPermWithSwap(rank, orders.outer, outerDim, innerDim);
SmallVector<int64_t> threadOrder =
getIdentityPermWithSwap(rank, orders.thread, outerDim, innerDim);
- SmallVector<int64_t> elementOrder =
- getIdentityPermWithSwap(rank, orders.element, outerDim, innerDim);
SmallVector<int64_t> threadBasis =
getUnitOfRankWithDims(rank, counts.thread, outerDim, innerDim);
@@ -666,20 +655,14 @@
llvm::interleaveComma(subgroupOrder, llvm::errs());
llvm::errs() << "\n batchCount: ";
llvm::interleaveComma(batchCount, llvm::errs());
- llvm::errs() << "\n batchOrder: ";
- llvm::interleaveComma(batchOrder, llvm::errs());
llvm::errs() << "\n outerCount: ";
llvm::interleaveComma(outerCount, llvm::errs());
- llvm::errs() << "\n outerOrder: ";
- llvm::interleaveComma(outerOrder, llvm::errs());
llvm::errs() << "\n threadCount: ";
llvm::interleaveComma(threadCount, llvm::errs());
llvm::errs() << "\n threadOrder: ";
llvm::interleaveComma(threadOrder, llvm::errs());
llvm::errs() << "\n elementCount: ";
llvm::interleaveComma(elementCount, llvm::errs());
- llvm::errs() << "\n elementOrder: ";
- llvm::interleaveComma(elementOrder, llvm::errs());
llvm::errs() << "\n subgroupBasis: ";
llvm::interleaveComma(subgroupBasis, llvm::errs());
llvm::errs() << "\n subgroupActiveIds: ";
@@ -690,10 +673,9 @@
});
auto layoutAttr = NestedLayoutAttr::get(
- context, subgroupCount, subgroupOrder, batchCount, batchOrder, outerCount,
- outerOrder, threadCount, threadOrder, elementCount, elementOrder,
- subgroupBasis, subgroupActiveIds, threadBasis,
- SmallVector<bool>(threadBasis.size(), true));
+ context, subgroupCount, subgroupOrder, batchCount, outerCount,
+ threadCount, threadOrder, elementCount, subgroupBasis, subgroupActiveIds,
+ threadBasis, SmallVector<bool>(threadBasis.size(), true));
return layoutAttr;
}
@@ -847,8 +829,7 @@
context, cRank, m, n,
/*subgroupCount=*/cSubgroupSizes,
/*subgroupOrder=*/cOverallOrder,
- /*batchCount=*/cBatchSizes,
- /*batchOrder=*/cOverallOrder, cCounts, cOrders,
+ /*batchCount=*/cBatchSizes, cCounts, cOrders,
/*dataDuplicate=*/mmaAttr.getCDataDuplicate(), subgroupBasis,
cActiveSubgroups);
LLVM_DEBUG({ llvm::errs() << "C layout: " << cLayout << "\n"; });
@@ -868,15 +849,12 @@
SmallVector<int64_t> aBatchSizes(aRank, 1);
SmallVector<int64_t> aSubgroupSizes(aRank, 1);
SmallVector<int64_t> aSubgroupOrder(aRank, 0);
- SmallVector<int64_t> aBatchOrder(aRank, 0);
for (auto [i, dim] : llvm::enumerate(aMDims)) {
aBatchSizes[dim] = batchMSizes[i];
aSubgroupSizes[dim] = subgroupMBasis[i];
aSubgroupOrder[dim] = i;
- aBatchOrder[dim] = i >= afk ? i + 1 : i;
}
aSubgroupOrder[afk] = aRank - 1;
- aBatchOrder[afk] = afk;
aBatchSizes[afk] = getSubgroupKTileCount();
SmallVector<bool> aActiveSubgroups(subgroupBasis.size(), false);
@@ -889,8 +867,7 @@
context, aRank, afm, afk,
/*subgroupCount=*/aSubgroupSizes,
/*subgroupOrder=*/aSubgroupOrder,
- /*batchCount=*/aBatchSizes,
- /*batchOrder=*/getIdentityPerm(aRank), aCounts, aOrders,
+ /*batchCount=*/aBatchSizes, aCounts, aOrders,
/*dataDuplicate=*/mmaAttr.getADataDuplicate(), subgroupBasis,
aActiveSubgroups);
LLVM_DEBUG({ llvm::errs() << "A layout: " << aLayout << "\n"; });
@@ -907,15 +884,12 @@
SmallVector<int64_t> bBatchSizes(bRank, 1);
SmallVector<int64_t> bSubgroupSizes(bRank, 1);
SmallVector<int64_t> bSubgroupOrder(bRank, 0);
- SmallVector<int64_t> bBatchOrder(bRank, 0);
for (auto [i, dim] : llvm::enumerate(bNDims)) {
bBatchSizes[dim] = batchNSizes[i];
bSubgroupSizes[dim] = subgroupNBasis[i];
bSubgroupOrder[dim] = i;
- bBatchOrder[dim] = i >= bfk ? i + 1 : i;
}
bSubgroupOrder[bfk] = bRank - 1;
- bBatchOrder[bfk] = bfk;
bBatchSizes[bfk] = getSubgroupKTileCount();
SmallVector<bool> bActiveSubgroups(subgroupBasis.size(), false);
@@ -928,8 +902,7 @@
context, bRank, bfk, bfn,
/*subgroupCount=*/bSubgroupSizes,
/*subgroupOrder=*/bSubgroupOrder,
- /*batchCount=*/bBatchSizes,
- /*batchOrder=*/bBatchOrder, bCounts, bOrders,
+ /*batchCount=*/bBatchSizes, bCounts, bOrders,
/*dataDuplicate=*/mmaAttr.getBDataDuplicate(), subgroupBasis,
bActiveSubgroups);
LLVM_DEBUG({ llvm::errs() << "B layout: " << bLayout << "\n"; });
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorDistribute.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorDistribute.cpp
index ed14df5..fa01423 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorDistribute.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorDistribute.cpp
@@ -326,8 +326,8 @@
SmallVector<int64_t> threadBasis = threadCounts;
auto layout = IREE::VectorExt::NestedLayoutAttr::get(
- context, subgroupCounts, order, batchSizes, order, outerSizes, order,
- threadCounts, order, elementSizes, order, subgroupBasis,
+ context, subgroupCounts, order, batchSizes, outerSizes, threadCounts,
+ order, elementSizes, subgroupBasis,
SmallVector<bool>(subgroupBasis.size(), true), threadBasis,
SmallVector<bool>(threadBasis.size(), true));
if (analysis.setAnchor(transfer.getResult(), layout).failed()) {
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir
index 99a864b..83ae1be 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir
@@ -49,11 +49,11 @@
// CHECK-LABEL: func.func @matmul_256x256x256_f16_f32()
// CHECK-SAME: translation_info = #[[$TRANSLATION]]
-// CHECK: scf.for {{.*}} = %c0 to %c256 step %c128 iter_args({{.*}}) -> (vector<2x2x1x1x1x4xf32>)
+// CHECK: scf.for {{.*}} = %c0 to %c256 step %c128 iter_args({{.*}}) -> (vector<2x2x1x1x4x1xf32>)
// Each subgroup handles 2 * 2 tiles, and for each tile we accumulate 8 times
// along the K dimension. So in total 32 mfma ops.
// CHECK-COUNT-32: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32>
-// CHECK: scf.yield %{{.+}} : vector<2x2x1x1x1x4xf32>
+// CHECK: scf.yield %{{.+}} : vector<2x2x1x1x4x1xf32>
// CHECK-COUNT-4: vector.transfer_write {{.+}} {in_bounds = [true, true]} : vector<4x1xf32>, memref<256x256xf32, #hal.descriptor_type<storage_buffer>>
// -----
@@ -100,11 +100,11 @@
// CHECK-LABEL: func.func @matmul_256x256x256_f16_f16()
// CHECK-SAME: translation_info = #[[$TRANSLATION]]
-// CHECK: scf.for {{.*}} = %c0 to %c256 step %c128 iter_args(%[[ARG:.+]] = {{.*}}) -> (vector<2x2x1x1x1x4xf16>)
-// CHECK: arith.extf %[[ARG]] : vector<2x2x1x1x1x4xf16> to vector<2x2x1x1x1x4xf32>
+// CHECK: scf.for {{.*}} = %c0 to %c256 step %c128 iter_args(%[[ARG:.+]] = {{.*}}) -> (vector<2x2x1x1x4x1xf16>)
+// CHECK: arith.extf %[[ARG]] : vector<2x2x1x1x4x1xf16> to vector<2x2x1x1x4x1xf32>
// CHECK-COUNT-32: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32>
-// CHECK: %[[TRUNC:.+]] = arith.truncf %{{.*}} : vector<2x2x1x1x1x4xf32> to vector<2x2x1x1x1x4xf16>
-// CHECK: scf.yield %[[TRUNC]] : vector<2x2x1x1x1x4xf16>
+// CHECK: %[[TRUNC:.+]] = arith.truncf %{{.*}} : vector<2x2x1x1x4x1xf32> to vector<2x2x1x1x4x1xf16>
+// CHECK: scf.yield %[[TRUNC]] : vector<2x2x1x1x4x1xf16>
// CHECK-COUNT-4: vector.transfer_write {{.+}} {in_bounds = [true, true]} : vector<4x1xf16>, memref<256x256xf16, #hal.descriptor_type<storage_buffer>>
// -----
@@ -170,11 +170,11 @@
// CHECK-SAME: translation_info = #[[TRANSLATION]]
// This has more than 2 iteartions. So we have prefetching enabled for this case. Due to
// prefetching, we have one iteration peeled of so upper bound is 2048 - 128 = 1920.
-// CHECK: scf.for {{.*}} = %c0 to %c1920 step %c128 iter_args(%[[ARG:.+]] = {{.*}}) -> (vector<4x1x1x1x1x4xf16>)
-// CHECK: arith.extf %[[ARG]] : vector<4x1x1x1x1x4xf16> to vector<4x1x1x1x1x4xf32>
+// CHECK: scf.for {{.*}} = %c0 to %c1920 step %c128 iter_args(%[[ARG:.+]] = {{.*}}) -> (vector<4x1x1x1x4x1xf16>)
+// CHECK: arith.extf %[[ARG]] : vector<4x1x1x1x4x1xf16> to vector<4x1x1x1x4x1xf32>
// CHECK-COUNT-32: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32>
-// CHECK: %[[TRUNC:.+]] = arith.truncf %{{.*}} : vector<4x1x1x1x1x4xf32> to vector<4x1x1x1x1x4xf16>
-// CHECK: scf.yield %[[TRUNC]] : vector<4x1x1x1x1x4xf16>
+// CHECK: %[[TRUNC:.+]] = arith.truncf %{{.*}} : vector<4x1x1x1x4x1xf32> to vector<4x1x1x1x4x1xf16>
+// CHECK: scf.yield %[[TRUNC]] : vector<4x1x1x1x4x1xf16>
// CHECK-COUNT-32: amdgpu.mfma
// CHECK-COUNT-4: vector.transfer_write {{.+}} {in_bounds = [true, true]} : vector<4x1xf16>, memref<2x10x64x64xf16, #hal.descriptor_type<storage_buffer>>
@@ -221,7 +221,7 @@
// CHECK: scf.for {{.*}} = %c0 to %c3
// This has more than 2 iteartions. So we have prefetching enabled for this case. Due to
// prefetching, we have one iteration peeled of so upper bound is 768 - 32 = 736.
-// CHECK: scf.for {{.*}} = %c0 to %c736 step %c32 iter_args(%[[ARG:.+]] = {{.*}}) -> (vector<2x4x1x1x1x4xf32>)
+// CHECK: scf.for {{.*}} = %c0 to %c736 step %c32 iter_args(%[[ARG:.+]] = {{.*}}) -> (vector<2x4x1x1x4x1xf32>)
// CHECK-COUNT-16: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32>
// CHECK: scf.yield
// CHECK-COUNT-16: amdgpu.mfma
@@ -300,11 +300,11 @@
// CHECK-LABEL: func.func @generic_2x1024x20x64x1280_f16
// This has more than 2 iteartions. So we have prefetching enabled for this case. Due to
// prefetching, we have one iteration peeled of so upper bound is 1280 - 128 = 1152.
-// CHECK: scf.for {{.*}} = %c0 to %c1152 step %c128 iter_args({{.*}}) -> (vector<2x2x1x1x1x4xf16>)
+// CHECK: scf.for {{.*}} = %c0 to %c1152 step %c128 iter_args({{.*}}) -> (vector<2x2x1x1x4x1xf16>)
// Each subgroup handles 2 * 2 tiles, and for each tile we accumulate 8 times
// along the K dimension. So in total 32 mfma ops.
// CHECK-COUNT-32: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32>
-// CHECK: scf.yield %{{.+}} : vector<2x2x1x1x1x4xf16>
+// CHECK: scf.yield %{{.+}} : vector<2x2x1x1x4x1xf16>
// CHECK-COUNT-32: amdgpu.mfma
// CHECK-COUNT-4: vector.transfer_write {{.+}} : vector<4x1xf16>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_conversion.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_conversion.mlir
index bcd47c4..244fef4 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_conversion.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_conversion.mlir
@@ -38,10 +38,10 @@
}
// CHECK-LABEL: func.func @mfma_matmul_256x256x256
-// CHECK: %[[INIT:.+]] = arith.constant dense<0.000000e+00> : vector<1x1x1x1x1x4xf32>
+// CHECK: %[[INIT:.+]] = arith.constant dense<0.000000e+00> : vector<1x1x1x1x4x1xf32>
// CHECK: %[[RHS_ALLOC:.+]] = memref.alloc() : memref<32x16xf16, #gpu.address_space<workgroup>>
// CHECK: %[[LHS_ALLOC:.+]] = memref.alloc() : memref<16x32xf16, #gpu.address_space<workgroup>>
-// CHECK: scf.for {{.*}} = %c0 to %c256 step %c32 iter_args({{.*}} = %[[INIT]]) -> (vector<1x1x1x1x1x4xf32>)
+// CHECK: scf.for {{.*}} = %c0 to %c256 step %c32 iter_args({{.*}} = %[[INIT]])
// CHECK: %[[LLOAD:.+]] = vector.transfer_read {{.*}} : memref<16x256xf16, {{.*}}>, vector<1x8xf16>
// CHECK: %[[RLOAD:.+]] = vector.transfer_read {{.*}} : memref<256x16xf16, {{.*}}>, vector<1x8xf16>
// CHECK: vector.transfer_write %[[LLOAD]], %[[LHS_ALLOC]]{{.*}} : vector<1x8xf16>, memref<16x32xf16, #gpu.address_space<workgroup>>
@@ -50,8 +50,8 @@
// CHECK-COUNT-2: vector.transfer_read %[[LHS_ALLOC]][{{.+}}], %{{.+}} {in_bounds = [true, true]} : memref<16x32xf16, #gpu.address_space<workgroup>>, vector<1x4xf16>
// CHECK-COUNT-2: vector.transfer_read %[[RHS_ALLOC]][{{.+}}], %{{.+}} {in_bounds = [true, true]} : memref<32x16xf16, #gpu.address_space<workgroup>>, vector<4x1xf16>
// CHECK-COUNT-2: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32>
-// CHECK: %[[BCAST:.+]] = vector.broadcast {{.*}} : vector<1x1x1x4xf32> to vector<1x1x1x1x1x4xf32>
-// CHECK: scf.yield %[[BCAST]] : vector<1x1x1x1x1x4xf32>
+// CHECK: %[[BCAST:.+]] = vector.broadcast {{.*}} : vector<1x1x4x1xf32> to vector<1x1x1x1x4x1xf32>
+// CHECK: scf.yield %[[BCAST]]
// CHECK: vector.transfer_write {{.+}} {in_bounds = [true, true]} : vector<4x1xf32>, memref<16x16xf32{{.*}}>
// -----
@@ -107,9 +107,8 @@
// CHECK: %[[LHS_ALLOC:.+]] = memref.alloc() : memref<16x32xf16, #gpu.address_space<workgroup>>
// CHECK: affine.delinearize_index %[[LIN_ID]]
// CHECK: %[[INIT_READ:.+]] = vector.transfer_read %{{.*}} memref<16x16xf32, {{.*}}>, vector<4x1xf32>
-// CHECK: %[[INIT_TRANSP:.+]] = vector.transpose %[[INIT_READ]], [1, 0]
-// CHECK: %[[INIT:.+]] = vector.insert_strided_slice %[[INIT_TRANSP]]
-// CHECK: scf.for {{.*}} = %c0 to %c256 step %c32 iter_args({{.*}} = %[[INIT]]) -> (vector<1x1x1x1x1x4xf32>)
+// CHECK: %[[INIT:.+]] = vector.insert_strided_slice %[[INIT_READ]]
+// CHECK: scf.for {{.*}} = %c0 to %c256 step %c32 iter_args({{.*}} = %[[INIT]]) -> (vector<1x1x1x1x4x1xf32>)
// CHECK: %[[LLOAD:.+]] = vector.transfer_read {{.*}} : memref<16x256xf16, {{.*}}>, vector<1x8xf16>
// CHECK: %[[RLOAD:.+]] = vector.transfer_read {{.*}} permutation_map = #[[$MAP1]]} : memref<16x256xf16, {{.*}}>, vector<8x1xf16>
// CHECK: vector.transfer_write %[[LLOAD]], %[[LHS_ALLOC]]{{.*}} : vector<1x8xf16>, memref<16x32xf16, #gpu.address_space<workgroup>>
@@ -118,8 +117,8 @@
// CHECK-COUNT-2: vector.transfer_read %[[LHS_ALLOC]][{{.+}}], %{{.+}} {in_bounds = [true, true]} : memref<16x32xf16, #gpu.address_space<workgroup>>, vector<1x4xf16>
// CHECK-COUNT-2: vector.transfer_read %[[RHS_ALLOC]][{{.+}}], %{{.+}} {in_bounds = [true, true]} : memref<32x16xf16, #gpu.address_space<workgroup>>, vector<4x1xf16>
// CHECK-COUNT-2: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32>
-// CHECK: %[[BCAST:.+]] = vector.broadcast {{.*}} : vector<1x1x1x4xf32> to vector<1x1x1x1x1x4xf32>
-// CHECK: scf.yield %[[BCAST]] : vector<1x1x1x1x1x4xf32>
+// CHECK: %[[BCAST:.+]] = vector.broadcast {{.*}} : vector<1x1x4x1xf32> to vector<1x1x1x1x4x1xf32>
+// CHECK: scf.yield %[[BCAST]]
// CHECK: vector.transfer_write {{.+}} {in_bounds = [true, true]} : vector<4x1xf32>, memref<16x16xf32{{.*}}>
// -----
@@ -237,29 +236,21 @@
// CHECK: %[[LHS_ALLOC:.+]] = memref.alloc() : memref<16x32xf16, #gpu.address_space<workgroup>>
// CHECK: affine.delinearize_index %[[LIN_ID]]
// CHECK: %[[INIT_READ0:.+]] = vector.transfer_read %{{.*}} memref<16x16xf32, {{.*}}>, vector<1x1xf32>
-// CHECK: %[[INIT_TRANSP0:.+]] = vector.transpose %[[INIT_READ0]], [1, 0]
-// CHECK: %[[INIT0:.+]] = vector.insert_strided_slice %[[INIT_TRANSP0]]
+// CHECK: %[[INIT0:.+]] = vector.insert_strided_slice %[[INIT_READ0]]
// CHECK: %[[INIT_READ1:.+]] = vector.transfer_read %{{.*}} memref<16x16xf32, {{.*}}>, vector<1x1xf32>
-// CHECK: %[[INIT_TRANSP1:.+]] = vector.transpose %[[INIT_READ1]], [1, 0]
-// CHECK: %[[INIT1:.+]] = vector.insert_strided_slice %[[INIT_TRANSP1]]
+// CHECK: %[[INIT1:.+]] = vector.insert_strided_slice %[[INIT_READ1]]
// CHECK: %[[INIT_READ2:.+]] = vector.transfer_read %{{.*}} memref<16x16xf32, {{.*}}>, vector<1x1xf32>
-// CHECK: %[[INIT_TRANSP2:.+]] = vector.transpose %[[INIT_READ2]], [1, 0]
-// CHECK: %[[INIT2:.+]] = vector.insert_strided_slice %[[INIT_TRANSP2]]
+// CHECK: %[[INIT2:.+]] = vector.insert_strided_slice %[[INIT_READ2]]
// CHECK: %[[INIT_READ3:.+]] = vector.transfer_read %{{.*}} memref<16x16xf32, {{.*}}>, vector<1x1xf32>
-// CHECK: %[[INIT_TRANSP3:.+]] = vector.transpose %[[INIT_READ3]], [1, 0]
-// CHECK: %[[INIT3:.+]] = vector.insert_strided_slice %[[INIT_TRANSP3]]
+// CHECK: %[[INIT3:.+]] = vector.insert_strided_slice %[[INIT_READ3]]
// CHECK: %[[INIT_READ4:.+]] = vector.transfer_read %{{.*}} memref<16x16xf32, {{.*}}>, vector<1x1xf32>
-// CHECK: %[[INIT_TRANSP4:.+]] = vector.transpose %[[INIT_READ4]], [1, 0]
-// CHECK: %[[INIT4:.+]] = vector.insert_strided_slice %[[INIT_TRANSP4]]
+// CHECK: %[[INIT4:.+]] = vector.insert_strided_slice %[[INIT_READ4]]
// CHECK: %[[INIT_READ5:.+]] = vector.transfer_read %{{.*}} memref<16x16xf32, {{.*}}>, vector<1x1xf32>
-// CHECK: %[[INIT_TRANSP5:.+]] = vector.transpose %[[INIT_READ5]], [1, 0]
-// CHECK: %[[INIT5:.+]] = vector.insert_strided_slice %[[INIT_TRANSP5]]
+// CHECK: %[[INIT5:.+]] = vector.insert_strided_slice %[[INIT_READ5]]
// CHECK: %[[INIT_READ6:.+]] = vector.transfer_read %{{.*}} memref<16x16xf32, {{.*}}>, vector<1x1xf32>
-// CHECK: %[[INIT_TRANSP6:.+]] = vector.transpose %[[INIT_READ6]], [1, 0]
-// CHECK: %[[INIT6:.+]] = vector.insert_strided_slice %[[INIT_TRANSP6]]
+// CHECK: %[[INIT6:.+]] = vector.insert_strided_slice %[[INIT_READ6]]
// CHECK: %[[INIT_READ7:.+]] = vector.transfer_read %{{.*}} memref<16x16xf32, {{.*}}>, vector<1x1xf32>
-// CHECK: %[[INIT_TRANSP7:.+]] = vector.transpose %[[INIT_READ7]], [1, 0]
-// CHECK: %[[INIT7:.+]] = vector.insert_strided_slice %[[INIT_TRANSP7]]
+// CHECK: %[[INIT7:.+]] = vector.insert_strided_slice %[[INIT_READ7]]
// CHECK: scf.for {{.*}} = %c0 to %c256 step %c32 iter_args({{.*}} = %[[INIT7]]) -> (vector<1x1x8x1x1x1xf32>)
// CHECK: %[[LLOAD0:.+]] = vector.transfer_read {{.*}} : memref<16x256xf16, {{.*}}>, vector<1x8xf16>
// CHECK: %[[LLOAD1:.+]] = vector.transfer_read {{.*}} : memref<16x256xf16, {{.*}}>, vector<1x8xf16>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_layout.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_layout.mlir
index 520f7f9..76ca1f2 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_layout.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_layout.mlir
@@ -19,11 +19,9 @@
// CHECK-SAME: subgroup_basis = [1, 1, 1], subgroup_active_ids = [true, false, true], thread_basis = [2, 32]>
// CHECK: contract B vector layout: #iree_vector_ext.nested_layout<
// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 2], outers_per_batch = [1, 1], threads_per_outer = [2, 32], elements_per_thread = [4, 1],
-// CHECK-SAME: element_order = [1, 0],
// CHECK-SAME: subgroup_basis = [1, 1, 1], subgroup_active_ids = [false, true, true], thread_basis = [2, 32]>
// CHECK: contract C vector layout: #iree_vector_ext.nested_layout<
// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [4, 1], threads_per_outer = [2, 32], elements_per_thread = [4, 1],
-// CHECK-SAME: element_order = [1, 0],
// CHECK-SAME: subgroup_basis = [1, 1, 1], subgroup_active_ids = [true, true, false], thread_basis = [2, 32]>
// -----
@@ -47,11 +45,10 @@
// CHECK-SAME: subgroup_basis = [1, 1, 1], subgroup_active_ids = [true, false, true], thread_basis = [2, 32]>
// CHECK: contract B vector layout: #iree_vector_ext.nested_layout<
// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 2], outers_per_batch = [1, 1], threads_per_outer = [32, 2], elements_per_thread = [1, 4],
-// CHECK-SAME: outer_order = [1, 0], thread_order = [1, 0]
+// CHECK-SAME: thread_order = [1, 0]
// CHECK-SAME: subgroup_basis = [1, 1, 1], subgroup_active_ids = [false, true, true], thread_basis = [2, 32]>
// CHECK: contract C vector layout: #iree_vector_ext.nested_layout<
// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [4, 1], threads_per_outer = [2, 32], elements_per_thread = [4, 1],
-// CHECK-SAME: element_order = [1, 0]
// CHECK-SAME: subgroup_basis = [1, 1, 1], subgroup_active_ids = [true, true, false], thread_basis = [2, 32]>
// -----
@@ -125,11 +122,10 @@
// CHECK-SAME: subgroup_active_ids = [true, false, true], thread_basis = [4, 16]>
// CHECK: contract B vector layout: #iree_vector_ext.nested_layout<
// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 1], outers_per_batch = [1, 1], threads_per_outer = [4, 16], elements_per_thread = [4, 1],
-// CHECK-SAME: subgroup_order = [1, 0], element_order = [1, 0],
+// CHECK-SAME: subgroup_order = [1, 0]
// CHECK-SAME: subgroup_basis = [1, 1, 1], subgroup_active_ids = [false, true, true], thread_basis = [4, 16]>
// CHECK: contract C vector layout: #iree_vector_ext.nested_layout<
// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [1, 1], outers_per_batch = [1, 1], threads_per_outer = [4, 16], elements_per_thread = [4, 1],
-// CHECK-SAME: element_order = [1, 0],
// CHECK-SAME: subgroup_basis = [1, 1, 1], subgroup_active_ids = [true, true, false], thread_basis = [4, 16]>
// -----
@@ -180,7 +176,7 @@
// CHECK-SAME: subgroup_basis = [1, 1], thread_basis = [16, 4]>
// CHECK: transfer '{{.+}} memref<16x256xf16{{.+}}storage_buffer>>, vector<32x16xf16>' vector layout: #iree_vector_ext.nested_layout<
// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [1, 1], outers_per_batch = [1, 1], threads_per_outer = [4, 16], elements_per_thread = [8, 1],
-// CHECK-SAME: subgroup_order = [1, 0], batch_order = [1, 0], outer_order = [1, 0], thread_order = [1, 0], element_order = [1, 0],
+// CHECK-SAME: subgroup_order = [1, 0], thread_order = [1, 0],
// CHECK-SAME: subgroup_basis = [1, 1], thread_basis = [4, 16]>
// CHECK: contract A vector layout: #iree_vector_ext.nested_layout<
@@ -189,11 +185,9 @@
// CHECK-SAME: subgroup_basis = [1, 1, 1], subgroup_active_ids = [true, false, true], thread_basis = [4, 16]>
// CHECK: contract B vector layout: #iree_vector_ext.nested_layout<
// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 1], outers_per_batch = [1, 1], threads_per_outer = [4, 16], elements_per_thread = [4, 1],
-// CHECK-SAME: element_order = [1, 0],
// CHECK-SAME: subgroup_basis = [1, 1, 1], subgroup_active_ids = [false, true, true], thread_basis = [4, 16]>
// CHECK: contract C vector layout: #iree_vector_ext.nested_layout<
// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [1, 1], outers_per_batch = [1, 1], threads_per_outer = [4, 16], elements_per_thread = [4, 1],
-// CHECK-SAME: element_order = [1, 0],
// CHECK-SAME: subgroup_basis = [1, 1, 1], subgroup_active_ids = [true, true, false], thread_basis = [4, 16]>
// -----
@@ -257,11 +251,9 @@
// CHECK-SAME: subgroup_basis = [1, 1, 1], subgroup_active_ids = [true, false, true], thread_basis = [1, 32]>
// CHECK: contract B vector layout: #iree_vector_ext.nested_layout<
// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 2], outers_per_batch = [1, 1], threads_per_outer = [1, 16], elements_per_thread = [16, 1],
-// CHECK-SAME: element_order = [1, 0],
// CHECK-SAME: subgroup_basis = [1, 1, 1], subgroup_active_ids = [false, true, true], thread_basis = [1, 32]>
// CHECK: contract C vector layout: #iree_vector_ext.nested_layout<
// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [8, 1], threads_per_outer = [2, 16], elements_per_thread = [1, 1],
-// CHECK-SAME: element_order = [1, 0],
// CHECK-SAME: subgroup_basis = [1, 1, 1], subgroup_active_ids = [true, true, false], thread_basis = [2, 16]>
// -----
@@ -285,11 +277,10 @@
// CHECK-SAME: subgroup_basis = [1, 1, 1], subgroup_active_ids = [true, false, true], thread_basis = [1, 32]>
// CHECK: contract B vector layout: #iree_vector_ext.nested_layout<
// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 2], outers_per_batch = [1, 1], threads_per_outer = [16, 1], elements_per_thread = [1, 16],
-// CHECK-SAME: outer_order = [1, 0], thread_order = [1, 0],
+// CHECK-SAME: thread_order = [1, 0],
// CHECK-SAME: subgroup_basis = [1, 1, 1], subgroup_active_ids = [false, true, true], thread_basis = [1, 32]>
// CHECK: contract C vector layout: #iree_vector_ext.nested_layout<
// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [8, 1], threads_per_outer = [2, 16], elements_per_thread = [1, 1],
-// CHECK-SAME: element_order = [1, 0],
// CHECK-SAME: subgroup_basis = [1, 1, 1], subgroup_active_ids = [true, true, false], thread_basis = [2, 16]>
// -----
@@ -325,7 +316,6 @@
// CHECK-SAME: threads_per_outer = [4, 16],
// CHECK-SAME: elements_per_thread = [4, 1],
// CHECK-SAME: subgroup_order = [1, 0],
-// CHECK-SAME: element_order = [1, 0],
// CHECK-SAME: subgroup_basis = [2, 1, 1, 1],
// CHECK-SAME: subgroup_active_ids = [false, false, true, true],
// CHECK-SAME: thread_basis = [4, 16]>
@@ -335,7 +325,6 @@
// CHECK-SAME: outers_per_batch = [1, 1, 1],
// CHECK-SAME: threads_per_outer = [1, 4, 16],
// CHECK-SAME: elements_per_thread = [1, 4, 1],
-// CHECK-SAME: element_order = [0, 2, 1],
// CHECK-SAME: subgroup_basis = [2, 1, 1, 1],
// CHECK-SAME: subgroup_active_ids = [true, true, true, false],
// CHECK-SAME: thread_basis = [1, 4, 16]>
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.td
index 7096682..61b5710 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.td
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.td
@@ -130,8 +130,7 @@
would have 2-D tile sizes per compute hierarchy level.
We now describe each level of tiling. Each level of tiling represents a
- count of tiles over the next level (rather than a list of tile sizes) and
- an ordering over the tiles:
+ count of tiles over the next level (rather than a list of tile sizes).
1. Subgroups per Workgroup
@@ -178,9 +177,7 @@
for b_1 in range(batch_1):
...
- Batches are represented using two attributes:
- - batches_per_subgroup: Ranges of each loop
- - batch_order: Ordering of each loop, from outermost to innermost
+ `batches_per_subgroup` represents the range of each loop.
The second level, outers, is a way to represent thread layout duplication
required by a particular intrinsic. For example, some AMDGPU matrix
@@ -193,9 +190,7 @@
0 1 2 3 4 outers_per_batch=[2, 1]
5 6 7 8 9 threads_per_outer=[2, 5]
- Outers is represented using two attributes:
- - outers_per_batch: Number of outers in a batch
- - outer_order: Ordering of outers, from outermost to innermost
+ `outers_per_batch` represents the number of outers in a batch.
Finally, threads are distributed in a single outer. The thread
distribution is represented by:
@@ -210,8 +205,6 @@
outers_per_batch = [2, 2]
threads_per_outer = [2, 2]
- batch_order = [0, 1]
- outer_order = [0, 1]
thread_order = [1, 0]
}
@@ -259,10 +252,7 @@
The final level of tiling, representing the minimum shape of vector that
is treated as an atom.
- The elements are placed contigiously with their shape and ordering
- determined by:
- - `elements_per_thread`: Sizes of this level of tiling
- - `element_order`: Ordering of dimensions, from outermost to innermost
+ `elements_per_thread` represents the native size of the vector.
}];
let parameters = (ins
@@ -270,16 +260,13 @@
ArrayRefParameter<"int64_t", "subgroup_order">:$subgroupOrder,
ArrayRefParameter<"int64_t", "batches_per_subgroup">:$batchesPerSubgroup,
- ArrayRefParameter<"int64_t", "batch_order">:$batchOrder,
ArrayRefParameter<"int64_t", "outers_per_batch">:$outersPerBatch,
- ArrayRefParameter<"int64_t", "outer_order">:$outerOrder,
ArrayRefParameter<"int64_t", "threads_per_outer">:$threadsPerOuter,
ArrayRefParameter<"int64_t", "thread_order">:$threadOrder,
ArrayRefParameter<"int64_t", "elements_per_thread">:$elementsPerThread,
- ArrayRefParameter<"int64_t", "element_order">:$elementOrder,
ArrayRefParameter<"int64_t", "subgroup_basis">:$subgroupBasis,
ArrayRefParameter<"bool", "subgroup_active_ids">:$subgroupActiveIds,
@@ -298,10 +285,7 @@
`elements_per_thread` `=` `[` $elementsPerThread `]` `,`
custom<Permutation>("\"subgroup_order\"", ref($subgroupsPerWorkgroup), "true", $subgroupOrder) ``
- custom<Permutation>("\"batch_order\"", ref($batchesPerSubgroup), "true", $batchOrder) ``
- custom<Permutation>("\"outer_order\"", ref($outersPerBatch), "true", $outerOrder) ``
custom<Permutation>("\"thread_order\"", ref($threadsPerOuter), "true", $threadOrder) ``
- custom<Permutation>("\"element_order\"", ref($elementsPerThread), "true", $elementOrder) ``
custom<Basis>("\"subgroup_basis\"", "\"subgroup_active_ids\"", "true", $subgroupBasis, $subgroupActiveIds) ``
custom<Basis>("\"thread_basis\"", "\"thread_active_ids\"", "false", $threadBasis, $threadActiveIds) ``
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp
index 0234765..b481964 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp
@@ -285,12 +285,8 @@
SmallVector<int64_t> subgroupOrder =
getRankReducedPermutation(getSubgroupOrder());
- SmallVector<int64_t> batchOrder = getRankReducedPermutation(getBatchOrder());
- SmallVector<int64_t> outerOrder = getRankReducedPermutation(getOuterOrder());
SmallVector<int64_t> threadOrder =
getRankReducedPermutation(getThreadOrder());
- SmallVector<int64_t> elementOrder =
- getRankReducedPermutation(getElementOrder());
// Compose the projected dims with the basis mask to get the new active
// ids. Active ids indicates that we should use the ids marked as true, and
@@ -322,9 +318,8 @@
composeMasks(threadMask, invertedDroppedSubgroupMask);
return NestedLayoutAttr::get(getContext(), subgroupCount, subgroupOrder,
- batchCount, batchOrder, outerCount, outerOrder,
- threadCount, threadOrder, elementCount,
- elementOrder, getSubgroupBasis(), subgroupMask,
+ batchCount, outerCount, threadCount, threadOrder,
+ elementCount, getSubgroupBasis(), subgroupMask,
getThreadBasis(), threadMask);
}
@@ -360,33 +355,28 @@
applyPermutation(invPerm, getSubgroupOrder());
SmallVector<int64_t> batchCount =
applyPermutation(getBatchesPerSubgroup(), permutation);
- SmallVector<int64_t> batchOrder = applyPermutation(invPerm, getBatchOrder());
SmallVector<int64_t> outerCount =
applyPermutation(getOutersPerBatch(), permutation);
- SmallVector<int64_t> outerOrder = applyPermutation(invPerm, getOuterOrder());
SmallVector<int64_t> threadCount =
applyPermutation(getThreadsPerOuter(), permutation);
SmallVector<int64_t> threadOrder =
applyPermutation(invPerm, getThreadOrder());
SmallVector<int64_t> elementCount =
applyPermutation(getElementsPerThread(), permutation);
- SmallVector<int64_t> elementOrder =
- applyPermutation(invPerm, getElementOrder());
return NestedLayoutAttr::get(
- getContext(), subgroupCount, subgroupOrder, batchCount, batchOrder,
- outerCount, outerOrder, threadCount, threadOrder, elementCount,
- elementOrder, getSubgroupBasis(), getSubgroupActiveIds(),
- getThreadBasis(), getThreadActiveIds());
+ getContext(), subgroupCount, subgroupOrder, batchCount, outerCount,
+ threadCount, threadOrder, elementCount, getSubgroupBasis(),
+ getSubgroupActiveIds(), getThreadBasis(), getThreadActiveIds());
}
/// We distribute to:
/// <BATCH x OUTER x ELEMENT>
SmallVector<int64_t> NestedLayoutAttr::getDistributedShape() const {
SmallVector<int64_t> shape;
- shape.append(applyPermutation(getBatchesPerSubgroup(), getBatchOrder()));
- shape.append(applyPermutation(getOutersPerBatch(), getOuterOrder()));
- shape.append(applyPermutation(getElementsPerThread(), getElementOrder()));
+ shape.append(getBatchesPerSubgroup().begin(), getBatchesPerSubgroup().end());
+ shape.append(getOutersPerBatch().begin(), getOutersPerBatch().end());
+ shape.append(getElementsPerThread().begin(), getElementsPerThread().end());
return shape;
}
@@ -437,19 +427,26 @@
LogicalResult NestedLayoutAttr::verify(
llvm::function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> subgroupsPerWorkgroup, ArrayRef<int64_t> subgroupOrder,
- ArrayRef<int64_t> batchesPerSubgroup, ArrayRef<int64_t> batchOrder,
- ArrayRef<int64_t> outersPerBatch, ArrayRef<int64_t> outerOrder,
+ ArrayRef<int64_t> batchesPerSubgroup, ArrayRef<int64_t> outersPerBatch,
ArrayRef<int64_t> threadsPerOuter, ArrayRef<int64_t> threadOrder,
- ArrayRef<int64_t> elementsPerThread, ArrayRef<int64_t> elementOrder,
- ArrayRef<int64_t> subgroupBasis, ArrayRef<bool> subgroupActiveIds,
- ArrayRef<int64_t> threadBasis, ArrayRef<bool> threadActiveIds) {
+ ArrayRef<int64_t> elementsPerThread, ArrayRef<int64_t> subgroupBasis,
+ ArrayRef<bool> subgroupActiveIds, ArrayRef<int64_t> threadBasis,
+ ArrayRef<bool> threadActiveIds) {
size_t rank = subgroupsPerWorkgroup.size();
- auto checkTile = [&](ArrayRef<int64_t> tileShape, ArrayRef<int64_t> order) {
- if (tileShape.size() != rank || order.size() != rank) {
+ auto checkTile = [&](ArrayRef<int64_t> tileShape) {
+ if (tileShape.size() != rank) {
emitError() << "all tiles must have the same rank as the layout";
return failure();
}
+ return success();
+ };
+
+ auto checkOrder = [&](ArrayRef<int64_t> order) {
+ if (order.size() != rank) {
+ emitError() << "all orders must have the same rank as the layout";
+ return failure();
+ }
if (!mlir::isPermutationVector(order)) {
emitError() << "all orderings must be permutation vectors";
return failure();
@@ -457,11 +454,14 @@
return success();
};
- if (failed(checkTile(subgroupsPerWorkgroup, subgroupOrder)) ||
- failed(checkTile(batchesPerSubgroup, batchOrder)) ||
- failed(checkTile(outersPerBatch, outerOrder)) ||
- failed(checkTile(threadsPerOuter, threadOrder)) ||
- failed(checkTile(elementsPerThread, elementOrder))) {
+ if (failed(checkTile(subgroupsPerWorkgroup)) ||
+ failed(checkTile(batchesPerSubgroup)) ||
+ failed(checkTile(outersPerBatch)) || failed(checkTile(threadsPerOuter)) ||
+ failed(checkTile(elementsPerThread))) {
+ return failure();
+ }
+
+ if (failed(checkOrder(subgroupOrder)) || failed(checkOrder(threadOrder))) {
return failure();
}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/roundtrip.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/roundtrip.mlir
index 5600905..275eac0 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/roundtrip.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/roundtrip.mlir
@@ -29,10 +29,7 @@
elements_per_thread = [1, 4],
subgroup_order = [0, 1],
- batch_order = [0, 1],
- outer_order = [1, 0],
thread_order = [0, 1],
- element_order = [0, 1],
subgroup_basis = [1, 1],
thread_basis = [4, 2]
@@ -46,9 +43,7 @@
elements_per_thread = [4, 1],
subgroup_order = [1, 0],
- batch_order = [1, 0],
thread_order = [1, 0],
- element_order = [1, 0],
subgroup_basis = [1, 1],
thread_basis = [2, 4]
@@ -62,9 +57,7 @@
elements_per_thread = [4, 1],
subgroup_order = [1, 0],
- batch_order = [1, 0],
thread_order = [1, 0],
- element_order = [1, 0],
subgroup_basis = [2, 4, 8],
subgroup_active_ids = [true, true, false],
@@ -79,9 +72,7 @@
elements_per_thread = [4, 1],
subgroup_order = [1, 0],
- batch_order = [1, 0],
thread_order = [1, 0],
- element_order = [1, 0],
subgroup_basis = [2, 4, 8],
subgroup_active_ids = [true, true, false],
@@ -97,9 +88,7 @@
elements_per_thread = [4, 1],
subgroup_order = [1, 0],
- batch_order = [1, 0],
thread_order = [1, 0],
- element_order = [1, 0],
subgroup_basis = [2, 4],
subgroup_active_ids = [true, true],
@@ -127,7 +116,6 @@
// CHECK-SAME: outers_per_batch = [4, 1],
// CHECK-SAME: threads_per_outer = [4, 2],
// CHECK-SAME: elements_per_thread = [1, 4],
-// CHECK-SAME: outer_order = [1, 0],
// CHECK-SAME: subgroup_basis = [1, 1],
// CHECK-SAME: thread_basis = [4, 2]>
@@ -138,9 +126,7 @@
// CHECK-SAME: threads_per_outer = [2, 4],
// CHECK-SAME: elements_per_thread = [4, 1],
// CHECK-SAME: subgroup_order = [1, 0],
-// CHECK-SAME: batch_order = [1, 0],
// CHECK-SAME: thread_order = [1, 0],
-// CHECK-SAME: element_order = [1, 0],
// CHECK-SAME: subgroup_basis = [1, 1],
// CHECK-SAME: thread_basis = [2, 4]>
@@ -151,9 +137,7 @@
// CHECK-SAME: threads_per_outer = [2, 4],
// CHECK-SAME: elements_per_thread = [4, 1],
// CHECK-SAME: subgroup_order = [1, 0],
-// CHECK-SAME: batch_order = [1, 0],
// CHECK-SAME: thread_order = [1, 0],
-// CHECK-SAME: element_order = [1, 0],
// CHECK-SAME: subgroup_basis = [2, 4, 8],
// CHECK-SAME: subgroup_active_ids = [true, true, false],
// CHECK-SAME: thread_basis = [2, 4]>
@@ -165,9 +149,7 @@
// CHECK-SAME: threads_per_outer = [2, 4],
// CHECK-SAME: elements_per_thread = [4, 1],
// CHECK-SAME: subgroup_order = [1, 0],
-// CHECK-SAME: batch_order = [1, 0],
// CHECK-SAME: thread_order = [1, 0],
-// CHECK-SAME: element_order = [1, 0],
// CHECK-SAME: subgroup_basis = [2, 4, 8],
// CHECK-SAME: subgroup_active_ids = [true, true, false],
// CHECK-SAME: thread_basis = [2, 4, 2],
@@ -180,9 +162,7 @@
// CHECK-SAME: threads_per_outer = [2, 4],
// CHECK-SAME: elements_per_thread = [4, 1],
// CHECK-SAME: subgroup_order = [1, 0],
-// CHECK-SAME: batch_order = [1, 0],
// CHECK-SAME: thread_order = [1, 0],
-// CHECK-SAME: element_order = [1, 0],
// CHECK-SAME: subgroup_basis = [2, 4],
// CHECK-SAME: thread_basis = [4, 2]>