blob: 20d51a2831f5742c205bd190231266ae4bbac952 [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "compiler/IR/Sequencer/LLOps.h"
#include "compiler/IR/StructureOps.h"
#include "compiler/Utils/OpUtils.h"
#include "llvm/ADT/StringMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/Utils.h"
namespace mlir {
namespace iree_compiler {
namespace {
struct WorkloadInfo {
SmallVector<ElementsAttr, 4> staticWorkloads;
SmallVector<Value *, 4> dynamicWorkloads;
};
// Finds all dispatches and records their workload attributes mapped by
// (executable ordinal, entry point ordinal).
llvm::StringMap<llvm::StringMap<WorkloadInfo>> gatherExecutableWorkloadInfos(
ModuleOp moduleOp) {
llvm::StringMap<llvm::StringMap<WorkloadInfo>> workloadInfos;
for (auto funcOp : moduleOp.getOps<FuncOp>()) {
funcOp.walk([&](IREESeq::LL::DynamicDispatchOp op) {
auto &workloadInfo =
workloadInfos[op.getExecutable()][op.getEntryPoint()];
workloadInfo.dynamicWorkloads.push_back(op.getWorkload());
});
funcOp.walk([&](IREESeq::LL::StaticDispatchOp op) {
auto &workloadInfo =
workloadInfos[op.getExecutable()][op.getEntryPoint()];
for (auto existingWorkloadAttr : workloadInfo.staticWorkloads) {
if (existingWorkloadAttr == op.getWorkload()) {
return; // Already present, ignore.
}
}
workloadInfo.staticWorkloads.push_back(op.getWorkload());
});
}
return workloadInfos;
}
// Adds attributes to the given executable entry point describing the workload
// info to the backends that will be processing them.
LogicalResult attributeExecutableEntryPointWorkload(
FuncOp entryPointOp, const WorkloadInfo &workloadInfo) {
if (!workloadInfo.dynamicWorkloads.empty()) {
return entryPointOp.emitError() << "Dynamic workloads not yet supported";
}
if (workloadInfo.staticWorkloads.size() != 1) {
return entryPointOp.emitError() << "Static workload sizes differ in shape";
}
// Easy because we just support static workloads now.
// When this code is adapted to support dynamic workloads we'll want to put
// a pair of attrs describing which dimensions may be static and which args
// have the dynamic values to reference.
entryPointOp.setAttr("iree.executable.workload",
workloadInfo.staticWorkloads.front());
return success();
}
} // namespace
class AssignExecutableWorkloadAttrsPass
: public ModulePass<AssignExecutableWorkloadAttrsPass> {
public:
void runOnModule() override {
Builder builder(getModule());
// Find all dispatches and capture their workload information.
// We store this information by executable and then entry point ordinal.
auto executableWorkloadInfos = gatherExecutableWorkloadInfos(getModule());
// Process each executable with the workload information.
for (auto &executableIt : executableWorkloadInfos) {
auto multiArchExecutableOp = cast<IREE::MultiArchExecutableOp>(
getModule().lookupSymbol(executableIt.first()));
for (auto executableOp :
multiArchExecutableOp.getBlock().getOps<IREE::ExecutableOp>()) {
for (auto &entryPointIt : executableIt.second) {
auto funcOp = cast<FuncOp>(
executableOp.getInnerModule().lookupSymbol(entryPointIt.first()));
if (failed(attributeExecutableEntryPointWorkload(
funcOp, entryPointIt.second))) {
return signalPassFailure();
}
}
}
}
}
};
std::unique_ptr<OpPassBase<ModuleOp>>
createAssignExecutableWorkloadAttrsPass() {
return std::make_unique<AssignExecutableWorkloadAttrsPass>();
}
static PassRegistration<AssignExecutableWorkloadAttrsPass> pass(
"iree-assign-executable-workload-attrs",
"Assigns executable entrypoint workload attributes");
} // namespace iree_compiler
} // namespace mlir