Adding input type legalization to chop up i64/f64 prior to processing. In the future (when we have backends that can support those types) we can make this narrowing optional and/or warn on it. For now this matches the behavior of the existing input legalization so no functional change. PiperOrigin-RevId: 284877942
diff --git a/iree/compiler/Dialect/Flow/Conversion/BUILD b/iree/compiler/Dialect/Flow/Conversion/BUILD new file mode 100644 index 0000000..1fa35a9 --- /dev/null +++ b/iree/compiler/Dialect/Flow/Conversion/BUILD
@@ -0,0 +1,34 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "Conversion", + srcs = [ + "TypeConverter.cpp", + ], + hdrs = [ + "TypeConverter.h", + ], + deps = [ + "//iree/compiler/Dialect", + "@local_config_mlir//:IR", + "@local_config_mlir//:Parser", + "@local_config_mlir//:Transforms", + ], +)
diff --git a/iree/compiler/Dialect/Flow/Conversion/TypeConverter.cpp b/iree/compiler/Dialect/Flow/Conversion/TypeConverter.cpp new file mode 100644 index 0000000..d445782 --- /dev/null +++ b/iree/compiler/Dialect/Flow/Conversion/TypeConverter.cpp
@@ -0,0 +1,67 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/compiler/Dialect/Flow/Conversion/TypeConverter.h" + +#include "iree/compiler/Dialect/Types.h" +#include "mlir/IR/StandardTypes.h" + +namespace mlir { +namespace iree_compiler { + +Type FlowTypeConverter::convertType(Type t) { + if (t.isIndex()) { + // Always treat as 32-bit. + return IntegerType::get(32, t.getContext()); + } else if (t.isIntOrIndexOrFloat()) { + if (auto integerType = t.dyn_cast<IntegerType>()) { + if (integerType.getWidth() > 32) { + // Don't support 64-bit types in general. Rewrite to i32 (if desired). + // TODO(benvanik): split to i32+i32? allow and use availability? + // TODO(benvanik): make an option. + return IntegerType::get(32, t.getContext()); + } + } else if (auto floatType = t.dyn_cast<FloatType>()) { + if (floatType.getWidth() > 32) { + // Don't support 64-bit types in general. Rewrite to f32 (if desired). + // TODO(benvanik): make an option. + return FloatType::getF32(t.getContext()); + } + } + } else if (auto tensorType = t.dyn_cast<RankedTensorType>()) { + auto convertedElementType = convertType(tensorType.getElementType()); + if (!convertedElementType) { + return {}; + } + return RankedTensorType::get(tensorType.getShape(), convertedElementType); + } else if (auto tensorType = t.dyn_cast<TensorType>()) { + // We only support ranked tensors. We could convert unranked to ranked + // here for certain cases (such as * on the LHS). + return {}; + } + // Allow types through by default. + return t; +} + +Operation *FlowTypeConverter::materializeConversion(PatternRewriter &rewriter, + Type resultType, + ArrayRef<Value *> inputs, + Location loc) { + // TODO(b/145876978): materialize conversion when this is called. + llvm_unreachable("unhandled materialization"); + return nullptr; +} + +} // namespace iree_compiler +} // namespace mlir
diff --git a/iree/compiler/Dialect/Flow/Conversion/TypeConverter.h b/iree/compiler/Dialect/Flow/Conversion/TypeConverter.h new file mode 100644 index 0000000..618c6ef --- /dev/null +++ b/iree/compiler/Dialect/Flow/Conversion/TypeConverter.h
@@ -0,0 +1,36 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_COMPILER_DIALECT_FLOW_CONVERSION_TYPECONVERTER_H_ +#define IREE_COMPILER_DIALECT_FLOW_CONVERSION_TYPECONVERTER_H_ + +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +namespace iree_compiler { + +class FlowTypeConverter : public TypeConverter { + public: + Type convertType(Type t) override; + using TypeConverter::convertType; + + Operation *materializeConversion(PatternRewriter &rewriter, Type resultType, + ArrayRef<Value *> inputs, + Location loc) override; +}; + +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_DIALECT_FLOW_CONVERSION_TYPECONVERTER_H_
diff --git a/iree/compiler/Dialect/Flow/Transforms/BUILD b/iree/compiler/Dialect/Flow/Transforms/BUILD index e3d31a4..30b5c10 100644 --- a/iree/compiler/Dialect/Flow/Transforms/BUILD +++ b/iree/compiler/Dialect/Flow/Transforms/BUILD
@@ -27,6 +27,7 @@ "FormStreams.cpp", "IdentifyDispatchRegions.cpp", "IdentifyReductionRegions.cpp", + "LegalizeInputTypes.cpp", "OutlineDispatchRegions.cpp", "OutlineReductionRegions.cpp", "Passes.cpp", @@ -38,6 +39,7 @@ ], deps = [ "//iree/compiler/Dialect/Flow/Analysis", + "//iree/compiler/Dialect/Flow/Conversion", "//iree/compiler/Dialect/Flow/Conversion/HLOToFlow", "//iree/compiler/Dialect/Flow/Conversion/StandardToFlow", "//iree/compiler/Dialect/Flow/IR",
diff --git a/iree/compiler/Dialect/Flow/Transforms/LegalizeInputTypes.cpp b/iree/compiler/Dialect/Flow/Transforms/LegalizeInputTypes.cpp new file mode 100644 index 0000000..44ab3d9 --- /dev/null +++ b/iree/compiler/Dialect/Flow/Transforms/LegalizeInputTypes.cpp
@@ -0,0 +1,253 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/compiler/Dialect/Flow/Conversion/TypeConverter.h" +#include "mlir/Analysis/Verifier.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/Utils.h" + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace Flow { + +namespace { + +Attribute convertAttribute(Location loc, Attribute value, + FlowTypeConverter &typeConverter) { + auto newType = typeConverter.convertType(value.getType()); + if (value.getType() == newType) { + return value; + } + + // TODO(benvanik): when std has a conversion op use that instead. + + if (auto attr = value.dyn_cast<IntegerAttr>()) { + // TODO(b/130356985): saturate when signedness is known. + return IntegerAttr::get( + newType, attr.getValue().trunc(newType.getIntOrFloatBitWidth())); + } else if (auto attr = value.dyn_cast<FloatAttr>()) { + switch (newType.getIntOrFloatBitWidth()) { + case 32: + return FloatAttr::get(newType, attr.getValueAsDouble()); + case 64: + return FloatAttr::get(newType, attr.getValueAsDouble()); + default: + break; + } + } else if (auto attr = value.dyn_cast<SplatElementsAttr>()) { + return SplatElementsAttr::get( + newType.cast<ShapedType>(), + convertAttribute(loc, attr.getSplatValue(), typeConverter)); + } else if (auto attr = value.dyn_cast<DenseIntElementsAttr>()) { + auto newElementType = newType.cast<ShapedType>().getElementType(); + auto newElementBitWidth = newElementType.getIntOrFloatBitWidth(); + return attr.mapValues(newElementType, [&](APInt src) { + // TODO(b/130356985): saturate when signedness is known. + return src.trunc(newElementBitWidth); + }); + } + + emitError(loc) << "unsupported attribute kind for conversion from " + << value.getType() << " to " << newType; + return {}; +} + +LogicalResult convertRegion(Region &oldRegion, Region &newRegion, + FlowTypeConverter &typeConverter, + BlockAndValueMapping &mapping); + +LogicalResult convertOperation(Operation *oldOp, + FlowTypeConverter &typeConverter, + BlockAndValueMapping &mapping, + OpBuilder &builder) { + OperationState state(oldOp->getLoc(), oldOp->getName()); + for (auto oldType : oldOp->getResultTypes()) { + if (failed(typeConverter.convertType(oldType, state.types))) { + return failure(); + } + } + + if (auto constantOp = dyn_cast<mlir::ConstantOp>(oldOp)) { + auto newValue = + convertAttribute(oldOp->getLoc(), constantOp.value(), typeConverter); + if (!newValue) { + return failure(); + } + state.addAttribute("value", newValue); + } else { + state.attributes = llvm::to_vector<4>(oldOp->getAttrs()); + } + + if (oldOp->getNumSuccessors() == 0) { + // Non-branching operations can just add all the operands. + for (auto *oldOperand : oldOp->getOperands()) { + state.operands.push_back(mapping.lookup(oldOperand)); + } + } else { + // We add the operands separated by nullptr's for each successor. + unsigned firstSuccOperand = oldOp->getNumSuccessors() + ? oldOp->getSuccessorOperandIndex(0) + : oldOp->getNumOperands(); + auto oldOperands = oldOp->getOpOperands(); + for (unsigned i = 0; i != firstSuccOperand; ++i) { + state.operands.push_back(mapping.lookup(oldOperands[i].get())); + } + for (unsigned succ = 0, e = oldOp->getNumSuccessors(); succ != e; ++succ) { + state.successors.push_back( + mapping.lookupOrDefault(oldOp->getSuccessor(succ))); + // Add sentinel to delineate successor operands. + state.operands.push_back(nullptr); + // Remap the successors operands. + for (auto *operand : oldOp->getSuccessorOperands(succ)) { + state.operands.push_back(mapping.lookup(operand)); + } + } + } + + for (auto &oldRegion : oldOp->getRegions()) { + auto *newRegion = state.addRegion(); + if (failed(convertRegion(oldRegion, *newRegion, typeConverter, mapping))) { + return failure(); + } + } + + auto *newOp = builder.createOperation(state); + if (failed(mlir::verify(newOp))) { + // TODO(benvanik): we could possibly try again with a different set of type + // conversions to see if that works. For example, we could lean toward + // materializing conversions/inserting cases instead of directly doing the + // conversions here. Unfortunately ops don't allow us to query what types + // they support so this is trial-and-error. + return newOp->emitOpError() + << "post-conversion verification failed - unsupported types"; + } + + for (auto oldNewResult : + llvm::zip(oldOp->getResults(), newOp->getResults())) { + auto *oldResult = std::get<0>(oldNewResult); + auto *newResult = std::get<1>(oldNewResult); + mapping.map(oldResult, newResult); + } + + return success(); +} + +LogicalResult convertBlock(Block &oldBlock, Block &newBlock, + FlowTypeConverter &typeConverter, + BlockAndValueMapping &mapping) { + OpBuilder builder(oldBlock.getParent()->getContext()); + builder.setInsertionPointToEnd(&newBlock); + for (auto &oldOp : oldBlock) { + if (failed(convertOperation(&oldOp, typeConverter, mapping, builder))) { + return oldOp.emitOpError() << "unable to legalize operation types"; + } + } + return success(); +} + +LogicalResult convertRegion(Region &oldRegion, Region &newRegion, + FlowTypeConverter &typeConverter, + BlockAndValueMapping &mapping) { + OpBuilder builder(oldRegion.getContext()); + for (auto &oldBlock : oldRegion) { + auto &newBlock = *builder.createBlock(&newRegion); + auto blockSignature = typeConverter.convertBlockSignature(&oldBlock); + if (!blockSignature) { + return oldBlock.front().emitError() + << "unable to legalize block signature"; + } + newBlock.addArguments(blockSignature->getConvertedTypes()); + for (auto oldNewArg : + llvm::zip(oldBlock.getArguments(), newBlock.getArguments())) { + mapping.map(std::get<0>(oldNewArg), std::get<1>(oldNewArg)); + } + mapping.map(&oldBlock, &newBlock); + } + for (auto &oldBlock : oldRegion) { + if (failed(convertBlock(oldBlock, *mapping.lookup(&oldBlock), typeConverter, + mapping))) { + return failure(); + } + } + return success(); +} + +} // namespace + +class LegalizeInputTypesPass : public ModulePass<LegalizeInputTypesPass> { + public: + void runOnModule() override { + auto moduleOp = getModule(); + FlowTypeConverter typeConverter; + + auto oldFuncOps = llvm::to_vector<16>(moduleOp.getOps<FuncOp>()); + for (auto oldFuncOp : oldFuncOps) { + OpBuilder moduleBuilder(moduleOp); + moduleBuilder.setInsertionPoint(oldFuncOp); + + auto oldType = oldFuncOp.getType(); + TypeConverter::SignatureConversion signature(oldType.getNumInputs()); + for (unsigned i = 0, e = oldType.getNumInputs(); i != e; ++i) { + if (failed(typeConverter.convertSignatureArg(i, oldType.getInput(i), + signature))) { + oldFuncOp.emitOpError() << "unable to legalize type of input " << i; + return signalPassFailure(); + } + } + SmallVector<Type, 1> convertedResults; + if (failed(typeConverter.convertTypes(oldType.getResults(), + convertedResults))) { + oldFuncOp.emitOpError() << "unable to legalize result types"; + return signalPassFailure(); + } + + auto newFuncOp = + cast<FuncOp>(moduleBuilder.cloneWithoutRegions(*oldFuncOp)); + newFuncOp.setType(FunctionType::get(signature.getConvertedTypes(), + convertedResults, &getContext())); + + BlockAndValueMapping mapping; + if (failed(convertRegion(oldFuncOp.getBody(), newFuncOp.getBody(), + typeConverter, mapping))) { + return signalPassFailure(); + } + + oldFuncOp.erase(); + } + } +}; + +std::unique_ptr<OpPassBase<ModuleOp>> createLegalizeInputTypesPass() { + return std::make_unique<LegalizeInputTypesPass>(); +} + +static PassRegistration<LegalizeInputTypesPass> pass( + "iree-flow-legalize-input-types", + "Legalizes input types to ones supported by the IREE flow dialect"); + +} // namespace Flow +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir
diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp index 8ab26b2..0d19fbc 100644 --- a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp +++ b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
@@ -38,6 +38,10 @@ passManager.addNestedPass<FuncOp>(createCanonicalizerPass()); passManager.addNestedPass<FuncOp>(createCSEPass()); + // Legalize input types. We do this after flattening tuples so that we don't + // have to deal with them. + passManager.addPass(IREE::Flow::createLegalizeInputTypesPass()); + // Convert into our expected input and (hopefully) some flow ops. passManager.addNestedPass<FuncOp>( IREE::Flow::createPrePartitioningConversionPass());
diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.h b/iree/compiler/Dialect/Flow/Transforms/Passes.h index 042f296..431dc49 100644 --- a/iree/compiler/Dialect/Flow/Transforms/Passes.h +++ b/iree/compiler/Dialect/Flow/Transforms/Passes.h
@@ -51,6 +51,12 @@ // Flattens tuple values in function signatures and blocks. std::unique_ptr<OpPassBase<ModuleOp>> createFlattenTuplesInCFGPass(); +// Legalizes the input types to those supported by the flow dialect. +// This will fail if types that cannot be supported at all are present, however +// conditionally supported types (based on availability, etc) may still be +// allowed to pass through successfully. +std::unique_ptr<OpPassBase<ModuleOp>> createLegalizeInputTypesPass(); + // Runs pre-partitioning conversion passes to convert to the flow dialect. // This converts some input ops directly to flow ops when doing so has a // benefit. Other ops are left unmodified and will be outlined later on.
diff --git a/iree/compiler/Dialect/Flow/Transforms/PrePostPartitioningConversion.cpp b/iree/compiler/Dialect/Flow/Transforms/PrePostPartitioningConversion.cpp index 6993207..f4c6a11 100644 --- a/iree/compiler/Dialect/Flow/Transforms/PrePostPartitioningConversion.cpp +++ b/iree/compiler/Dialect/Flow/Transforms/PrePostPartitioningConversion.cpp
@@ -16,6 +16,7 @@ #include "iree/compiler/Dialect/Flow/Conversion/HLOToFlow/ConvertHLOToFlow.h" #include "iree/compiler/Dialect/Flow/Conversion/StandardToFlow/ConvertStandardToFlow.h" +#include "iree/compiler/Dialect/Flow/Conversion/TypeConverter.h" #include "iree/compiler/Dialect/Flow/IR/FlowDialect.h" #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/Builders.h" @@ -35,7 +36,9 @@ : public FunctionPass<PrePartitioningConversionPass> { public: void runOnFunction() override { - ConversionTarget conversionTarget(getContext()); + auto *context = &getContext(); + FlowTypeConverter typeConverter; + ConversionTarget conversionTarget(*context); OwningRewritePatternList conversionPatterns; conversionTarget.addLegalDialect<IREE::Flow::FlowDialect>(); @@ -55,18 +58,17 @@ conversionTarget.addIllegalOp<xla_hlo::ConditionalOp, xla_hlo::WhileOp>(); conversionTarget.addIllegalOp<xla_hlo::DotGeneralOp>(); - xla_hlo::PopulateGeneralDotOpLoweringPatterns(&conversionPatterns, - &getContext()); + xla_hlo::PopulateGeneralDotOpLoweringPatterns(&conversionPatterns, context); // Early conversion of ops that have matches we want to route through. // For example, DynamicUpdateSlice should end up as a stream operation. - setupDirectHLOToFlowLegality(&getContext(), conversionTarget); - populateHLOToFlowPatterns(&getContext(), conversionPatterns); - setupDirectStandardToFlowLegality(&getContext(), conversionTarget); - populateStandardToFlowPatterns(&getContext(), conversionPatterns); + setupDirectHLOToFlowLegality(context, conversionTarget); + populateHLOToFlowPatterns(context, conversionPatterns); + setupDirectStandardToFlowLegality(context, conversionTarget); + populateStandardToFlowPatterns(context, conversionPatterns); if (failed(applyFullConversion(getFunction(), conversionTarget, - conversionPatterns))) { + conversionPatterns, &typeConverter))) { getFunction().emitError() << "module is not in a compatible input format"; return signalPassFailure(); } @@ -77,7 +79,9 @@ : public FunctionPass<PostPartitioningConversionPass> { public: void runOnFunction() override { + auto *context = &getContext(); ConversionTarget conversionTarget(getContext()); + FlowTypeConverter typeConverter; OwningRewritePatternList conversionPatterns; // We have completed all flow op creation at this point. @@ -91,11 +95,11 @@ conversionTarget.addLegalOp<ModuleOp, ModuleTerminatorOp, FuncOp>(); // Pick up any remaining HLO ops that were not partitioned. - populateHLOToFlowPatterns(&getContext(), conversionPatterns); - populateStandardToFlowPatterns(&getContext(), conversionPatterns); + populateHLOToFlowPatterns(context, conversionPatterns); + populateStandardToFlowPatterns(context, conversionPatterns); if (failed(applyFullConversion(getFunction(), conversionTarget, - conversionPatterns))) { + conversionPatterns, &typeConverter))) { getFunction().emitError() << "module is not in a compatible input format"; return signalPassFailure(); }
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/legalize_input_types.mlir b/iree/compiler/Dialect/Flow/Transforms/test/legalize_input_types.mlir new file mode 100644 index 0000000..fd99593 --- /dev/null +++ b/iree/compiler/Dialect/Flow/Transforms/test/legalize_input_types.mlir
@@ -0,0 +1,124 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// RUN: iree-opt -split-input-file -verify-diagnostics -iree-flow-legalize-input-types %s | IreeFileCheck %s + +// CHECK-LABEL: func @constantI64 +// CHECK-SAME: () -> i32 +func @constantI64() -> i64 { + // CHECK-NEXT: constant 123 : i32 + %c123 = constant 123 : i64 + return %c123 : i64 +} + +// ----- + +// CHECK-LABEL: func @constantF64 +// CHECK-SAME: () -> f32 +func @constantF64() -> f64 { + // CHECK-NEXT: constant 1.234000e+02 : f32 + %c1234 = constant 123.4 : f64 + return %c1234 : f64 +} + +// ----- + +// CHECK-LABEL: func @constantSplatTensorI64 +// CHECK-SAME: () -> tensor<4xi32> +func @constantSplatTensorI64() -> tensor<4xi64> { + // CHECK-NEXT: constant dense<123> : tensor<4xi32> + %c123 = constant dense<123> : tensor<4xi64> + return %c123 : tensor<4xi64> +} + +// ----- + +// CHECK-LABEL: func @constantDenseTensorI64 +// CHECK-SAME: () -> tensor<4xi32> +func @constantDenseTensorI64() -> tensor<4xi64> { + // CHECK-NEXT: constant dense<[0, 1, 2, 3]> : tensor<4xi32> + %c123 = constant dense<[0, 1, 2, 3]> : tensor<4xi64> + return %c123 : tensor<4xi64> +} + +// ----- + +// CHECK-LABEL: func @typesIndex +// CHECK-SAME: (%arg0: i32) -> i32 +func @typesIndex(%arg0 : index) -> index { + // CHECK-NEXT: return %arg0 : i32 + return %arg0 : index +} + +// ----- + +// CHECK-LABEL: func @typesI64 +// CHECK-SAME: (%arg0: i32) -> i32 +func @typesI64(%arg0 : i64) -> i64 { + // CHECK-NEXT: return %arg0 : i32 + return %arg0 : i64 +} + +// ----- + +// CHECK-LABEL: func @tensorTypesI64 +// CHECK-SAME: (%arg0: tensor<4x4xi32>) -> tensor<4x4xi32> +func @tensorTypesI64(%arg0 : tensor<4x4xi64>) -> tensor<4x4xi64> { + // CHECK-NEXT: return %arg0 : tensor<4x4xi32> + return %arg0 : tensor<4x4xi64> +} + +// ----- + +// CHECK-LABEL: func @tensorTypesF64 +// CHECK-SAME: (%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> +func @tensorTypesF64(%arg0 : tensor<4x4xf64>) -> tensor<4x4xf64> { + // CHECK-NEXT: return %arg0 : tensor<4x4xf32> + return %arg0 : tensor<4x4xf64> +} + +// ----- +// expected-error@+1 {{'func' op unable to legalize type of input 0}} +func @tensorUnrankedArg(%arg0 : tensor<*xi64>) -> tensor<*xi64> { + return %arg0 : tensor<*xi64> +} + +// ----- +func @tensorUnrankedValue(%arg0 : tensor<4xi64>) -> tensor<4xi64> { + // expected-error@+1 {{'std.tensor_cast' op unable to legalize operation types}} + %0 = tensor_cast %arg0 : tensor<4xi64> to tensor<*xi64> + %1 = tensor_cast %0 : tensor<*xi64> to tensor<4xi64> + return %1 : tensor<4xi64> +} + +// ----- + +// CHECK-LABEL: func @compareI64 +// CHECK-SAME: (%arg0: tensor<i32>, %arg1: tensor<i32>) -> (i1, tensor<i32>) +func @compareI64(%arg0 : tensor<i64>, %arg1 : tensor<i64>) -> (i1, tensor<i64>) { + // CHECK-NEXT: %0 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "LT"} : (tensor<i32>, tensor<i32>) -> tensor<i1> + // CHECK-NEXT: %1 = extract_element %0[] : tensor<i1> + // CHECK-NEXT: cond_br %1, ^bb1(%1, %arg0 : i1, tensor<i32>), ^bb2(%1, %arg1 : i1, tensor<i32>) + // CHECK-NEXT: ^bb1(%2: i1, %3: tensor<i32>): // pred: ^bb0 + // CHECK-NEXT: return %2, %3 : i1, tensor<i32> + // CHECK-NEXT: ^bb2(%4: i1, %5: tensor<i32>): // pred: ^bb0 + // CHECK-NEXT: return %4, %5 : i1, tensor<i32> + %0 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1> + %1 = extract_element %0[] : tensor<i1> + cond_br %1, ^bb1(%1, %arg0 : i1, tensor<i64>), ^bb2(%1, %arg1 : i1, tensor<i64>) +^bb1(%2 : i1, %3 : tensor<i64>): + return %2, %3 : i1, tensor<i64> +^bb2(%4 : i1, %5 : tensor<i64>): + return %4, %5 : i1, tensor<i64> +}
diff --git a/iree/compiler/Transforms/Interpreter/LowerXLAToInterpreterDialect.cpp b/iree/compiler/Transforms/Interpreter/LowerXLAToInterpreterDialect.cpp index 6d390cf..c207df1 100644 --- a/iree/compiler/Transforms/Interpreter/LowerXLAToInterpreterDialect.cpp +++ b/iree/compiler/Transforms/Interpreter/LowerXLAToInterpreterDialect.cpp
@@ -400,12 +400,18 @@ ->getResult(0); } - llvm::SmallVector<int64_t, 4> zeroes; - zeroes.resize(extraDims, 0); - - auto elementsAttr = DenseIntElementsAttr::get( - RankedTensorType::get(zeroes.size(), elementType), - llvm::makeArrayRef(zeroes)); + ElementsAttr elementsAttr; + if (elementType.isInteger(32)) { + llvm::SmallVector<int32_t, 4> zeroes(extraDims); + elementsAttr = DenseIntElementsAttr::get( + RankedTensorType::get(zeroes.size(), elementType), + llvm::makeArrayRef(zeroes)); + } else { + llvm::SmallVector<int64_t, 4> zeroes(extraDims); + elementsAttr = DenseIntElementsAttr::get( + RankedTensorType::get(zeroes.size(), elementType), + llvm::makeArrayRef(zeroes)); + } auto extraStartIndices = rewriter.create<IREE::ConstantOp>(gatherOp.getLoc(), elementsAttr);