blob: aa3ac013eead2bc7ebf7761929b1f20ac99fa5c4 [file] [log] [blame]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Dialect/Stream/Transforms/Passes.h"
#include <memory>
#include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
#include "iree/compiler/Dialect/Util/Transforms/Passes.h"
#include "mlir/Pass/PassOptions.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/Passes.h"
namespace mlir {
namespace iree_compiler {
namespace IREE {
namespace Stream {
//===----------------------------------------------------------------------===//
// Utilities
//===----------------------------------------------------------------------===//
static void addCleanupPatterns(OpPassManager &passManager) {
// Standard MLIR cleanup.
passManager.addPass(mlir::createCanonicalizerPass());
passManager.addPass(mlir::createCSEPass());
// Cleanup and canonicalization of util.global (and other util ops).
passManager.addPass(IREE::Util::createApplyPatternsPass());
passManager.addPass(IREE::Util::createFoldGlobalsPass());
passManager.addPass(IREE::Util::createFuseGlobalsPass());
// Simplify util.global accesses; this can help with data flow tracking as
// redundant store-loads are removed.
passManager.addNestedPass<IREE::Util::InitializerOp>(
IREE::Util::createSimplifyGlobalAccessesPass());
passManager.addNestedPass<mlir::FuncOp>(
IREE::Util::createSimplifyGlobalAccessesPass());
}
//===----------------------------------------------------------------------===//
// -iree-stream-tensor-transformation-pipeline
//===----------------------------------------------------------------------===//
void buildStreamTensorPassPipeline(OpPassManager &passManager,
const TransformOptions &transformOptions) {
//----------------------------------------------------------------------------
// Input cleanup and simplification
//----------------------------------------------------------------------------
// Verify we support the program.
passManager.addPass(IREE::Stream::createVerifyInputPass());
// Turn all constant ops into global variables and fix up the IR.
// As many locations change and constants are deduplicated we'll end up with
// a lot of extraneous IR (mostly global loads) and clean those up here.
passManager.addPass(IREE::Stream::createOutlineConstantsPass());
// Perform cleanup after constnat simplification as more canonicalizers may be
// able to kick in.
addCleanupPatterns(passManager);
//----------------------------------------------------------------------------
// Conversion
//----------------------------------------------------------------------------
// Converts from all input dialects into various levels of the stream dialect.
// Tensor-like things go to stream.tensor.* ops while lower level buffer-like
// things will go to stream.async.* ops.
passManager.addPass(IREE::Stream::createConvertToStreamPass());
// No more tensor.*/etc ops are allowed. This is conservative - there may be
// a lot of ops we convert but this will catch the majority of stragglers.
passManager.addPass(IREE::Stream::createVerifyLoweringToTensorsPass());
//----------------------------------------------------------------------------
// Constant/variable optimization
//----------------------------------------------------------------------------
// Cleanup globals that were created during conversion.
addCleanupPatterns(passManager);
// Bring all initializers together so that we can schedule them.
passManager.addPass(IREE::Util::createCombineInitializersPass());
//----------------------------------------------------------------------------
// Stream affinity/assignment
//----------------------------------------------------------------------------
// TODO(benvanik): pin based on target backends here.
// TODO(benvanik): compute affinities for executables.
// TODO(benvanik): annotate all dispatches with preferred executable affinity.
// TODO(benvanik): DFA to specify all value affinities and pin dispatches.
}
//===----------------------------------------------------------------------===//
// -iree-stream-async-transformation-pipeline
//===----------------------------------------------------------------------===//
void buildStreamAsyncPassPipeline(OpPassManager &passManager,
const TransformOptions &transformOptions) {
//----------------------------------------------------------------------------
// Tensor lowering and resource management
//----------------------------------------------------------------------------
// Lower stream.tensor.* ops to stream.async.* ops based on
// affinity/configuration assigned during placement.
passManager.addNestedPass<IREE::Util::InitializerOp>(
IREE::Stream::createEncodeTensorsPass());
passManager.addNestedPass<mlir::FuncOp>(
IREE::Stream::createEncodeTensorsPass());
addCleanupPatterns(passManager);
// Materialize copy-on-write behavior with explicit stream.async.* ops.
// This will insert a lot of copies, so follow it up with a pass that elides
// ones that aren't needed. This is easier to verify than if there was one
// pass attempting to do both. Note that copy-on-write materialization is
// required for correct execution while copy elision is for performance only
// (though it's critical enough that it is not optional).
passManager.addNestedPass<IREE::Util::InitializerOp>(
IREE::Stream::createMaterializeCopyOnWritePass());
passManager.addNestedPass<mlir::FuncOp>(
IREE::Stream::createMaterializeCopyOnWritePass());
passManager.addPass(IREE::Stream::createElideAsyncCopiesPass());
// Refine lifetime of all resources across the module.
// We do this after scheduling execution so that we know how the resources
// move across devices. We do it before scheduling waves as lifetime doesn't
// change and it makes the IR cleaner.
passManager.addPass(IREE::Stream::createRefineUsagePass());
//----------------------------------------------------------------------------
// Stream formation and scheduling
//----------------------------------------------------------------------------
// Combine async work into execution regions.
passManager.addNestedPass<IREE::Util::InitializerOp>(
IREE::Stream::createScheduleExecutionPass());
passManager.addNestedPass<mlir::FuncOp>(
IREE::Stream::createScheduleExecutionPass());
// Group concurrently executable work into waves.
passManager.addNestedPass<IREE::Util::InitializerOp>(
IREE::Stream::createScheduleConcurrencyPass());
passManager.addNestedPass<mlir::FuncOp>(
IREE::Stream::createScheduleConcurrencyPass());
}
//===----------------------------------------------------------------------===//
// -iree-stream-cmd-transformation-pipeline
//===----------------------------------------------------------------------===//
void buildStreamCmdPassPipeline(OpPassManager &passManager,
const TransformOptions &transformOptions) {}
//===----------------------------------------------------------------------===//
// -iree-stream-optimization-pipeline
//===----------------------------------------------------------------------===//
void buildStreamOptimizationPassPipeline(
OpPassManager &passManager, const TransformOptions &transformOptions) {}
//===----------------------------------------------------------------------===//
// -iree-stream-transformation-pipeline
//===----------------------------------------------------------------------===//
void buildStreamTransformPassPipeline(
OpPassManager &passManager, const TransformOptions &transformOptions) {
//----------------------------------------------------------------------------
// Primary pipeline stages (required)
//----------------------------------------------------------------------------
buildStreamTensorPassPipeline(passManager, transformOptions);
buildStreamAsyncPassPipeline(passManager, transformOptions);
buildStreamCmdPassPipeline(passManager, transformOptions);
//----------------------------------------------------------------------------
// Optimizations (may be required by some targets)
//----------------------------------------------------------------------------
buildStreamOptimizationPassPipeline(passManager, transformOptions);
//----------------------------------------------------------------------------
// Post-pipeline cleanup
//----------------------------------------------------------------------------
// Forming streams involves a fair amount of subgraph stitching, which can
// cause duplication. Run CSE to collapse.
addCleanupPatterns(passManager);
// Symbol DCE any remaining variables/functions that are now no longer
// required.
passManager.addPass(mlir::createSymbolDCEPass());
}
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
void registerStreamTransformPassPipelines() {
PassPipelineRegistration<TransformOptions> tensorPassPipeline(
"iree-stream-tensor-transformation-pipeline",
"Lowers source dialects into stream.tensor.* IR.",
[](OpPassManager &passManager, const TransformOptions &transformOptions) {
buildStreamTensorPassPipeline(passManager, transformOptions);
});
PassPipelineRegistration<TransformOptions> asyncPassPipeline(
"iree-stream-async-transformation-pipeline",
"Lowers stream.tensor.* to stream.async.* IR.",
[](OpPassManager &passManager, const TransformOptions &transformOptions) {
buildStreamAsyncPassPipeline(passManager, transformOptions);
});
PassPipelineRegistration<TransformOptions> cmdPassPipeline(
"iree-stream-cmd-transformation-pipeline",
"Lowers stream.async.* to stream.cmd.* IR.",
[](OpPassManager &passManager, const TransformOptions &transformOptions) {
buildStreamCmdPassPipeline(passManager, transformOptions);
});
PassPipelineRegistration<TransformOptions> optimizationPassPipeline(
"iree-stream-optimization-pipeline",
"Optimizes stream commands and resources (may be required for some "
"targets).",
[](OpPassManager &passManager, const TransformOptions &transformOptions) {
buildStreamOptimizationPassPipeline(passManager, transformOptions);
});
PassPipelineRegistration<TransformOptions> transformPassPipeline(
"iree-stream-transformation-pipeline",
"Runs the full IREE stream dialect transformation pipeline.",
[](OpPassManager &passManager, const TransformOptions &transformOptions) {
buildStreamTransformPassPipeline(passManager, transformOptions);
});
}
namespace {
#define GEN_PASS_REGISTRATION
#include "iree/compiler/Dialect/Stream/Transforms/Passes.h.inc" // IWYU pragma: export
} // namespace
void registerStreamPasses() {
// Generated.
registerPasses();
// Pipelines.
registerStreamTransformPassPipelines();
}
} // namespace Stream
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir