[spirv] Replace SPIRVOpLowering with a local class (#4383)

SPIRVOpLowering is going away in upstream:
https://reviews.llvm.org/D94080
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp
index 65f9994..48e8f91 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp
@@ -244,8 +244,8 @@
 /// in SPIR-V lowering, linalg.reshape becomes a no-op.
 // TODO(ravishankarm): Move this into MLIR Core.
 struct LinalgReshapeConverter final
-    : public SPIRVOpLowering<linalg::ReshapeOp> {
-  using SPIRVOpLowering<linalg::ReshapeOp>::SPIRVOpLowering;
+    : public OpConversionPattern<linalg::ReshapeOp> {
+  using OpConversionPattern<linalg::ReshapeOp>::OpConversionPattern;
   LogicalResult matchAndRewrite(
       linalg::ReshapeOp reshapeOp, ArrayRef<Value> operands,
       ConversionPatternRewriter &rewriter) const override {
@@ -254,16 +254,36 @@
   }
 };
 
+/// Base class for lowering to SPIR-V cooperative matrix ops.
+template <typename SourceOp>
+class CoopMatOpLowering : public OpConversionPattern<SourceOp> {
+ public:
+  CoopMatOpLowering(MLIRContext *context, SPIRVTypeConverter &converter,
+                    PatternBenefit benefit = 1)
+      : OpConversionPattern<SourceOp>(context, benefit), converter(converter) {}
+
+ protected:
+  // TODO: We explicitly keep a reference of the type converter instead of
+  // passing it to OpConversionPattern during construction. This effectively
+  // bypasses the dialect conversion framework's automation over type
+  // conversion. This is needed for now because upstream SPIRVTypeConverter does
+  // not support cooperative matrix well yet so the framework won't know how to
+  // generate cooperative matrix. We are manually constructing the cooperative
+  // matrix in patterns. This should be fixed when we upstream all cooperative
+  // matrix related code.
+  SPIRVTypeConverter &converter;
+};
+
 /// Convert subgroup level vector transfert to SPIR-V cooperative
 /// matrix load/store if those are supported.
 /// TODO(thomasraoux): Move to MLIR core once this is stable.
 template <typename OpTy>
-class TransferToCoopMatLoadStore final : public SPIRVOpLowering<OpTy> {
+class TransferToCoopMatLoadStore final : public CoopMatOpLowering<OpTy> {
  public:
   TransferToCoopMatLoadStore(
       MLIRContext *context, SPIRVTypeConverter &converter,
       const CooperativeMatrixAnalysis &cooperativeMatrixAnalysis)
-      : SPIRVOpLowering<OpTy>(context, converter),
+      : CoopMatOpLowering<OpTy>(context, converter),
         cooperativeMatrixAnalysis(cooperativeMatrixAnalysis) {}
 
   LogicalResult matchAndRewrite(
@@ -292,7 +312,7 @@
     for (auto i : op.indices())
       remappedIndices.push_back(rewriter.getRemappedValue(i));
     Value ptr = spirv::getElementPtr(
-        SPIRVOpLowering<OpTy>::typeConverter, memrefType,
+        CoopMatOpLowering<OpTy>::converter, memrefType,
         rewriter.getRemappedValue(op.source()), remappedIndices, loc, rewriter);
     int64_t offset = 0;
     SmallVector<int64_t, 2> strides;
@@ -342,12 +362,12 @@
 /// Convert subgroup level vector contract to SPIR-V cooperative
 /// matrix matmuladd.
 class VectorContractToCoopMatmul final
-    : public SPIRVOpLowering<vector::ContractionOp> {
+    : public CoopMatOpLowering<vector::ContractionOp> {
  public:
   VectorContractToCoopMatmul(
       MLIRContext *context, SPIRVTypeConverter &converter,
       const CooperativeMatrixAnalysis &cooperativeMatrixAnalysis)
-      : SPIRVOpLowering<vector::ContractionOp>(context, converter),
+      : CoopMatOpLowering<vector::ContractionOp>(context, converter),
         cooperativeMatrixAnalysis(cooperativeMatrixAnalysis) {}
 
   LogicalResult matchAndRewrite(
@@ -491,7 +511,7 @@
   auto aliasedResources = getAliasedResources(moduleOp);
   patterns.insert<IREEPlaceholderConverter>(typeConverter, context,
                                             std::move(aliasedResources));
-  patterns.insert<LinalgReshapeConverter>(context, typeConverter);
+  patterns.insert<LinalgReshapeConverter>(typeConverter, context);
 
   std::unique_ptr<ConversionTarget> target =
       spirv::SPIRVConversionTarget::get(targetAttr);