blob: 62e211481f55170231ee13fa3f65161d666867c9 [file] [log] [blame]
// Copyright 2020 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/HAL/Transforms/Passes.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace iree_compiler {
namespace IREE {
namespace HAL {
class PropagateConstantWorkgroupInfoPass
: public PassWrapper<PropagateConstantWorkgroupInfoPass,
OperationPass<IREE::HAL::ExecutableTargetOp>> {
public:
void runOnOperation() override {
auto targetOp = getOperation();
SymbolTable targetSymbolTable(targetOp);
for (auto funcOp : targetOp.getInnerModule().getOps<FuncOp>()) {
auto entryPointOp =
targetSymbolTable.lookup<IREE::HAL::ExecutableEntryPointOp>(
funcOp.getName());
if (!entryPointOp) continue;
if (!entryPointOp.workgroup_size().hasValue()) continue;
auto workgroupSizeAttr = entryPointOp.workgroup_sizeAttr();
auto workgroupSizeOps = llvm::to_vector<4>(
funcOp.getOps<IREE::HAL::InterfaceWorkgroupSizeOp>());
for (auto workgroupSizeOp : workgroupSizeOps) {
OpBuilder builder(workgroupSizeOp);
auto dimValue = builder.createOrFold<ConstantIndexOp>(
workgroupSizeOp.getLoc(),
workgroupSizeAttr[workgroupSizeOp.dimension().getZExtValue()]
.cast<IntegerAttr>()
.getInt());
workgroupSizeOp.replaceAllUsesWith(dimValue);
workgroupSizeOp.erase();
}
}
}
};
std::unique_ptr<OperationPass<IREE::HAL::ExecutableTargetOp>>
createPropagateConstantWorkgroupInfoPass() {
return std::make_unique<PropagateConstantWorkgroupInfoPass>();
}
static PassRegistration<PropagateConstantWorkgroupInfoPass> pass(
"iree-hal-propagate-constant-workgroup-info",
"Propagates constant hal.interface.workgroup.* queries when known");
} // namespace HAL
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir