Adjust default mmt4d tiling (#7071)
* Adjust default mmt4d tiling
- Tile to M0=4, N0=4, K0=1.
- Add microbenchmark for that case.
* Set M0,N0,K0 from operands
diff --git a/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp b/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
index 89d61d0..e9d69ee 100644
--- a/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
+++ b/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
@@ -183,23 +183,33 @@
return SmallVector<int64_t>(mmt4dWorkgroupTileSizes.begin(),
mmt4dWorkgroupTileSizes.end());
}
- return {64, 32};
+ return {48, 32};
};
auto getL1TileSizes = [&]() -> SmallVector<int64_t> {
+ auto lhsShape = getUntiledShape(mmt4dOp.inputs()[0]);
+ auto rhsShape = getUntiledShape(mmt4dOp.inputs()[1]);
+ int M0 = lhsShape[2];
+ int N0 = rhsShape[2];
+ int K0 = lhsShape[3];
if (!mmt4dL1TileSizes.empty()) {
return SmallVector<int64_t>(mmt4dL1TileSizes.begin(),
mmt4dL1TileSizes.end());
}
- return {1, 1, 4, 4, 1, 4};
+ return {1, 1, 1, M0, N0, K0};
};
auto getVectorSizes = [&]() -> SmallVector<int64_t> {
+ auto lhsShape = getUntiledShape(mmt4dOp.inputs()[0]);
+ auto rhsShape = getUntiledShape(mmt4dOp.inputs()[1]);
+ int M0 = lhsShape[2];
+ int N0 = rhsShape[2];
+ int K0 = lhsShape[3];
if (!mmt4dVectorSizes.empty()) {
return SmallVector<int64_t>(mmt4dVectorSizes.begin(),
mmt4dVectorSizes.end());
}
- return {1, 1, 4, 4, 1, 4};
+ return {1, 1, 1, M0, N0, K0};
};
SmallVector<int64_t, 4> nativeVectorSize = getVectorSizes();
diff --git a/iree/test/microbenchmarks/linalg_mmt4d.mlir b/iree/test/microbenchmarks/linalg_mmt4d.mlir
index c8a9d10..543d6bd 100644
--- a/iree/test/microbenchmarks/linalg_mmt4d.mlir
+++ b/iree/test/microbenchmarks/linalg_mmt4d.mlir
@@ -10,10 +10,18 @@
return %0 : tensor<384x512xf32>
}
-func @mmt4d_384x384x512() -> tensor<96x128x4x4xf32> {
- %lhs = util.unfoldable_constant dense<1.0> : tensor<96x96x4x4xf32>
- %rhs = util.unfoldable_constant dense<1.0> : tensor<128x96x4x4xf32>
+func @mmt4d_384x384x512_4x1x4() -> tensor<96x128x4x4xf32> {
+ %lhs = util.unfoldable_constant dense<1.0> : tensor<96x384x4x1xf32>
+ %rhs = util.unfoldable_constant dense<1.0> : tensor<128x384x4x1xf32>
%dst = util.unfoldable_constant dense<1.0> : tensor<96x128x4x4xf32>
- %0 = linalg.mmt4d ins(%lhs, %rhs : tensor<96x96x4x4xf32>, tensor<128x96x4x4xf32>) outs(%dst : tensor<96x128x4x4xf32>) -> tensor<96x128x4x4xf32>
+ %0 = linalg.mmt4d ins(%lhs, %rhs : tensor<96x384x4x1xf32>, tensor<128x384x4x1xf32>) outs(%dst : tensor<96x128x4x4xf32>) -> tensor<96x128x4x4xf32>
return %0 : tensor<96x128x4x4xf32>
}
+
+func @mmt4d_384x384x512_8x1x8() -> tensor<48x64x8x8xf32> {
+ %lhs = util.unfoldable_constant dense<1.0> : tensor<48x384x8x1xf32>
+ %rhs = util.unfoldable_constant dense<1.0> : tensor<64x384x8x1xf32>
+ %dst = util.unfoldable_constant dense<1.0> : tensor<48x64x8x8xf32>
+ %0 = linalg.mmt4d ins(%lhs, %rhs : tensor<48x384x8x1xf32>, tensor<64x384x8x1xf32>) outs(%dst : tensor<48x64x8x8xf32>) -> tensor<48x64x8x8xf32>
+ return %0 : tensor<48x64x8x8xf32>
+}