Add patterns to vectorToGPU pass (#3587)
Add patterns to lower some vector operations to a mix of standard
operations and vector operations that have a direct mapping to SPIR-V.
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp
index 7721f41..eb2a32e 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp
@@ -64,6 +64,7 @@
private:
void tileAndVectorizeLinalgCopy(FuncOp funcOp, MLIRContext *context);
+ void lowerVectorOps(FuncOp funcOp, MLIRContext *context);
};
// Common class for all vector to GPU patterns.
@@ -149,8 +150,7 @@
loc, rewriter.getIndexType(), rewriter.getStringAttr("x"));
Value index = rewriter.create<AddIOp>(loc, ThreadIndex, indices.back());
indices.back() = index;
- rewriter.create<StoreOp>(op.getLoc(), operands[0], operands[1], indices);
- rewriter.eraseOp(op);
+ rewriter.replaceOpWithNewOp<StoreOp>(op, operands[0], operands[1], indices);
return success();
}
};
@@ -204,11 +204,141 @@
applyPatternsAndFoldGreedily(funcOp, vectorizationPatterns);
}
+// Convert vector transfer_read to a load if possible. This is the case only if
+// the element type of the memref matches the element type we want to load.
+class VectorTransferReadToLoad
+ : public OpRewritePattern<vector::TransferReadOp> {
+ public:
+ using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransferReadOp op,
+ PatternRewriter &rewriter) const override {
+ if (op.getVectorType().getNumElements() != 1 ||
+ op.getMemRefType().getElementType() !=
+ op.getVectorType().getElementType()) {
+ return failure();
+ }
+ auto loc = op.getLoc();
+ Value newOp = rewriter.create<LoadOp>(loc, op.memref(), op.indices());
+ newOp =
+ rewriter.create<vector::BroadcastOp>(loc, op.getVectorType(), newOp);
+ rewriter.replaceOp(op, newOp);
+ return success();
+ }
+};
+
+// Convert vector transfer_write to a store if possible. This is the case only
+// if the element type of the memref matches the element type we want to store.
+class VectorTransferWriteToStore
+ : public OpRewritePattern<vector::TransferWriteOp> {
+ public:
+ using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransferWriteOp op,
+ PatternRewriter &rewriter) const override {
+ if (op.getVectorType().getNumElements() != 1 ||
+ op.getMemRefType().getElementType() !=
+ op.getVectorType().getElementType()) {
+ return failure();
+ }
+ auto loc = op.getLoc();
+ SmallVector<int64_t, 2> zero(op.getVectorType().getRank(), 0);
+ Value scalarValue =
+ rewriter.create<vector::ExtractOp>(loc, op.vector(), zero);
+ rewriter.create<StoreOp>(loc, scalarValue, op.memref(), op.indices());
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
+// Lower vector contract to a single scalar or vector mulf+addf. Insert casts to
+// convert from 2D vector to 1D vector or scalar.
+class VectorContractLowering : public OpRewritePattern<vector::ContractionOp> {
+ public:
+ using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ContractionOp op,
+ PatternRewriter &rewriter) const override {
+ auto iteratorTypes = op.iterator_types().getValue();
+ if (iteratorTypes.size() != 3 || !isParallelIterator(iteratorTypes[0]) ||
+ !isParallelIterator(iteratorTypes[1]) ||
+ !isReductionIterator(iteratorTypes[2]) ||
+ !isRowMajorMatmul(op.indexing_maps())) {
+ return failure();
+ }
+ if (op.getLhsType().getNumElements() != 1) return failure();
+ unsigned vecSize = op.getAccType().cast<VectorType>().getNumElements();
+ if (!(vecSize >= 1 && vecSize <= 4)) return failure();
+ auto loc = op.getLoc();
+ VectorType vecType = VectorType::get(
+ vecSize, op.getResultType().cast<VectorType>().getElementType());
+ std::array<int64_t, 2> zero = {0, 0};
+ Value lhs = rewriter.create<vector::ExtractOp>(loc, op.lhs(), zero);
+ Value rhs, acc;
+ if (vecSize == 1) {
+ rhs = rewriter.create<vector::ExtractOp>(loc, op.rhs(), zero);
+ acc = rewriter.create<vector::ExtractOp>(loc, op.acc(), zero);
+ } else {
+ lhs = rewriter.create<vector::BroadcastOp>(loc, vecType, lhs);
+ rhs = rewriter.create<vector::ShapeCastOp>(loc, vecType, op.rhs());
+ acc = rewriter.create<vector::ShapeCastOp>(loc, vecType, op.acc());
+ }
+ Value newOp = rewriter.create<MulFOp>(loc, lhs, rhs);
+ newOp = rewriter.create<AddFOp>(loc, newOp, acc);
+ if (vecSize == 1)
+ newOp =
+ rewriter.create<vector::BroadcastOp>(loc, op.getResultType(), newOp);
+ else
+ newOp =
+ rewriter.create<vector::ShapeCastOp>(loc, op.getResultType(), newOp);
+ rewriter.replaceOp(op, newOp);
+ return success();
+ }
+};
+
+// Lower ExtractStridedSliceOp to an ExtractOp instruction that can be natively
+// converted to SPIR-V. Add a BroadcastOp to keep the type consistent, we expect
+// the Broadcast to be removed by canonicalization.
+class ExtractStridedLowering
+ : public OpRewritePattern<vector::ExtractStridedSliceOp> {
+ public:
+ using OpRewritePattern<vector::ExtractStridedSliceOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp op,
+ PatternRewriter &rewriter) const override {
+ // Only handle cases extracting a degenerated vector so that we can generate
+ // an extractOp with scalar destination.
+ if (op.getResult().getType().cast<VectorType>().getNumElements() != 1)
+ return failure();
+ auto loc = op.getLoc();
+ SmallVector<int64_t, 4> offsets = llvm::to_vector<4>(
+ llvm::map_range(op.offsets().getAsRange<IntegerAttr>(),
+ [](IntegerAttr attr) { return attr.getInt(); }));
+ offsets.resize(op.getVectorType().getRank(), 0);
+ Value newOp = rewriter.create<vector::ExtractOp>(loc, op.vector(), offsets);
+ newOp = rewriter.create<vector::BroadcastOp>(loc, op.getResult().getType(),
+ newOp);
+ rewriter.replaceOp(op, newOp);
+ return success();
+ }
+};
+
+// Lower vector ops to instructions that can be later converted to SPIR-V.
+void ConvertVectorToGPUPass::lowerVectorOps(FuncOp funcOp,
+ MLIRContext *context) {
+ OwningRewritePatternList patterns;
+ patterns.insert<VectorContractLowering, VectorTransferReadToLoad,
+ VectorTransferWriteToStore, ExtractStridedLowering>(context);
+ applyPatternsAndFoldGreedily(funcOp, patterns);
+}
+
void ConvertVectorToGPUPass::runOnOperation() {
MLIRContext *context = &getContext();
FuncOp funcOp = getOperation();
tileAndVectorizeLinalgCopy(funcOp, context);
+ lowerVectorOps(funcOp, context);
+
auto &cooperativeMatrixAnalysis = getAnalysis<CooperativeMatrixAnalysis>();
OwningRewritePatternList patterns;
patterns.insert<UnaryAndBinaryOpPattern<AddFOp>, VectorTransferReadConversion,
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/vector_to_gpu.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/vector_to_gpu.mlir
index 30d94c5..99ec424 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/vector_to_gpu.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/vector_to_gpu.mlir
@@ -59,3 +59,75 @@
// CHECK: %[[LOAD:.+]] = vector.transfer_read %[[SVs]][%c0, %c0], %cst {{.*}} : memref<1x4xf32, {{.*}}>, vector<1x4xf32>
// CHECK: vector.transfer_write %[[LOAD]], %[[SVd]][%[[C0]], %[[C0]]] {{.*}} : vector<1x4xf32>, memref<1x4xf32
}
+
+// -----
+
+module attributes {gpu.container_module, spv.target_env = #spv.target_env<#spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
+ func @transfer_ops(%arg0: memref<32x32xf32>, %arg1 : vector<1x1xf32>) -> vector<1x1xf32> attributes {spv.entry_point_abi = {local_size = dense<[128, 1, 1]> : vector<3xi32>}} {
+ %c0 = constant 0 : index
+ %cst = constant 0.0 : f32
+ %0 = vector.transfer_read %arg0[%c0, %c0], %cst : memref<32x32xf32>, vector<1x1xf32>
+ vector.transfer_write %arg1, %arg0[%c0, %c0] : vector<1x1xf32>, memref<32x32xf32>
+ return %0 : vector<1x1xf32>
+ }
+ // CHECK-LABEL: func @transfer_ops
+ // CHECK-SAME: (%[[ARG0:.*]]: memref<32x32xf32>, %[[ARG1:.*]]: vector<1x1xf32>
+ // CHECK: %[[C0:.*]] = constant 0 : index
+ // CHECK: %[[LOAD:.*]] = load %[[ARG0]][%[[C0]], %[[C0]]] : memref<32x32xf32>
+ // CHECK: %[[B:.*]] = vector.broadcast %[[LOAD]] : f32 to vector<1x1xf32>
+ // CHECK: %[[EXT:.*]] = vector.extract %[[ARG1]][0, 0] : vector<1x1xf32>
+ // CHECK: store %[[EXT]], %[[ARG0]][%[[C0]], %[[C0]]] : memref<32x32xf32>
+ // CHECK: return %[[B]] : vector<1x1xf32>
+}
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+module attributes {gpu.container_module, spv.target_env = #spv.target_env<#spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
+ func @contract_ops(%arg0 : vector<1x1xf32>, %arg1 : vector<1x4xf32>,
+ %arg2 : vector<1x4xf32>, %arg3 : vector<1x1xf32>,
+ %arg4 : vector<1x1xf32>) -> (vector<1x1xf32>, vector<1x4xf32>) attributes {spv.entry_point_abi = {local_size = dense<[128, 1, 1]> : vector<3xi32>}} {
+ %0 = vector.contract {indexing_maps = [#map0, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"]} %arg0, %arg3, %arg4
+ : vector<1x1xf32>, vector<1x1xf32> into vector<1x1xf32>
+ %1 = vector.contract {indexing_maps = [#map0, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"]} %arg0, %arg1, %arg2
+ : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
+ return %0, %1 : vector<1x1xf32>, vector<1x4xf32>
+ }
+ // CHECK-LABEL: func @contract_ops
+ // CHECK-SAME: (%[[ARG0:.*]]: vector<1x1xf32>, %[[ARG1:.*]]: vector<1x4xf32>, %[[ARG2:.*]]: vector<1x4xf32>, %[[ARG3:.*]]: vector<1x1xf32>, %[[ARG4:.*]]: vector<1x1xf32>)
+ // CHECK: %[[A:.*]] = vector.extract %[[ARG0]][0, 0] : vector<1x1xf32>
+ // CHECK: %[[B:.*]] = vector.extract %[[ARG3]][0, 0] : vector<1x1xf32>
+ // CHECK: %[[C:.*]] = vector.extract %[[ARG4]][0, 0] : vector<1x1xf32>
+ // CHECK: %[[MUL:.*]] = mulf %[[A]], %[[B]] : f32
+ // CHECK: %[[ADD:.*]] = addf %[[MUL]], %[[C]] : f32
+ // CHECK: %[[R0:.*]] = vector.broadcast %[[ADD]] : f32 to vector<1x1xf32>
+ // CHECK: %[[A:.*]] = vector.extract %[[ARG0]][0, 0] : vector<1x1xf32>
+ // CHECK: %[[VA:.*]] = vector.broadcast %[[A]] : f32 to vector<4xf32>
+ // CHECK: %[[VB:.*]] = vector.shape_cast %[[ARG1]] : vector<1x4xf32> to vector<4xf32>
+ // CHECK: %[[VC:.*]] = vector.shape_cast %[[ARG2]] : vector<1x4xf32> to vector<4xf32>
+ // CHECK: %[[VMUL:.*]] = mulf %[[VA]], %[[VB]] : vector<4xf32>
+ // CHECK: %[[VADD:.*]] = addf %[[VMUL]], %[[VC]] : vector<4xf32>
+ // CHECK: %[[R1:.*]] = vector.shape_cast %[[VADD]] : vector<4xf32> to vector<1x4xf32>
+ // CHECK: return %[[R0]], %[[R1]] : vector<1x1xf32>, vector<1x4xf32>
+}
+
+// -----
+
+module attributes {gpu.container_module, spv.target_env = #spv.target_env<#spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
+ func @extract(%arg0 : vector<1x4xf32>) -> vector<1x1xf32> attributes {spv.entry_point_abi = {local_size = dense<[128, 1, 1]> : vector<3xi32>}} {
+ %0 = vector.extract_strided_slice %arg0
+ {offsets = [0, 2], sizes = [1, 1], strides = [1, 1]}
+ : vector<1x4xf32> to vector<1x1xf32>
+ return %0 : vector<1x1xf32>
+ }
+ // CHECK-LABEL: func @extract
+ // CHECK-SAME: (%[[ARG0:.*]]: vector<1x4xf32>
+ // CHECK: %[[A:.*]] = vector.extract %[[ARG0]][0, 2] : vector<1x4xf32>
+ // CHECK: %[[B:.*]] = vector.broadcast %[[A]] : f32 to vector<1x1xf32>
+ // CHECK: return %[[B]] : vector<1x1xf32>
+}