Add patterns to vectorToGPU pass (#3587)

Add patterns to lower some vector operations to a mix of standard
operations and vector operations that have a direct mapping to SPIR-V.
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp
index 7721f41..eb2a32e 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp
@@ -64,6 +64,7 @@
 
  private:
   void tileAndVectorizeLinalgCopy(FuncOp funcOp, MLIRContext *context);
+  void lowerVectorOps(FuncOp funcOp, MLIRContext *context);
 };
 
 // Common class for all vector to GPU patterns.
@@ -149,8 +150,7 @@
         loc, rewriter.getIndexType(), rewriter.getStringAttr("x"));
     Value index = rewriter.create<AddIOp>(loc, ThreadIndex, indices.back());
     indices.back() = index;
-    rewriter.create<StoreOp>(op.getLoc(), operands[0], operands[1], indices);
-    rewriter.eraseOp(op);
+    rewriter.replaceOpWithNewOp<StoreOp>(op, operands[0], operands[1], indices);
     return success();
   }
 };
@@ -204,11 +204,141 @@
   applyPatternsAndFoldGreedily(funcOp, vectorizationPatterns);
 }
 
+// Convert vector transfer_read to a load if possible. This is the case only if
+// the element type of the memref matches the element type we want to load.
+class VectorTransferReadToLoad
+    : public OpRewritePattern<vector::TransferReadOp> {
+ public:
+  using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::TransferReadOp op,
+                                PatternRewriter &rewriter) const override {
+    if (op.getVectorType().getNumElements() != 1 ||
+        op.getMemRefType().getElementType() !=
+            op.getVectorType().getElementType()) {
+      return failure();
+    }
+    auto loc = op.getLoc();
+    Value newOp = rewriter.create<LoadOp>(loc, op.memref(), op.indices());
+    newOp =
+        rewriter.create<vector::BroadcastOp>(loc, op.getVectorType(), newOp);
+    rewriter.replaceOp(op, newOp);
+    return success();
+  }
+};
+
+// Convert vector transfer_write to a store if possible. This is the case only
+// if the element type of the memref matches the element type we want to store.
+class VectorTransferWriteToStore
+    : public OpRewritePattern<vector::TransferWriteOp> {
+ public:
+  using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::TransferWriteOp op,
+                                PatternRewriter &rewriter) const override {
+    if (op.getVectorType().getNumElements() != 1 ||
+        op.getMemRefType().getElementType() !=
+            op.getVectorType().getElementType()) {
+      return failure();
+    }
+    auto loc = op.getLoc();
+    SmallVector<int64_t, 2> zero(op.getVectorType().getRank(), 0);
+    Value scalarValue =
+        rewriter.create<vector::ExtractOp>(loc, op.vector(), zero);
+    rewriter.create<StoreOp>(loc, scalarValue, op.memref(), op.indices());
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
+// Lower vector contract to a single scalar or vector mulf+addf. Insert casts to
+// convert from 2D vector to 1D vector or scalar.
+class VectorContractLowering : public OpRewritePattern<vector::ContractionOp> {
+ public:
+  using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ContractionOp op,
+                                PatternRewriter &rewriter) const override {
+    auto iteratorTypes = op.iterator_types().getValue();
+    if (iteratorTypes.size() != 3 || !isParallelIterator(iteratorTypes[0]) ||
+        !isParallelIterator(iteratorTypes[1]) ||
+        !isReductionIterator(iteratorTypes[2]) ||
+        !isRowMajorMatmul(op.indexing_maps())) {
+      return failure();
+    }
+    if (op.getLhsType().getNumElements() != 1) return failure();
+    unsigned vecSize = op.getAccType().cast<VectorType>().getNumElements();
+    if (!(vecSize >= 1 && vecSize <= 4)) return failure();
+    auto loc = op.getLoc();
+    VectorType vecType = VectorType::get(
+        vecSize, op.getResultType().cast<VectorType>().getElementType());
+    std::array<int64_t, 2> zero = {0, 0};
+    Value lhs = rewriter.create<vector::ExtractOp>(loc, op.lhs(), zero);
+    Value rhs, acc;
+    if (vecSize == 1) {
+      rhs = rewriter.create<vector::ExtractOp>(loc, op.rhs(), zero);
+      acc = rewriter.create<vector::ExtractOp>(loc, op.acc(), zero);
+    } else {
+      lhs = rewriter.create<vector::BroadcastOp>(loc, vecType, lhs);
+      rhs = rewriter.create<vector::ShapeCastOp>(loc, vecType, op.rhs());
+      acc = rewriter.create<vector::ShapeCastOp>(loc, vecType, op.acc());
+    }
+    Value newOp = rewriter.create<MulFOp>(loc, lhs, rhs);
+    newOp = rewriter.create<AddFOp>(loc, newOp, acc);
+    if (vecSize == 1)
+      newOp =
+          rewriter.create<vector::BroadcastOp>(loc, op.getResultType(), newOp);
+    else
+      newOp =
+          rewriter.create<vector::ShapeCastOp>(loc, op.getResultType(), newOp);
+    rewriter.replaceOp(op, newOp);
+    return success();
+  }
+};
+
+// Lower ExtractStridedSliceOp to an ExtractOp instruction that can be natively
+// converted to SPIR-V. Add a BroadcastOp to keep the type consistent, we expect
+// the Broadcast to be removed by canonicalization.
+class ExtractStridedLowering
+    : public OpRewritePattern<vector::ExtractStridedSliceOp> {
+ public:
+  using OpRewritePattern<vector::ExtractStridedSliceOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp op,
+                                PatternRewriter &rewriter) const override {
+    // Only handle cases extracting a degenerated vector so that we can generate
+    // an extractOp with scalar destination.
+    if (op.getResult().getType().cast<VectorType>().getNumElements() != 1)
+      return failure();
+    auto loc = op.getLoc();
+    SmallVector<int64_t, 4> offsets = llvm::to_vector<4>(
+        llvm::map_range(op.offsets().getAsRange<IntegerAttr>(),
+                        [](IntegerAttr attr) { return attr.getInt(); }));
+    offsets.resize(op.getVectorType().getRank(), 0);
+    Value newOp = rewriter.create<vector::ExtractOp>(loc, op.vector(), offsets);
+    newOp = rewriter.create<vector::BroadcastOp>(loc, op.getResult().getType(),
+                                                 newOp);
+    rewriter.replaceOp(op, newOp);
+    return success();
+  }
+};
+
+// Lower vector ops to instructions that can be later converted to SPIR-V.
+void ConvertVectorToGPUPass::lowerVectorOps(FuncOp funcOp,
+                                            MLIRContext *context) {
+  OwningRewritePatternList patterns;
+  patterns.insert<VectorContractLowering, VectorTransferReadToLoad,
+                  VectorTransferWriteToStore, ExtractStridedLowering>(context);
+  applyPatternsAndFoldGreedily(funcOp, patterns);
+}
+
 void ConvertVectorToGPUPass::runOnOperation() {
   MLIRContext *context = &getContext();
   FuncOp funcOp = getOperation();
   tileAndVectorizeLinalgCopy(funcOp, context);
 
+  lowerVectorOps(funcOp, context);
+
   auto &cooperativeMatrixAnalysis = getAnalysis<CooperativeMatrixAnalysis>();
   OwningRewritePatternList patterns;
   patterns.insert<UnaryAndBinaryOpPattern<AddFOp>, VectorTransferReadConversion,
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/vector_to_gpu.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/vector_to_gpu.mlir
index 30d94c5..99ec424 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/vector_to_gpu.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/vector_to_gpu.mlir
@@ -59,3 +59,75 @@
     // CHECK:   %[[LOAD:.+]] = vector.transfer_read %[[SVs]][%c0, %c0], %cst {{.*}} : memref<1x4xf32, {{.*}}>, vector<1x4xf32>
     // CHECK:   vector.transfer_write %[[LOAD]], %[[SVd]][%[[C0]], %[[C0]]] {{.*}} : vector<1x4xf32>, memref<1x4xf32
 }
+
+// -----
+
+module attributes {gpu.container_module, spv.target_env = #spv.target_env<#spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
+  func @transfer_ops(%arg0: memref<32x32xf32>, %arg1 : vector<1x1xf32>) -> vector<1x1xf32> attributes {spv.entry_point_abi = {local_size = dense<[128, 1, 1]> : vector<3xi32>}} {
+  %c0 = constant 0 : index
+  %cst = constant 0.0 : f32
+  %0 = vector.transfer_read %arg0[%c0, %c0], %cst : memref<32x32xf32>, vector<1x1xf32>
+  vector.transfer_write %arg1, %arg0[%c0, %c0] : vector<1x1xf32>, memref<32x32xf32>
+  return %0 : vector<1x1xf32>
+  }
+  // CHECK-LABEL: func @transfer_ops
+  //  CHECK-SAME: (%[[ARG0:.*]]: memref<32x32xf32>, %[[ARG1:.*]]: vector<1x1xf32>
+  //       CHECK:   %[[C0:.*]] = constant 0 : index
+  //       CHECK:   %[[LOAD:.*]] = load %[[ARG0]][%[[C0]], %[[C0]]] : memref<32x32xf32>
+  //       CHECK:   %[[B:.*]] = vector.broadcast %[[LOAD]] : f32 to vector<1x1xf32>
+  //       CHECK:   %[[EXT:.*]] = vector.extract %[[ARG1]][0, 0] : vector<1x1xf32>
+  //       CHECK:   store %[[EXT]], %[[ARG0]][%[[C0]], %[[C0]]] : memref<32x32xf32>
+  //       CHECK:   return %[[B]] : vector<1x1xf32>
+}
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+module attributes {gpu.container_module, spv.target_env = #spv.target_env<#spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
+  func @contract_ops(%arg0 : vector<1x1xf32>, %arg1 : vector<1x4xf32>,
+                    %arg2 : vector<1x4xf32>, %arg3 : vector<1x1xf32>,
+                    %arg4 : vector<1x1xf32>) -> (vector<1x1xf32>, vector<1x4xf32>) attributes {spv.entry_point_abi = {local_size = dense<[128, 1, 1]> : vector<3xi32>}} {
+  %0 = vector.contract {indexing_maps = [#map0, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"]} %arg0, %arg3, %arg4
+                      : vector<1x1xf32>, vector<1x1xf32> into vector<1x1xf32>
+  %1 = vector.contract {indexing_maps = [#map0, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"]} %arg0, %arg1, %arg2
+                      : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
+  return %0, %1 : vector<1x1xf32>, vector<1x4xf32>
+  }
+  // CHECK-LABEL: func @contract_ops
+  //  CHECK-SAME: (%[[ARG0:.*]]: vector<1x1xf32>, %[[ARG1:.*]]: vector<1x4xf32>, %[[ARG2:.*]]: vector<1x4xf32>, %[[ARG3:.*]]: vector<1x1xf32>, %[[ARG4:.*]]: vector<1x1xf32>)
+  //       CHECK:   %[[A:.*]] = vector.extract %[[ARG0]][0, 0] : vector<1x1xf32>
+  //       CHECK:   %[[B:.*]] = vector.extract %[[ARG3]][0, 0] : vector<1x1xf32>
+  //       CHECK:   %[[C:.*]] = vector.extract %[[ARG4]][0, 0] : vector<1x1xf32>
+  //       CHECK:   %[[MUL:.*]] = mulf %[[A]], %[[B]] : f32
+  //       CHECK:   %[[ADD:.*]] = addf %[[MUL]], %[[C]] : f32
+  //       CHECK:   %[[R0:.*]] = vector.broadcast %[[ADD]] : f32 to vector<1x1xf32>
+  //       CHECK:   %[[A:.*]] = vector.extract %[[ARG0]][0, 0] : vector<1x1xf32>
+  //       CHECK:   %[[VA:.*]] = vector.broadcast %[[A]] : f32 to vector<4xf32>
+  //       CHECK:   %[[VB:.*]] = vector.shape_cast %[[ARG1]] : vector<1x4xf32> to vector<4xf32>
+  //       CHECK:   %[[VC:.*]] = vector.shape_cast %[[ARG2]] : vector<1x4xf32> to vector<4xf32>
+  //       CHECK:   %[[VMUL:.*]] = mulf %[[VA]], %[[VB]] : vector<4xf32>
+  //       CHECK:   %[[VADD:.*]] = addf %[[VMUL]], %[[VC]] : vector<4xf32>
+  //       CHECK:   %[[R1:.*]] = vector.shape_cast %[[VADD]] : vector<4xf32> to vector<1x4xf32>
+  //       CHECK:   return %[[R0]], %[[R1]] : vector<1x1xf32>, vector<1x4xf32>
+}
+
+// -----
+
+module attributes {gpu.container_module, spv.target_env = #spv.target_env<#spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
+  func @extract(%arg0 : vector<1x4xf32>) -> vector<1x1xf32> attributes {spv.entry_point_abi = {local_size = dense<[128, 1, 1]> : vector<3xi32>}} {
+    %0 = vector.extract_strided_slice %arg0
+      {offsets = [0, 2], sizes = [1, 1], strides = [1, 1]}
+        : vector<1x4xf32> to vector<1x1xf32>
+    return %0 : vector<1x1xf32>
+  }
+  // CHECK-LABEL: func @extract
+  //  CHECK-SAME: (%[[ARG0:.*]]: vector<1x4xf32>
+  //       CHECK:   %[[A:.*]] = vector.extract %[[ARG0]][0, 2] : vector<1x4xf32>
+  //       CHECK:   %[[B:.*]] = vector.broadcast %[[A]] : f32 to vector<1x1xf32>
+  //       CHECK:   return %[[B]] : vector<1x1xf32>
+}