Adding input type legalization to chop up i64/f64 prior to processing.
In the future (when we have backends that can support those types) we can make this narrowing optional and/or warn on it. For now this matches the behavior of the existing input legalization so no functional change.

PiperOrigin-RevId: 284877942
diff --git a/iree/compiler/Dialect/Flow/Conversion/BUILD b/iree/compiler/Dialect/Flow/Conversion/BUILD
new file mode 100644
index 0000000..1fa35a9
--- /dev/null
+++ b/iree/compiler/Dialect/Flow/Conversion/BUILD
@@ -0,0 +1,34 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+package(
+    default_visibility = ["//visibility:public"],
+    licenses = ["notice"],  # Apache 2.0
+)
+
+cc_library(
+    name = "Conversion",
+    srcs = [
+        "TypeConverter.cpp",
+    ],
+    hdrs = [
+        "TypeConverter.h",
+    ],
+    deps = [
+        "//iree/compiler/Dialect",
+        "@local_config_mlir//:IR",
+        "@local_config_mlir//:Parser",
+        "@local_config_mlir//:Transforms",
+    ],
+)
diff --git a/iree/compiler/Dialect/Flow/Conversion/TypeConverter.cpp b/iree/compiler/Dialect/Flow/Conversion/TypeConverter.cpp
new file mode 100644
index 0000000..d445782
--- /dev/null
+++ b/iree/compiler/Dialect/Flow/Conversion/TypeConverter.cpp
@@ -0,0 +1,67 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/compiler/Dialect/Flow/Conversion/TypeConverter.h"
+
+#include "iree/compiler/Dialect/Types.h"
+#include "mlir/IR/StandardTypes.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+Type FlowTypeConverter::convertType(Type t) {
+  if (t.isIndex()) {
+    // Always treat as 32-bit.
+    return IntegerType::get(32, t.getContext());
+  } else if (t.isIntOrIndexOrFloat()) {
+    if (auto integerType = t.dyn_cast<IntegerType>()) {
+      if (integerType.getWidth() > 32) {
+        // Don't support 64-bit types in general. Rewrite to i32 (if desired).
+        // TODO(benvanik): split to i32+i32? allow and use availability?
+        // TODO(benvanik): make an option.
+        return IntegerType::get(32, t.getContext());
+      }
+    } else if (auto floatType = t.dyn_cast<FloatType>()) {
+      if (floatType.getWidth() > 32) {
+        // Don't support 64-bit types in general. Rewrite to f32 (if desired).
+        // TODO(benvanik): make an option.
+        return FloatType::getF32(t.getContext());
+      }
+    }
+  } else if (auto tensorType = t.dyn_cast<RankedTensorType>()) {
+    auto convertedElementType = convertType(tensorType.getElementType());
+    if (!convertedElementType) {
+      return {};
+    }
+    return RankedTensorType::get(tensorType.getShape(), convertedElementType);
+  } else if (auto tensorType = t.dyn_cast<TensorType>()) {
+    // We only support ranked tensors. We could convert unranked to ranked
+    // here for certain cases (such as * on the LHS).
+    return {};
+  }
+  // Allow types through by default.
+  return t;
+}
+
+Operation *FlowTypeConverter::materializeConversion(PatternRewriter &rewriter,
+                                                    Type resultType,
+                                                    ArrayRef<Value *> inputs,
+                                                    Location loc) {
+  // TODO(b/145876978): materialize conversion when this is called.
+  llvm_unreachable("unhandled materialization");
+  return nullptr;
+}
+
+}  // namespace iree_compiler
+}  // namespace mlir
diff --git a/iree/compiler/Dialect/Flow/Conversion/TypeConverter.h b/iree/compiler/Dialect/Flow/Conversion/TypeConverter.h
new file mode 100644
index 0000000..618c6ef
--- /dev/null
+++ b/iree/compiler/Dialect/Flow/Conversion/TypeConverter.h
@@ -0,0 +1,36 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_COMPILER_DIALECT_FLOW_CONVERSION_TYPECONVERTER_H_
+#define IREE_COMPILER_DIALECT_FLOW_CONVERSION_TYPECONVERTER_H_
+
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+class FlowTypeConverter : public TypeConverter {
+ public:
+  Type convertType(Type t) override;
+  using TypeConverter::convertType;
+
+  Operation *materializeConversion(PatternRewriter &rewriter, Type resultType,
+                                   ArrayRef<Value *> inputs,
+                                   Location loc) override;
+};
+
+}  // namespace iree_compiler
+}  // namespace mlir
+
+#endif  // IREE_COMPILER_DIALECT_FLOW_CONVERSION_TYPECONVERTER_H_
diff --git a/iree/compiler/Dialect/Flow/Transforms/BUILD b/iree/compiler/Dialect/Flow/Transforms/BUILD
index e3d31a4..30b5c10 100644
--- a/iree/compiler/Dialect/Flow/Transforms/BUILD
+++ b/iree/compiler/Dialect/Flow/Transforms/BUILD
@@ -27,6 +27,7 @@
         "FormStreams.cpp",
         "IdentifyDispatchRegions.cpp",
         "IdentifyReductionRegions.cpp",
+        "LegalizeInputTypes.cpp",
         "OutlineDispatchRegions.cpp",
         "OutlineReductionRegions.cpp",
         "Passes.cpp",
@@ -38,6 +39,7 @@
     ],
     deps = [
         "//iree/compiler/Dialect/Flow/Analysis",
+        "//iree/compiler/Dialect/Flow/Conversion",
         "//iree/compiler/Dialect/Flow/Conversion/HLOToFlow",
         "//iree/compiler/Dialect/Flow/Conversion/StandardToFlow",
         "//iree/compiler/Dialect/Flow/IR",
diff --git a/iree/compiler/Dialect/Flow/Transforms/LegalizeInputTypes.cpp b/iree/compiler/Dialect/Flow/Transforms/LegalizeInputTypes.cpp
new file mode 100644
index 0000000..44ab3d9
--- /dev/null
+++ b/iree/compiler/Dialect/Flow/Transforms/LegalizeInputTypes.cpp
@@ -0,0 +1,253 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/compiler/Dialect/Flow/Conversion/TypeConverter.h"
+#include "mlir/Analysis/Verifier.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/Utils.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace Flow {
+
+namespace {
+
+Attribute convertAttribute(Location loc, Attribute value,
+                           FlowTypeConverter &typeConverter) {
+  auto newType = typeConverter.convertType(value.getType());
+  if (value.getType() == newType) {
+    return value;
+  }
+
+  // TODO(benvanik): when std has a conversion op use that instead.
+
+  if (auto attr = value.dyn_cast<IntegerAttr>()) {
+    // TODO(b/130356985): saturate when signedness is known.
+    return IntegerAttr::get(
+        newType, attr.getValue().trunc(newType.getIntOrFloatBitWidth()));
+  } else if (auto attr = value.dyn_cast<FloatAttr>()) {
+    switch (newType.getIntOrFloatBitWidth()) {
+      case 32:
+        return FloatAttr::get(newType, attr.getValueAsDouble());
+      case 64:
+        return FloatAttr::get(newType, attr.getValueAsDouble());
+      default:
+        break;
+    }
+  } else if (auto attr = value.dyn_cast<SplatElementsAttr>()) {
+    return SplatElementsAttr::get(
+        newType.cast<ShapedType>(),
+        convertAttribute(loc, attr.getSplatValue(), typeConverter));
+  } else if (auto attr = value.dyn_cast<DenseIntElementsAttr>()) {
+    auto newElementType = newType.cast<ShapedType>().getElementType();
+    auto newElementBitWidth = newElementType.getIntOrFloatBitWidth();
+    return attr.mapValues(newElementType, [&](APInt src) {
+      // TODO(b/130356985): saturate when signedness is known.
+      return src.trunc(newElementBitWidth);
+    });
+  }
+
+  emitError(loc) << "unsupported attribute kind for conversion from "
+                 << value.getType() << " to " << newType;
+  return {};
+}
+
+LogicalResult convertRegion(Region &oldRegion, Region &newRegion,
+                            FlowTypeConverter &typeConverter,
+                            BlockAndValueMapping &mapping);
+
+LogicalResult convertOperation(Operation *oldOp,
+                               FlowTypeConverter &typeConverter,
+                               BlockAndValueMapping &mapping,
+                               OpBuilder &builder) {
+  OperationState state(oldOp->getLoc(), oldOp->getName());
+  for (auto oldType : oldOp->getResultTypes()) {
+    if (failed(typeConverter.convertType(oldType, state.types))) {
+      return failure();
+    }
+  }
+
+  if (auto constantOp = dyn_cast<mlir::ConstantOp>(oldOp)) {
+    auto newValue =
+        convertAttribute(oldOp->getLoc(), constantOp.value(), typeConverter);
+    if (!newValue) {
+      return failure();
+    }
+    state.addAttribute("value", newValue);
+  } else {
+    state.attributes = llvm::to_vector<4>(oldOp->getAttrs());
+  }
+
+  if (oldOp->getNumSuccessors() == 0) {
+    // Non-branching operations can just add all the operands.
+    for (auto *oldOperand : oldOp->getOperands()) {
+      state.operands.push_back(mapping.lookup(oldOperand));
+    }
+  } else {
+    // We add the operands separated by nullptr's for each successor.
+    unsigned firstSuccOperand = oldOp->getNumSuccessors()
+                                    ? oldOp->getSuccessorOperandIndex(0)
+                                    : oldOp->getNumOperands();
+    auto oldOperands = oldOp->getOpOperands();
+    for (unsigned i = 0; i != firstSuccOperand; ++i) {
+      state.operands.push_back(mapping.lookup(oldOperands[i].get()));
+    }
+    for (unsigned succ = 0, e = oldOp->getNumSuccessors(); succ != e; ++succ) {
+      state.successors.push_back(
+          mapping.lookupOrDefault(oldOp->getSuccessor(succ)));
+      // Add sentinel to delineate successor operands.
+      state.operands.push_back(nullptr);
+      // Remap the successors operands.
+      for (auto *operand : oldOp->getSuccessorOperands(succ)) {
+        state.operands.push_back(mapping.lookup(operand));
+      }
+    }
+  }
+
+  for (auto &oldRegion : oldOp->getRegions()) {
+    auto *newRegion = state.addRegion();
+    if (failed(convertRegion(oldRegion, *newRegion, typeConverter, mapping))) {
+      return failure();
+    }
+  }
+
+  auto *newOp = builder.createOperation(state);
+  if (failed(mlir::verify(newOp))) {
+    // TODO(benvanik): we could possibly try again with a different set of type
+    // conversions to see if that works. For example, we could lean toward
+    // materializing conversions/inserting cases instead of directly doing the
+    // conversions here. Unfortunately ops don't allow us to query what types
+    // they support so this is trial-and-error.
+    return newOp->emitOpError()
+           << "post-conversion verification failed - unsupported types";
+  }
+
+  for (auto oldNewResult :
+       llvm::zip(oldOp->getResults(), newOp->getResults())) {
+    auto *oldResult = std::get<0>(oldNewResult);
+    auto *newResult = std::get<1>(oldNewResult);
+    mapping.map(oldResult, newResult);
+  }
+
+  return success();
+}
+
+LogicalResult convertBlock(Block &oldBlock, Block &newBlock,
+                           FlowTypeConverter &typeConverter,
+                           BlockAndValueMapping &mapping) {
+  OpBuilder builder(oldBlock.getParent()->getContext());
+  builder.setInsertionPointToEnd(&newBlock);
+  for (auto &oldOp : oldBlock) {
+    if (failed(convertOperation(&oldOp, typeConverter, mapping, builder))) {
+      return oldOp.emitOpError() << "unable to legalize operation types";
+    }
+  }
+  return success();
+}
+
+LogicalResult convertRegion(Region &oldRegion, Region &newRegion,
+                            FlowTypeConverter &typeConverter,
+                            BlockAndValueMapping &mapping) {
+  OpBuilder builder(oldRegion.getContext());
+  for (auto &oldBlock : oldRegion) {
+    auto &newBlock = *builder.createBlock(&newRegion);
+    auto blockSignature = typeConverter.convertBlockSignature(&oldBlock);
+    if (!blockSignature) {
+      return oldBlock.front().emitError()
+             << "unable to legalize block signature";
+    }
+    newBlock.addArguments(blockSignature->getConvertedTypes());
+    for (auto oldNewArg :
+         llvm::zip(oldBlock.getArguments(), newBlock.getArguments())) {
+      mapping.map(std::get<0>(oldNewArg), std::get<1>(oldNewArg));
+    }
+    mapping.map(&oldBlock, &newBlock);
+  }
+  for (auto &oldBlock : oldRegion) {
+    if (failed(convertBlock(oldBlock, *mapping.lookup(&oldBlock), typeConverter,
+                            mapping))) {
+      return failure();
+    }
+  }
+  return success();
+}
+
+}  // namespace
+
+class LegalizeInputTypesPass : public ModulePass<LegalizeInputTypesPass> {
+ public:
+  void runOnModule() override {
+    auto moduleOp = getModule();
+    FlowTypeConverter typeConverter;
+
+    auto oldFuncOps = llvm::to_vector<16>(moduleOp.getOps<FuncOp>());
+    for (auto oldFuncOp : oldFuncOps) {
+      OpBuilder moduleBuilder(moduleOp);
+      moduleBuilder.setInsertionPoint(oldFuncOp);
+
+      auto oldType = oldFuncOp.getType();
+      TypeConverter::SignatureConversion signature(oldType.getNumInputs());
+      for (unsigned i = 0, e = oldType.getNumInputs(); i != e; ++i) {
+        if (failed(typeConverter.convertSignatureArg(i, oldType.getInput(i),
+                                                     signature))) {
+          oldFuncOp.emitOpError() << "unable to legalize type of input " << i;
+          return signalPassFailure();
+        }
+      }
+      SmallVector<Type, 1> convertedResults;
+      if (failed(typeConverter.convertTypes(oldType.getResults(),
+                                            convertedResults))) {
+        oldFuncOp.emitOpError() << "unable to legalize result types";
+        return signalPassFailure();
+      }
+
+      auto newFuncOp =
+          cast<FuncOp>(moduleBuilder.cloneWithoutRegions(*oldFuncOp));
+      newFuncOp.setType(FunctionType::get(signature.getConvertedTypes(),
+                                          convertedResults, &getContext()));
+
+      BlockAndValueMapping mapping;
+      if (failed(convertRegion(oldFuncOp.getBody(), newFuncOp.getBody(),
+                               typeConverter, mapping))) {
+        return signalPassFailure();
+      }
+
+      oldFuncOp.erase();
+    }
+  }
+};
+
+std::unique_ptr<OpPassBase<ModuleOp>> createLegalizeInputTypesPass() {
+  return std::make_unique<LegalizeInputTypesPass>();
+}
+
+static PassRegistration<LegalizeInputTypesPass> pass(
+    "iree-flow-legalize-input-types",
+    "Legalizes input types to ones supported by the IREE flow dialect");
+
+}  // namespace Flow
+}  // namespace IREE
+}  // namespace iree_compiler
+}  // namespace mlir
diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
index 8ab26b2..0d19fbc 100644
--- a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
@@ -38,6 +38,10 @@
   passManager.addNestedPass<FuncOp>(createCanonicalizerPass());
   passManager.addNestedPass<FuncOp>(createCSEPass());
 
+  // Legalize input types. We do this after flattening tuples so that we don't
+  // have to deal with them.
+  passManager.addPass(IREE::Flow::createLegalizeInputTypesPass());
+
   // Convert into our expected input and (hopefully) some flow ops.
   passManager.addNestedPass<FuncOp>(
       IREE::Flow::createPrePartitioningConversionPass());
diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.h b/iree/compiler/Dialect/Flow/Transforms/Passes.h
index 042f296..431dc49 100644
--- a/iree/compiler/Dialect/Flow/Transforms/Passes.h
+++ b/iree/compiler/Dialect/Flow/Transforms/Passes.h
@@ -51,6 +51,12 @@
 // Flattens tuple values in function signatures and blocks.
 std::unique_ptr<OpPassBase<ModuleOp>> createFlattenTuplesInCFGPass();
 
+// Legalizes the input types to those supported by the flow dialect.
+// This will fail if types that cannot be supported at all are present, however
+// conditionally supported types (based on availability, etc) may still be
+// allowed to pass through successfully.
+std::unique_ptr<OpPassBase<ModuleOp>> createLegalizeInputTypesPass();
+
 // Runs pre-partitioning conversion passes to convert to the flow dialect.
 // This converts some input ops directly to flow ops when doing so has a
 // benefit. Other ops are left unmodified and will be outlined later on.
diff --git a/iree/compiler/Dialect/Flow/Transforms/PrePostPartitioningConversion.cpp b/iree/compiler/Dialect/Flow/Transforms/PrePostPartitioningConversion.cpp
index 6993207..f4c6a11 100644
--- a/iree/compiler/Dialect/Flow/Transforms/PrePostPartitioningConversion.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/PrePostPartitioningConversion.cpp
@@ -16,6 +16,7 @@
 
 #include "iree/compiler/Dialect/Flow/Conversion/HLOToFlow/ConvertHLOToFlow.h"
 #include "iree/compiler/Dialect/Flow/Conversion/StandardToFlow/ConvertStandardToFlow.h"
+#include "iree/compiler/Dialect/Flow/Conversion/TypeConverter.h"
 #include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
 #include "mlir/Dialect/StandardOps/Ops.h"
 #include "mlir/IR/Builders.h"
@@ -35,7 +36,9 @@
     : public FunctionPass<PrePartitioningConversionPass> {
  public:
   void runOnFunction() override {
-    ConversionTarget conversionTarget(getContext());
+    auto *context = &getContext();
+    FlowTypeConverter typeConverter;
+    ConversionTarget conversionTarget(*context);
     OwningRewritePatternList conversionPatterns;
 
     conversionTarget.addLegalDialect<IREE::Flow::FlowDialect>();
@@ -55,18 +58,17 @@
     conversionTarget.addIllegalOp<xla_hlo::ConditionalOp, xla_hlo::WhileOp>();
 
     conversionTarget.addIllegalOp<xla_hlo::DotGeneralOp>();
-    xla_hlo::PopulateGeneralDotOpLoweringPatterns(&conversionPatterns,
-                                                  &getContext());
+    xla_hlo::PopulateGeneralDotOpLoweringPatterns(&conversionPatterns, context);
 
     // Early conversion of ops that have matches we want to route through.
     // For example, DynamicUpdateSlice should end up as a stream operation.
-    setupDirectHLOToFlowLegality(&getContext(), conversionTarget);
-    populateHLOToFlowPatterns(&getContext(), conversionPatterns);
-    setupDirectStandardToFlowLegality(&getContext(), conversionTarget);
-    populateStandardToFlowPatterns(&getContext(), conversionPatterns);
+    setupDirectHLOToFlowLegality(context, conversionTarget);
+    populateHLOToFlowPatterns(context, conversionPatterns);
+    setupDirectStandardToFlowLegality(context, conversionTarget);
+    populateStandardToFlowPatterns(context, conversionPatterns);
 
     if (failed(applyFullConversion(getFunction(), conversionTarget,
-                                   conversionPatterns))) {
+                                   conversionPatterns, &typeConverter))) {
       getFunction().emitError() << "module is not in a compatible input format";
       return signalPassFailure();
     }
@@ -77,7 +79,9 @@
     : public FunctionPass<PostPartitioningConversionPass> {
  public:
   void runOnFunction() override {
+    auto *context = &getContext();
     ConversionTarget conversionTarget(getContext());
+    FlowTypeConverter typeConverter;
     OwningRewritePatternList conversionPatterns;
 
     // We have completed all flow op creation at this point.
@@ -91,11 +95,11 @@
     conversionTarget.addLegalOp<ModuleOp, ModuleTerminatorOp, FuncOp>();
 
     // Pick up any remaining HLO ops that were not partitioned.
-    populateHLOToFlowPatterns(&getContext(), conversionPatterns);
-    populateStandardToFlowPatterns(&getContext(), conversionPatterns);
+    populateHLOToFlowPatterns(context, conversionPatterns);
+    populateStandardToFlowPatterns(context, conversionPatterns);
 
     if (failed(applyFullConversion(getFunction(), conversionTarget,
-                                   conversionPatterns))) {
+                                   conversionPatterns, &typeConverter))) {
       getFunction().emitError() << "module is not in a compatible input format";
       return signalPassFailure();
     }
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/legalize_input_types.mlir b/iree/compiler/Dialect/Flow/Transforms/test/legalize_input_types.mlir
new file mode 100644
index 0000000..fd99593
--- /dev/null
+++ b/iree/compiler/Dialect/Flow/Transforms/test/legalize_input_types.mlir
@@ -0,0 +1,124 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// RUN: iree-opt -split-input-file -verify-diagnostics -iree-flow-legalize-input-types %s | IreeFileCheck %s
+
+// CHECK-LABEL: func @constantI64
+// CHECK-SAME: () -> i32
+func @constantI64() -> i64 {
+  // CHECK-NEXT: constant 123 : i32
+  %c123 = constant 123 : i64
+  return %c123 : i64
+}
+
+// -----
+
+// CHECK-LABEL: func @constantF64
+// CHECK-SAME: () -> f32
+func @constantF64() -> f64 {
+  // CHECK-NEXT: constant 1.234000e+02 : f32
+  %c1234 = constant 123.4 : f64
+  return %c1234 : f64
+}
+
+// -----
+
+// CHECK-LABEL: func @constantSplatTensorI64
+// CHECK-SAME: () -> tensor<4xi32>
+func @constantSplatTensorI64() -> tensor<4xi64> {
+  // CHECK-NEXT: constant dense<123> : tensor<4xi32>
+  %c123 = constant dense<123> : tensor<4xi64>
+  return %c123 : tensor<4xi64>
+}
+
+// -----
+
+// CHECK-LABEL: func @constantDenseTensorI64
+// CHECK-SAME: () -> tensor<4xi32>
+func @constantDenseTensorI64() -> tensor<4xi64> {
+  // CHECK-NEXT: constant dense<[0, 1, 2, 3]> : tensor<4xi32>
+  %c123 = constant dense<[0, 1, 2, 3]> : tensor<4xi64>
+  return %c123 : tensor<4xi64>
+}
+
+// -----
+
+// CHECK-LABEL: func @typesIndex
+// CHECK-SAME: (%arg0: i32) -> i32
+func @typesIndex(%arg0 : index) -> index {
+  // CHECK-NEXT: return %arg0 : i32
+  return %arg0 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @typesI64
+// CHECK-SAME: (%arg0: i32) -> i32
+func @typesI64(%arg0 : i64) -> i64 {
+  // CHECK-NEXT: return %arg0 : i32
+  return %arg0 : i64
+}
+
+// -----
+
+// CHECK-LABEL: func @tensorTypesI64
+// CHECK-SAME: (%arg0: tensor<4x4xi32>) -> tensor<4x4xi32>
+func @tensorTypesI64(%arg0 : tensor<4x4xi64>) -> tensor<4x4xi64> {
+  // CHECK-NEXT: return %arg0 : tensor<4x4xi32>
+  return %arg0 : tensor<4x4xi64>
+}
+
+// -----
+
+// CHECK-LABEL: func @tensorTypesF64
+// CHECK-SAME: (%arg0: tensor<4x4xf32>) -> tensor<4x4xf32>
+func @tensorTypesF64(%arg0 : tensor<4x4xf64>) -> tensor<4x4xf64> {
+  // CHECK-NEXT: return %arg0 : tensor<4x4xf32>
+  return %arg0 : tensor<4x4xf64>
+}
+
+// -----
+// expected-error@+1 {{'func' op unable to legalize type of input 0}}
+func @tensorUnrankedArg(%arg0 : tensor<*xi64>) -> tensor<*xi64> {
+  return %arg0 : tensor<*xi64>
+}
+
+// -----
+func @tensorUnrankedValue(%arg0 : tensor<4xi64>) -> tensor<4xi64> {
+  // expected-error@+1 {{'std.tensor_cast' op unable to legalize operation types}}
+  %0 = tensor_cast %arg0 : tensor<4xi64> to tensor<*xi64>
+  %1 = tensor_cast %0 : tensor<*xi64> to tensor<4xi64>
+  return %1 : tensor<4xi64>
+}
+
+// -----
+
+// CHECK-LABEL: func @compareI64
+// CHECK-SAME: (%arg0: tensor<i32>, %arg1: tensor<i32>) -> (i1, tensor<i32>)
+func @compareI64(%arg0 : tensor<i64>, %arg1 : tensor<i64>) -> (i1, tensor<i64>) {
+  // CHECK-NEXT: %0 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "LT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
+  // CHECK-NEXT: %1 = extract_element %0[] : tensor<i1>
+  // CHECK-NEXT: cond_br %1, ^bb1(%1, %arg0 : i1, tensor<i32>), ^bb2(%1, %arg1 : i1, tensor<i32>)
+  // CHECK-NEXT: ^bb1(%2: i1, %3: tensor<i32>): // pred: ^bb0
+  // CHECK-NEXT: return %2, %3 : i1, tensor<i32>
+  // CHECK-NEXT: ^bb2(%4: i1, %5: tensor<i32>): // pred: ^bb0
+  // CHECK-NEXT: return %4, %5 : i1, tensor<i32>
+  %0 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
+  %1 = extract_element %0[] : tensor<i1>
+  cond_br %1, ^bb1(%1, %arg0 : i1, tensor<i64>), ^bb2(%1, %arg1 : i1, tensor<i64>)
+^bb1(%2 : i1, %3 : tensor<i64>):
+  return %2, %3 : i1, tensor<i64>
+^bb2(%4 : i1, %5 : tensor<i64>):
+  return %4, %5 : i1, tensor<i64>
+}
diff --git a/iree/compiler/Transforms/Interpreter/LowerXLAToInterpreterDialect.cpp b/iree/compiler/Transforms/Interpreter/LowerXLAToInterpreterDialect.cpp
index 6d390cf..c207df1 100644
--- a/iree/compiler/Transforms/Interpreter/LowerXLAToInterpreterDialect.cpp
+++ b/iree/compiler/Transforms/Interpreter/LowerXLAToInterpreterDialect.cpp
@@ -400,12 +400,18 @@
                            ->getResult(0);
       }
 
-      llvm::SmallVector<int64_t, 4> zeroes;
-      zeroes.resize(extraDims, 0);
-
-      auto elementsAttr = DenseIntElementsAttr::get(
-          RankedTensorType::get(zeroes.size(), elementType),
-          llvm::makeArrayRef(zeroes));
+      ElementsAttr elementsAttr;
+      if (elementType.isInteger(32)) {
+        llvm::SmallVector<int32_t, 4> zeroes(extraDims);
+        elementsAttr = DenseIntElementsAttr::get(
+            RankedTensorType::get(zeroes.size(), elementType),
+            llvm::makeArrayRef(zeroes));
+      } else {
+        llvm::SmallVector<int64_t, 4> zeroes(extraDims);
+        elementsAttr = DenseIntElementsAttr::get(
+            RankedTensorType::get(zeroes.size(), elementType),
+            llvm::makeArrayRef(zeroes));
+      }
 
       auto extraStartIndices =
           rewriter.create<IREE::ConstantOp>(gatherOp.getLoc(), elementsAttr);