blob: efa1a1bd425ec261ccdaa2ea2ed6b9dc0f5c4db8 [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include "iree/compiler/IR/Ops.h"
#include "iree/compiler/IR/Types.h"
#include "iree/compiler/Utils/DispatchUtils.h"
#include "third_party/llvm/llvm/include/llvm/ADT/ArrayRef.h"
#include "third_party/llvm/llvm/include/llvm/ADT/DenseMap.h"
#include "third_party/llvm/llvm/include/llvm/ADT/DenseSet.h"
#include "third_party/llvm/llvm/include/llvm/ADT/STLExtras.h"
#include "third_party/llvm/llvm/include/llvm/ADT/SetVector.h"
#include "third_party/llvm/llvm/include/llvm/ADT/SmallVector.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/Dialect/StandardOps/Ops.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/IR/Attributes.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/IR/BlockAndValueMapping.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/IR/Builders.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/IR/Location.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/IR/MLIRContext.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/IR/StandardTypes.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/Pass/Pass.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/Pass/PassRegistry.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/Support/LLVM.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/Support/LogicalResult.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/Transforms/Utils.h"
#include "third_party/tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
namespace mlir {
namespace iree_compiler {
namespace {
// Builds a new iree.reduction_region with the given |invocationRegion|.
// The new region will be inserted after |originalOp|.
//
// All |invocationRegion| ops must be compatible with the |workload| specified
// as they will all be dispatched with the same workgroup structure. The
// |invocationRegion| will not be modified.
LogicalResult buildReductionRegion(Operation *originalOp,
ArrayRef<Value *> operands,
ArrayRef<Value *> initialValues,
ArrayRef<int64_t> dimensions,
Region &invocationRegion) {
OpBuilder parentBuilder(originalOp);
// Compute the workload based on the output shape.
// When variadic all output shapes match so we can just take the first.
auto *workload = calculateWorkload(originalOp, originalOp->getResult(0));
// Build the region op and add it to the parent block.
SmallVector<Type, 4> resultTypes{originalOp->getResultTypes()};
auto reductionRegionOp = parentBuilder.create<IREE::ReductionRegionOp>(
originalOp->getLoc(), resultTypes, workload, operands, initialValues,
dimensions);
// Create the block and setup the arg mapping for captured values.
BlockAndValueMapping mapping;
invocationRegion.cloneInto(&reductionRegionOp.getBody(), mapping);
// Replace xla_hlo.return -> iree.return.
OpBuilder regionBuilder(reductionRegionOp.getBody());
reductionRegionOp.walk([&](xla_hlo::ReturnOp returnOp) {
regionBuilder.setInsertionPoint(returnOp);
SmallVector<Value *, 4> returnValues(returnOp.getOperands());
regionBuilder.create<IREE::ReturnOp>(returnOp.getLoc(), returnValues);
returnOp.erase();
});
// Replace usage of values with the results of the region.
for (int i = 0; i < originalOp->getNumResults(); ++i) {
originalOp->getResult(i)->replaceAllUsesWith(
reductionRegionOp.getResult(i));
}
return success();
}
// Converts an xla_hlo::ReduceOp to a reduction region and inlines the target
// computation into the region body.
LogicalResult buildReductionRegionFromXLAReduceOp(xla_hlo::ReduceOp reduceOp) {
SmallVector<Value *, 4> operands(reduceOp.getOperands());
OperandAdaptor<xla_hlo::ReduceOp> adaptor(operands);
SmallVector<int64_t, 4> dimensions;
for (auto dim : reduceOp.dimensions().getIntValues()) {
dimensions.push_back(dim.getSExtValue());
}
// Create the iree.reduction_region.
if (failed(buildReductionRegion(reduceOp, adaptor.operands(),
adaptor.init_values(), dimensions,
reduceOp.body()))) {
return failure();
}
// Remove original XLA reduction op.
reduceOp.erase();
return success();
}
// Identifies reduction ops and moves them into reduction regions.
LogicalResult identifyBlockReductionRegions(FuncOp funcOp, Block *block) {
// Fixed point iteration until we can no longer fuse anything.
bool didFindAnyNewRegions;
do {
// Iterate in reverse so we root further along in the op list.
didFindAnyNewRegions = false;
for (auto &rootOp : llvm::reverse(*block)) {
if (auto reduceOp = dyn_cast<xla_hlo::ReduceOp>(rootOp)) {
if (failed(buildReductionRegionFromXLAReduceOp(reduceOp))) {
return failure();
}
// Successfully created a dispatch region from the ops and we must now
// start over again as we've likely trashed the whole block structure.
didFindAnyNewRegions = true;
break;
}
}
} while (didFindAnyNewRegions);
return success();
}
} // namespace
// Identifies reduction ops and moves their targets into iree.reduction_regions.
class IdentifyReductionRegionsPass
: public ModulePass<IdentifyReductionRegionsPass> {
public:
void runOnModule() override {
for (auto funcOp : getModule().getOps<FuncOp>()) {
for (auto &block : funcOp) {
if (failed(identifyBlockReductionRegions(funcOp, &block))) {
return signalPassFailure();
}
}
}
}
};
std::unique_ptr<OpPassBase<ModuleOp>> createIdentifyReductionRegionsPass() {
return std::make_unique<IdentifyReductionRegionsPass>(); // NOLINT
}
static PassRegistration<IdentifyReductionRegionsPass> pass(
"iree-identify-reduction-regions",
"Identifies reduction regions based on input reduction ops.");
} // namespace iree_compiler
} // namespace mlir