[spirv] Use convolution output type bitwidth for tiling (#13049)
In the case of int8 convolutions, we can have int8 input/filter but
int32 accumulator. So we need to use the output type bitwidth for
deducing the tiling scheme.
This should address the perf regression for int8 EfficientNet.
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/AMDConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/AMDConfig.cpp
index f0ff478..cfa6c97 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/AMDConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/AMDConfig.cpp
@@ -82,7 +82,8 @@
}
if (auto convOp = dyn_cast<linalg::ConvolutionOpInterface>(rootOp)) {
- auto type = cast<ShapedType>(convOp.image().getType());
+ // Use the result type in case of larger bitwidth for accumulators.
+ auto type = cast<ShapedType>(convOp->getResult(0).getType());
const int bitwidth = type.getElementTypeBitWidth();
if (bitwidth > 32) return failure();
const int multipler = 32 / bitwidth;
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/AdrenoConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/AdrenoConfig.cpp
index 130ebb1..6c1294c 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/AdrenoConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/AdrenoConfig.cpp
@@ -50,7 +50,8 @@
}
if (auto convOp = dyn_cast<linalg::ConvolutionOpInterface>(rootOp)) {
- auto type = cast<ShapedType>(convOp.image().getType());
+ // Use the result type in case of larger bitwidth for accumulators.
+ auto type = cast<ShapedType>(convOp->getResult(0).getType());
const int bitwidth = type.getElementTypeBitWidth();
if (bitwidth > 32) return failure();
const int multipler = 32 / bitwidth;
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/AppleConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/AppleConfig.cpp
index add1aa7..39fe1d3 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/AppleConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/AppleConfig.cpp
@@ -50,7 +50,8 @@
}
if (auto convOp = dyn_cast<linalg::ConvolutionOpInterface>(rootOp)) {
- auto type = cast<ShapedType>(convOp.image().getType());
+ // Use the result type in case of larger bitwidth for accumulators.
+ auto type = cast<ShapedType>(convOp->getResult(0).getType());
const int bitwidth = type.getElementTypeBitWidth();
if (bitwidth > 32) return failure();
const int multipler = 32 / bitwidth;
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
index 9828e77..2f332bf 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
@@ -1451,7 +1451,8 @@
return setDefaultOpConfig(limits, op);
})
.Case<linalg::ConvolutionOpInterface>([limits](auto op) {
- auto type = cast<ShapedType>(op.image().getType());
+ // Use the result type in case of larger bitwidth for accumulators.
+ auto type = cast<ShapedType>(op->getResult(0).getType());
const int bitwidth = type.getElementTypeBitWidth();
if (bitwidth <= 32) {
const int multipler = 32 / bitwidth;
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/MaliConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/MaliConfig.cpp
index f5cb3eb..4cdbbe0 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/MaliConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/MaliConfig.cpp
@@ -53,7 +53,8 @@
}
if (auto convOp = dyn_cast<linalg::ConvolutionOpInterface>(rootOp)) {
- auto type = cast<ShapedType>(convOp.image().getType());
+ // Use the result type in case of larger bitwidth for accumulators.
+ auto type = cast<ShapedType>(convOp->getResult(0).getType());
const int bitwidth = type.getElementTypeBitWidth();
if (bitwidth > 32) return failure();
const int multipler = 32 / bitwidth;