blob: 00f806e1c3b4760a68e895a67cf02cb4f0b5894b [file] [log] [blame]
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "iree/compiler/Dialect/Shape/Transforms/Patterns.h"
#include "iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.h"
#include "iree/compiler/Dialect/VMLA/Conversion/HALToVMLA/ConvertHALToVMLA.h"
#include "iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.h"
#include "iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/ConvertStandardToVMLA.h"
#include "iree/compiler/Dialect/VMLA/Conversion/TypeConverter.h"
#include "iree/compiler/Dialect/VMLA/IR/VMLATypes.h"
#include "iree/compiler/Dialect/VMLA/Transforms/Passes.h"
#include "llvm/ADT/STLExtras.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/DialectConversion.h"
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
namespace mlir {
namespace iree_compiler {
namespace IREE {
namespace VMLA {
static LogicalResult insertInterfacesToEntryPoints(mlir::ModuleOp moduleOp) {
for (auto funcOp : moduleOp.getOps<FuncOp>()) {
if (SymbolTable::getSymbolVisibility(funcOp) !=
SymbolTable::Visibility::Public) {
continue;
}
auto originalType = funcOp.getType();
if (originalType.getNumInputs() != 0 || originalType.getNumResults() != 0) {
return funcOp.emitError() << "exported functions must have no I/O";
}
auto interfaceType = IREE::VMLA::InterfaceType::get(moduleOp.getContext());
auto newType =
FunctionType::get({interfaceType}, {}, moduleOp.getContext());
funcOp.setType(newType);
funcOp.front().addArgument(interfaceType);
}
return success();
}
// Runs conversion with registered input dialects.
class ConversionPass
: public PassWrapper<ConversionPass, OperationPass<mlir::ModuleOp>> {
public:
void runOnOperation() override {
// First insert vmla.interface arguments to all exported functions.
// The conversions require that the interface argument is present in order
// to properly retrieve buffer bindings.
if (failed(insertInterfacesToEntryPoints(getOperation()))) {
return signalPassFailure();
}
auto *context = &getContext();
VMLATypeConverter typeConverter;
VMLAConversionTarget conversionTarget(context, typeConverter);
// Ensure all input dialects go away.
conversionTarget.addIllegalDialect<xla_hlo::XlaHloDialect>();
conversionTarget.addIllegalDialect<IREE::HAL::HALDialect>();
OwningRewritePatternList conversionPatterns;
populateStandardToVMLAPatterns(context, conversionPatterns, typeConverter);
populateHLOToVMLAPatterns(context, conversionPatterns, typeConverter);
populateHALToVMLAPatterns(context, conversionPatterns, typeConverter);
// Ensure FuncOp signatures are updated.
populateFuncOpTypeConversionPattern(conversionPatterns, context,
typeConverter);
// We allow the shape dialect to persist, making specific dim queries
// illegal (which allows them to fold away). These patterns allow dimension
// queries to convert properly, but they do not allow the introduction
// of new shaped tensors.
Shape::populateFoldConversionPatterns(&getContext(), conversionPatterns);
conversionTarget.addLegalDialect<ShapeDialect>();
// Since all inputs are converted to buffers, must trigger the TieShape
// type conversion if the result type is illegal.
conversionTarget.addDynamicallyLegalOp<Shape::TieShapeOp>(
[](Shape::TieShapeOp op) {
return op.result().getType().isa<BufferType>();
});
conversionTarget.addIllegalOp<Shape::RankedDimOp>();
conversionTarget.addIllegalOp<Shape::RankedDimsOp>();
// XLA ops use tensors of extents, so we tend to launder back to
// !shapex.ranked_shape for most shape-related things. This is a problem
// because we don't have a lowering for the ops going back and forth between
// tensors of extents and !shapex.ranked_shape. So we mark this op as
// illegal and rely on our fold of `from_extent_tensor(to_extent_tensor(x))
// -> x` to eliminate these ops. Setting it illegal here triggers that fold.
// This is skating on thin ice.
// TODO(silvasean): Legalize ToExtentTensorOp and FromExtentTensorOp.
conversionTarget.addIllegalOp<Shape::FromExtentTensorOp>();
// RankedBroadcastInDimOp is an logically something that should be an
// xla_hlo op (or in a dialect at a similar level of abstraction), but since
// it isn't technically in that dialect, we need to special-case mark it as
// illegal here.
// TODO(silvasean): Reconcile the dialect layering here.
conversionTarget.addIllegalOp<Shape::RankedBroadcastInDimOp>();
if (failed(applyPartialConversion(getOperation(), conversionTarget,
conversionPatterns, &typeConverter))) {
getOperation().emitError() << "conversion to the VMLA dialect failed";
return signalPassFailure();
}
}
};
std::unique_ptr<OperationPass<mlir::ModuleOp>> createConversionPass() {
return std::make_unique<ConversionPass>();
}
static PassRegistration<ConversionPass> pass(
"iree-vmla-conversion",
"Converts from various dialects to the VMLA dialect");
} // namespace VMLA
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir