[spirv] Replace SPIRVOpLowering with a local class (#4383)
SPIRVOpLowering is going away in upstream:
https://reviews.llvm.org/D94080
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp
index 65f9994..48e8f91 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp
@@ -244,8 +244,8 @@
/// in SPIR-V lowering, linalg.reshape becomes a no-op.
// TODO(ravishankarm): Move this into MLIR Core.
struct LinalgReshapeConverter final
- : public SPIRVOpLowering<linalg::ReshapeOp> {
- using SPIRVOpLowering<linalg::ReshapeOp>::SPIRVOpLowering;
+ : public OpConversionPattern<linalg::ReshapeOp> {
+ using OpConversionPattern<linalg::ReshapeOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
linalg::ReshapeOp reshapeOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
@@ -254,16 +254,36 @@
}
};
+/// Base class for lowering to SPIR-V cooperative matrix ops.
+template <typename SourceOp>
+class CoopMatOpLowering : public OpConversionPattern<SourceOp> {
+ public:
+ CoopMatOpLowering(MLIRContext *context, SPIRVTypeConverter &converter,
+ PatternBenefit benefit = 1)
+ : OpConversionPattern<SourceOp>(context, benefit), converter(converter) {}
+
+ protected:
+ // TODO: We explicitly keep a reference of the type converter instead of
+ // passing it to OpConversionPattern during construction. This effectively
+ // bypasses the dialect conversion framework's automation over type
+ // conversion. This is needed for now because upstream SPIRVTypeConverter does
+ // not support cooperative matrix well yet so the framework won't know how to
+ // generate cooperative matrix. We are manually constructing the cooperative
+ // matrix in patterns. This should be fixed when we upstream all cooperative
+ // matrix related code.
+ SPIRVTypeConverter &converter;
+};
+
/// Convert subgroup level vector transfert to SPIR-V cooperative
/// matrix load/store if those are supported.
/// TODO(thomasraoux): Move to MLIR core once this is stable.
template <typename OpTy>
-class TransferToCoopMatLoadStore final : public SPIRVOpLowering<OpTy> {
+class TransferToCoopMatLoadStore final : public CoopMatOpLowering<OpTy> {
public:
TransferToCoopMatLoadStore(
MLIRContext *context, SPIRVTypeConverter &converter,
const CooperativeMatrixAnalysis &cooperativeMatrixAnalysis)
- : SPIRVOpLowering<OpTy>(context, converter),
+ : CoopMatOpLowering<OpTy>(context, converter),
cooperativeMatrixAnalysis(cooperativeMatrixAnalysis) {}
LogicalResult matchAndRewrite(
@@ -292,7 +312,7 @@
for (auto i : op.indices())
remappedIndices.push_back(rewriter.getRemappedValue(i));
Value ptr = spirv::getElementPtr(
- SPIRVOpLowering<OpTy>::typeConverter, memrefType,
+ CoopMatOpLowering<OpTy>::converter, memrefType,
rewriter.getRemappedValue(op.source()), remappedIndices, loc, rewriter);
int64_t offset = 0;
SmallVector<int64_t, 2> strides;
@@ -342,12 +362,12 @@
/// Convert subgroup level vector contract to SPIR-V cooperative
/// matrix matmuladd.
class VectorContractToCoopMatmul final
- : public SPIRVOpLowering<vector::ContractionOp> {
+ : public CoopMatOpLowering<vector::ContractionOp> {
public:
VectorContractToCoopMatmul(
MLIRContext *context, SPIRVTypeConverter &converter,
const CooperativeMatrixAnalysis &cooperativeMatrixAnalysis)
- : SPIRVOpLowering<vector::ContractionOp>(context, converter),
+ : CoopMatOpLowering<vector::ContractionOp>(context, converter),
cooperativeMatrixAnalysis(cooperativeMatrixAnalysis) {}
LogicalResult matchAndRewrite(
@@ -491,7 +511,7 @@
auto aliasedResources = getAliasedResources(moduleOp);
patterns.insert<IREEPlaceholderConverter>(typeConverter, context,
std::move(aliasedResources));
- patterns.insert<LinalgReshapeConverter>(context, typeConverter);
+ patterns.insert<LinalgReshapeConverter>(typeConverter, context);
std::unique_ptr<ConversionTarget> target =
spirv::SPIRVConversionTarget::get(targetAttr);