Changing tflite binding generation to use hal.tensor.cast. (#7661)
This allows for the buffers to be directly specified (just like we
specify buffer_views in the normal IREE bindings) and lets us drop all
of the shapex dialect ops from this layer.
Previously this was two passes but it relied on some shady assumptions
about when and how shapes were propagated. Now as a single pass all global
expanded shape dimensions, the ops tying them to values, and the logic
for query/update are explicitly specified and robust to further
transformation.
diff --git a/bindings/tflite/testdata/BUILD b/bindings/tflite/testdata/BUILD
index 64496d3..bc37cf5 100644
--- a/bindings/tflite/testdata/BUILD
+++ b/bindings/tflite/testdata/BUILD
@@ -18,7 +18,7 @@
src = "add_dynamic.mlir",
c_identifier = "iree_tflite_testdata_add_dynamic",
flags = [
- "--iree-input-type=mhlo",
+ "-iree-input-type=mhlo",
"-iree-native-bindings-support=false",
"-iree-tflite-bindings-support",
"-iree-mlir-to-vm-bytecode-module",
@@ -32,7 +32,7 @@
src = "add_multi.mlir",
c_identifier = "iree_tflite_testdata_add_multi",
flags = [
- "--iree-input-type=mhlo",
+ "-iree-input-type=mhlo",
"-iree-native-bindings-support=false",
"-iree-tflite-bindings-support",
"-iree-mlir-to-vm-bytecode-module",
@@ -46,7 +46,7 @@
src = "add_static.mlir",
c_identifier = "iree_tflite_testdata_add_static",
flags = [
- "--iree-input-type=mhlo",
+ "-iree-input-type=mhlo",
"-iree-native-bindings-support=false",
"-iree-tflite-bindings-support",
"-iree-mlir-to-vm-bytecode-module",
diff --git a/bindings/tflite/testdata/CMakeLists.txt b/bindings/tflite/testdata/CMakeLists.txt
index a2e4c93..69ffd6f 100644
--- a/bindings/tflite/testdata/CMakeLists.txt
+++ b/bindings/tflite/testdata/CMakeLists.txt
@@ -18,7 +18,7 @@
C_IDENTIFIER
"iree_tflite_testdata_add_dynamic"
FLAGS
- "--iree-input-type=mhlo"
+ "-iree-input-type=mhlo"
"-iree-native-bindings-support=false"
"-iree-tflite-bindings-support"
"-iree-mlir-to-vm-bytecode-module"
@@ -35,7 +35,7 @@
C_IDENTIFIER
"iree_tflite_testdata_add_multi"
FLAGS
- "--iree-input-type=mhlo"
+ "-iree-input-type=mhlo"
"-iree-native-bindings-support=false"
"-iree-tflite-bindings-support"
"-iree-mlir-to-vm-bytecode-module"
@@ -52,7 +52,7 @@
C_IDENTIFIER
"iree_tflite_testdata_add_static"
FLAGS
- "--iree-input-type=mhlo"
+ "-iree-input-type=mhlo"
"-iree-native-bindings-support=false"
"-iree-tflite-bindings-support"
"-iree-mlir-to-vm-bytecode-module"
diff --git a/iree/compiler/Bindings/TFLite/Transforms/BUILD b/iree/compiler/Bindings/TFLite/Transforms/BUILD
index f123eb3..0fee820 100644
--- a/iree/compiler/Bindings/TFLite/Transforms/BUILD
+++ b/iree/compiler/Bindings/TFLite/Transforms/BUILD
@@ -13,7 +13,6 @@
cc_library(
name = "Transforms",
srcs = [
- "MaterializeShapeSupport.cpp",
"Passes.cpp",
"WrapEntryPoints.cpp",
],
@@ -22,9 +21,7 @@
],
deps = [
"//iree/compiler/Dialect/Flow/IR",
- "//iree/compiler/Dialect/Shape/IR",
- "//iree/compiler/Dialect/Shape/Transforms",
- "//iree/compiler/Dialect/Shape/Utils:TypeConversion",
+ "//iree/compiler/Dialect/HAL/IR",
"//iree/compiler/Dialect/Util/IR",
"//iree/compiler/Utils",
"@llvm-project//llvm:Support",
diff --git a/iree/compiler/Bindings/TFLite/Transforms/CMakeLists.txt b/iree/compiler/Bindings/TFLite/Transforms/CMakeLists.txt
index 7fc31bf..9ceda22 100644
--- a/iree/compiler/Bindings/TFLite/Transforms/CMakeLists.txt
+++ b/iree/compiler/Bindings/TFLite/Transforms/CMakeLists.txt
@@ -16,7 +16,6 @@
HDRS
"Passes.h"
SRCS
- "MaterializeShapeSupport.cpp"
"Passes.cpp"
"WrapEntryPoints.cpp"
DEPS
@@ -31,9 +30,7 @@
MLIRTransformUtils
MLIRTransforms
iree::compiler::Dialect::Flow::IR
- iree::compiler::Dialect::Shape::IR
- iree::compiler::Dialect::Shape::Transforms
- iree::compiler::Dialect::Shape::Utils::TypeConversion
+ iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::Util::IR
iree::compiler::Utils
PUBLIC
diff --git a/iree/compiler/Bindings/TFLite/Transforms/MaterializeShapeSupport.cpp b/iree/compiler/Bindings/TFLite/Transforms/MaterializeShapeSupport.cpp
deleted file mode 100644
index 5f14a58..0000000
--- a/iree/compiler/Bindings/TFLite/Transforms/MaterializeShapeSupport.cpp
+++ /dev/null
@@ -1,462 +0,0 @@
-// Copyright 2021 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
-#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
-#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
-#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
-#include "llvm/ADT/STLExtras.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/IR/SymbolTable.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Pass/PassRegistry.h"
-#include "mlir/Support/LLVM.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/Utils.h"
-
-namespace mlir {
-namespace iree_compiler {
-namespace IREE {
-namespace TFLite {
-
-// Materializes the shape query and manipulation functions used by the
-// bindings. In tflite the specification of input shapes is performed
-// independently of execution and the input shapes are stateful.
-//
-// We do this by adding one variable for each I/O shape that stores the
-// current shape and functions to allow the bindings to query/manipulate those
-// variables. We then generate a function to perform the same kind of shape
-// propagation that the tflite runtime would have performed, only we need no
-// runtime support and can do so much more efficiently :)
-class MaterializeShapeSupportPass
- : public PassWrapper<MaterializeShapeSupportPass, OperationPass<ModuleOp>> {
- public:
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<iree_compiler::IREE::Flow::FlowDialect>();
- registry.insert<iree_compiler::IREE::Util::UtilDialect>();
- registry.insert<iree_compiler::ShapeDialect>();
- registry.insert<StandardOpsDialect, mlir::arith::ArithmeticDialect>();
- }
-
- StringRef getArgument() const override {
- return "iree-tflite-materialize-shape-support";
- }
-
- StringRef getDescription() const override {
- return "Materializes support functions for the tflite runtime bindings";
- }
-
- void runOnOperation() override {
- auto moduleOp = getOperation();
- auto moduleBuilder = OpBuilder::atBlockBegin(moduleOp.getBody());
- for (auto funcOp : llvm::to_vector<4>(moduleOp.getOps<FuncOp>())) {
- if (!funcOp.isPublic()) {
- continue;
- }
- if (failed(materializeShapeSupport(funcOp, moduleBuilder))) {
- signalPassFailure();
- return;
- }
- }
- }
-
- private:
- // Materializes all of the state and supporting functions to track the shapes
- // of the inputs and outputs of |funcOp|.
- LogicalResult materializeShapeSupport(FuncOp funcOp,
- OpBuilder &moduleBuilder) {
- auto loc = funcOp.getLoc();
- auto namePrefix = funcOp.getName();
-
- // Create a variable for each input and output to store the ranked shape.
- // These variables may represent fully static shapes - in which case they'll
- // get constant propagated - or dynamic shapes that will eventually get
- // turned into dynamic runtime values.
- SmallVector<IREE::Util::GlobalOp, 4> inputGlobalOps;
- SmallVector<IREE::Util::GlobalOp, 4> outputGlobalOps;
- createShapeVariables(loc, namePrefix, funcOp, inputGlobalOps,
- outputGlobalOps, moduleBuilder);
-
- // Create internal shape calculation function that updates output shapes if
- // needed. This is only required if there are dynamic shapes.
- auto dirtyGlobalOp = moduleBuilder.create<IREE::Util::GlobalOp>(
- loc, funcOp.getName().str() + "_shapes_dirty",
- /*isMutable=*/true, moduleBuilder.getI1Type(),
- moduleBuilder.getIntegerAttr(moduleBuilder.getI1Type(), 1));
- dirtyGlobalOp.setPrivate();
- auto calculateShapeFuncOp = createShapeCalculationFunc(
- loc, namePrefix, funcOp, inputGlobalOps, outputGlobalOps, dirtyGlobalOp,
- moduleBuilder);
-
- // Create input query function (just reads variables).
- createQueryInputShapeFunc(loc, namePrefix, inputGlobalOps, moduleBuilder);
-
- // Create input resize function (updates variables, set dirty flag).
- createResizeInputShapeFunc(loc, namePrefix, inputGlobalOps, dirtyGlobalOp,
- moduleBuilder);
-
- // Create output query function (if dirty recalculates shapes).
- createQueryOutputShapeFunc(loc, namePrefix, outputGlobalOps,
- calculateShapeFuncOp, moduleBuilder);
-
- return success();
- }
-
- // Creates and initializes to default values one util.global for each I/O
- // shape of the given |funcOp|. |inputGlobalOps| and |outputGlobalOps| will be
- // populated with the created variables.
- void createShapeVariables(
- Location loc, StringRef namePrefix, FuncOp funcOp,
- SmallVectorImpl<IREE::Util::GlobalOp> &inputGlobalOps,
- SmallVectorImpl<IREE::Util::GlobalOp> &outputGlobalOps,
- OpBuilder &moduleBuilder) {
- auto funcType = funcOp.getType();
-
- // TFLite requires the tensor names at runtime. If they've previously been
- // extracted into iree.identifiers we use those and otherwise fallback to
- // a generic naming scheme that matches the IR (somewhat).
- SmallVector<std::string, 4> inputNames;
- SmallVector<std::string, 4> outputNames;
- for (unsigned i = 0; i < funcType.getNumInputs(); ++i) {
- auto identifier =
- funcOp.getArgAttrOfType<StringAttr>(i, "iree.identifier");
- if (identifier) {
- inputNames.push_back(identifier.getValue().str());
- } else {
- inputNames.push_back(std::string("arg") + std::to_string(i));
- }
- }
- for (unsigned i = 0; i < funcType.getNumResults(); ++i) {
- auto identifier =
- funcOp.getResultAttrOfType<StringAttr>(i, "iree.identifier");
- if (identifier) {
- outputNames.push_back(identifier.getValue().str());
- } else {
- outputNames.push_back(std::string("ret") + std::to_string(i));
- }
- }
-
- for (auto input : llvm::zip(inputNames, funcType.getInputs())) {
- auto name = funcOp.getName().str() + "_" + std::get<0>(input) + "_shape";
- auto type = std::get<1>(input);
- auto tensorType = type.dyn_cast<TensorType>();
- assert(tensorType && "expecting only tensors in tflite function I/O");
- inputGlobalOps.push_back(
- createShapeVariable(loc, name, tensorType, moduleBuilder));
- }
- for (auto output : llvm::zip(outputNames, funcType.getResults())) {
- auto name = funcOp.getName().str() + "_" + std::get<0>(output) + "_shape";
- auto type = std::get<1>(output);
- auto tensorType = type.dyn_cast<TensorType>();
- assert(tensorType && "expecting only tensors in tflite function I/O");
- outputGlobalOps.push_back(
- createShapeVariable(loc, name, tensorType, moduleBuilder));
- }
- }
-
- // Declares a global variable that holds a shape for the given |tensorType|.
- IREE::Util::GlobalOp createShapeVariable(Location loc, StringRef name,
- TensorType tensorType,
- OpBuilder &moduleBuilder) {
- auto shapeType = Shape::RankedShapeType::get(tensorType.getShape(),
- moduleBuilder.getContext());
- auto globalOp = moduleBuilder.create<IREE::Util::GlobalOp>(
- loc, name, /*isMutable=*/true, shapeType);
- globalOp.setPrivate();
- return globalOp;
- }
-
- // Derives a shape calculation function from the given entry point |funcOp|.
- FuncOp createShapeCalculationFunc(
- Location loc, StringRef namePrefix, FuncOp funcOp,
- ArrayRef<IREE::Util::GlobalOp> inputGlobalOps,
- ArrayRef<IREE::Util::GlobalOp> outputGlobalOps,
- IREE::Util::GlobalOp dirtyGlobalOp, OpBuilder &moduleBuilder) {
- // Clone the entire entry function with all its IR.
- auto calcFuncOp = cast<FuncOp>(moduleBuilder.clone(*funcOp.getOperation()));
- mlir::StringAttr nameAttr = mlir::StringAttr::get(
- loc.getContext(), namePrefix.str() + "_calculate_shapes");
- calcFuncOp.setName(nameAttr);
- calcFuncOp.setPrivate();
- // TODO(benvanik): find a better way to strip these attributes.
- calcFuncOp->removeAttr("iree.abi.stub");
- calcFuncOp->removeAttr("iree.reflection");
- auto &entryBlock = calcFuncOp.front();
- auto entryBuilder = OpBuilder::atBlockBegin(&entryBlock);
-
- // Go back and insert a check for the dirty flag.
- auto dirtyValue = entryBuilder.createOrFold<IREE::Util::GlobalLoadOp>(
- loc, dirtyGlobalOp.type(), dirtyGlobalOp.getName());
- auto *recalculateBlock = calcFuncOp.addBlock();
- auto *returnBlock = calcFuncOp.addBlock();
- entryBuilder.create<CondBranchOp>(loc, dirtyValue, recalculateBlock,
- returnBlock);
- auto *followBlock = entryBlock.splitBlock(entryBuilder.getInsertionPoint());
-
- // Turn inputs into placeholder values and kill all return values.
- // DCE then has an easy time ripping the tensor values all out.
- // We need to tie the input variable shapes to the placeholders so shape
- // propagation can use them.
- auto recalculateBuilder = OpBuilder::atBlockBegin(recalculateBlock);
- calcFuncOp.setType(
- recalculateBuilder.getFunctionType(/*inputs=*/TypeRange{},
- /*outputs=*/TypeRange{}));
- for (auto inputValueVar :
- llvm::zip(entryBlock.getArguments(), inputGlobalOps)) {
- auto inputValue = std::get<0>(inputValueVar);
- auto inputGlobalOp = std::get<1>(inputValueVar);
- auto inputPlaceholder =
- recalculateBuilder.createOrFold<IREE::Util::NullOp>(
- loc, inputValue.getType());
- auto inputShapeValue =
- recalculateBuilder.createOrFold<IREE::Util::GlobalLoadOp>(
- loc, inputGlobalOp.type(), inputGlobalOp.getName());
- auto tiedValue = recalculateBuilder.create<Shape::TieShapeOp>(
- loc, inputPlaceholder, inputShapeValue);
- inputValue.replaceAllUsesWith(tiedValue);
- }
- while (entryBlock.getNumArguments() > 0) {
- entryBlock.eraseArgument(entryBlock.getNumArguments() - 1);
- }
- recalculateBuilder.create<BranchOp>(loc, followBlock);
- recalculateBlock->moveBefore(followBlock);
-
- // Replace each exit from the function with a storage back to the shape
- // variables.
- for (auto returnOp : llvm::to_vector<4>(calcFuncOp.getOps<ReturnOp>())) {
- auto exitLoc = returnOp.getLoc();
- OpBuilder exitBuilder(returnOp);
-
- // Store the derived shape values into the output shape variables.
- // We do this per exit-site so that if the function has multiple code
- // paths that may return different shape sizes we capture them all.
- for (auto outputValueVar :
- llvm::zip(returnOp.getOperands(), outputGlobalOps)) {
- auto outputValue = std::get<0>(outputValueVar);
- auto outputGlobalOp = std::get<1>(outputValueVar);
- auto outputShapeValue =
- exitBuilder.createOrFold<Shape::GetRankedShapeOp>(exitLoc,
- outputValue);
- exitBuilder.create<IREE::Util::GlobalStoreOp>(exitLoc, outputShapeValue,
- outputGlobalOp.getName());
- }
-
- // Clear the dirty flag now that the shapes have been updated.
- auto falseValue =
- exitBuilder.createOrFold<arith::ConstantIntOp>(exitLoc, 0, 1);
- exitBuilder.create<IREE::Util::GlobalStoreOp>(exitLoc, falseValue,
- dirtyGlobalOp.getName());
- exitBuilder.create<ReturnOp>(exitLoc);
- returnOp.erase();
- }
-
- OpBuilder::atBlockBegin(returnBlock).create<ReturnOp>(loc);
-
- return calcFuncOp;
- }
-
- // Builds a switch-statement-like chain of blocks starting at |builder|.
- // Returns a block that execution resumes at after the switch.
- Block *buildSwitch(
- Location loc, Value indexValue, size_t caseCount,
- std::function<void(size_t i, OpBuilder &caseBuilder)> caseGenerator,
- OpBuilder &builder) {
- auto *entryBlock = builder.getBlock();
- auto ip = builder.saveInsertionPoint();
- auto *exitBlock = builder.createBlock(entryBlock->getParent(),
- ++Region::iterator(entryBlock));
- if (caseCount == 0) {
- builder.create<BranchOp>(loc, exitBlock);
- return exitBlock;
- }
- SmallVector<Block *, 4> compareBlocks;
- SmallVector<Block *, 4> caseBlocks;
- for (size_t i = 0; i < caseCount; ++i) {
- compareBlocks.push_back(builder.createBlock(exitBlock));
- caseBlocks.push_back(builder.createBlock(exitBlock));
- }
- builder.restoreInsertionPoint(ip);
- builder.create<BranchOp>(loc, compareBlocks[0]);
- for (size_t i = 0; i < caseCount; ++i) {
- auto compareBuilder = OpBuilder::atBlockBegin(compareBlocks[i]);
- auto caseValue =
- compareBuilder.createOrFold<arith::ConstantIndexOp>(loc, i);
- auto eqValue = compareBuilder.createOrFold<arith::CmpIOp>(
- loc, arith::CmpIPredicate::eq, indexValue, caseValue);
- compareBuilder.create<CondBranchOp>(
- loc, eqValue, caseBlocks[i],
- i < caseCount - 1 ? compareBlocks[i + 1] : exitBlock);
-
- auto caseBuilder = OpBuilder::atBlockBegin(caseBlocks[i]);
- caseGenerator(i, caseBuilder);
- caseBuilder.create<BranchOp>(loc, exitBlock);
- }
- builder = OpBuilder::atBlockBegin(exitBlock);
- return exitBlock;
- }
-
- // Packs a shape into a list.
- void packShape(Location loc, Shape::RankedShapeType shapeType,
- Value shapeValue, Value listValue, OpBuilder &builder) {
- builder.create<IREE::Util::ListResizeOp>(
- loc, listValue,
- builder.createOrFold<arith::ConstantIndexOp>(loc, shapeType.getRank()));
- for (int i = 0; i < shapeType.getRank(); ++i) {
- auto dimValue =
- builder.createOrFold<Shape::RankedDimOp>(loc, shapeValue, i);
- builder.create<IREE::Util::ListSetOp>(
- loc, listValue, builder.createOrFold<arith::ConstantIndexOp>(loc, i),
- dimValue);
- }
- }
-
- // Unpacks a shape from a list.
- Value unpackShape(Location loc, Shape::RankedShapeType shapeType,
- Value listValue, OpBuilder &builder) {
- SmallVector<Value, 4> dynamicDims;
- for (int i = 0; i < shapeType.getRank(); ++i) {
- if (!shapeType.isDimDynamic(i)) continue;
- dynamicDims.push_back(builder.createOrFold<IREE::Util::ListGetOp>(
- loc, builder.getIndexType(), listValue,
- builder.createOrFold<arith::ConstantIndexOp>(loc, i)));
- }
- return builder.createOrFold<Shape::MakeRankedShapeOp>(loc, shapeType,
- dynamicDims);
- }
-
- // Creates a function to query the |inputGlobalOps| at runtime by the
- // bindings.
- //
- // func @_query_input_shape(%index : index, %shape : !util.list<index>)
- void createQueryInputShapeFunc(Location loc, StringRef namePrefix,
- ArrayRef<IREE::Util::GlobalOp> inputGlobalOps,
- OpBuilder &moduleBuilder) {
- auto queryFuncOp = moduleBuilder.create<FuncOp>(
- loc, namePrefix.str() + "_query_input_shape",
- moduleBuilder.getFunctionType(/*inputs=*/
- TypeRange{
- moduleBuilder.getIndexType(),
- IREE::Util::ListType::get(
- moduleBuilder.getIndexType()),
- },
- /*outputs=*/TypeRange{}));
- queryFuncOp->setAttr("iree.abi.stub", moduleBuilder.getUnitAttr());
- auto *entryBlock = queryFuncOp.addEntryBlock();
- auto entryBuilder = OpBuilder::atBlockBegin(entryBlock);
- auto listValue = entryBlock->getArgument(1);
-
- auto *exitBlock = buildSwitch(
- loc, entryBlock->getArgument(0), inputGlobalOps.size(),
- [&](size_t i, OpBuilder &caseBuilder) {
- auto inputGlobalOp = inputGlobalOps[i];
- auto shapeType = inputGlobalOp.type().cast<Shape::RankedShapeType>();
- auto shapeValue = caseBuilder.createOrFold<IREE::Util::GlobalLoadOp>(
- loc, inputGlobalOp.type(), inputGlobalOp.getName());
- packShape(loc, shapeType, shapeValue, listValue, caseBuilder);
- },
- entryBuilder);
-
- auto exitBuilder = OpBuilder::atBlockBegin(exitBlock);
- exitBuilder.create<ReturnOp>(loc);
- }
-
- // Creates a function to resize |inputGlobalOps| and sets the |dirtyGlobalOp|
- // flag.
- //
- // func @_resize_input_shape(%index : index, %shape : !util.list<index>)
- void createResizeInputShapeFunc(Location loc, StringRef namePrefix,
- ArrayRef<IREE::Util::GlobalOp> inputGlobalOps,
- IREE::Util::GlobalOp dirtyGlobalOp,
- OpBuilder &moduleBuilder) {
- auto resizeFuncOp = moduleBuilder.create<FuncOp>(
- loc, namePrefix.str() + "_resize_input_shape",
- moduleBuilder.getFunctionType(/*inputs=*/
- TypeRange{
- moduleBuilder.getIndexType(),
- IREE::Util::ListType::get(
- moduleBuilder.getIndexType()),
- },
- /*outputs=*/TypeRange{}));
- resizeFuncOp->setAttr("iree.abi.stub", moduleBuilder.getUnitAttr());
- auto *entryBlock = resizeFuncOp.addEntryBlock();
- auto entryBuilder = OpBuilder::atBlockBegin(entryBlock);
- auto listValue = entryBlock->getArgument(1);
-
- auto *exitBlock = buildSwitch(
- loc, entryBlock->getArgument(0), inputGlobalOps.size(),
- [&](size_t i, OpBuilder &caseBuilder) {
- auto inputGlobalOp = inputGlobalOps[i];
- auto shapeType = inputGlobalOp.type().cast<Shape::RankedShapeType>();
- auto shapeValue = unpackShape(loc, shapeType, listValue, caseBuilder);
- caseBuilder.create<IREE::Util::GlobalStoreOp>(
- loc, shapeValue, inputGlobalOp.getName());
- },
- entryBuilder);
-
- // Set the dirty flag so that shapes get recalculated as needed.
- auto exitBuilder = OpBuilder::atBlockBegin(exitBlock);
- auto trueValue = exitBuilder.createOrFold<arith::ConstantIntOp>(loc, 1, 1);
- exitBuilder.create<IREE::Util::GlobalStoreOp>(loc, trueValue,
- dirtyGlobalOp.getName());
- exitBuilder.create<ReturnOp>(loc);
- }
-
- // Creates a function to query the |outputGlobalOps| at runtime by the
- // bindings.
- //
- // func @_query_output_shape(%index : index, %shape : !util.list<index>)
- void createQueryOutputShapeFunc(
- Location loc, StringRef namePrefix,
- ArrayRef<IREE::Util::GlobalOp> outputGlobalOps,
- FuncOp calculateShapeFuncOp, OpBuilder &moduleBuilder) {
- auto queryFuncOp = moduleBuilder.create<FuncOp>(
- loc, namePrefix.str() + "_query_output_shape",
- moduleBuilder.getFunctionType(/*inputs=*/
- TypeRange{
- moduleBuilder.getIndexType(),
- IREE::Util::ListType::get(
- moduleBuilder.getIndexType()),
- },
- /*outputs=*/TypeRange{}));
- queryFuncOp->setAttr("iree.abi.stub", moduleBuilder.getUnitAttr());
- auto *entryBlock = queryFuncOp.addEntryBlock();
- auto entryBuilder = OpBuilder::atBlockBegin(entryBlock);
- auto listValue = entryBlock->getArgument(1);
-
- // Always call the recalculation function - it checks for whether it needs
- // to run based on the dirty flag value.
- entryBuilder.create<CallOp>(loc, calculateShapeFuncOp);
-
- auto *exitBlock = buildSwitch(
- loc, entryBlock->getArgument(0), outputGlobalOps.size(),
- [&](size_t i, OpBuilder &caseBuilder) {
- auto outputGlobalOp = outputGlobalOps[i];
- auto shapeType = outputGlobalOp.type().cast<Shape::RankedShapeType>();
- auto shapeValue = caseBuilder.createOrFold<IREE::Util::GlobalLoadOp>(
- loc, outputGlobalOp.type(), outputGlobalOp.getName());
- packShape(loc, shapeType, shapeValue, listValue, caseBuilder);
- },
- entryBuilder);
-
- auto exitBuilder = OpBuilder::atBlockBegin(exitBlock);
- exitBuilder.create<ReturnOp>(loc);
- }
-};
-
-std::unique_ptr<OperationPass<ModuleOp>> createMaterializeShapeSupportPass() {
- return std::make_unique<MaterializeShapeSupportPass>();
-}
-
-static PassRegistration<MaterializeShapeSupportPass> pass;
-
-} // namespace TFLite
-} // namespace IREE
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Bindings/TFLite/Transforms/Passes.cpp b/iree/compiler/Bindings/TFLite/Transforms/Passes.cpp
index 8758227..6ddf1e5 100644
--- a/iree/compiler/Bindings/TFLite/Transforms/Passes.cpp
+++ b/iree/compiler/Bindings/TFLite/Transforms/Passes.cpp
@@ -18,13 +18,9 @@
namespace TFLite {
void buildTransformPassPipeline(OpPassManager &passManager) {
- // Wraps the entry points in a "_tflite_xx" function.
+ // Wraps the entry points in a "_tflite_xx" function and adds shape support.
passManager.addPass(createWrapEntryPointsPass());
- // Materialize the functions required by the runtime bindings to manipulate
- // the program state.
- passManager.addPass(createMaterializeShapeSupportPass());
-
// Cleanup the IR after manipulating it.
passManager.addPass(createInlinerPass());
passManager.addNestedPass<FuncOp>(createCanonicalizerPass());
diff --git a/iree/compiler/Bindings/TFLite/Transforms/Passes.h b/iree/compiler/Bindings/TFLite/Transforms/Passes.h
index 6bfd57f..0b8f30c 100644
--- a/iree/compiler/Bindings/TFLite/Transforms/Passes.h
+++ b/iree/compiler/Bindings/TFLite/Transforms/Passes.h
@@ -35,17 +35,12 @@
// expected invocation semantics of the IREE TFLite bindings.
std::unique_ptr<OperationPass<ModuleOp>> createWrapEntryPointsPass();
-// Materialize the functions required by the runtime bindings to manipulate
-// the program state (like _resize_input_shape, etc).
-std::unique_ptr<OperationPass<ModuleOp>> createMaterializeShapeSupportPass();
-
//===----------------------------------------------------------------------===//
// Register all Passes
//===----------------------------------------------------------------------===//
inline void registerPasses() {
createWrapEntryPointsPass();
- createMaterializeShapeSupportPass();
}
} // namespace TFLite
diff --git a/iree/compiler/Bindings/TFLite/Transforms/WrapEntryPoints.cpp b/iree/compiler/Bindings/TFLite/Transforms/WrapEntryPoints.cpp
index d874d5a..ae54240 100644
--- a/iree/compiler/Bindings/TFLite/Transforms/WrapEntryPoints.cpp
+++ b/iree/compiler/Bindings/TFLite/Transforms/WrapEntryPoints.cpp
@@ -4,9 +4,14 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
+#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringExtras.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/SymbolTable.h"
@@ -22,12 +27,30 @@
namespace TFLite {
// Wraps each model entry point in a "_tflite_xx" function that matches the
-// expectations of the IREE TFLite C bindings.
+// expectations of the IREE TFLite C bindings and materializes shape query and
+// calculation functions for dynamically shaped I/O.
+//
+// For each exported function we produce:
+// - `_tflite_xx_argN`/`retN` globals carrying shape dimensions
+// - `_tflite_xx` entry function wrapping the existing export
+// - `_tflite_xx_calculate_shapes` shape calculation function
+// - `_tflite_xx_query_input_shape` shape query function
+// - `_tflite_xx_query_output_shape` shape query function
+//
+// Each I/O of the function gets one global per dynamic dimension storing the
+// provided or calculated dimension value at runtime. For example:
+// (%arg0: tensor<1x?x?xf32>)
+// ->
+// // no dim0 as it is static 1.
+// util.global private mutable @_tflite_xx_arg0_dim1 : index
+// util.global private mutable @_tflite_xx_arg0_dim2 : index
class WrapEntryPointsPass
: public PassWrapper<WrapEntryPointsPass, OperationPass<ModuleOp>> {
public:
void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<StandardOpsDialect, mlir::arith::ArithmeticDialect>();
+ registry.insert<mlir::StandardOpsDialect, mlir::arith::ArithmeticDialect,
+ mlir::tensor::TensorDialect, IREE::HAL::HALDialect,
+ IREE::Util::UtilDialect>();
}
StringRef getArgument() const override {
@@ -61,16 +84,383 @@
signalPassFailure();
return;
}
-
- // Create a wrapper function for the entry point.
- auto entryFuncOp = entryFuncOps.front();
- entryFuncOp.setPrivate();
- auto wrapperFuncOp = createWrapperFunc(entryFuncOp);
- wrapperFuncOp.setPublic();
- moduleOp.insert(Block::iterator(entryFuncOp), wrapperFuncOp);
+ wrapEntryPoint(entryFuncOps.front());
}
private:
+ // Globals representing each dynamic dimension of an IO tensor.
+ struct DynamicDims {
+ TensorType tensorType;
+ mutable SmallVector<IREE::Util::GlobalOp> globalOps;
+
+ SmallVector<Value> loadDynamicDims(OpBuilder &builder) {
+ SmallVector<Value> dims;
+ unsigned dynamicDimIdx = 0;
+ for (unsigned i = 0; i < tensorType.getRank(); ++i) {
+ if (tensorType.isDynamicDim(i)) {
+ auto globalOp = globalOps[dynamicDimIdx++];
+ dims.push_back(
+ builder
+ .create<IREE::Util::GlobalLoadOp>(globalOp.getLoc(), globalOp)
+ .result());
+ }
+ }
+ return dims;
+ }
+ };
+
+ // Creates one util.global index op for each |tensorType| dynamic dimension.
+ static DynamicDims createDynamicDimGlobals(Location loc, StringRef namePrefix,
+ TensorType tensorType,
+ OpBuilder &moduleBuilder) {
+ DynamicDims dynamicDims;
+ dynamicDims.tensorType = tensorType;
+ for (unsigned i = 0; i < tensorType.getRank(); ++i) {
+ if (tensorType.isDynamicDim(i)) {
+ auto globalOp = moduleBuilder.create<IREE::Util::GlobalOp>(
+ loc, (namePrefix + "_dim" + std::to_string(i)).str(),
+ /*isMutable=*/true, moduleBuilder.getIndexType());
+ globalOp.setPrivate();
+ dynamicDims.globalOps.push_back(globalOp);
+ }
+ }
+ return dynamicDims;
+ }
+
+ // Creates dynamic dim globals for each input and output of |funcOp|.
+ static std::pair<SmallVector<DynamicDims>, SmallVector<DynamicDims>>
+ createDynamicDimGlobals(Location loc, StringRef namePrefix,
+ mlir::FuncOp funcOp, OpBuilder &moduleBuilder) {
+ auto funcType = funcOp.getType();
+
+ // TFLite requires the tensor names at runtime. If they've previously been
+ // extracted into iree.identifiers we use those and otherwise fallback to
+ // a generic naming scheme that matches the IR (somewhat).
+ SmallVector<std::string, 4> inputNames;
+ SmallVector<std::string, 4> outputNames;
+ for (unsigned i = 0; i < funcType.getNumInputs(); ++i) {
+ auto identifier =
+ funcOp.getArgAttrOfType<StringAttr>(i, "iree.identifier");
+ if (identifier) {
+ inputNames.push_back(identifier.getValue().str());
+ } else {
+ inputNames.push_back(std::string("arg") + std::to_string(i));
+ }
+ }
+ for (unsigned i = 0; i < funcType.getNumResults(); ++i) {
+ auto identifier =
+ funcOp.getResultAttrOfType<StringAttr>(i, "iree.identifier");
+ if (identifier) {
+ outputNames.push_back(identifier.getValue().str());
+ } else {
+ outputNames.push_back(std::string("ret") + std::to_string(i));
+ }
+ }
+
+ SmallVector<DynamicDims> inputDynamicDims;
+ for (auto input :
+ llvm::zip(funcOp.getArguments(), inputNames, funcType.getInputs())) {
+ auto argLoc = std::get<0>(input).getLoc();
+ auto name = (namePrefix + "_" + std::get<1>(input) + "_shape").str();
+ auto type = std::get<2>(input);
+ auto tensorType = type.dyn_cast<TensorType>();
+ assert(tensorType && "expecting only tensors in tflite function I/O");
+ inputDynamicDims.push_back(
+ createDynamicDimGlobals(argLoc, name, tensorType, moduleBuilder));
+ }
+ SmallVector<DynamicDims> outputDynamicDims;
+ for (auto output : llvm::zip(outputNames, funcType.getResults())) {
+ auto name = (namePrefix + "_" + std::get<0>(output) + "_shape").str();
+ auto type = std::get<1>(output);
+ auto tensorType = type.dyn_cast<TensorType>();
+ assert(tensorType && "expecting only tensors in tflite function I/O");
+ outputDynamicDims.push_back(
+ createDynamicDimGlobals(loc, name, tensorType, moduleBuilder));
+ }
+
+ return std::make_pair(inputDynamicDims, outputDynamicDims);
+ }
+
+ // Derives a shape calculation function from the given entry point |funcOp|.
+ static mlir::FuncOp createShapeCalculationFunc(
+ Location loc, StringRef namePrefix, mlir::FuncOp funcOp,
+ ArrayRef<DynamicDims> inputDynamicDims,
+ ArrayRef<DynamicDims> outputDynamicDims,
+ IREE::Util::GlobalOp dirtyGlobalOp, OpBuilder &moduleBuilder) {
+ // Clone the entire entry function with all its IR.
+ auto calcFuncOp =
+ cast<mlir::FuncOp>(moduleBuilder.clone(*funcOp.getOperation()));
+ calcFuncOp.setName(
+ moduleBuilder.getStringAttr(namePrefix.str() + "_calculate_shapes"));
+ calcFuncOp.setPrivate();
+ // TODO(benvanik): find a better way to strip these attributes.
+ calcFuncOp->removeAttr("iree.abi.stub");
+ calcFuncOp->removeAttr("iree.reflection");
+ auto &entryBlock = calcFuncOp.front();
+ auto entryBuilder = OpBuilder::atBlockBegin(&entryBlock);
+
+ // Go back and insert a check for the dirty flag.
+ auto dirtyValue = entryBuilder.createOrFold<IREE::Util::GlobalLoadOp>(
+ loc, dirtyGlobalOp.type(), dirtyGlobalOp.getName());
+ auto *recalculateBlock = calcFuncOp.addBlock();
+ auto *returnBlock = calcFuncOp.addBlock();
+ entryBuilder.create<mlir::CondBranchOp>(loc, dirtyValue, recalculateBlock,
+ returnBlock);
+ auto *followBlock = entryBlock.splitBlock(entryBuilder.getInsertionPoint());
+
+ auto bufferType = entryBuilder.getType<IREE::HAL::BufferType>();
+
+ // Turn inputs into placeholder values and kill all return values.
+ // DCE then has an easy time ripping the tensor values all out.
+ // We need to tie the input variable shapes to the placeholders so shape
+ // propagation can use them.
+ auto recalculateBuilder = OpBuilder::atBlockBegin(recalculateBlock);
+ calcFuncOp.setType(
+ recalculateBuilder.getFunctionType(/*inputs=*/TypeRange{},
+ /*outputs=*/TypeRange{}));
+ for (auto inputValueDims :
+ llvm::zip(entryBlock.getArguments(), inputDynamicDims)) {
+ auto inputValue = std::get<0>(inputValueDims);
+ auto inputDynamicDims = std::get<1>(inputValueDims);
+ auto inputPlaceholder =
+ recalculateBuilder.createOrFold<IREE::Util::NullOp>(loc, bufferType);
+ auto dynamicDims = inputDynamicDims.loadDynamicDims(recalculateBuilder);
+ auto castOp = recalculateBuilder.create<IREE::HAL::TensorCastOp>(
+ loc, inputValue.getType(), inputPlaceholder, dynamicDims);
+ inputValue.replaceAllUsesWith(castOp.target());
+ }
+ while (entryBlock.getNumArguments() > 0) {
+ entryBlock.eraseArgument(entryBlock.getNumArguments() - 1);
+ }
+ recalculateBuilder.create<mlir::BranchOp>(loc, followBlock);
+ recalculateBlock->moveBefore(followBlock);
+
+ // Replace each exit from the function with a storage back to the shape
+ // variables.
+ for (auto returnOp :
+ llvm::to_vector<4>(calcFuncOp.getOps<mlir::ReturnOp>())) {
+ auto exitLoc = returnOp.getLoc();
+ OpBuilder exitBuilder(returnOp);
+
+ // Store the derived shape values into the output shape variables.
+ // We do this per exit-site so that if the function has multiple code
+ // paths that may return different shape sizes we capture them all.
+ for (auto outputValueDims :
+ llvm::zip(returnOp.getOperands(), outputDynamicDims)) {
+ auto outputValue = std::get<0>(outputValueDims);
+ auto outputDynamicDims = std::get<1>(outputValueDims);
+ SmallVector<Value> dynamicDims;
+ for (int64_t i = 0; i < outputDynamicDims.globalOps.size(); ++i) {
+ auto dimValue =
+ exitBuilder.createOrFold<tensor::DimOp>(exitLoc, outputValue, i);
+ exitBuilder.create<IREE::Util::GlobalStoreOp>(
+ exitLoc, dimValue,
+ outputDynamicDims.globalOps[i].getSymbolName());
+ }
+ }
+
+ // Clear the dirty flag now that the shapes have been updated.
+ auto falseValue =
+ exitBuilder.createOrFold<arith::ConstantIntOp>(exitLoc, 0, 1);
+ exitBuilder.create<IREE::Util::GlobalStoreOp>(
+ exitLoc, falseValue, dirtyGlobalOp.getSymbolName());
+ exitBuilder.create<mlir::ReturnOp>(exitLoc);
+ returnOp.erase();
+ }
+
+ OpBuilder::atBlockBegin(returnBlock).create<mlir::ReturnOp>(loc);
+
+ return calcFuncOp;
+ }
+
+ // Builds a switch-statement-like chain of blocks starting at |builder|.
+ // Returns a block that execution resumes at after the switch.
+ static Block *buildSwitch(
+ Location loc, Value indexValue, size_t caseCount,
+ std::function<void(size_t i, OpBuilder &caseBuilder)> caseGenerator,
+ OpBuilder &builder) {
+ auto *entryBlock = builder.getBlock();
+ auto ip = builder.saveInsertionPoint();
+ auto *exitBlock = builder.createBlock(entryBlock->getParent(),
+ ++Region::iterator(entryBlock));
+ if (caseCount == 0) {
+ builder.create<mlir::BranchOp>(loc, exitBlock);
+ return exitBlock;
+ }
+ SmallVector<Block *, 4> compareBlocks;
+ SmallVector<Block *, 4> caseBlocks;
+ for (size_t i = 0; i < caseCount; ++i) {
+ compareBlocks.push_back(builder.createBlock(exitBlock));
+ caseBlocks.push_back(builder.createBlock(exitBlock));
+ }
+ builder.restoreInsertionPoint(ip);
+ builder.create<mlir::BranchOp>(loc, compareBlocks[0]);
+ for (size_t i = 0; i < caseCount; ++i) {
+ auto compareBuilder = OpBuilder::atBlockBegin(compareBlocks[i]);
+ auto caseValue =
+ compareBuilder.createOrFold<arith::ConstantIndexOp>(loc, i);
+ auto eqValue = compareBuilder.createOrFold<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::eq, indexValue, caseValue);
+ compareBuilder.create<mlir::CondBranchOp>(
+ loc, eqValue, caseBlocks[i],
+ i < caseCount - 1 ? compareBlocks[i + 1] : exitBlock);
+
+ auto caseBuilder = OpBuilder::atBlockBegin(caseBlocks[i]);
+ caseGenerator(i, caseBuilder);
+ caseBuilder.create<mlir::BranchOp>(loc, exitBlock);
+ }
+ builder = OpBuilder::atBlockBegin(exitBlock);
+ return exitBlock;
+ }
+
+ // Packs a shape into a list.
+ void packShape(Location loc, const DynamicDims &dynamicDims, Value listValue,
+ OpBuilder &builder) {
+ auto shapeType = dynamicDims.tensorType;
+ builder.create<IREE::Util::ListResizeOp>(
+ loc, listValue,
+ builder.createOrFold<arith::ConstantIndexOp>(loc, shapeType.getRank()));
+ unsigned dynamicDimIdx = 0;
+ for (unsigned i = 0; i < shapeType.getRank(); ++i) {
+ Value dimValue;
+ if (shapeType.isDynamicDim(i)) {
+ dimValue = builder.create<IREE::Util::GlobalLoadOp>(
+ loc, dynamicDims.globalOps[dynamicDimIdx++]);
+ } else {
+ dimValue = builder.createOrFold<arith::ConstantIndexOp>(
+ loc, shapeType.getDimSize(i));
+ }
+ builder.create<IREE::Util::ListSetOp>(
+ loc, listValue, builder.createOrFold<arith::ConstantIndexOp>(loc, i),
+ dimValue);
+ }
+ }
+
+ // Unpacks a shape from a list.
+ void unpackShape(Location loc, Value listValue,
+ const DynamicDims &dynamicDims, OpBuilder &builder) {
+ auto shapeType = dynamicDims.tensorType;
+ unsigned dynamicDimIdx = 0;
+ for (unsigned i = 0; i < shapeType.getRank(); ++i) {
+ if (!shapeType.isDynamicDim(i)) continue;
+ auto dimValue =
+ builder
+ .create<IREE::Util::ListGetOp>(
+ loc, builder.getIndexType(), listValue,
+ builder.createOrFold<arith::ConstantIndexOp>(loc, i))
+ .result();
+ builder.create<IREE::Util::GlobalStoreOp>(
+ loc, dimValue,
+ dynamicDims.globalOps[dynamicDimIdx++].getSymbolName());
+ }
+ }
+
+ // Creates a function to query the |inputGlobalOps| at runtime by the
+ // bindings.
+ //
+ // func @_query_input_shape(%index : index, %shape : !util.list<index>)
+ void createQueryInputShapeFunc(Location loc, StringRef namePrefix,
+ ArrayRef<DynamicDims> inputDynamicDims,
+ OpBuilder &moduleBuilder) {
+ auto queryFuncOp = moduleBuilder.create<mlir::FuncOp>(
+ loc, namePrefix.str() + "_query_input_shape",
+ moduleBuilder.getFunctionType(/*inputs=*/
+ TypeRange{
+ moduleBuilder.getIndexType(),
+ IREE::Util::ListType::get(
+ moduleBuilder.getIndexType()),
+ },
+ /*outputs=*/TypeRange{}));
+ queryFuncOp->setAttr("iree.abi.stub", moduleBuilder.getUnitAttr());
+ auto *entryBlock = queryFuncOp.addEntryBlock();
+ auto entryBuilder = OpBuilder::atBlockBegin(entryBlock);
+ auto listValue = entryBlock->getArgument(1);
+
+ auto *exitBlock = buildSwitch(
+ loc, entryBlock->getArgument(0), inputDynamicDims.size(),
+ [&](size_t i, OpBuilder &caseBuilder) {
+ packShape(loc, inputDynamicDims[i], listValue, caseBuilder);
+ },
+ entryBuilder);
+
+ auto exitBuilder = OpBuilder::atBlockBegin(exitBlock);
+ exitBuilder.create<mlir::ReturnOp>(loc);
+ }
+
+ // Creates a function to resize |inputGlobalOps| and sets the |dirtyGlobalOp|
+ // flag.
+ //
+ // func @_resize_input_shape(%index : index, %shape : !util.list<index>)
+ void createResizeInputShapeFunc(Location loc, StringRef namePrefix,
+ ArrayRef<DynamicDims> inputDynamicDims,
+ IREE::Util::GlobalOp dirtyGlobalOp,
+ OpBuilder &moduleBuilder) {
+ auto resizeFuncOp = moduleBuilder.create<mlir::FuncOp>(
+ loc, namePrefix.str() + "_resize_input_shape",
+ moduleBuilder.getFunctionType(/*inputs=*/
+ TypeRange{
+ moduleBuilder.getIndexType(),
+ IREE::Util::ListType::get(
+ moduleBuilder.getIndexType()),
+ },
+ /*outputs=*/TypeRange{}));
+ resizeFuncOp->setAttr("iree.abi.stub", moduleBuilder.getUnitAttr());
+ auto *entryBlock = resizeFuncOp.addEntryBlock();
+ auto entryBuilder = OpBuilder::atBlockBegin(entryBlock);
+ auto listValue = entryBlock->getArgument(1);
+
+ auto *exitBlock = buildSwitch(
+ loc, entryBlock->getArgument(0), inputDynamicDims.size(),
+ [&](size_t i, OpBuilder &caseBuilder) {
+ unpackShape(loc, listValue, inputDynamicDims[i], caseBuilder);
+ },
+ entryBuilder);
+
+ // Set the dirty flag so that shapes get recalculated as needed.
+ auto exitBuilder = OpBuilder::atBlockBegin(exitBlock);
+ auto trueValue = exitBuilder.createOrFold<arith::ConstantIntOp>(loc, 1, 1);
+ exitBuilder.create<IREE::Util::GlobalStoreOp>(loc, trueValue,
+ dirtyGlobalOp.getName());
+ exitBuilder.create<mlir::ReturnOp>(loc);
+ }
+
+ // Creates a function to query the |outputGlobalOps| at runtime by the
+ // bindings.
+ //
+ // func @_query_output_shape(%index : index, %shape : !util.list<index>)
+ void createQueryOutputShapeFunc(Location loc, StringRef namePrefix,
+ ArrayRef<DynamicDims> outputDynamicDims,
+ mlir::FuncOp calculateShapeFuncOp,
+ OpBuilder &moduleBuilder) {
+ auto queryFuncOp = moduleBuilder.create<FuncOp>(
+ loc, namePrefix.str() + "_query_output_shape",
+ moduleBuilder.getFunctionType(/*inputs=*/
+ TypeRange{
+ moduleBuilder.getIndexType(),
+ IREE::Util::ListType::get(
+ moduleBuilder.getIndexType()),
+ },
+ /*outputs=*/TypeRange{}));
+ queryFuncOp->setAttr("iree.abi.stub", moduleBuilder.getUnitAttr());
+ auto *entryBlock = queryFuncOp.addEntryBlock();
+ auto entryBuilder = OpBuilder::atBlockBegin(entryBlock);
+ auto listValue = entryBlock->getArgument(1);
+
+ // Always call the recalculation function - it checks for whether it needs
+ // to run based on the dirty flag value.
+ entryBuilder.create<mlir::CallOp>(loc, calculateShapeFuncOp);
+
+ auto *exitBlock = buildSwitch(
+ loc, entryBlock->getArgument(0), outputDynamicDims.size(),
+ [&](size_t i, OpBuilder &caseBuilder) {
+ packShape(loc, outputDynamicDims[i], listValue, caseBuilder);
+ },
+ entryBuilder);
+
+ auto exitBuilder = OpBuilder::atBlockBegin(exitBlock);
+ exitBuilder.create<mlir::ReturnOp>(loc);
+ }
+
// Creates the corresponding wrapper function for the given entry point.
// The wrapper function will contain the reflection metadata required at
// runtime to get input/output tensor names, quantization parameters, etc.
@@ -82,18 +472,26 @@
//
// NOTE: today we only support a single entry point; with minor tweaks we
// could fix this up to support multiple if we wanted.
- FuncOp createWrapperFunc(FuncOp entryFuncOp) {
+ void createWrapperFunc(StringRef namePrefix, mlir::FuncOp entryFuncOp,
+ ArrayRef<DynamicDims> inputDynamicDims,
+ ArrayRef<DynamicDims> outputDynamicDims,
+ IREE::Util::GlobalOp dirtyGlobalOp,
+ OpBuilder &moduleBuilder) {
// NOTE: this is where we could change our signature to provide additional
// values from the runtime bindings as may be required - like semaphores for
// async behavior or cancellation.
auto entryFuncType = entryFuncOp.getType();
- auto wrapperFuncType = entryFuncType;
+ auto bufferType = moduleBuilder.getType<IREE::HAL::BufferType>();
+ SmallVector<Type> inputTypes(entryFuncType.getNumInputs(), bufferType);
+ SmallVector<Type> outputTypes(entryFuncType.getNumResults(), bufferType);
+ auto wrapperFuncType =
+ moduleBuilder.getFunctionType(inputTypes, outputTypes);
- auto wrapperFuncOp =
- FuncOp::create(entryFuncOp.getLoc(), "_tflite_main", wrapperFuncType);
+ auto wrapperFuncOp = moduleBuilder.create<mlir::FuncOp>(
+ entryFuncOp.getLoc(), "_tflite_main", wrapperFuncType);
wrapperFuncOp.setPublic();
wrapperFuncOp.getOperation()->setAttr("iree.abi.stub",
- UnitAttr::get(&getContext()));
+ moduleBuilder.getUnitAttr());
SmallVector<DictionaryAttr, 4> argAttrDict;
entryFuncOp.getAllArgAttrs(argAttrDict);
@@ -104,27 +502,107 @@
populateReflectionAttrs(entryFuncOp, wrapperFuncOp);
- // Just call the entryFuncOp and return the results.
+ // Call the entryFuncOp and return the results.
// If we wanted to perform additional work here to invalidate cached shapes
// from the shape support functions or validate the inputs we'd do that
// here. Format conversion/decomposition (interleaved complex ->
// deinterleaved, float <-> quantized conversions, etc) can also be inserted
// such that other bindings that don't need such things aren't impacted.
+ //
+ // To make the interface concrete we insert casts to HAL buffers so that
+ // in the final program we know they end up as the iree_hal_buffer_t we
+ // expect in the runtime.
auto *entryBlock = wrapperFuncOp.addEntryBlock();
auto entryBuilder = OpBuilder::atBlockBegin(entryBlock);
- auto results = entryBuilder.create<CallOp>(
- entryFuncOp.getLoc(), entryFuncOp,
- llvm::to_vector<4>(llvm::map_range(
- entryBlock->getArguments(),
- [](BlockArgument arg) { return static_cast<Value>(arg); })));
- entryBuilder.create<ReturnOp>(entryFuncOp.getLoc(), results.getResults());
+ SmallVector<Value> callOperands;
+ for (auto input : llvm::zip(entryBlock->getArguments(), inputDynamicDims)) {
+ auto arg = std::get<0>(input);
+ auto inputDynamicDims = std::get<1>(input);
+ SmallVector<Value> dynamicDims;
+ for (auto globalOp : inputDynamicDims.globalOps) {
+ dynamicDims.push_back(entryBuilder.create<IREE::Util::GlobalLoadOp>(
+ arg.getLoc(), globalOp));
+ }
+ callOperands.push_back(entryBuilder.create<IREE::HAL::TensorCastOp>(
+ arg.getLoc(), inputDynamicDims.tensorType, arg, dynamicDims));
+ }
+ auto callOp = entryBuilder.create<mlir::CallOp>(entryFuncOp.getLoc(),
+ entryFuncOp, callOperands);
+ SmallVector<Value> callResults;
+ for (auto output : llvm::zip(callOp.getResults(), outputDynamicDims)) {
+ auto result = std::get<0>(output);
+ auto outputDynamicDims = std::get<1>(output);
+ SmallVector<Value> dynamicDims;
+ for (unsigned i = 0; i < outputDynamicDims.tensorType.getRank(); ++i) {
+ if (outputDynamicDims.tensorType.isDynamicDim(i)) {
+ dynamicDims.push_back(
+ entryBuilder.create<tensor::DimOp>(result.getLoc(), result, i));
+ }
+ }
+ callResults.push_back(entryBuilder.create<IREE::HAL::TensorCastOp>(
+ result.getLoc(), bufferType, result, dynamicDims));
+ for (auto it : llvm::zip(dynamicDims, outputDynamicDims.globalOps)) {
+ auto dynamicDim = std::get<0>(it);
+ auto globalOp = std::get<1>(it);
+ entryBuilder.create<IREE::Util::GlobalStoreOp>(
+ result.getLoc(), dynamicDim, globalOp.getSymbolName());
+ }
+ }
- return wrapperFuncOp;
+ // We recomputed the shapes of the outputs and can clear the dirty flag.
+ entryBuilder.create<IREE::Util::GlobalStoreOp>(
+ entryFuncOp.getLoc(),
+ entryBuilder.create<arith::ConstantIntOp>(entryFuncOp.getLoc(), 0, 1),
+ dirtyGlobalOp.getSymbolName());
+
+ entryBuilder.create<mlir::ReturnOp>(entryFuncOp.getLoc(), callResults);
+ }
+
+ void wrapEntryPoint(mlir::FuncOp funcOp) {
+ auto loc = funcOp.getLoc();
+ auto namePrefix = ("_tflite_" + funcOp.getName()).str();
+ OpBuilder moduleBuilder(funcOp);
+
+ // Create a variable for each input and output dynamic dim. These variables
+ // may represent fully static shapes - in which case they'll get constant
+ // propagated - or dynamic shapes that will eventually get turned into
+ // dynamic runtime values.
+ auto dynamicDimGlobals =
+ createDynamicDimGlobals(loc, namePrefix, funcOp, moduleBuilder);
+
+ // Create internal shape calculation function that updates output shapes if
+ // needed. This is only required if there are dynamic shapes.
+ auto dirtyGlobalOp = moduleBuilder.create<IREE::Util::GlobalOp>(
+ loc, namePrefix + "_shapes_dirty",
+ /*isMutable=*/true, moduleBuilder.getI1Type(),
+ moduleBuilder.getIntegerAttr(moduleBuilder.getI1Type(), 1));
+ dirtyGlobalOp.setPrivate();
+ auto calculateShapeFuncOp = createShapeCalculationFunc(
+ loc, namePrefix, funcOp, dynamicDimGlobals.first,
+ dynamicDimGlobals.second, dirtyGlobalOp, moduleBuilder);
+
+ // Create input query function (just reads variables).
+ createQueryInputShapeFunc(loc, namePrefix, dynamicDimGlobals.first,
+ moduleBuilder);
+
+ // Create input resize function (updates variables, set dirty flag).
+ createResizeInputShapeFunc(loc, namePrefix, dynamicDimGlobals.first,
+ dirtyGlobalOp, moduleBuilder);
+
+ // Create output query function (if dirty recalculates shapes).
+ createQueryOutputShapeFunc(loc, namePrefix, dynamicDimGlobals.second,
+ calculateShapeFuncOp, moduleBuilder);
+
+ // Create a wrapper function for the entry point.
+ funcOp.setPrivate();
+ createWrapperFunc(namePrefix, funcOp, dynamicDimGlobals.first,
+ dynamicDimGlobals.second, dirtyGlobalOp, moduleBuilder);
}
// Populates attributes on |wrapperFuncOp| to support runtime reflection like
// IO tensor names and quantization information.
- void populateReflectionAttrs(FuncOp entryFuncOp, FuncOp wrapperFuncOp) {
+ void populateReflectionAttrs(mlir::FuncOp entryFuncOp,
+ mlir::FuncOp wrapperFuncOp) {
SmallVector<NamedAttribute, 4> attrs;
attrs.push_back(buildIONamesAttr(entryFuncOp));
// TODO(#3972): tfl.io.quant: quantization information.
@@ -137,7 +615,7 @@
// tfl.io.names=arg0;arg1;ret0;ret1
//
// Default names will be used if no iree.identifiers are set on the function.
- NamedAttribute buildIONamesAttr(FuncOp entryFuncOp) {
+ NamedAttribute buildIONamesAttr(mlir::FuncOp entryFuncOp) {
SmallVector<std::string, 4> pieces;
for (int i = 0; i < entryFuncOp.getNumArguments(); ++i) {
StringRef identifier =
@@ -165,7 +643,7 @@
}
};
-std::unique_ptr<OperationPass<ModuleOp>> createWrapEntryPointsPass() {
+std::unique_ptr<OperationPass<mlir::ModuleOp>> createWrapEntryPointsPass() {
return std::make_unique<WrapEntryPointsPass>();
}
diff --git a/iree/compiler/Bindings/TFLite/Transforms/test/BUILD b/iree/compiler/Bindings/TFLite/Transforms/test/BUILD
index 16b3671..287e4a5 100644
--- a/iree/compiler/Bindings/TFLite/Transforms/test/BUILD
+++ b/iree/compiler/Bindings/TFLite/Transforms/test/BUILD
@@ -17,7 +17,6 @@
name = "lit",
srcs = enforce_glob(
[
- "materialize_shape_support.mlir",
"wrap_entry_points.mlir",
],
include = ["*.mlir"],
diff --git a/iree/compiler/Bindings/TFLite/Transforms/test/CMakeLists.txt b/iree/compiler/Bindings/TFLite/Transforms/test/CMakeLists.txt
index 365af3d..b6cb5c6 100644
--- a/iree/compiler/Bindings/TFLite/Transforms/test/CMakeLists.txt
+++ b/iree/compiler/Bindings/TFLite/Transforms/test/CMakeLists.txt
@@ -14,7 +14,6 @@
NAME
lit
SRCS
- "materialize_shape_support.mlir"
"wrap_entry_points.mlir"
DATA
iree::tools::IreeFileCheck
diff --git a/iree/compiler/Bindings/TFLite/Transforms/test/materialize_shape_support.mlir b/iree/compiler/Bindings/TFLite/Transforms/test/materialize_shape_support.mlir
deleted file mode 100644
index 69be1d6..0000000
--- a/iree/compiler/Bindings/TFLite/Transforms/test/materialize_shape_support.mlir
+++ /dev/null
@@ -1,130 +0,0 @@
-// RUN: iree-opt -iree-tflite-materialize-shape-support -canonicalize -split-input-file %s | IreeFileCheck %s
-
-// NOTE: canonicalization is run because otherwise there's just way too much IR.
-
-// CHECK-DAG: util.global private mutable @_tflite_dynamicEntry_input0_shape : !shapex.ranked_shape<[?,8,8,3]>
-// CHECK-DAG: util.global private mutable @_tflite_dynamicEntry_input1_shape : !shapex.ranked_shape<[?,8,8,3]>
-// CHECK-DAG: util.global private mutable @_tflite_dynamicEntry_output0_shape : !shapex.ranked_shape<[?,8,8,3]>
-// CHECK-DAG: util.global private mutable @_tflite_dynamicEntry_output1_shape : !shapex.ranked_shape<[?,8,8,3]>
-// CHECK-DAG: util.global private mutable @_tflite_dynamicEntry_shapes_dirty = true
-
-// CHECK-LABEL: func private @_tflite_dynamicEntry_calculate_shapes() {
-// CHECK-NEXT: %false = arith.constant false
-// CHECK-NEXT: %[[IS_DIRTY:.+]] = util.global.load @_tflite_dynamicEntry_shapes_dirty : i1
-// CHECK-NEXT: cond_br %[[IS_DIRTY]], ^bb1, ^bb2
-// CHECK-NEXT: ^bb1:
-// CHECK-NEXT: %[[IN0_NULL:.+]] = util.null : tensor<?x8x8x3xf32>
-// CHECK-NEXT: %[[IN0_SHAPE:.+]] = util.global.load @_tflite_dynamicEntry_input0_shape : !shapex.ranked_shape<[?,8,8,3]>
-// CHECK-NEXT: %[[IN0:.+]] = shapex.tie_shape %[[IN0_NULL]], %[[IN0_SHAPE]] : tensor<?x8x8x3xf32>, !shapex.ranked_shape<[?,8,8,3]>
-// CHECK-NEXT: %[[IN1_NULL:.+]] = util.null : tensor<?x8x8x3xf32>
-// CHECK-NEXT: %[[IN1_SHAPE:.+]] = util.global.load @_tflite_dynamicEntry_input1_shape : !shapex.ranked_shape<[?,8,8,3]>
-// CHECK-NEXT: %[[IN1:.+]] = shapex.tie_shape %[[IN1_NULL]], %[[IN1_SHAPE]] : tensor<?x8x8x3xf32>, !shapex.ranked_shape<[?,8,8,3]>
-// CHECK-NEXT: %[[TMP:.+]]:2 = call @dynamicEntry(%[[IN0]], %[[IN1]])
-// CHECK-NEXT: %[[OUT0_SHAPE:.+]] = shapex.get_ranked_shape %[[TMP]]#0 : tensor<?x8x8x3xf32> -> !shapex.ranked_shape<[?,8,8,3]>
-// CHECK-NEXT: util.global.store %[[OUT0_SHAPE]], @_tflite_dynamicEntry_output0_shape : !shapex.ranked_shape<[?,8,8,3]>
-// CHECK-NEXT: %[[OUT1_SHAPE:.+]] = shapex.get_ranked_shape %[[TMP]]#1 : tensor<?x8x8x3xf32> -> !shapex.ranked_shape<[?,8,8,3]>
-// CHECK-NEXT: util.global.store %[[OUT1_SHAPE]], @_tflite_dynamicEntry_output1_shape : !shapex.ranked_shape<[?,8,8,3]>
-// CHECK-NEXT: util.global.store %false, @_tflite_dynamicEntry_shapes_dirty : i1
-// CHECK-NEXT: return
-// CHECK-NEXT: ^bb2:
-// CHECK-NEXT: return
-// CHECK-NEXT: }
-
-// CHECK-LABEL: func @_tflite_dynamicEntry_query_input_shape
-// CHECK-SAME: (%[[INDEX:.+]]: index, %[[LIST:.+]]: !util.list<index>)
-// CHECK: %[[IS_0:.+]] = arith.cmpi eq, %[[INDEX]], %c0 : index
-// CHECK-NEXT: cond_br %[[IS_0]], ^bb1, ^bb2
-// CHECK-NEXT: ^bb1:
-// CHECK-NEXT: %[[IN0_SHAPE:.+]] = util.global.load @_tflite_dynamicEntry_input0_shape : !shapex.ranked_shape<[?,8,8,3]>
-// CHECK-NEXT: util.list.resize %[[LIST]], %c4 : !util.list<index>
-// CHECK-NEXT: %[[IN0_D0:.+]] = shapex.ranked_dim %[[IN0_SHAPE]][0] : !shapex.ranked_shape<[?,8,8,3]> -> index
-// CHECK-NEXT: util.list.set %[[LIST]][%c0], %[[IN0_D0]] : !util.list<index>
-// CHECK-NEXT: util.list.set %[[LIST]][%c1], %c8 : !util.list<index>
-// CHECK-NEXT: util.list.set %[[LIST]][%c2], %c8 : !util.list<index>
-// CHECK-NEXT: util.list.set %[[LIST]][%c3], %c3 : !util.list<index>
-// CHECK-NEXT: br ^bb4
-// CHECK-NEXT: ^bb2:
-// CHECK-NEXT: %[[IS_1:.+]] = arith.cmpi eq, %[[INDEX]], %c1 : index
-// CHECK-NEXT: cond_br %[[IS_1]], ^bb3, ^bb4
-// CHECK-NEXT: ^bb3:
-// CHECK-NEXT: %[[IN1_SHAPE:.+]] = util.global.load @_tflite_dynamicEntry_input1_shape : !shapex.ranked_shape<[?,8,8,3]>
-// CHECK-NEXT: util.list.resize %[[LIST]], %c4 : !util.list<index>
-// CHECK-NEXT: %[[IN1_D0:.+]] = shapex.ranked_dim %[[IN1_SHAPE]][0] : !shapex.ranked_shape<[?,8,8,3]> -> index
-// CHECK-NEXT: util.list.set %[[LIST]][%c0], %[[IN1_D0]] : !util.list<index>
-// CHECK-NEXT: util.list.set %[[LIST]][%c1], %c8 : !util.list<index>
-// CHECK-NEXT: util.list.set %[[LIST]][%c2], %c8 : !util.list<index>
-// CHECK-NEXT: util.list.set %[[LIST]][%c3], %c3 : !util.list<index>
-// CHECK-NEXT: br ^bb4
-// CHECK-NEXT: ^bb4:
-// CHECK-NEXT: return
-// CHECK-NEXT: }
-
-// CHECK-LABEL: func @_tflite_dynamicEntry_resize_input_shape
-// CHECK-SAME: (%[[INDEX:.+]]: index, %[[LIST:.+]]: !util.list<index>)
-// CHECK: %[[IS_0:.+]] = arith.cmpi eq, %[[INDEX]], %c0 : index
-// CHECK-NEXT: cond_br %[[IS_0]], ^bb1, ^bb2
-// CHECK-NEXT: ^bb1:
-// CHECK-NEXT: %[[IN0_D0:.+]] = util.list.get %[[LIST]][%c0] : !util.list<index>
-// CHECK-NEXT: %[[IN0_SHAPE:.+]] = shapex.make_ranked_shape %[[IN0_D0]] : (index) -> !shapex.ranked_shape<[?,8,8,3]>
-// CHECK-NEXT: util.global.store %[[IN0_SHAPE]], @_tflite_dynamicEntry_input0_shape : !shapex.ranked_shape<[?,8,8,3]>
-// CHECK-NEXT: br ^bb4
-// CHECK-NEXT: ^bb2:
-// CHECK-NEXT: %[[IS_1:.+]] = arith.cmpi eq, %[[INDEX]], %c1 : index
-// CHECK-NEXT: cond_br %[[IS_1]], ^bb3, ^bb4
-// CHECK-NEXT: ^bb3:
-// CHECK-NEXT: %[[IN1_D0:.+]] = util.list.get %[[LIST]][%c0] : !util.list<index>
-// CHECK-NEXT: %[[IN1_SHAPE:.+]] = shapex.make_ranked_shape %[[IN1_D0]] : (index) -> !shapex.ranked_shape<[?,8,8,3]>
-// CHECK-NEXT: util.global.store %[[IN1_SHAPE]], @_tflite_dynamicEntry_input1_shape : !shapex.ranked_shape<[?,8,8,3]>
-// CHECK-NEXT: br ^bb4
-// CHECK-NEXT: ^bb4:
-// CHECK-NEXT: util.global.store %true, @_tflite_dynamicEntry_shapes_dirty : i1
-// CHECK-NEXT: return
-// CHECK-NEXT: }
-
-// CHECK-LABEL: func @_tflite_dynamicEntry_query_output_shape
-// CHECK-SAME: (%[[INDEX:.+]]: index, %[[LIST:.+]]: !util.list<index>)
-// CHECK: call @_tflite_dynamicEntry_calculate_shapes() : () -> ()
-// CHECK-NEXT: %[[IS_0:.+]] = arith.cmpi eq, %[[INDEX]], %c0 : index
-// CHECK-NEXT: cond_br %[[IS_0]], ^bb1, ^bb2
-// CHECK-NEXT: ^bb1:
-// CHECK-NEXT: %[[OUT0_SHAPE:.+]] = util.global.load @_tflite_dynamicEntry_output0_shape : !shapex.ranked_shape<[?,8,8,3]>
-// CHECK-NEXT: util.list.resize %[[LIST]], %c4 : !util.list<index>
-// CHECK-NEXT: %[[OUT0_D0:.+]] = shapex.ranked_dim %[[OUT0_SHAPE]][0] : !shapex.ranked_shape<[?,8,8,3]> -> index
-// CHECK-NEXT: util.list.set %[[LIST]][%c0], %[[OUT0_D0]] : !util.list<index>
-// CHECK-NEXT: util.list.set %[[LIST]][%c1], %c8 : !util.list<index>
-// CHECK-NEXT: util.list.set %[[LIST]][%c2], %c8 : !util.list<index>
-// CHECK-NEXT: util.list.set %[[LIST]][%c3], %c3 : !util.list<index>
-// CHECK-NEXT: br ^bb4
-// CHECK-NEXT: ^bb2:
-// CHECK-NEXT: %[[IS_1:.+]] = arith.cmpi eq, %[[INDEX]], %c1 : index
-// CHECK-NEXT: cond_br %[[IS_1]], ^bb3, ^bb4
-// CHECK-NEXT: ^bb3:
-// CHECK-NEXT: %[[OUT1_SHAPE:.+]] = util.global.load @_tflite_dynamicEntry_output1_shape : !shapex.ranked_shape<[?,8,8,3]>
-// CHECK-NEXT: util.list.resize %[[LIST]], %c4 : !util.list<index>
-// CHECK-NEXT: %[[OUT1_D0:.+]] = shapex.ranked_dim %[[OUT1_SHAPE]][0] : !shapex.ranked_shape<[?,8,8,3]> -> index
-// CHECK-NEXT: util.list.set %[[LIST]][%c0], %[[OUT1_D0]] : !util.list<index>
-// CHECK-NEXT: util.list.set %[[LIST]][%c1], %c8 : !util.list<index>
-// CHECK-NEXT: util.list.set %[[LIST]][%c2], %c8 : !util.list<index>
-// CHECK-NEXT: util.list.set %[[LIST]][%c3], %c3 : !util.list<index>
-// CHECK-NEXT: br ^bb4
-// CHECK-NEXT: ^bb4:
-// CHECK-NEXT: return
-// CHECK-NEXT: }
-
-// CHECK-LABEL: func @_tflite_dynamicEntry(
-func @_tflite_dynamicEntry(%arg0: tensor<?x8x8x3xf32> {iree.identifier = "input0"}, %arg1: tensor<?x8x8x3xf32> {iree.identifier = "input1"}) -> (tensor<?x8x8x3xf32> {iree.identifier = "output0"}, tensor<?x8x8x3xf32> {iree.identifier = "output1"}) attributes {
- iree.abi.stub,
- iree.reflection = {
- tfl.io.names = "input0;input1;output0;output1"
- }
-} {
- %0:2 = call @dynamicEntry(%arg0, %arg1) : (tensor<?x8x8x3xf32>, tensor<?x8x8x3xf32>) -> (tensor<?x8x8x3xf32>, tensor<?x8x8x3xf32>)
- return %0#0, %0#1 : tensor<?x8x8x3xf32>, tensor<?x8x8x3xf32>
-}
-
-// CHECK-LABEL: func private @dynamicEntry(
-func private @dynamicEntry(%arg0: tensor<?x8x8x3xf32> {iree.identifier = "input0"}, %arg1: tensor<?x8x8x3xf32> {iree.identifier = "input1"}) -> (tensor<?x8x8x3xf32> {iree.identifier = "output0"}, tensor<?x8x8x3xf32> {iree.identifier = "output1"}) {
- %0 = mhlo.add %arg0, %arg1 : tensor<?x8x8x3xf32>
- %1 = mhlo.add %0, %arg0 : tensor<?x8x8x3xf32>
- return %0, %1 : tensor<?x8x8x3xf32>, tensor<?x8x8x3xf32>
-}
diff --git a/iree/compiler/Bindings/TFLite/Transforms/test/wrap_entry_points.mlir b/iree/compiler/Bindings/TFLite/Transforms/test/wrap_entry_points.mlir
index 52d5089..8f9fb2d 100644
--- a/iree/compiler/Bindings/TFLite/Transforms/test/wrap_entry_points.mlir
+++ b/iree/compiler/Bindings/TFLite/Transforms/test/wrap_entry_points.mlir
@@ -1,21 +1,197 @@
-// RUN: iree-opt -iree-tflite-wrap-entry-points -split-input-file %s | IreeFileCheck %s
+// RUN: iree-opt -iree-tflite-wrap-entry-points -canonicalize -cse -split-input-file %s | IreeFileCheck %s
+
+// NOTE: CSE is run because otherwise there's just way too much IR and we don't
+// care about 100 random 0-N constants.
+// NOTE: we do a lot of CHECK-NEXTing here because we want to ensure we are
+// emitting things in the same order as they are in the function
+// signatures to make the IR easier to read.
+
+// CHECK-DAG: util.global private mutable @_tflite_dynamicEntry_input0_shape_dim0 : index
+// CHECK-DAG: util.global private mutable @_tflite_dynamicEntry_input1_shape_dim0 : index
+// CHECK-DAG: util.global private mutable @_tflite_dynamicEntry_output0_shape_dim0 : index
+// CHECK-DAG: util.global private mutable @_tflite_dynamicEntry_output1_shape_dim0 : index
+// CHECK-DAG: util.global private mutable @_tflite_dynamicEntry_shapes_dirty = true
+
+
+
+// CHECK-LABEL: func private @_tflite_dynamicEntry_calculate_shapes() {
+
+// Only recalculate shapes if the shapes are dirty.
+// CHECK: %[[IS_DIRTY:.+]] = util.global.load @_tflite_dynamicEntry_shapes_dirty : i1
+// CHECK-NEXT: cond_br %[[IS_DIRTY]], ^bb1, ^bb2
+
+// CHECK: ^bb1:
+// CHECK-NEXT: %[[NULL:.+]] = util.null : !hal.buffer
+
+// Tie input0 shapes.
+// CHECK-NEXT: %[[IN0_DIM0:.+]] = util.global.load @_tflite_dynamicEntry_input0_shape_dim0 : index
+// CHECK-NEXT: %[[IN0:.+]] = hal.tensor.cast %[[NULL]] : !hal.buffer -> tensor<?x8x8x3xf32>{%[[IN0_DIM0]]}
+
+// Tie input1 shapes.
+// CHECK-NEXT: %[[IN1_DIM0:.+]] = util.global.load @_tflite_dynamicEntry_input1_shape_dim0 : index
+// CHECK-NEXT: %[[IN1:.+]] = hal.tensor.cast %[[NULL]] : !hal.buffer -> tensor<?x8x8x3xf32>{%[[IN1_DIM0]]}
+
+// The actual model code used to (eventually) compute shapes.
+// CHECK-NEXT: %[[OUT0:.+]] = mhlo.add %[[IN0]], %[[IN1]]
+// CHECK-NEXT: %[[OUT1:.+]] = mhlo.add %[[OUT0]], %[[IN0]]
+
+// Store back the new dynamic dimensions of out0/out1.
+// CHECK: %[[OUT0_DIM0:.+]] = tensor.dim %[[OUT0]], %c0
+// CHECK-NEXT: util.global.store %[[OUT0_DIM0]], @_tflite_dynamicEntry_output0_shape_dim0 : index
+// CHECK: %[[OUT1_DIM0:.+]] = tensor.dim %[[OUT1]], %c0
+// CHECK-NEXT: util.global.store %[[OUT1_DIM0]], @_tflite_dynamicEntry_output1_shape_dim0 : index
+
+// Clear dirty bit now that the shapes have been recalculated.
+// CHECK: util.global.store %false, @_tflite_dynamicEntry_shapes_dirty : i1
+// CHECK-NEXT: return
+
+// Exit for when the shapes are not dirty and no work is needed.
+// CHECK-NEXT: ^bb2:
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+
+
+
+// CHECK-LABEL: func @_tflite_dynamicEntry_query_input_shape
+// CHECK-SAME: (%[[INDEX:.+]]: index, %[[LIST:.+]]: !util.list<index>)
+
+// Query input0 shape:
+// CHECK: %[[IS_0:.+]] = arith.cmpi eq, %[[INDEX]], %c0 : index
+// CHECK-NEXT: cond_br %[[IS_0]], ^bb1, ^bb2
+// CHECK-NEXT: ^bb1:
+// CHECK-NEXT: util.list.resize %[[LIST]], %c4 : !util.list<index>
+// CHECK-NEXT: %[[IN0_DIM0:.+]] = util.global.load @_tflite_dynamicEntry_input0_shape_dim0 : index
+// CHECK-NEXT: util.list.set %[[LIST]][%c0], %[[IN0_DIM0]] : !util.list<index>
+// CHECK-NEXT: util.list.set %[[LIST]][%c1], %c8 : !util.list<index>
+// CHECK-NEXT: util.list.set %[[LIST]][%c2], %c8 : !util.list<index>
+// CHECK-NEXT: util.list.set %[[LIST]][%c3], %c3 : !util.list<index>
+// CHECK-NEXT: br ^bb4
+
+// Query input1 shape:
+// CHECK: ^bb2:
+// CHECK-NEXT: %[[IS_1:.+]] = arith.cmpi eq, %[[INDEX]], %c1 : index
+// CHECK-NEXT: cond_br %[[IS_1]], ^bb3, ^bb4
+// CHECK-NEXT: ^bb3:
+// CHECK-NEXT: util.list.resize %[[LIST]], %c4 : !util.list<index>
+// CHECK-NEXT: %[[IN1_DIM0:.+]] = util.global.load @_tflite_dynamicEntry_input1_shape_dim0 : index
+// CHECK-NEXT: util.list.set %[[LIST]][%c0], %[[IN1_DIM0]] : !util.list<index>
+// CHECK-NEXT: util.list.set %[[LIST]][%c1], %c8 : !util.list<index>
+// CHECK-NEXT: util.list.set %[[LIST]][%c2], %c8 : !util.list<index>
+// CHECK-NEXT: util.list.set %[[LIST]][%c3], %c3 : !util.list<index>
+// CHECK-NEXT: br ^bb4
+
+// Invalid input index:
+// CHECK: ^bb4:
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+
+
+
+// CHECK-LABEL: func @_tflite_dynamicEntry_resize_input_shape
+// CHECK-SAME: (%[[INDEX:.+]]: index, %[[LIST:.+]]: !util.list<index>)
+
+// CHECK: %[[IS_0:.+]] = arith.cmpi eq, %[[INDEX]], %c0 : index
+// CHECK-NEXT: cond_br %[[IS_0]], ^bb1, ^bb2
+// CHECK-NEXT: ^bb1:
+// CHECK-NEXT: %[[IN0_DIM0:.+]] = util.list.get %[[LIST]][%c0] : !util.list<index>
+// CHECK-NEXT: util.global.store %[[IN0_DIM0]], @_tflite_dynamicEntry_input0_shape_dim0 : index
+// CHECK-NEXT: br ^bb4
+
+// CHECK: ^bb2:
+// CHECK-NEXT: %[[IS_1:.+]] = arith.cmpi eq, %[[INDEX]], %c1 : index
+// CHECK-NEXT: cond_br %[[IS_1]], ^bb3, ^bb4
+// CHECK-NEXT: ^bb3:
+// CHECK-NEXT: %[[IN1_DIM0:.+]] = util.list.get %[[LIST]][%c0] : !util.list<index>
+// CHECK-NEXT: util.global.store %[[IN1_DIM0]], @_tflite_dynamicEntry_input1_shape_dim0 : index
+// CHECK-NEXT: br ^bb4
+
+// Set the dirty flag so that shape calculation must run again.
+// CHECK-NEXT: ^bb4:
+// CHECK-NEXT: util.global.store %true, @_tflite_dynamicEntry_shapes_dirty : i1
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+
+
+
+// CHECK-LABEL: func @_tflite_dynamicEntry_query_output_shape
+// CHECK-SAME: (%[[INDEX:.+]]: index, %[[LIST:.+]]: !util.list<index>)
+
+// Recalculate shapes, if needed.
+// CHECK: call @_tflite_dynamicEntry_calculate_shapes() : () -> ()
+
+// Query output0:
+// CHECK: %[[IS_0:.+]] = arith.cmpi eq, %[[INDEX]], %c0 : index
+// CHECK-NEXT: cond_br %[[IS_0]], ^bb1, ^bb2
+// CHECK-NEXT: ^bb1:
+// CHECK-NEXT: util.list.resize %[[LIST]], %c4 : !util.list<index>
+// CHECK-NEXT: %[[OUT0_DIM0:.+]] = util.global.load @_tflite_dynamicEntry_output0_shape_dim0 : index
+// CHECK-NEXT: util.list.set %[[LIST]][%c0], %[[OUT0_DIM0]] : !util.list<index>
+// CHECK-NEXT: util.list.set %[[LIST]][%c1], %c8 : !util.list<index>
+// CHECK-NEXT: util.list.set %[[LIST]][%c2], %c8 : !util.list<index>
+// CHECK-NEXT: util.list.set %[[LIST]][%c3], %c3 : !util.list<index>
+// CHECK-NEXT: br ^bb4
+
+// Query output1:
+// CHECK: ^bb2:
+// CHECK-NEXT: %[[IS_1:.+]] = arith.cmpi eq, %[[INDEX]], %c1 : index
+// CHECK-NEXT: cond_br %[[IS_1]], ^bb3, ^bb4
+// CHECK-NEXT: ^bb3:
+// CHECK-NEXT: util.list.resize %[[LIST]], %c4 : !util.list<index>
+// CHECK-NEXT: %[[OUT1_DIM0:.+]] = util.global.load @_tflite_dynamicEntry_output1_shape_dim0 : index
+// CHECK-NEXT: util.list.set %[[LIST]][%c0], %[[OUT1_DIM0]] : !util.list<index>
+// CHECK-NEXT: util.list.set %[[LIST]][%c1], %c8 : !util.list<index>
+// CHECK-NEXT: util.list.set %[[LIST]][%c2], %c8 : !util.list<index>
+// CHECK-NEXT: util.list.set %[[LIST]][%c3], %c3 : !util.list<index>
+// CHECK-NEXT: br ^bb4
+
+// CHECK-NEXT: ^bb4:
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+
+
// CHECK-LABEL: func @_tflite_main(
-// CHECK-SAME: %[[ARG0:.+]]: tensor<?x8x8x3xf32> {iree.identifier = "input0"},
-// CHECK-SAME: %[[ARG1:.+]]: tensor<?x8x8x3xf32> {iree.identifier = "input1"})
+// CHECK-SAME: %[[IN0_BUFFER:.+]]: !hal.buffer {iree.identifier = "input0"},
+// CHECK-SAME: %[[IN1_BUFFER:.+]]: !hal.buffer {iree.identifier = "input1"})
// CHECK-SAME: -> (
-// CHECK-SAME: tensor<?x8x8x3xf32> {iree.identifier = "output0"},
-// CHECK-SAME: tensor<?x8x8x3xf32> {iree.identifier = "output1"}
+// CHECK-SAME: !hal.buffer {iree.identifier = "output0"},
+// CHECK-SAME: !hal.buffer {iree.identifier = "output1"}
// CHECK-SAME: ) attributes {
// CHECK-SAME: iree.abi.stub,
// CHECK-SAME: iree.reflection = {
// CHECK-SAME: tfl.io.names = "input0;input1;output0;output1"
// CHECK-SAME: }
// CHECK-SAME: } {
-// CHECK-NEXT: %[[RET:.+]]:2 = call @dynamicEntry(%[[ARG0]], %[[ARG1]])
-// CHECK-NEXT: return %[[RET]]#0, %[[RET]]#1
+
+// Cast input0 buffer to a shaped tensor.
+// CHECK: %[[IN0_DIM0:.+]] = util.global.load @_tflite_dynamicEntry_input0_shape_dim0 : index
+// CHECK-NEXT: %[[IN0:.+]] = hal.tensor.cast %[[IN0_BUFFER]] : !hal.buffer -> tensor<?x8x8x3xf32>{%[[IN0_DIM0]]}
+
+// Cast input1 buffer to a shaped tensor.
+// CHECK: %[[IN1_DIM0:.+]] = util.global.load @_tflite_dynamicEntry_input1_shape_dim0 : index
+// CHECK-NEXT: %[[IN1:.+]] = hal.tensor.cast %[[IN1_BUFFER]] : !hal.buffer -> tensor<?x8x8x3xf32>{%[[IN1_DIM0]]}
+
+// Call the original function with tensor arguments.
+// CHECK: %[[OUT:.+]]:2 = call @dynamicEntry(%[[IN0]], %[[IN1]]) : (tensor<?x8x8x3xf32>, tensor<?x8x8x3xf32>) -> (tensor<?x8x8x3xf32>, tensor<?x8x8x3xf32>)
+
+// Query output0 shape and get the HAL buffer to return.
+// CHECK: %[[OUT0_DIM0:.+]] = tensor.dim %[[OUT]]#0, %c0 : tensor<?x8x8x3xf32>
+// CHECK-NEXT: %[[OUT0_BUFFER:.+]] = hal.tensor.cast %[[OUT]]#0 : tensor<?x8x8x3xf32>{%[[OUT0_DIM0]]} -> !hal.buffer
+// CHECK-NEXT: util.global.store %[[OUT0_DIM0]], @_tflite_dynamicEntry_output0_shape_dim0 : index
+
+// Query output1 shape and get the HAL buffer to return.
+// CHECK: %[[OUT1_DIM0:.+]] = tensor.dim %[[OUT]]#1, %c0 : tensor<?x8x8x3xf32>
+// CHECK-NEXT: %[[OUT1_BUFFER:.+]] = hal.tensor.cast %[[OUT]]#1 : tensor<?x8x8x3xf32>{%[[OUT1_DIM0]]} -> !hal.buffer
+// CHECK-NEXT: util.global.store %[[OUT1_DIM0]], @_tflite_dynamicEntry_output1_shape_dim0 : index
+
+// Clear shape dirty bit as we've updated the shapes unconditionally.
+// CHECK-NEXT: util.global.store %false, @_tflite_dynamicEntry_shapes_dirty : i1
+
+// CHECK-NEXT: return %[[OUT0_BUFFER]], %[[OUT1_BUFFER]]
// CHECK-NEXT: }
+
+
// CHECK-LABEL: func private @dynamicEntry(
func @dynamicEntry(
%arg0: tensor<?x8x8x3xf32> {iree.identifier = "input0"},
@@ -24,7 +200,10 @@
tensor<?x8x8x3xf32> {iree.identifier = "output0"},
tensor<?x8x8x3xf32> {iree.identifier = "output1"}
) {
- %0 = "mhlo.add"(%arg0, %arg1) : (tensor<?x8x8x3xf32>, tensor<?x8x8x3xf32>) -> tensor<?x8x8x3xf32>
- %1 = "mhlo.add"(%0, %arg0) : (tensor<?x8x8x3xf32>, tensor<?x8x8x3xf32>) -> tensor<?x8x8x3xf32>
+ // CHECK: = mhlo.add
+ %0 = mhlo.add %arg0, %arg1 : tensor<?x8x8x3xf32>
+ // CHECK: = mhlo.add
+ %1 = mhlo.add %0, %arg0 : tensor<?x8x8x3xf32>
+ // CHECK: return
return %0, %1 : tensor<?x8x8x3xf32>, tensor<?x8x8x3xf32>
}
diff --git a/iree/compiler/Dialect/HAL/Conversion2/StreamToHAL/ConvertStreamToHAL.cpp b/iree/compiler/Dialect/HAL/Conversion2/StreamToHAL/ConvertStreamToHAL.cpp
index feb9e48..fedd583 100644
--- a/iree/compiler/Dialect/HAL/Conversion2/StreamToHAL/ConvertStreamToHAL.cpp
+++ b/iree/compiler/Dialect/HAL/Conversion2/StreamToHAL/ConvertStreamToHAL.cpp
@@ -425,12 +425,74 @@
}
};
-struct TensorImportOpPattern
+// Inserts IR to assert that the underlying buffer storage is compatible with
+// the intended usage in the program. The allocator used to allocate the
+// buffer must have compatibility with our target device allocator and the
+// buffer must have at least the minimum expected size (additional padding is
+// ok).
+static LogicalResult buildStorageAssertions(
+ Location loc, Value buffer, StringAttr message, Value allocator,
+ Value minimumLength, IREE::Stream::ResourceType resourceType,
+ OpBuilder &builder) {
+ auto memoryTypes = IREE::HAL::MemoryTypeBitfield::None;
+ auto bufferUsage = IREE::HAL::BufferUsageBitfield::None;
+ if (failed(deriveRequiredResourceBufferBits(loc, resourceType, memoryTypes,
+ bufferUsage))) {
+ return failure();
+ }
+
+ auto requiredTypes =
+ IREE::HAL::MemoryTypeBitfieldAttr::get(builder.getContext(), memoryTypes);
+ auto requiredUsage = IREE::HAL::BufferUsageBitfieldAttr::get(
+ builder.getContext(), bufferUsage);
+
+ builder.create<IREE::HAL::BufferAssertOp>(loc, buffer, message, allocator,
+ minimumLength, requiredTypes,
+ requiredUsage);
+ return success();
+}
+
+struct TensorImportBufferOpPattern
: public StreamConversionPattern<IREE::Stream::TensorImportOp> {
using StreamConversionPattern::StreamConversionPattern;
LogicalResult matchAndRewrite(
IREE::Stream::TensorImportOp importOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ if (!importOp.source().getType().isa<IREE::HAL::BufferType>()) {
+ return failure();
+ }
+
+ // TODO(benvanik): get a name for the tensor (argument name/etc).
+ auto message = rewriter.getStringAttr("tensor");
+
+ // Directly use the buffer.
+ auto buffer = adaptor.source();
+ rewriter.replaceOp(importOp, buffer);
+
+ // Assert the storage is compatible with our expected device and usage.
+ auto targetAllocator = lookupAllocatorFor(importOp, rewriter);
+ auto resourceType =
+ importOp.result().getType().cast<IREE::Stream::ResourceType>();
+ if (failed(buildStorageAssertions(
+ importOp.getLoc(), adaptor.source(), message, targetAllocator,
+ adaptor.result_size(), resourceType, rewriter))) {
+ return failure();
+ }
+
+ return success();
+ }
+};
+
+struct TensorImportBufferViewOpPattern
+ : public StreamConversionPattern<IREE::Stream::TensorImportOp> {
+ using StreamConversionPattern::StreamConversionPattern;
+ LogicalResult matchAndRewrite(
+ IREE::Stream::TensorImportOp importOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (!importOp.source().getType().isa<IREE::HAL::BufferViewType>()) {
+ return failure();
+ }
+
auto loc = importOp.getLoc();
// TODO(benvanik): get a name for the tensor (argument name/etc).
@@ -511,41 +573,32 @@
shapeDims);
return success();
}
-
- // Inserts IR to assert that the underlying buffer storage is compatible with
- // the intended usage in the program. The allocator used to allocate the
- // buffer must have compatibility with our target device allocator and the
- // buffer must have at least the minimum expected size (additional padding is
- // ok).
- static LogicalResult buildStorageAssertions(
- Location loc, Value buffer, StringAttr message, Value allocator,
- Value minimumLength, IREE::Stream::ResourceType resourceType,
- OpBuilder &builder) {
- auto memoryTypes = IREE::HAL::MemoryTypeBitfield::None;
- auto bufferUsage = IREE::HAL::BufferUsageBitfield::None;
- if (failed(deriveRequiredResourceBufferBits(loc, resourceType, memoryTypes,
- bufferUsage))) {
- return failure();
- }
-
- auto requiredTypes = IREE::HAL::MemoryTypeBitfieldAttr::get(
- builder.getContext(), memoryTypes);
- auto requiredUsage = IREE::HAL::BufferUsageBitfieldAttr::get(
- builder.getContext(), bufferUsage);
-
- builder.create<IREE::HAL::BufferAssertOp>(loc, buffer, message, allocator,
- minimumLength, requiredTypes,
- requiredUsage);
- return success();
- }
};
-struct TensorExportOpPattern
+struct TensorExportBufferOpPattern
: public StreamConversionPattern<IREE::Stream::TensorExportOp> {
using StreamConversionPattern::StreamConversionPattern;
LogicalResult matchAndRewrite(
IREE::Stream::TensorExportOp exportOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ if (!exportOp.result().getType().isa<IREE::HAL::BufferType>()) {
+ return failure();
+ }
+ rewriter.replaceOp(exportOp, adaptor.source());
+ return success();
+ }
+};
+
+struct TensorExportBufferViewOpPattern
+ : public StreamConversionPattern<IREE::Stream::TensorExportOp> {
+ using StreamConversionPattern::StreamConversionPattern;
+ LogicalResult matchAndRewrite(
+ IREE::Stream::TensorExportOp exportOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (!exportOp.result().getType().isa<IREE::HAL::BufferViewType>()) {
+ return failure();
+ }
+
auto loc = exportOp.getLoc();
auto tensorType =
adaptor.source_encoding().getValue().cast<RankedTensorType>();
@@ -1132,7 +1185,8 @@
ResourceMapOpPattern, ResourceTryMapOpPattern,
ResourceLoadOpPattern, ResourceStoreOpPattern,
ResourceSubviewOpPattern>(mapping, typeConverter, context);
- patterns.insert<TensorImportOpPattern, TensorExportOpPattern,
+ patterns.insert<TensorImportBufferOpPattern, TensorImportBufferViewOpPattern,
+ TensorExportBufferOpPattern, TensorExportBufferViewOpPattern,
TensorTraceOpPattern>(mapping, typeConverter, context);
patterns
.insert<CmdFlushOpPattern, CmdInvalidateOpPattern, CmdDiscardOpPattern,
diff --git a/iree/compiler/Dialect/Stream/Conversion/FlowToStream/ConvertFlowToStream.cpp b/iree/compiler/Dialect/Stream/Conversion/FlowToStream/ConvertFlowToStream.cpp
index 8fb42ad..1411009 100644
--- a/iree/compiler/Dialect/Stream/Conversion/FlowToStream/ConvertFlowToStream.cpp
+++ b/iree/compiler/Dialect/Stream/Conversion/FlowToStream/ConvertFlowToStream.cpp
@@ -49,16 +49,18 @@
LogicalResult matchAndRewrite(
IREE::HAL::TensorCastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- if (op.source().getType().isa<IREE::HAL::BufferViewType>()) {
+ auto sourceType = op.source().getType();
+ auto targetType = op.target().getType();
+ if (sourceType.isa<IREE::HAL::BufferType>() ||
+ sourceType.isa<IREE::HAL::BufferViewType>()) {
// Import (buffer view to stream resource).
auto resultType = rewriter.getType<IREE::Stream::ResourceType>(
IREE::Stream::Lifetime::External);
auto resultSize = buildResultSizeOf(op.getLoc(), op.target(),
adaptor.target_dims(), rewriter);
auto newOp = rewriter.create<IREE::Stream::TensorImportOp>(
- op.getLoc(), resultType, adaptor.source(),
- TypeAttr::get(op.target().getType()), adaptor.target_dims(),
- resultSize,
+ op.getLoc(), resultType, adaptor.source(), TypeAttr::get(targetType),
+ adaptor.target_dims(), resultSize,
/*affinity=*/nullptr);
auto unknownType = rewriter.getType<IREE::Stream::ResourceType>();
@@ -66,7 +68,8 @@
op, unknownType, newOp.result(), resultSize, resultSize,
/*source_affinity=*/nullptr,
/*result_affinity=*/nullptr);
- } else if (op.target().getType().isa<IREE::HAL::BufferViewType>()) {
+ } else if (targetType.isa<IREE::HAL::BufferType>() ||
+ targetType.isa<IREE::HAL::BufferViewType>()) {
auto source =
consumeTensorOperand(op.getLoc(), adaptor.source(), rewriter);
auto externalType = rewriter.getType<IREE::Stream::ResourceType>(
@@ -82,9 +85,8 @@
// Export (stream resource to buffer view).
rewriter.replaceOpWithNewOp<IREE::Stream::TensorExportOp>(
- op, op.target().getType(), exportSource,
- TypeAttr::get(op.source().getType()), adaptor.source_dims(),
- source.resourceSize,
+ op, targetType, exportSource, TypeAttr::get(op.source().getType()),
+ adaptor.source_dims(), source.resourceSize,
/*affinity=*/nullptr);
} else {
return rewriter.notifyMatchFailure(op, "unsupported HAL cast conversion");