| // Copyright 2020 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| //===- Transforms.h - Transformations common to all backends --------------===// |
| // |
| // Defines transformations that are common to backends |
| // |
| //===----------------------------------------------------------------------===// |
| #ifndef IREE_COMPILER_CODEGEN_COMMON_TRANSFORMS_H_ |
| #define IREE_COMPILER_CODEGEN_COMMON_TRANSFORMS_H_ |
| |
| #include "mlir/Dialect/Linalg/IR/LinalgOps.h" |
| #include "mlir/Dialect/Linalg/Transforms/Transforms.h" |
| #include "mlir/Dialect/Vector/VectorOps.h" |
| #include "mlir/IR/Operation.h" |
| #include "mlir/Pass/Pass.h" |
| |
| namespace mlir { |
| namespace iree_compiler { |
| |
| /// Specifies the number of workgroups to use for a particular entry point |
| /// function, by updating the `worgroup_count` region in the |
| /// `hal.executable.entry_point` op for this operation. The method takes a |
| /// callback function, which computes the workgroup count (x,y,z) given the |
| /// workload along (x,y,z). |
| using WorkgroupCountRegionBuilder = std::function<std::array<Value, 3>( |
| OpBuilder &b, Location loc, std::array<Value, 3> workload)>; |
| LogicalResult defineWorkgroupCountRegion( |
| OpBuilder &builder, FuncOp funcOp, |
| WorkgroupCountRegionBuilder regionBuilder); |
| |
| /// Using linalg on tensors for dispatch region creation does first-level of |
| /// tile (fuse and distribute) during dispatch region formation. At that point |
| /// the workload per workgroup is set to the dynamic value represented by |
| /// `flow.dispatch.workgroup.size` and is later lowered to |
| /// `hal.dispatch.workgroup.size`. This method is to materialize the static |
| /// information of the workload per workgroup determined based on target |
| /// architecture. Note that the value of hal.dispatch.workgroup.size is now |
| /// different after this function is called and represents the actual value used |
| /// at runtime. |
| LogicalResult materializeStaticLaunchInformation( |
| FuncOp funcOp, ArrayRef<int64_t> workloadPerWorkgroup); |
| |
| /// Return a fused vector::ContractionOp which represents a patterns such as: |
| /// |
| /// ```mlir |
| /// %c0 = vector.constant 0: ... |
| /// %c = vector.contract %a, %b, %c0: ... |
| /// %e = add %c, %d: ... |
| /// ``` |
| /// |
| /// by: |
| /// |
| /// ```mlir |
| /// %e = vector.contract %a, %b, %d: ... |
| /// ``` |
| /// |
| /// Return null if the canonicalization does not apply. |
| // TODO: This should be a folding of Add into Contract in core but while they |
| // live in different dialects, it is not possible without unnatural |
| // dependencies. |
| vector::ContractionOp canonicalizeContractionAdd(Operation *op); |
| |
| /// Insert patterns to perform folding of AffineMinOp by matching the pattern |
| /// generated by tile and distribute. Try to fold a affine.min op by matching |
| /// the following form: |
| /// ``` |
| /// scf.for %iv = %lb to %ub step %step |
| /// %affine.min affine_map<(d0, d1) -> (N, d0 - d1)>(%ub, %iv) |
| /// ``` |
| /// With N a compile time constant. This operations can be replace by |
| /// `%cN = constant N : index` if we can prove that %lb, %step and %ub are |
| /// divisible by N. |
| void populateAffineMinSCFCanonicalizationPattern(RewritePatternSet &patterns); |
| |
| using GetMinMaxExprFn = |
| std::function<Optional<std::pair<AffineExpr, AffineExpr>>( |
| Value value, SmallVectorImpl<Value> &dims, |
| SmallVectorImpl<Value> &symbols)>; |
| |
| /// Insert pattern to remove single iteration loop. The pattern will detect |
| /// single iteration loops based on the range returned by the lambda |
| /// |getMinMaxFn| for some know values. |
| void populateRemoveSingleIterationLoopPattern(RewritePatternSet &patterns, |
| GetMinMaxExprFn getMinMaxFn); |
| |
| /// Insert pattern to fold chains of `affine.min` operations. |
| // TODO: It is not clear what this pattern is doing and should be deprecated. |
| void populateAffineMinCanonicalizationPattern(RewritePatternSet &patterns); |
| |
| } // namespace iree_compiler |
| } // namespace mlir |
| |
| #endif // IREE_COMPILER_CODEGEN_COMMON_TRANSFORMS_H_ |