[CUDA] Re-organize the set of vector transformation to improve codegen (#6628)
Improve the order of transformation and enable transfer ops
optimizations.
diff --git a/iree/compiler/Codegen/LLVMGPU/BUILD b/iree/compiler/Codegen/LLVMGPU/BUILD
index c2bddc1..fd00b70 100644
--- a/iree/compiler/Codegen/LLVMGPU/BUILD
+++ b/iree/compiler/Codegen/LLVMGPU/BUILD
@@ -20,6 +20,7 @@
"LLVMGPULowerExecutableTarget.cpp",
"LLVMGPURemoveTrivialLoops.cpp",
"LLVMGPUTileAndDistribute.cpp",
+ "LLVMGPUVectorLowering.cpp",
"LLVMGPUVectorization.cpp",
"Passes.cpp",
],
diff --git a/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt b/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
index 074b92c..20ae473 100644
--- a/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
+++ b/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
@@ -24,6 +24,7 @@
"LLVMGPULowerExecutableTarget.cpp"
"LLVMGPURemoveTrivialLoops.cpp"
"LLVMGPUTileAndDistribute.cpp"
+ "LLVMGPUVectorLowering.cpp"
"LLVMGPUVectorization.cpp"
"Passes.cpp"
DEPS
diff --git a/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorLowering.cpp b/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorLowering.cpp
new file mode 100644
index 0000000..aac5f1c
--- /dev/null
+++ b/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorLowering.cpp
@@ -0,0 +1,48 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/PassDetail.h"
+#include "iree/compiler/Codegen/Passes.h"
+#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/Vector/VectorTransforms.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/Passes.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+//====---------------------------------------------------------------------===//
+// Patterns for late vector op lowering.
+//====---------------------------------------------------------------------===//
+
+namespace {
+struct LLVMGPUVectorLoweringPass
+ : public LLVMGPUVectorLoweringBase<LLVMGPUVectorLoweringPass> {
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<vector::VectorDialect>();
+ }
+ void runOnOperation() override {
+ FuncOp funcOp = getOperation();
+ RewritePatternSet vectorToLoopsPatterns(&getContext());
+ VectorTransferToSCFOptions vectorToSCFOptions;
+ vectorToSCFOptions.setUnroll(true);
+ populateVectorToSCFConversionPatterns(vectorToLoopsPatterns,
+ vectorToSCFOptions);
+ memref::populateFoldSubViewOpPatterns(vectorToLoopsPatterns);
+ vector::populateVectorTransferLoweringPatterns(vectorToLoopsPatterns);
+ (void)applyPatternsAndFoldGreedily(funcOp,
+ std::move(vectorToLoopsPatterns));
+ }
+};
+} // namespace
+
+std::unique_ptr<OperationPass<FuncOp>> createLLVMGPUVectorLoweringPass() {
+ return std::make_unique<LLVMGPUVectorLoweringPass>();
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorization.cpp b/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorization.cpp
index 7bef525..cdf55d6 100644
--- a/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorization.cpp
+++ b/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorization.cpp
@@ -99,17 +99,22 @@
populateVectorizationPatterns(vectorizationPatterns);
(void)applyPatternsAndFoldGreedily(funcOp,
std::move(vectorizationPatterns));
- }
- // TODO: This should be a folding of Add into Contract in core but while
- // they live in different dialects, it is not possible without unnatural
- // dependencies.
- funcOp.walk([&](Operation *op) {
- if (auto contract = canonicalizeContractionAdd(op))
- op->replaceAllUsesWith(contract);
- });
+ // TODO: This should be a folding of Add into Contract in core but while
+ // they live in different dialects, it is not possible without unnatural
+ // dependencies.
+ funcOp.walk([&](Operation *op) {
+ if (auto contract = canonicalizeContractionAdd(op))
+ op->replaceAllUsesWith(contract);
+ });
+ RewritePatternSet vectorUnrollPatterns(context);
+ populateVectorUnrollPatterns(vectorUnrollPatterns);
+ (void)applyPatternsAndFoldGreedily(funcOp,
+ std::move(vectorUnrollPatterns));
+ linalg::hoistRedundantVectorTransfers(funcOp);
+ }
{
- // Lower transfer op to canonical form.
+ // Step 2. Lower transfer op to canonical form.
RewritePatternSet lowerTransferOpPatterns(funcOp.getContext());
vector::populateVectorToVectorCanonicalizationPatterns(
lowerTransferOpPatterns);
@@ -126,24 +131,21 @@
(void)applyPatternsAndFoldGreedily(funcOp,
std::move(vectorUnrollPatterns));
- RewritePatternSet canonicalizationPatterns1(funcOp.getContext());
+ RewritePatternSet canonicalizationPatterns(funcOp.getContext());
vector::populateVectorToVectorCanonicalizationPatterns(
- canonicalizationPatterns1);
+ canonicalizationPatterns);
(void)applyPatternsAndFoldGreedily(funcOp,
- std::move(canonicalizationPatterns1));
-
- linalg::hoistRedundantVectorTransfers(funcOp);
+ std::move(canonicalizationPatterns));
}
{
- // Step 3. Canonicalize the transfer ops generated.
- RewritePatternSet vectorToLoopsPatterns(context);
- VectorTransferToSCFOptions vectorToSCFOptions;
- vectorToSCFOptions.setUnroll(true);
- populateVectorToSCFConversionPatterns(vectorToLoopsPatterns,
- vectorToSCFOptions);
- memref::populateFoldSubViewOpPatterns(vectorToLoopsPatterns);
+ // Step 3. Lower contract op to outer product.
+ RewritePatternSet contractLoweringPatterns(funcOp.getContext());
+ vector::populateVectorContractLoweringPatterns(
+ contractLoweringPatterns,
+ vector::VectorTransformsOptions().setVectorTransformsOptions(
+ vector::VectorContractLowering::OuterProduct));
(void)applyPatternsAndFoldGreedily(funcOp,
- std::move(vectorToLoopsPatterns));
+ std::move(contractLoweringPatterns));
}
}
};
diff --git a/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index ec13ad3..c64e6fa 100644
--- a/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -49,6 +49,7 @@
pm.addNestedPass<FuncOp>(createLLVMGPUVectorizationPass());
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
pm.addNestedPass<FuncOp>(createCSEPass());
+ pm.addNestedPass<FuncOp>(createOptimizeVectorTransferPass());
}
void addGPUSimpleDistributePassPipeline(OpPassManager &pm) {
@@ -91,6 +92,7 @@
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
pm.addNestedPass<FuncOp>(createCSEPass());
+ pm.addNestedPass<FuncOp>(createLLVMGPUVectorLoweringPass());
pm.addPass(createLowerAffinePass());
// Strip out the debug info for the kernel as CUDA driver doesn't diggest PTX
diff --git a/iree/compiler/Codegen/LLVMGPU/test/nvvm_pipeline_test.mlir b/iree/compiler/Codegen/LLVMGPU/test/nvvm_pipeline_test.mlir
index 4e2de4f..73bd058 100644
--- a/iree/compiler/Codegen/LLVMGPU/test/nvvm_pipeline_test.mlir
+++ b/iree/compiler/Codegen/LLVMGPU/test/nvvm_pipeline_test.mlir
@@ -101,6 +101,7 @@
// CHECK-LABEL: hal.executable @dot_dispatch_0
// CHECK: hal.executable.variant @cuda
+// CHECK-NOT: llvm.store
// CHECK-COUNT-2: llvm.load {{.*}} : !llvm.ptr<vector<4xf32>>
// CHECK: llvm.br
// CHECK-COUNT-6: llvm.load {{.*}} : !llvm.ptr<vector<4xf32>, 3>
diff --git a/iree/compiler/Codegen/LLVMGPU/test/vectorization.mlir b/iree/compiler/Codegen/LLVMGPU/test/vectorization.mlir
index f0574b6..480d0c4 100644
--- a/iree/compiler/Codegen/LLVMGPU/test/vectorization.mlir
+++ b/iree/compiler/Codegen/LLVMGPU/test/vectorization.mlir
@@ -50,7 +50,13 @@
return
}
// CHECK-LABEL: func @add_dispatch_0()
-// CHECK: vector.transfer_read {{.*}} : memref<1024x1024x1024xf32>, vector<4xf32>
-// CHECK: vector.transfer_read {{.*}} : memref<1024x1024x1024xf32>, vector<4xf32>
-// CHECK: addf %{{.*}}, %{{.*}} : vector<1x1x4xf32>
-// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, memref<1024x1024x1024xf32>
+// CHECK: scf.for
+// CHECK: scf.for
+// CHECK: scf.for
+// CHECK: scf.for
+// CHECK: scf.for
+// CHECK: scf.for
+// CHECK: vector.transfer_read {{.*}} : memref<1x1x4xf32, {{.*}}>, vector<1x1x4xf32>
+// CHECK: vector.transfer_read {{.*}} : memref<1x1x4xf32, {{.*}}>, vector<1x1x4xf32>
+// CHECK: addf %{{.*}}, %{{.*}} : vector<1x1x4xf32>
+// CHECK: vector.transfer_write {{.*}} : vector<1x1x4xf32>, memref<1x1x4xf32, {{.*}}>
diff --git a/iree/compiler/Codegen/Passes.h b/iree/compiler/Codegen/Passes.h
index fc69b16..6d9ae4b 100644
--- a/iree/compiler/Codegen/Passes.h
+++ b/iree/compiler/Codegen/Passes.h
@@ -233,6 +233,9 @@
/// Convert Linalg ops to Vector.
std::unique_ptr<OperationPass<FuncOp>> createLLVMGPUVectorizationPass();
+/// Lower vector ops before convertion to LLVM.
+std::unique_ptr<OperationPass<FuncOp>> createLLVMGPUVectorLoweringPass();
+
//------------------------------------------------------------------------------
// SPIRV Passes
//------------------------------------------------------------------------------
diff --git a/iree/compiler/Codegen/Passes.td b/iree/compiler/Codegen/Passes.td
index 3716f0e..328515c 100644
--- a/iree/compiler/Codegen/Passes.td
+++ b/iree/compiler/Codegen/Passes.td
@@ -187,6 +187,12 @@
let constructor = "mlir::iree_compiler::createLLVMGPUVectorizationPass()";
}
+def LLVMGPUVectorLowering :
+ Pass<"iree-llvmgpu-vector-lowering", "FuncOp"> {
+ let summary = "Pass to lower Vector ops before conversion to LLVM.";
+ let constructor = "mlir::iree_compiler::createLLVMGPUVectorLoweringPass()";
+}
+
//------------------------------------------------------------------------------
// SPIRV
//------------------------------------------------------------------------------