blob: 77b6a53c390ebf9f05968c7c99d802959bc61fb1 [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//===- ReductionFunctionLowering.cpp ---------------------------*- C++//-*-===//
//
// Lowering for reduction function body
//
//===----------------------------------------------------------------------===//
#include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/IR/Function.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
namespace mlir {
namespace iree_compiler {
namespace {
/// Type converter for legalization of reduction apply function.
class SPIRVReductionTypeConverter : public TypeConverter {
public:
Type convertType(Type t) override;
};
/// Base class for legalization of operations within the reduction apply
/// function (and the function itself).
template <typename OpTy>
class SPIRVReductionConversion : public OpConversionPattern<OpTy> {
public:
SPIRVReductionConversion(MLIRContext *context,
SPIRVReductionTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: OpConversionPattern<OpTy>(context, benefit),
typeConverter(typeConverter) {}
protected:
SPIRVReductionTypeConverter &typeConverter;
};
/// The apply function has a signature (lhs, rhs) -> output, all of the same
/// type t. This is converted to a function with the signature (t, !spv.ptr<t,
/// StorageBuffer>) -> (), where the first argument is the update, the second
/// argument is the buffer which contains the result of the reduction.
// TODO(ravishankarm): This is assuming storage class is StorageBuffer. This
// needs to be generalized.
class ReductionApplyFnConversion final
: public SPIRVReductionConversion<FuncOp> {
public:
using SPIRVReductionConversion<FuncOp>::SPIRVReductionConversion;
PatternMatchResult matchAndRewrite(
FuncOp funcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
/// Return operation conversion. Just converts ReturnOp to
/// spirv::ReturnOp.
// TODO: This can be moved into DRR.
template <typename ReturnOpTy>
class ReturnOpConversion final : public SPIRVReductionConversion<ReturnOpTy> {
public:
using SPIRVReductionConversion<ReturnOpTy>::SPIRVReductionConversion;
PatternMatchResult matchAndRewrite(
ReturnOpTy op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<spirv::ReturnOp>(op);
return this->matchSuccess();
}
};
/// Operations within the apply function need to be converted to a atomic
/// update.
template <typename OpTy, typename ReplacementOpTy>
class ReductionOpConversion final : public SPIRVReductionConversion<OpTy> {
public:
using SPIRVReductionConversion<OpTy>::SPIRVReductionConversion;
PatternMatchResult matchAndRewrite(
OpTy op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
} // namespace
//===----------------------------------------------------------------------===//
// Type Conversion
//===----------------------------------------------------------------------===//
Type SPIRVReductionTypeConverter::convertType(Type t) {
if (spirv::SPIRVDialect::isValidType(t)) {
return t;
}
if (auto tensorType = t.dyn_cast<RankedTensorType>()) {
if (tensorType.getRank() == 0) {
return tensorType.getElementType();
}
}
return nullptr;
}
//===----------------------------------------------------------------------===//
// Apply fn conversion.
//===----------------------------------------------------------------------===//
PatternMatchResult ReductionApplyFnConversion::matchAndRewrite(
FuncOp funcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
auto fnType = funcOp.getType();
if (fnType.getNumInputs() != 2 || fnType.getNumResults() != 1) {
return matchFailure();
}
if (fnType.getInput(0) != fnType.getInput(1) ||
fnType.getInput(0) != fnType.getResult(0)) {
return matchFailure();
}
TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
auto convertedType = typeConverter.convertType(fnType.getInput(0));
if (!convertedType) {
return matchFailure();
}
signatureConverter.addInputs(0, convertedType);
signatureConverter.addInputs(
1, spirv::PointerType::get(convertedType,
spirv::StorageClass::StorageBuffer));
auto newFn = rewriter.cloneWithoutRegions(funcOp);
rewriter.inlineRegionBefore(funcOp.getBody(), newFn.getBody(), newFn.end());
newFn.setType(rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
llvm::None));
rewriter.applySignatureConversion(&newFn.getBody(), signatureConverter);
rewriter.eraseOp(funcOp);
return matchSuccess();
}
//===----------------------------------------------------------------------===//
// ReductionOp
//===----------------------------------------------------------------------===//
template <typename OpTy, typename ReplacementOpTy>
PatternMatchResult
ReductionOpConversion<OpTy, ReplacementOpTy>::matchAndRewrite(
OpTy op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
if (operands.size() != 2) {
return this->matchFailure();
}
// One of the replacement operands will be a pointer type, and another a value
// type.
Value ptr = operands[0];
Value value = operands[1];
if (!ptr.getType().isa<spirv::PointerType>()) std::swap(ptr, value);
if (!ptr.getType().isa<spirv::PointerType>()) return this->matchFailure();
rewriter.replaceOpWithNewOp<ReplacementOpTy>(
op, ptr.getType().cast<spirv::PointerType>().getPointeeType(), ptr,
spirv::Scope::Device, spirv::MemorySemantics::AcquireRelease, value);
return this->matchSuccess();
}
//===----------------------------------------------------------------------===//
// Pattern builder
//===----------------------------------------------------------------------===//
LogicalResult lowerReductionApplyFunction(MLIRContext *context,
ArrayRef<Operation *> fns) {
OwningRewritePatternList patterns;
SPIRVReductionTypeConverter typeConverter;
patterns
.insert<ReductionApplyFnConversion,
ReductionOpConversion<xla_hlo::MinOp, spirv::AtomicSMinOp>,
ReductionOpConversion<xla_hlo::MaxOp, spirv::AtomicSMaxOp>,
ReductionOpConversion<AddIOp, spirv::AtomicIAddOp>,
ReturnOpConversion<IREE::ReturnOp>, ReturnOpConversion<ReturnOp>>(
context, typeConverter);
ConversionTarget target(*context);
target.addLegalDialect<spirv::SPIRVDialect>();
target.addDynamicallyLegalOp<FuncOp>(
[&](FuncOp op) { return typeConverter.isSignatureLegal(op.getType()); });
if (failed(applyPartialConversion(fns, target, patterns))) {
return failure();
}
return success();
}
//===----------------------------------------------------------------------===//
// Pass for invoking the conversion.
//===----------------------------------------------------------------------===//
namespace {
// Pass to invoke the reduction fn lowering from command line.
class ReduceFnSPIRVLoweringPass final
: public OperationPass<ReduceFnSPIRVLoweringPass, ModuleOp> {
private:
void runOnOperation() override;
};
} // namespace
void ReduceFnSPIRVLoweringPass::runOnOperation() {
ModuleOp module = getOperation();
MLIRContext *context = &getContext();
if (failed(lowerReductionApplyFunction(context, module.getOperation()))) {
return signalPassFailure();
}
}
static PassRegistration<ReduceFnSPIRVLoweringPass> pass(
"iree-spirv-reduction-fn-lowering",
"Convert the reduction apply function within reduction dispatches to "
"SPIR-V");
} // namespace iree_compiler
} // namespace mlir