Preserve unsigned integer inputs/outputs (#13090)

Thanks so much to @rsuderman for the guidance and help !

This PR solves a bug where the signedness of inputs and outputs is lost,
because lowering to linalg dialect throws the signedness away. To fix
it, I'm using a new op `tensor.bitcast`, which casts the inputs/outputs
of the function from an unsigned type to the signless type in the body.

I am completely new to IREE, so this PR is still a little rough. Any
critique and suggestions for improvement is highly welcome. I'm
particularly interested in ways to make this more robust.

Fixes #9282
Fixes #12665
diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/ConvertMHLOToLinalgExt.cpp b/compiler/src/iree/compiler/InputConversion/MHLO/ConvertMHLOToLinalgExt.cpp
index 974e04e..0c901da 100644
--- a/compiler/src/iree/compiler/InputConversion/MHLO/ConvertMHLOToLinalgExt.cpp
+++ b/compiler/src/iree/compiler/InputConversion/MHLO/ConvertMHLOToLinalgExt.cpp
@@ -70,9 +70,8 @@
       castType = shapedType.clone(castType);
 
     if (castType != fromType)
-      fromValue =
-          builder.create<UnrealizedConversionCastOp>(loc, castType, fromValue)
-              ->getResult(0);
+      fromValue = builder.create<tensor::BitcastOp>(loc, castType, fromValue)
+                      ->getResult(0);
   }
 
   if (fromType.getRank() != 0) return fromValue;
@@ -595,9 +594,6 @@
         [](mhlo::ComplexOp complexOp) {
           return !isInBodyOfLinalgExtOps(complexOp);
         });
-    // We deliberately allow unrealized casts to persist. These should fall away
-    // when the rest of MHLO is converted.
-    target.addLegalOp<UnrealizedConversionCastOp>();
 
     if (failed(applyPartialConversion(getOperation(), target,
                                       std::move(patterns)))) {
diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/MHLOToLinalgOnTensors.cpp b/compiler/src/iree/compiler/InputConversion/MHLO/MHLOToLinalgOnTensors.cpp
index 3197903..dce9642 100644
--- a/compiler/src/iree/compiler/InputConversion/MHLO/MHLOToLinalgOnTensors.cpp
+++ b/compiler/src/iree/compiler/InputConversion/MHLO/MHLOToLinalgOnTensors.cpp
@@ -357,6 +357,29 @@
       .getResult();
 }
 
+std::optional<Value> materializeCastFromIllegal(OpBuilder &builder, Type type,
+                                                ValueRange inputs,
+                                                Location loc) {
+  Type fromType = getElementTypeOrSelf(inputs[0].getType());
+  Type toType = getElementTypeOrSelf(type);
+  if ((!fromType.isSignedInteger() && !fromType.isUnsignedInteger()) ||
+      !toType.isSignlessInteger())
+    return std::nullopt;
+  // Use bitcast to do signless->signful conversions.
+  return builder.create<tensor::BitcastOp>(loc, type, inputs[0])->getResult(0);
+}
+
+std::optional<Value> materializeCastToIllegal(OpBuilder &builder, Type type,
+                                              ValueRange inputs, Location loc) {
+  Type fromType = getElementTypeOrSelf(inputs[0].getType());
+  Type toType = getElementTypeOrSelf(type);
+  if (!fromType.isSignlessInteger() ||
+      (!toType.isSignedInteger() && !toType.isUnsignedInteger()))
+    return std::nullopt;
+  // Use bitcast to do signless->signful conversions.
+  return builder.create<tensor::BitcastOp>(loc, type, inputs[0])->getResult(0);
+}
+
 struct ConvertMHLOToLinalgOnTensorsPass
     : public ConvertMHLOToLinalgOnTensorsBase<
           ConvertMHLOToLinalgOnTensorsPass> {
@@ -373,6 +396,9 @@
 
     auto typeConverter = mhlo::createHloToLinalgTypeConverter();
     typeConverter->addArgumentMaterialization(scalarToTensor);
+    typeConverter->addArgumentMaterialization(materializeCastFromIllegal);
+    typeConverter->addTargetMaterialization(materializeCastFromIllegal);
+    typeConverter->addSourceMaterialization(materializeCastToIllegal);
     // NOTE: not using corresponding setupMHLOToFlowPatterns because the entire
     // MHLO dialects are marked illegal by this pass.
     // TODO: Collapse/rework all of these patterns once the consolidation
@@ -419,6 +445,11 @@
     ConversionTarget target(getContext());
 
     auto isIllegalType = [&](Type t) { return !typeConverter->isLegal(t); };
+    auto isIllegalFuncType = [&](Type t) {
+      // Allows unsigned integers for function inputs and outputs.
+      return !typeConverter->isLegal(t) &&
+             !getElementTypeOrSelf(t).isUnsignedInteger();
+    };
     auto isLegallyTypedOp = [&](Operation *op) -> bool {
       for (Type type : op->getResultTypes()) {
         if (isIllegalType(type)) return false;
@@ -435,18 +466,34 @@
     // Functions must have legal types.
     target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp funcOp) {
       for (Type type : funcOp.getFunctionType().getInputs()) {
-        if (isIllegalType(type)) return false;
+        if (isIllegalFuncType(type)) return false;
       }
       for (Type type : funcOp.getFunctionType().getResults()) {
-        if (isIllegalType(type)) return false;
+        if (isIllegalFuncType(type)) return false;
       }
       for (Block &block : funcOp.getFunctionBody()) {
         for (Type type : block.getArgumentTypes()) {
-          if (isIllegalType(type)) return false;
+          if (isIllegalFuncType(type)) return false;
         }
       }
       return true;
     });
+    target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
+      return llvm::all_of(op.getOperandTypes(),
+                          [&](Type type) { return !isIllegalFuncType(type); });
+    });
+    target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
+      return llvm::all_of(op.getOperandTypes(),
+                          [&](Type type) { return !isIllegalFuncType(type); });
+    });
+    target.addDynamicallyLegalOp<cf::CondBranchOp>([&](cf::CondBranchOp op) {
+      return llvm::all_of(op.getOperandTypes(),
+                          [&](Type type) { return !isIllegalFuncType(type); });
+    });
+    target.addDynamicallyLegalOp<cf::BranchOp>([&](cf::BranchOp op) {
+      return llvm::all_of(op.getOperandTypes(),
+                          [&](Type type) { return !isIllegalFuncType(type); });
+    });
     target.addDynamicallyLegalOp<ml_program::GlobalOp>(
         [&](ml_program::GlobalOp op) {
           return typeConverter->isLegal(op.getType());
@@ -455,6 +502,7 @@
     // Let the rest fall through.
     target.addLegalDialect<BuiltinDialect>();
     target.addLegalDialect<IREE::LinalgExt::IREELinalgExtDialect>();
+    target.addLegalOp<tensor::BitcastOp>();
     target.markUnknownOpDynamicallyLegal(isLegallyTypedOp);
 
     if (failed(applyPartialConversion(getOperation(), target,
diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/test/broadcasting.mlir b/compiler/src/iree/compiler/InputConversion/MHLO/test/broadcasting.mlir
index 39da165..e9fac89 100644
--- a/compiler/src/iree/compiler/InputConversion/MHLO/test/broadcasting.mlir
+++ b/compiler/src/iree/compiler/InputConversion/MHLO/test/broadcasting.mlir
@@ -409,14 +409,16 @@
   // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
   // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
   // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+  // CHECK-DAG: %[[CAST_ARG:.*]] = tensor.bitcast %arg0 : tensor<4x?x3x?xui32> to tensor<4x?x3x?xi32>
   // CHECK-DAG: %[[RESULT_D1:.*]] = tensor.extract %arg1[%[[C1]]] : tensor<5xindex>
   // CHECK-DAG: %[[RESULT_D2:.*]] = tensor.extract %arg1[%[[C2]]] : tensor<5xindex>
   // CHECK-DAG: %[[RESULT_D4:.*]] = tensor.extract %arg1[%[[C4]]] : tensor<5xindex>
-  // CHECK-DAG: %[[ARG_D1:.*]] = tensor.dim %arg0, %[[C1]] : tensor<4x?x3x?xi32>
-  // CHECK-DAG: %[[ARG_D3:.*]] = tensor.dim %arg0, %[[C3]] : tensor<4x?x3x?xi32>
-  // CHECK-DAG: %[[RESULT:.*]] = flow.tensor.reshape %arg0 : tensor<4x?x3x?xi32>{%[[ARG_D1]], %[[ARG_D3]]} -> tensor<12x?x?x1x?xi32>{%[[RESULT_D1]], %[[RESULT_D2]], %[[RESULT_D4]]}
+  // CHECK-DAG: %[[ARG_D1:.*]] = tensor.dim %[[CAST_ARG]], %[[C1]] : tensor<4x?x3x?xi32>
+  // CHECK-DAG: %[[ARG_D3:.*]] = tensor.dim %[[CAST_ARG]], %[[C3]] : tensor<4x?x3x?xi32>
+  // CHECK-DAG: %[[RESULT:.*]] = flow.tensor.reshape %[[CAST_ARG]] : tensor<4x?x3x?xi32>{%[[ARG_D1]], %[[ARG_D3]]} -> tensor<12x?x?x1x?xi32>{%[[RESULT_D1]], %[[RESULT_D2]], %[[RESULT_D4]]}
   %0 = "mhlo.dynamic_reshape"(%arg0, %arg1) : (tensor<4x?x3x?xui32>, tensor<5xindex>) -> tensor<12x?x?x1x?xui32>
-  // CHECK: return %[[RESULT]]
+  // CHECK-DAG: %[[CAST_RESULT:.*]] = tensor.bitcast %[[RESULT]] : tensor<12x?x?x1x?xi32> to tensor<12x?x?x1x?xui32>
+  // CHECK: return %[[CAST_RESULT]]
   return %0 : tensor<12x?x?x1x?xui32>
 }
 
@@ -427,16 +429,18 @@
   // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
   // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
   // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+  // CHECK-DAG: %[[CAST_ARG:.*]] = tensor.bitcast %arg0 : tensor<4x?x3x?xui32> to tensor<4x?x3x?xi32>
   // CHECK-DAG: %[[D1:.*]] = tensor.extract %arg1[%[[C1]]] : tensor<5xi32>
   // CHECK-DAG: %[[D2:.*]] = tensor.extract %arg1[%[[C2]]] : tensor<5xi32>
   // CHECK-DAG: %[[D4:.*]] = tensor.extract %arg1[%[[C4]]] : tensor<5xi32>
   // CHECK-DAG: %[[RESULT_D1:.*]] = arith.index_cast %[[D1]] : i32 to index
   // CHECK-DAG: %[[RESULT_D2:.*]] = arith.index_cast %[[D2]] : i32 to index
   // CHECK-DAG: %[[RESULT_D4:.*]] = arith.index_cast %[[D4]] : i32 to index
-  // CHECK-DAG: %[[ARG_D1:.*]] = tensor.dim %arg0, %[[C1]] : tensor<4x?x3x?xi32>
-  // CHECK-DAG: %[[ARG_D3:.*]] = tensor.dim %arg0, %[[C3]] : tensor<4x?x3x?xi32>
-  // CHECK-DAG: %[[RESULT:.*]] = flow.tensor.reshape %arg0 : tensor<4x?x3x?xi32>{%[[ARG_D1]], %[[ARG_D3]]} -> tensor<12x?x?x1x?xi32>{%[[RESULT_D1]], %[[RESULT_D2]], %[[RESULT_D4]]}
+  // CHECK-DAG: %[[ARG_D1:.*]] = tensor.dim %[[CAST_ARG]], %[[C1]] : tensor<4x?x3x?xi32>
+  // CHECK-DAG: %[[ARG_D3:.*]] = tensor.dim %[[CAST_ARG]], %[[C3]] : tensor<4x?x3x?xi32>
+  // CHECK-DAG: %[[RESULT:.*]] = flow.tensor.reshape %[[CAST_ARG]] : tensor<4x?x3x?xi32>{%[[ARG_D1]], %[[ARG_D3]]} -> tensor<12x?x?x1x?xi32>{%[[RESULT_D1]], %[[RESULT_D2]], %[[RESULT_D4]]}
   %0 = "mhlo.dynamic_reshape"(%arg0, %arg1) : (tensor<4x?x3x?xui32>, tensor<5xi32>) -> tensor<12x?x?x1x?xui32>
-  // CHECK: return %[[RESULT]]
+  // CHECK-DAG: %[[CAST_RESULT:.*]] = tensor.bitcast %[[RESULT]] : tensor<12x?x?x1x?xi32> to tensor<12x?x?x1x?xui32>
+  // CHECK: return %[[CAST_RESULT]]
   return %0 : tensor<12x?x?x1x?xui32>
 }
diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/test/convert_collective_ops.mlir b/compiler/src/iree/compiler/InputConversion/MHLO/test/convert_collective_ops.mlir
index f08df5d..b7a0936 100644
--- a/compiler/src/iree/compiler/InputConversion/MHLO/test/convert_collective_ops.mlir
+++ b/compiler/src/iree/compiler/InputConversion/MHLO/test/convert_collective_ops.mlir
@@ -6,7 +6,8 @@
   // CHECK-DAG: [[RANK:%.+]] = flow.channel.rank [[CHANNEL]] : index
   // CHECK-DAG: [[CAST:%.+]] = arith.index_castui [[RANK]] : index to i32
   // CHECK-DAG: [[TENSOR:%.+]] = tensor.from_elements [[CAST]] : tensor<i32>
-  // CHECK-DAG: return [[TENSOR]] : tensor<i32>
+  // CHECK-DAG: [[BITCAST:%.+]] = tensor.bitcast [[TENSOR]] : tensor<i32> to tensor<ui32> 
+  // CHECK-DAG: return [[BITCAST]] : tensor<ui32>
   %id = mhlo.replica_id : tensor<ui32>
   return %id : tensor<ui32>
 }
diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/test/convert_mhlo_to_linalg_ext.mlir b/compiler/src/iree/compiler/InputConversion/MHLO/test/convert_mhlo_to_linalg_ext.mlir
index 1504703..e05b990 100644
--- a/compiler/src/iree/compiler/InputConversion/MHLO/test/convert_mhlo_to_linalg_ext.mlir
+++ b/compiler/src/iree/compiler/InputConversion/MHLO/test/convert_mhlo_to_linalg_ext.mlir
@@ -34,14 +34,14 @@
 // CHECK-LABEL: func.func @sort_1d_ui(
 // CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
 // CHECK-SAME:  )
-// CHECK:         %[[CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : tensor<128xui32> to tensor<128xi32>
+// CHECK:         %[[CAST:.+]] = tensor.bitcast %[[ARG0]] : tensor<128xui32> to tensor<128xi32>
 // CHECK:         %[[SORT:.+]] = iree_linalg_ext.sort
 // CHECK-SAME:      dimension(0)
 // CHECK-SAME:      outs(%[[CAST]] : tensor<128xi32>)
 // CHECK:           ^bb0(%[[ARG1:.+]]: i32, %[[ARG2:.+]]: i32)
 // CHECK:             %[[CMP:.+]] = arith.cmpi ugt, %[[ARG1]], %[[ARG2]]
 // CHECK:             iree_linalg_ext.yield %[[CMP]]
-// CHECK:         %[[RESULT:.+]] = builtin.unrealized_conversion_cast %[[SORT]] : tensor<128xi32> to tensor<128xui32>
+// CHECK:         %[[RESULT:.+]] = tensor.bitcast %[[SORT]] : tensor<128xi32> to tensor<128xui32>
 // CHECK:         return %[[RESULT]]
 
 // -----
@@ -154,7 +154,7 @@
 // CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
 // CHECK-SAME:  )
 // CHECK:         %[[UI32:.+]] = mhlo.constant dense<2> : tensor<ui32>
-// CHECK:         %[[CONVERSION_CAST_CST:.+]] = builtin.unrealized_conversion_cast %[[UI32]] : tensor<ui32> to tensor<i32>
+// CHECK:         %[[CONVERSION_CAST_CST:.+]] = tensor.bitcast %[[UI32]] : tensor<ui32> to tensor<i32>
 // CHECK:         %[[EXTRACT_CST:.+]] = tensor.extract %[[CONVERSION_CAST_CST]][] : tensor<i32>
 // CHECK:         %[[SORT:.+]] = iree_linalg_ext.sort
 // CHECK-SAME:      dimension(1)
diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/test/convert_structural_types.mlir b/compiler/src/iree/compiler/InputConversion/MHLO/test/convert_structural_types.mlir
index 964b5a1..73455e7 100644
--- a/compiler/src/iree/compiler/InputConversion/MHLO/test/convert_structural_types.mlir
+++ b/compiler/src/iree/compiler/InputConversion/MHLO/test/convert_structural_types.mlir
@@ -2,24 +2,33 @@
 
 // CHECK-LABEL: @func_cfg_conversion
 module @func_cfg_conversion {
-  // CHECK: func.func @caller(%arg0: tensor<2xi32>, %arg1: i1) -> tensor<2xi32>
+  // CHECK: func.func @caller(%arg0: tensor<2xui32>, %arg1: i1) -> tensor<2xui32>
   func.func @caller(%arg0: tensor<2xui32>, %arg1 : i1) -> tensor<2xui32> {
-    // CHECK: %[[RESULT:.*]] = call @callee(%arg0, %arg1) : (tensor<2xi32>, i1) -> tensor<2xi32>
+    // CHECK: %[[RESULT:.*]] = call @callee(%arg0, %arg1) : (tensor<2xui32>, i1) -> tensor<2xui32>
     %1 = call @callee(%arg0, %arg1) : (tensor<2xui32>, i1) -> tensor<2xui32>
-    // CHECK: return %[[RESULT]] : tensor<2xi32>
+    // CHECK: return %[[RESULT]] : tensor<2xui32>
     return %1 : tensor<2xui32>
   }
 
-  // CHECK: func.func @callee(%arg0: tensor<2xi32>, %arg1: i1) -> tensor<2xi32>
+  // CHECK: func.func @callee(%arg0: tensor<2xui32>, %arg1: i1) -> tensor<2xui32>
   func.func @callee(%arg0: tensor<2xui32>, %arg1: i1) -> tensor<2xui32> {
-    // CHECK: cf.cond_br %arg1, ^bb1(%arg0 : tensor<2xi32>), ^bb2(%arg0 : tensor<2xi32>)
+    // CHECK: cf.cond_br %arg1, ^bb1(%arg0 : tensor<2xui32>), ^bb2(%arg0 : tensor<2xui32>)
     cf.cond_br %arg1, ^bb1(%arg0 : tensor<2xui32>), ^bb2(%arg0 : tensor<2xui32>)
-  // CHECK: ^bb1(%[[BB1_PHI:.*]]: tensor<2xi32>)
+  // CHECK: ^bb1(%[[BB1_PHI:.*]]: tensor<2xui32>)
   ^bb1(%phi0 : tensor<2xui32>) :
-    // CHECK: cf.br ^bb2(%[[BB1_PHI]] : tensor<2xi32>)
-    cf.br ^bb2(%phi0 : tensor<2xui32>)
-  // CHECK: ^bb2(%[[BB2_PHI:.*]]: tensor<2xi32>)
+    // CHECK: tensor.bitcast %[[BB1_PHI]] : tensor<2xui32> to tensor<2xi32> 
+    // CHECK: %[[BB1_PHI_ADD:.*]] = linalg.generic
+    // CHECK: %[[BB1_PHI_ADD_CAST:.*]] = tensor.bitcast %[[BB1_PHI_ADD]] : tensor<2xi32> to tensor<2xui32> 
+    // CHECK: cf.br ^bb2(%[[BB1_PHI_ADD_CAST]] : tensor<2xui32>)
+    %0 = "mhlo.add"(%phi0, %phi0) : (tensor<2xui32>, tensor<2xui32>) -> tensor<2xui32>
+    cf.br ^bb2(%0 : tensor<2xui32>)
+  // CHECK: ^bb2(%[[BB2_PHI:.*]]: tensor<2xui32>)
   ^bb2(%phi1 : tensor<2xui32>):
-    return %phi1 : tensor<2xui32>
+    // CHECK: tensor.bitcast %[[BB2_PHI]] : tensor<2xui32> to tensor<2xi32> 
+    // CHECK: %[[BB2_PHI_ADD:.*]] = linalg.generic
+    // CHECK: %[[BB2_PHI_ADD_CAST:.*]] = tensor.bitcast %[[BB2_PHI_ADD]] : tensor<2xi32> to tensor<2xui32> 
+    // CHECK: return %[[BB2_PHI_ADD_CAST]] : tensor<2xui32>
+    %1 = "mhlo.add"(%phi1, %phi1) : (tensor<2xui32>, tensor<2xui32>) -> tensor<2xui32>
+    return %1 : tensor<2xui32>
   }
 }
diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/test/mhlo_to_linalg.mlir b/compiler/src/iree/compiler/InputConversion/MHLO/test/mhlo_to_linalg.mlir
index a135925..b2001e1 100644
--- a/compiler/src/iree/compiler/InputConversion/MHLO/test/mhlo_to_linalg.mlir
+++ b/compiler/src/iree/compiler/InputConversion/MHLO/test/mhlo_to_linalg.mlir
@@ -19,11 +19,12 @@
 
 // CHECK: ml_program.global private mutable @variable(dense<0> : tensor<2xi32>) : tensor<2xi32>
 ml_program.global private mutable @variable(dense<0> : tensor<2xui32>) : tensor<2xui32>
-// CHECK: func.func @global_types() -> tensor<2xi32>
+// CHECK: func.func @global_types() -> tensor<2xui32>
 func.func @global_types() -> tensor<2xui32> {
   // CHECK-NEXT: %[[VALUE:.+]] = ml_program.global_load @variable : tensor<2xi32>
   %0 = ml_program.global_load @variable : tensor<2xui32>
-  // CHECK: return %[[VALUE]] : tensor<2xi32>
+  // CHECK-NEXT: %[[CAST:.*]] = tensor.bitcast %[[VALUE]] : tensor<2xi32> to tensor<2xui32>
+  // CHECK: return %[[CAST]] : tensor<2xui32>
   return %0 : tensor<2xui32>
 }
 
@@ -37,3 +38,22 @@
   %0, %1 = "mhlo.optimization_barrier"(%arg0, %arg1) : (tensor<3x4xf32>, tensor<4xi32>) -> (tensor<3x4xf32>, tensor<4xi32>)
   return %0, %1 : tensor<3x4xf32>, tensor<4xi32>
 }
+
+// -----
+
+// CHECK: @unsigned_integer_input_output(%[[ARG0:.*]]: tensor<2x2xui32>, %[[ARG1:.*]]: tensor<2x2xui32>) -> tensor<2x2xui32>
+func.func @unsigned_integer_input_output(%arg0: tensor<2x2xui32>, %arg1: tensor<2x2xui32>) -> tensor<2x2xui32> {
+  // CHECK: %[[CAST0:.*]] = tensor.bitcast %[[ARG0]] : tensor<2x2xui32> to tensor<2x2xi32>
+  // CHECK: %[[CAST1:.*]] = tensor.bitcast %[[ARG1]] : tensor<2x2xui32> to tensor<2x2xi32>
+  // CHECK: %[[INIT:.*]] = tensor.empty() : tensor<2x2xi32> 
+  // CHECK: %[[LINALG:.*]] = linalg.generic
+  //  CHECK-SAME:       ins(%[[CAST0]], %[[CAST1]] : tensor<2x2xi32>, tensor<2x2xi32>
+  //  CHECK-SAME:       outs(%[[INIT]] : tensor<2x2xi32>)
+  // CHECK: ^bb0(%[[IN0:.*]]: i32, %[[IN1:.*]]: i32, %out: i32): 
+  // CHECK: %[[ADD:.*]] = arith.addi %[[IN0]], %[[IN1]] : i32
+  // CHECK: linalg.yield %[[ADD:.*]] : i32 
+  %0 = "mhlo.add"(%arg0, %arg1) : (tensor<2x2xui32>, tensor<2x2xui32>) -> tensor<2x2xui32>
+  // CHECK: %[[CAST_RESULT:.*]] = tensor.bitcast %[[LINALG]] : tensor<2x2xi32> to tensor<2x2xui32>
+  // CHECK: return %[[CAST_RESULT]]
+  return %0 : tensor<2x2xui32>
+}