blob: 030bb3db87220a1ecd1cf79126cbdeb56d8fac43 [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/HAL/Utils/TypeUtils.h"
#include "iree/compiler/Dialect/VM/Conversion/ImportUtils.h"
#include "iree/compiler/Dialect/VM/IR/VMOps.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
namespace iree_compiler {
namespace {
class AllocatorAllocateConstOpConversion
: public OpConversionPattern<IREE::HAL::AllocatorAllocateConstOp> {
public:
AllocatorAllocateConstOpConversion(MLIRContext *context,
SymbolTable &importSymbols,
TypeConverter &typeConverter,
StringRef importName)
: OpConversionPattern(context) {
importOp = importSymbols.lookup<IREE::VM::ImportOp>(importName);
assert(importOp);
}
PatternMatchResult matchAndRewrite(
IREE::HAL::AllocatorAllocateConstOp op, llvm::ArrayRef<ValuePtr> operands,
ConversionPatternRewriter &rewriter) const override {
// Encode constant data into a rodata segment. These will eventually get
// deduped and combined.
auto ip = rewriter.saveInsertionPoint();
auto parentFuncOp = op.getParentOfType<IREE::VM::FuncOp>();
rewriter.setInsertionPoint(parentFuncOp);
auto constName = (parentFuncOp.getName() + "_const_" +
std::to_string(allocateUniqueId(parentFuncOp)))
.str();
auto rodataOp =
rewriter.create<IREE::VM::RodataOp>(op.getLoc(), constName, op.value());
rewriter.restoreInsertionPoint(ip);
auto loadRodataOp =
rewriter.create<IREE::VM::ConstRefRodataOp>(op.getLoc(), rodataOp);
IREE::HAL::AllocatorAllocateConstOpOperandAdaptor opAdaptor(operands);
auto shape = IREE::HAL::getStaticShapeDims(op.getLoc(),
op.value().getType(), rewriter);
SmallVector<ValuePtr, 8> callOperands = {
opAdaptor.allocator(),
rewriter.create<mlir::ConstantOp>(
op.getLoc(), rewriter.getI32IntegerAttr(
static_cast<int32_t>(op.memory_types()))),
rewriter.create<mlir::ConstantOp>(
op.getLoc(), rewriter.getI32IntegerAttr(
static_cast<int32_t>(op.buffer_usage()))),
};
callOperands.append(shape.begin(), shape.end());
callOperands.push_back(rewriter.create<mlir::ConstantOp>(
op.getLoc(),
rewriter.getI32IntegerAttr(IREE::HAL::getRoundedElementByteWidth(
op.value().getType().getElementType()))));
callOperands.push_back(loadRodataOp.getResult());
SmallVector<int8_t, 6> segmentSizes = {
/*allocator=*/-1,
/*memory_types=*/-1,
/*buffer_usage=*/-1,
/*shape=*/static_cast<int8_t>(shape.size()),
/*element_size=*/-1,
/*value=*/-1,
};
auto importType = importOp.getType();
rewriter.replaceOpWithNewOp<IREE::VM::CallVariadicOp>(
op, rewriter.getSymbolRefAttr(importOp), importType.getResults(),
segmentSizes, importType.getInputs(), callOperands);
return matchSuccess();
}
private:
// TODO(b/145839814): find a name that's unique or make the rewriter support
// assigning unique names.
int allocateUniqueId(Operation *context) const {
if (uniqueContext != context) {
uniqueContext = context;
uniqueCounter = 0;
}
return uniqueCounter++;
}
mutable Operation *uniqueContext = nullptr;
mutable int uniqueCounter = 0;
mutable IREE::VM::ImportOp importOp;
};
} // namespace
void populateHALAllocatorToVMPatterns(MLIRContext *context,
SymbolTable &importSymbols,
TypeConverter &typeConverter,
OwningRewritePatternList &patterns) {
patterns.insert<VMImportOpConversion<IREE::HAL::AllocatorComputeSizeOp>>(
context, importSymbols, typeConverter, "hal.allocator.compute_size");
patterns.insert<VMImportOpConversion<IREE::HAL::AllocatorAllocateOp>>(
context, importSymbols, typeConverter, "hal.allocator.allocate");
patterns.insert<AllocatorAllocateConstOpConversion>(
context, importSymbols, typeConverter, "hal.allocator.allocate.const");
patterns.insert<VMImportOpConversion<IREE::HAL::AllocatorAllocateShapedOp>>(
context, importSymbols, typeConverter, "hal.allocator.allocate.shaped");
}
} // namespace iree_compiler
} // namespace mlir