NFC: improve naming and doc for tiled and distributed loop info (#7558)
The logic behind rediscovering the loop tiling and distribution
information is quite dense. This makes at least the API clearer.
diff --git a/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp b/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
index 6b2ad8e..06b12bc 100644
--- a/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
+++ b/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
@@ -130,7 +130,7 @@
}
static SmallVector<int64_t> getDefaultWorkloadPerWorkgroup(
- ArrayRef<TiledLoopInfo> tiledLoops,
+ ArrayRef<LoopTilingAndDistributionInfo> tiledLoops,
ArrayRef<int64_t> nativeVectorSizeInElements) {
if (tiledLoops.empty()) {
return {};
@@ -138,7 +138,7 @@
assert(tiledLoops.size() == nativeVectorSizeInElements.size());
unsigned maxDim = 0;
for (auto tiledLoop : tiledLoops) {
- maxDim = std::max<unsigned>(tiledLoop.distributionDim, maxDim);
+ maxDim = std::max<unsigned>(tiledLoop.processorDistributionDim, maxDim);
}
SmallVector<int64_t> workloadPerWorkgroup(maxDim + 1, 1);
SmallVector<int64_t> numWorkgroupsPerDim(maxDim + 1, 1);
@@ -149,9 +149,9 @@
auto ceilFn = [](int64_t a, int64_t b) { return (a + b - 1) / b; };
for (auto tiledLoop : enumerate(tiledLoops)) {
- Optional<int64_t> lb = getStaticValue(tiledLoop.value().lb);
- Optional<int64_t> ub = getStaticValue(tiledLoop.value().ub);
- unsigned dim = tiledLoop.value().distributionDim;
+ Optional<int64_t> lb = getStaticValue(tiledLoop.value().untiledLowerBound);
+ Optional<int64_t> ub = getStaticValue(tiledLoop.value().untiledUpperBound);
+ unsigned dim = tiledLoop.value().processorDistributionDim;
if (!lb || !ub) {
workloadPerWorkgroup[dim] = defaultWorkgroupTileSize;
workload[dim] = ShapedType::kDynamicSize;
@@ -209,7 +209,7 @@
/// Sets the default launch configuration to use for a tiled + distributed
/// dispatch region based on the `tiledLoops` found.
static LogicalResult setDefaultLaunchConfig(
- FuncOp entryPointFn, ArrayRef<TiledLoopInfo> tiledLoops) {
+ FuncOp entryPointFn, ArrayRef<LoopTilingAndDistributionInfo> tiledLoops) {
unsigned typeWidthInBytes = getReferenceTypeLengthInBytes(entryPointFn);
SmallVector<int64_t> nativeVectorSizeInElements(tiledLoops.size(), 1);
if (!tiledLoops.empty()) {
@@ -232,9 +232,9 @@
/// Sets the lowering configuration for dispatch region with root op that
/// implements the contraction operation interface.
-static LogicalResult setRootConfig(FuncOp entryPointFn,
- linalg::ContractionOpInterface contractionOp,
- ArrayRef<TiledLoopInfo> tiledLoops) {
+static LogicalResult setRootConfig(
+ FuncOp entryPointFn, linalg::ContractionOpInterface contractionOp,
+ ArrayRef<LoopTilingAndDistributionInfo> tiledLoops) {
if (getLoweringConfig(contractionOp)) return success();
auto lhsShapedType = contractionOp.lhs().getType().cast<ShapedType>();
@@ -286,15 +286,17 @@
return maxSize;
};
for (unsigned i = tiledLoops.size() - 2; i < tiledLoops.size(); ++i) {
- if (!tiledLoops[i].lb.is<Attribute>() ||
- !tiledLoops[i].ub.is<Attribute>()) {
+ if (!tiledLoops[i].untiledLowerBound.is<Attribute>() ||
+ !tiledLoops[i].untiledUpperBound.is<Attribute>()) {
continue;
}
- int64_t lb = tiledLoops[i].lb.get<Attribute>().cast<IntegerAttr>().getInt();
- int64_t ub = tiledLoops[i].ub.get<Attribute>().cast<IntegerAttr>().getInt();
- workloadPerWorkgroup[tiledLoops.size() - 1 - i] =
- getTileSize(lb, ub, workloadPerWorkgroup[tiledLoops.size() - 1 - i],
- vectorSizeVals[i]);
+ auto lb =
+ tiledLoops[i].untiledLowerBound.get<Attribute>().cast<IntegerAttr>();
+ auto ub =
+ tiledLoops[i].untiledUpperBound.get<Attribute>().cast<IntegerAttr>();
+ workloadPerWorkgroup[tiledLoops.size() - 1 - i] = getTileSize(
+ lb.getInt(), ub.getInt(),
+ workloadPerWorkgroup[tiledLoops.size() - 1 - i], vectorSizeVals[i]);
}
setTranslationInfo(
entryPointFn,
@@ -328,8 +330,9 @@
/// Sets the lowering configuration for dispatch region for linalg.mmt4d root
/// op
-static LogicalResult setRootConfig(FuncOp entryPointFn, linalg::Mmt4DOp mmt4dOp,
- ArrayRef<TiledLoopInfo> tiledLoops) {
+static LogicalResult setRootConfig(
+ FuncOp entryPointFn, linalg::Mmt4DOp mmt4dOp,
+ ArrayRef<LoopTilingAndDistributionInfo> tiledLoops) {
// TODO(ataei): These are hand tuned for some performance benchmarks for
// now, we want to adapt the same strategy as matmul that dynamically sets
// tile size.
@@ -379,8 +382,9 @@
/// Sets the lowering configuration for dispatch region for linalg_ext.fft
/// root op.
-static LogicalResult setRootConfig(FuncOp entryPointFn, linalg_ext::FftOp fftOp,
- ArrayRef<TiledLoopInfo> tiledLoops) {
+static LogicalResult setRootConfig(
+ FuncOp entryPointFn, linalg_ext::FftOp fftOp,
+ ArrayRef<LoopTilingAndDistributionInfo> tiledLoops) {
auto partitionedLoops = getPartitionedLoops(fftOp);
unsigned maxDepth = partitionedLoops.back() + 1;
SmallVector<int64_t> workgroupTileSizes(maxDepth, defaultWorkgroupTileSize);
@@ -415,9 +419,9 @@
/// Finds the root operation in the given list of linalg operations and sets
/// its configuration. Returns error for multiple root operations.
-static LogicalResult setRootConfig(FuncOp entryPointFn,
- ArrayRef<Operation *> computeOps,
- ArrayRef<TiledLoopInfo> tiledLoops) {
+static LogicalResult setRootConfig(
+ FuncOp entryPointFn, ArrayRef<Operation *> computeOps,
+ ArrayRef<LoopTilingAndDistributionInfo> tiledLoops) {
Operation *rootOp = nullptr;
for (auto computeOp : computeOps) {
auto setRootConfigFn = [&](Operation *op) -> LogicalResult {
@@ -445,7 +449,7 @@
/// Sets the translation information to use for a dispatch region.
static LogicalResult setTranslationInfoAndRootConfig(
FuncOp entryPointFn, ArrayRef<Operation *> computeOps,
- ArrayRef<TiledLoopInfo> tiledLoops) {
+ ArrayRef<LoopTilingAndDistributionInfo> tiledLoops) {
// First check if the operations have a preset pipeline.
for (auto computeOp : computeOps) {
if (IREE::Codegen::CompilationInfoAttr compilationInfo =
@@ -486,7 +490,7 @@
if (!entryPointOp) continue;
if (getTranslationInfo(entryPointOp)) continue;
SmallVector<Operation *> computeOps;
- SmallVector<TiledLoopInfo> tiledLoops;
+ SmallVector<LoopTilingAndDistributionInfo> tiledLoops;
// If there are no linalg ops, not using Linalg based lowering.
if (failed(getComputeOps(funcOp, computeOps, tiledLoops))) {
diff --git a/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
index 0d1ae4e..d6bda1d 100644
--- a/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
+++ b/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -263,7 +263,7 @@
if (!entryPointOp) continue;
if (getTranslationInfo(entryPointOp)) continue;
SmallVector<Operation *> computeOps;
- SmallVector<TiledLoopInfo> tiledLoops;
+ SmallVector<LoopTilingAndDistributionInfo> tiledLoops;
if (failed(getComputeOps(funcOp, computeOps, tiledLoops))) {
return funcOp.emitOpError("failed to get compute ops");
}
diff --git a/iree/compiler/Codegen/LLVMGPU/LLVMGPURemoveTrivialLoops.cpp b/iree/compiler/Codegen/LLVMGPU/LLVMGPURemoveTrivialLoops.cpp
index df0e622..5837ff4 100644
--- a/iree/compiler/Codegen/LLVMGPU/LLVMGPURemoveTrivialLoops.cpp
+++ b/iree/compiler/Codegen/LLVMGPU/LLVMGPURemoveTrivialLoops.cpp
@@ -60,20 +60,27 @@
/// of element per workgroups.
static SmallVector<int64_t> getNumWorkgroup(
FuncOp funcOp, IREE::HAL::ExecutableEntryPointOp entryPointOp) {
- SmallVector<TiledLoopInfo> tiledLoopInfo = getTiledLoopInfo(funcOp);
+ SmallVector<LoopTilingAndDistributionInfo> tiledLoopInfo =
+ getTiledAndDistributedLoopInfo(funcOp);
SmallVector<int64_t> workloadSize(tiledLoopInfo.size());
- for (TiledLoopInfo &tileInfo : tiledLoopInfo) {
- if (tileInfo.distributionDim >= workloadSize.size())
+ for (LoopTilingAndDistributionInfo &tileInfo : tiledLoopInfo) {
+ if (tileInfo.processorDistributionDim >= workloadSize.size())
return SmallVector<int64_t>();
- if (!tileInfo.lb.is<Attribute>() || !tileInfo.ub.is<Attribute>() ||
- !tileInfo.step.is<Attribute>()) {
+ if (!tileInfo.untiledLowerBound.is<Attribute>() ||
+ !tileInfo.untiledUpperBound.is<Attribute>() ||
+ !tileInfo.untiledStep.is<Attribute>()) {
continue;
}
- int64_t lb = tileInfo.lb.get<Attribute>().cast<IntegerAttr>().getInt();
- int64_t ub = tileInfo.ub.get<Attribute>().cast<IntegerAttr>().getInt();
- int64_t step = tileInfo.step.get<Attribute>().cast<IntegerAttr>().getInt();
+ int64_t lb = tileInfo.untiledLowerBound.get<Attribute>()
+ .cast<IntegerAttr>()
+ .getInt();
+ int64_t ub = tileInfo.untiledUpperBound.get<Attribute>()
+ .cast<IntegerAttr>()
+ .getInt();
+ int64_t step =
+ tileInfo.untiledStep.get<Attribute>().cast<IntegerAttr>().getInt();
if (step == 0) return SmallVector<int64_t>();
- workloadSize[tileInfo.distributionDim] = (ub - lb) / step;
+ workloadSize[tileInfo.processorDistributionDim] = (ub - lb) / step;
}
auto translationInfo = getTranslationInfo(entryPointOp);
if (!translationInfo) return SmallVector<int64_t>();
diff --git a/iree/compiler/Codegen/SPIRV/KernelConfig.cpp b/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
index 3692cf6..1f59872 100644
--- a/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
+++ b/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
@@ -385,7 +385,8 @@
// 1) distributing to as many threads as possible, and 2) avoid assigning too
// many threads to handle out-of-bound elements (thus idle).
- SmallVector<TiledLoopInfo> tiledLoopInfo = getTiledLoopInfo(funcOp);
+ SmallVector<LoopTilingAndDistributionInfo> tiledLoopInfo =
+ getTiledAndDistributedLoopInfo(funcOp);
// The number of linalg implicit loops to partition and tiled loops
// surrounding the op should match. Otherwise, something is incorrect.
assert(partitionedLoops.size() == tiledLoopInfo.size());
@@ -395,8 +396,9 @@
// tiledLoopInfo uses the reverse order of partitionedLoops.
for (auto pair : llvm::zip(llvm::reverse(partitionedLoops), tiledLoopInfo)) {
unsigned loopIndex = std::get<0>(pair);
- const TiledLoopInfo &loopInfo = std::get<1>(pair);
- Optional<int64_t> attrValue = getConstantIntValue(loopInfo.ub);
+ const LoopTilingAndDistributionInfo &loopInfo = std::get<1>(pair);
+ Optional<int64_t> attrValue =
+ getConstantIntValue(loopInfo.untiledUpperBound);
if (attrValue) {
loopBounds[loopIndex] = *attrValue;
} else {
@@ -624,7 +626,7 @@
if (getTranslationInfo(entryPointOp)) continue;
SmallVector<Operation *> computeOps;
- SmallVector<TiledLoopInfo> tiledLoops;
+ SmallVector<LoopTilingAndDistributionInfo> tiledLoops;
if (failed(getComputeOps(funcOp, computeOps, tiledLoops))) {
return funcOp.emitOpError("failed to get compute ops");
}
diff --git a/iree/compiler/Codegen/SPIRV/SPIRVRemoveOneTripTiledLoops.cpp b/iree/compiler/Codegen/SPIRV/SPIRVRemoveOneTripTiledLoops.cpp
index 21b7e62..3a67961 100644
--- a/iree/compiler/Codegen/SPIRV/SPIRVRemoveOneTripTiledLoops.cpp
+++ b/iree/compiler/Codegen/SPIRV/SPIRVRemoveOneTripTiledLoops.cpp
@@ -103,7 +103,7 @@
// This pass seems to be only needed for the convolution vectorization. So
// filter out the necessary conv ops.
SmallVector<Operation *> rootOp;
- SmallVector<TiledLoopInfo> tiledLoops;
+ SmallVector<LoopTilingAndDistributionInfo> tiledLoops;
auto isConvOp = [](Operation *op) {
return isa<linalg::DepthwiseConv2DNhwOp, linalg::Conv2DNhwcHwcfOp>(op);
};
diff --git a/iree/compiler/Codegen/Utils/Utils.cpp b/iree/compiler/Codegen/Utils/Utils.cpp
index 96fa627..c371447 100644
--- a/iree/compiler/Codegen/Utils/Utils.cpp
+++ b/iree/compiler/Codegen/Utils/Utils.cpp
@@ -12,10 +12,13 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/SymbolTable.h"
+#define DEBUG_TYPE "iree-codegen-utils"
+
namespace mlir {
namespace iree_compiler {
@@ -163,7 +166,7 @@
return isaAffineExprOfType<T2, T3...>(expr);
}
-/// Returns a Value that repreesnts the value for symbol or dim expr for the map
+/// Returns a Value that represents the value for symbol or dim expr for the map
/// in the `applyOp`.
static Value getValueForDimOrSymbol(AffineApplyOp applyOp, AffineExpr expr) {
unsigned numDims = applyOp.getAffineMap().getNumDims();
@@ -233,7 +236,8 @@
class LowerBoundExprVisitor
: public AffineExprVisitor<LowerBoundExprVisitor, LogicalResult> {
public:
- LowerBoundExprVisitor(AffineApplyOp applyOp, TiledLoopInfo &loopInfo)
+ LowerBoundExprVisitor(AffineApplyOp applyOp,
+ LoopTilingAndDistributionInfo &loopInfo)
: applyOp(applyOp), loopInfo(loopInfo) {}
LogicalResult visitSymbolExpr(AffineSymbolExpr /*expr*/) { return failure(); }
@@ -260,10 +264,10 @@
if (!v) {
return failure();
}
- loopInfo.lb = getAsOpFoldResult(v);
+ loopInfo.untiledLowerBound = getAsOpFoldResult(v);
} else if (auto constExpr = lbExpr.dyn_cast<AffineConstantExpr>()) {
- loopInfo.lb = IntegerAttr::get(IndexType::get(applyOp.getContext()),
- constExpr.getValue());
+ loopInfo.untiledLowerBound = IntegerAttr::get(
+ IndexType::get(applyOp.getContext()), constExpr.getValue());
} else {
return failure();
}
@@ -279,7 +283,7 @@
if (vals.size() != 1 || !vals[0]) {
return failure();
}
- loopInfo.workgroupSize = wgSize.getValue();
+ loopInfo.tileSize = wgSize.getValue();
dimension = checkDimensions<IREE::HAL::InterfaceWorkgroupIDOp>(vals);
} else {
vals = getValuesForDimsOrSymbols(applyOp, {expr.getLHS(), expr.getRHS()});
@@ -292,16 +296,17 @@
if (!dimension) {
return failure();
}
- loopInfo.distributionDim = dimension.getValue();
- if (!loopInfo.lb) {
- loopInfo.lb = IntegerAttr::get(IndexType::get(applyOp.getContext()), 0);
+ loopInfo.processorDistributionDim = dimension.getValue();
+ if (!loopInfo.untiledLowerBound) {
+ loopInfo.untiledLowerBound =
+ IntegerAttr::get(IndexType::get(applyOp.getContext()), 0);
}
return success();
}
private:
AffineApplyOp applyOp;
- TiledLoopInfo &loopInfo;
+ LoopTilingAndDistributionInfo &loopInfo;
};
/// Visitor to walk the `step` of a distributed loop. Expected the expression to
@@ -311,7 +316,8 @@
class StepExprVisitor
: public AffineExprVisitor<StepExprVisitor, LogicalResult> {
public:
- StepExprVisitor(AffineApplyOp applyOp, TiledLoopInfo &loopInfo)
+ StepExprVisitor(AffineApplyOp applyOp,
+ LoopTilingAndDistributionInfo &loopInfo)
: applyOp(applyOp), loopInfo(loopInfo) {}
LogicalResult visitSymbolExpr(AffineSymbolExpr /*expr*/) { return failure(); }
@@ -335,28 +341,28 @@
}
expr = e.cast<AffineBinaryOpExpr>();
} else {
- // Check if WorkgroupSizeOp is folded.
- if (loopInfo.workgroupSize) {
- if (auto stepBySize = expr.getRHS().dyn_cast<AffineConstantExpr>()) {
- loopInfo.step =
+ // Check if the workgroup tile size is folded into the affine map itself.
+ if (loopInfo.tileSize) {
+ if (auto stepCst = expr.getRHS().dyn_cast<AffineConstantExpr>()) {
+ loopInfo.untiledStep =
IntegerAttr::get(IndexType::get(applyOp.getContext()),
- stepBySize.getValue() / *loopInfo.workgroupSize);
+ stepCst.getValue() / *loopInfo.tileSize);
}
} else {
- loopInfo.step =
+ loopInfo.untiledStep =
IntegerAttr::get(IndexType::get(applyOp.getContext()), 1);
}
}
if (failed(processSentinel(expr.getLHS(), sentinels)) ||
- (!loopInfo.workgroupSize &&
+ (!loopInfo.tileSize &&
failed(processSentinel(expr.getRHS(), sentinels)))) {
return failure();
}
// Either there are 3 sentinels and step isnt set, or there are two
// sentinels and the step is set.
if (sentinels.size() == 3) {
- if (loopInfo.step) {
+ if (loopInfo.untiledStep) {
return failure();
}
auto it = sentinels.begin();
@@ -364,7 +370,7 @@
Value v = getValueForDimOrSymbol(applyOp, *it);
if (!v.getDefiningOp<IREE::HAL::InterfaceWorkgroupSizeOp>() &&
!v.getDefiningOp<IREE::HAL::InterfaceWorkgroupCountOp>()) {
- loopInfo.step = getAsOpFoldResult(v);
+ loopInfo.untiledStep = getAsOpFoldResult(v);
break;
}
}
@@ -373,18 +379,18 @@
}
}
- if ((sentinels.size() != 2 || !loopInfo.step) &&
- (sentinels.size() != 1 || !loopInfo.workgroupSize)) {
+ if ((sentinels.size() != 2 || !loopInfo.untiledStep) &&
+ (sentinels.size() != 1 || !loopInfo.tileSize)) {
return failure();
}
SmallVector<Value> vals = getValuesForDimsOrSymbols(applyOp, sentinels);
- if ((loopInfo.workgroupSize &&
+ if ((loopInfo.tileSize &&
!checkDimensions<IREE::HAL::InterfaceWorkgroupCountOp>(
- vals, loopInfo.distributionDim)) ||
- (!loopInfo.workgroupSize &&
+ vals, loopInfo.processorDistributionDim)) ||
+ (!loopInfo.tileSize &&
!checkDimensions<IREE::HAL::InterfaceWorkgroupCountOp,
IREE::HAL::InterfaceWorkgroupSizeOp>(
- vals, loopInfo.distributionDim))) {
+ vals, loopInfo.processorDistributionDim))) {
return failure();
}
return success();
@@ -397,18 +403,18 @@
sentinels.push_back(e);
return success();
} else if (auto constExpr = e.dyn_cast<AffineConstantExpr>()) {
- if (loopInfo.step) {
+ if (loopInfo.untiledStep) {
return failure();
}
- loopInfo.step = IntegerAttr::get(IndexType::get(applyOp.getContext()),
- constExpr.getValue());
+ loopInfo.untiledStep = IntegerAttr::get(
+ IndexType::get(applyOp.getContext()), constExpr.getValue());
return success();
}
return failure();
}
AffineApplyOp applyOp;
- TiledLoopInfo &loopInfo;
+ LoopTilingAndDistributionInfo &loopInfo;
};
} // namespace
@@ -419,17 +425,16 @@
/// %id = flow.dispatch.workgroup.id[%dim]
/// %count = flow.dispatch.workgroup.count[%dim]
/// %size = flow.dispatch.workgroup.size[%dim]
-/// %offset = affine.apply affine_map<(d0)[s0, s1] -> (d0 + s0 *
-/// s1)>(%lb)[%id, %size] %new_step = affine.apply affine_map<(d0)[s0, s1] ->
-/// (d0 * s0 * s1)>(%step)[%id, %size] scf.for %iv = %offset to %ub step
-/// %new_step {
-/// ...
-/// }
+/// %offset = affine.apply
+/// affine_map<(d0)[s0, s1] -> (d0 + s0 * s1)>(%lb)[%id, %size]
+/// %new_step = affine.apply
+/// affine_map<(d0)[s0, s1] -> (d0 * s0 * s1)>(%step)[%id, %size]
+/// scf.for %iv = %offset to %ub step %new_step { ... }
/// ```
-static Optional<TiledLoopInfo> isTiledLoop(MLIRContext *context,
- scf::ForOp forOp) {
- TiledLoopInfo loopInfo;
- loopInfo.tiledLoop = forOp;
+Optional<LoopTilingAndDistributionInfo> isTiledAndDistributedLoop(
+ scf::ForOp forOp) {
+ LoopTilingAndDistributionInfo loopInfo;
+ loopInfo.loop = forOp;
auto lbApplyOp = forOp.lowerBound().getDefiningOp<AffineApplyOp>();
if (!lbApplyOp) {
return llvm::None;
@@ -444,27 +449,27 @@
failed(stepVisitor.visit(stepApplyOp.getAffineMap().getResults()[0]))) {
return llvm::None;
}
- if (!loopInfo.lb || !loopInfo.step) {
+ if (!loopInfo.untiledLowerBound || !loopInfo.untiledStep) {
return llvm::None;
}
- loopInfo.ub = getAsOpFoldResult(forOp.upperBound());
+ loopInfo.untiledUpperBound = getAsOpFoldResult(forOp.upperBound());
return loopInfo;
}
-LogicalResult getFilteredOps(FuncOp funcOp, RootOpFilteringFn filteringFn,
- SmallVectorImpl<Operation *> &filteredOps,
- SmallVectorImpl<TiledLoopInfo> &tiledLoops) {
+LogicalResult getFilteredOps(
+ FuncOp funcOp, RootOpFilteringFn filteringFn,
+ SmallVectorImpl<Operation *> &filteredOps,
+ SmallVectorImpl<LoopTilingAndDistributionInfo> &tiledLoops) {
Region ®ion = funcOp.body();
if (!llvm::hasSingleElement(region)) {
return funcOp.emitError("unable dispatch function with multiple blocks");
}
Block *body = ®ion.front();
- MLIRContext *context = funcOp.getContext();
auto forOps = body->getOps<scf::ForOp>();
while (!forOps.empty()) {
if (!llvm::hasSingleElement(forOps)) return failure();
scf::ForOp forOp = *(forOps.begin());
- if (auto tiledLoopInfo = isTiledLoop(context, forOp)) {
+ if (auto tiledLoopInfo = isTiledAndDistributedLoop(forOp)) {
tiledLoops.emplace_back(std::move(tiledLoopInfo.getValue()));
}
body = forOp.getBody();
@@ -478,9 +483,9 @@
return success();
}
-LogicalResult getComputeOps(FuncOp funcOp,
- SmallVectorImpl<Operation *> &computeOps,
- SmallVectorImpl<TiledLoopInfo> &tiledLoops) {
+LogicalResult getComputeOps(
+ FuncOp funcOp, SmallVectorImpl<Operation *> &computeOps,
+ SmallVectorImpl<LoopTilingAndDistributionInfo> &tiledLoops) {
if (failed(getFilteredOps(
funcOp,
[](Operation *op) {
@@ -492,10 +497,11 @@
return success();
}
-SmallVector<TiledLoopInfo> getTiledLoopInfo(FuncOp funcOp) {
- SmallVector<TiledLoopInfo> info;
+SmallVector<LoopTilingAndDistributionInfo> getTiledAndDistributedLoopInfo(
+ FuncOp funcOp) {
+ SmallVector<LoopTilingAndDistributionInfo> info;
funcOp.walk([&](scf::ForOp forOp) {
- if (auto tiledLoopInfo = isTiledLoop(forOp.getContext(), forOp)) {
+ if (auto tiledLoopInfo = isTiledAndDistributedLoop(forOp)) {
info.emplace_back(std::move(tiledLoopInfo.getValue()));
}
});
diff --git a/iree/compiler/Codegen/Utils/Utils.h b/iree/compiler/Codegen/Utils/Utils.h
index f0c563a..e5203b0 100644
--- a/iree/compiler/Codegen/Utils/Utils.h
+++ b/iree/compiler/Codegen/Utils/Utils.h
@@ -18,7 +18,7 @@
static constexpr unsigned kNumMaxParallelDims = 3;
//===----------------------------------------------------------------------===//
-// Utility functions to get entry point(s)
+// Utility functions to get entry points
//===----------------------------------------------------------------------===//
/// Returns true if the given `func` is a kernel dispatch entry point.
@@ -32,7 +32,7 @@
IREE::HAL::ExecutableEntryPointOp getEntryPoint(FuncOp funcOp);
//===----------------------------------------------------------------------===//
-// Utility functions used in setting default configurations.
+// Utility functions to set configurations
//===----------------------------------------------------------------------===//
/// Returns the loops that are partitioned during dispatch region formations, in
@@ -57,6 +57,47 @@
ArrayRef<int64_t> getUntiledResultShape(linalg::LinalgOp linalgOp,
unsigned resultNum);
+/// Information about a tiled and distributed loop.
+///
+/// Right now distribution is happening as the same time when we tile the linalg
+/// op. 0) Given an original loop:
+///
+/// ```
+/// scf.for %iv = %init_lb to %init_ub step %init_step { ... }
+/// ```
+//
+/// 1) After tiling with tile size `%tile_size`, we have:
+//
+/// ```
+/// %tiled_step = %init_step * %tile_size
+/// scf.for %iv = %init_lb to %init_ub step %tiled_step { ... }
+/// ```
+///
+/// 2) After distribution with processor `%id` and `%count`, we have:
+//
+/// ```
+/// %dist_lb = %init_lb + %id * %tiled_step
+/// %dist_step = %init_step * %tile_size * %count
+/// scf.for %iv = %dist_lb to %init_ub step %dist_step { ... }
+/// ```
+///
+/// Given a loop already after 2), this struct contains recovered information
+/// about 0) and 1).
+struct LoopTilingAndDistributionInfo {
+ // The tiled and distributed loop.
+ Operation *loop;
+ // The lower bound for the original untiled loop.
+ OpFoldResult untiledLowerBound;
+ // The upper bound for the original untiled loop.
+ OpFoldResult untiledUpperBound;
+ // The step for the original untiled loop.
+ OpFoldResult untiledStep;
+ // The tile size used to tile (and not distribute) the original untiled loop.
+ Optional<int64_t> tileSize;
+ // The processor dimension this loop is distributed to.
+ unsigned processorDistributionDim;
+};
+
/// Assuming that `funcOp` contains a single nested scf.for that represented the
/// tiled+fused+distributed loops with the distribution being across workgroups,
/// i.e.
@@ -76,19 +117,11 @@
/// `scf.for` operations in the function return the linalg operations in the
/// body of the function if it has a single basic block. Return failure in all
/// other cases.
-
-struct TiledLoopInfo {
- Operation *tiledLoop;
- OpFoldResult lb;
- OpFoldResult ub;
- OpFoldResult step;
- Optional<int64_t> workgroupSize;
- unsigned distributionDim;
-};
using RootOpFilteringFn = std::function<bool(Operation *)>;
-LogicalResult getFilteredOps(FuncOp funcOp, RootOpFilteringFn filteringFn,
- SmallVectorImpl<Operation *> &filteredOps,
- SmallVectorImpl<TiledLoopInfo> &tiledLoops);
+LogicalResult getFilteredOps(
+ FuncOp funcOp, RootOpFilteringFn filteringFn,
+ SmallVectorImpl<Operation *> &filteredOps,
+ SmallVectorImpl<LoopTilingAndDistributionInfo> &tiledLoops);
/// Specialization of `getFilteredOps` for filtering `LinalgOp`s and
/// `LinagExtOp`s.
@@ -96,12 +129,13 @@
/// within the loop. The marker is the way to tie into rest of the
/// codegen. Refactor the downstream passes and get rid of the markers once and
/// for all.
-LogicalResult getComputeOps(FuncOp funcOp,
- SmallVectorImpl<Operation *> &computeOps,
- SmallVectorImpl<TiledLoopInfo> &tiledLoops);
+LogicalResult getComputeOps(
+ FuncOp funcOp, SmallVectorImpl<Operation *> &computeOps,
+ SmallVectorImpl<LoopTilingAndDistributionInfo> &tiledLoops);
-/// Collect information about loops matching tiled+distribute pattern.
-SmallVector<TiledLoopInfo> getTiledLoopInfo(FuncOp funcOp);
+/// Collects information about loops matching tiled+distribute pattern.
+SmallVector<LoopTilingAndDistributionInfo> getTiledAndDistributedLoopInfo(
+ FuncOp funcOp);
} // namespace iree_compiler
} // namespace mlir