Preserve unsigned integer inputs/outputs (#13090)
Thanks so much to @rsuderman for the guidance and help !
This PR solves a bug where the signedness of inputs and outputs is lost,
because lowering to linalg dialect throws the signedness away. To fix
it, I'm using a new op `tensor.bitcast`, which casts the inputs/outputs
of the function from an unsigned type to the signless type in the body.
I am completely new to IREE, so this PR is still a little rough. Any
critique and suggestions for improvement is highly welcome. I'm
particularly interested in ways to make this more robust.
Fixes #9282
Fixes #12665
diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/ConvertMHLOToLinalgExt.cpp b/compiler/src/iree/compiler/InputConversion/MHLO/ConvertMHLOToLinalgExt.cpp
index 974e04e..0c901da 100644
--- a/compiler/src/iree/compiler/InputConversion/MHLO/ConvertMHLOToLinalgExt.cpp
+++ b/compiler/src/iree/compiler/InputConversion/MHLO/ConvertMHLOToLinalgExt.cpp
@@ -70,9 +70,8 @@
castType = shapedType.clone(castType);
if (castType != fromType)
- fromValue =
- builder.create<UnrealizedConversionCastOp>(loc, castType, fromValue)
- ->getResult(0);
+ fromValue = builder.create<tensor::BitcastOp>(loc, castType, fromValue)
+ ->getResult(0);
}
if (fromType.getRank() != 0) return fromValue;
@@ -595,9 +594,6 @@
[](mhlo::ComplexOp complexOp) {
return !isInBodyOfLinalgExtOps(complexOp);
});
- // We deliberately allow unrealized casts to persist. These should fall away
- // when the rest of MHLO is converted.
- target.addLegalOp<UnrealizedConversionCastOp>();
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {
diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/MHLOToLinalgOnTensors.cpp b/compiler/src/iree/compiler/InputConversion/MHLO/MHLOToLinalgOnTensors.cpp
index 3197903..dce9642 100644
--- a/compiler/src/iree/compiler/InputConversion/MHLO/MHLOToLinalgOnTensors.cpp
+++ b/compiler/src/iree/compiler/InputConversion/MHLO/MHLOToLinalgOnTensors.cpp
@@ -357,6 +357,29 @@
.getResult();
}
+std::optional<Value> materializeCastFromIllegal(OpBuilder &builder, Type type,
+ ValueRange inputs,
+ Location loc) {
+ Type fromType = getElementTypeOrSelf(inputs[0].getType());
+ Type toType = getElementTypeOrSelf(type);
+ if ((!fromType.isSignedInteger() && !fromType.isUnsignedInteger()) ||
+ !toType.isSignlessInteger())
+ return std::nullopt;
+ // Use bitcast to do signless->signful conversions.
+ return builder.create<tensor::BitcastOp>(loc, type, inputs[0])->getResult(0);
+}
+
+std::optional<Value> materializeCastToIllegal(OpBuilder &builder, Type type,
+ ValueRange inputs, Location loc) {
+ Type fromType = getElementTypeOrSelf(inputs[0].getType());
+ Type toType = getElementTypeOrSelf(type);
+ if (!fromType.isSignlessInteger() ||
+ (!toType.isSignedInteger() && !toType.isUnsignedInteger()))
+ return std::nullopt;
+ // Use bitcast to do signless->signful conversions.
+ return builder.create<tensor::BitcastOp>(loc, type, inputs[0])->getResult(0);
+}
+
struct ConvertMHLOToLinalgOnTensorsPass
: public ConvertMHLOToLinalgOnTensorsBase<
ConvertMHLOToLinalgOnTensorsPass> {
@@ -373,6 +396,9 @@
auto typeConverter = mhlo::createHloToLinalgTypeConverter();
typeConverter->addArgumentMaterialization(scalarToTensor);
+ typeConverter->addArgumentMaterialization(materializeCastFromIllegal);
+ typeConverter->addTargetMaterialization(materializeCastFromIllegal);
+ typeConverter->addSourceMaterialization(materializeCastToIllegal);
// NOTE: not using corresponding setupMHLOToFlowPatterns because the entire
// MHLO dialects are marked illegal by this pass.
// TODO: Collapse/rework all of these patterns once the consolidation
@@ -419,6 +445,11 @@
ConversionTarget target(getContext());
auto isIllegalType = [&](Type t) { return !typeConverter->isLegal(t); };
+ auto isIllegalFuncType = [&](Type t) {
+ // Allows unsigned integers for function inputs and outputs.
+ return !typeConverter->isLegal(t) &&
+ !getElementTypeOrSelf(t).isUnsignedInteger();
+ };
auto isLegallyTypedOp = [&](Operation *op) -> bool {
for (Type type : op->getResultTypes()) {
if (isIllegalType(type)) return false;
@@ -435,18 +466,34 @@
// Functions must have legal types.
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp funcOp) {
for (Type type : funcOp.getFunctionType().getInputs()) {
- if (isIllegalType(type)) return false;
+ if (isIllegalFuncType(type)) return false;
}
for (Type type : funcOp.getFunctionType().getResults()) {
- if (isIllegalType(type)) return false;
+ if (isIllegalFuncType(type)) return false;
}
for (Block &block : funcOp.getFunctionBody()) {
for (Type type : block.getArgumentTypes()) {
- if (isIllegalType(type)) return false;
+ if (isIllegalFuncType(type)) return false;
}
}
return true;
});
+ target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
+ return llvm::all_of(op.getOperandTypes(),
+ [&](Type type) { return !isIllegalFuncType(type); });
+ });
+ target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
+ return llvm::all_of(op.getOperandTypes(),
+ [&](Type type) { return !isIllegalFuncType(type); });
+ });
+ target.addDynamicallyLegalOp<cf::CondBranchOp>([&](cf::CondBranchOp op) {
+ return llvm::all_of(op.getOperandTypes(),
+ [&](Type type) { return !isIllegalFuncType(type); });
+ });
+ target.addDynamicallyLegalOp<cf::BranchOp>([&](cf::BranchOp op) {
+ return llvm::all_of(op.getOperandTypes(),
+ [&](Type type) { return !isIllegalFuncType(type); });
+ });
target.addDynamicallyLegalOp<ml_program::GlobalOp>(
[&](ml_program::GlobalOp op) {
return typeConverter->isLegal(op.getType());
@@ -455,6 +502,7 @@
// Let the rest fall through.
target.addLegalDialect<BuiltinDialect>();
target.addLegalDialect<IREE::LinalgExt::IREELinalgExtDialect>();
+ target.addLegalOp<tensor::BitcastOp>();
target.markUnknownOpDynamicallyLegal(isLegallyTypedOp);
if (failed(applyPartialConversion(getOperation(), target,
diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/test/broadcasting.mlir b/compiler/src/iree/compiler/InputConversion/MHLO/test/broadcasting.mlir
index 39da165..e9fac89 100644
--- a/compiler/src/iree/compiler/InputConversion/MHLO/test/broadcasting.mlir
+++ b/compiler/src/iree/compiler/InputConversion/MHLO/test/broadcasting.mlir
@@ -409,14 +409,16 @@
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+ // CHECK-DAG: %[[CAST_ARG:.*]] = tensor.bitcast %arg0 : tensor<4x?x3x?xui32> to tensor<4x?x3x?xi32>
// CHECK-DAG: %[[RESULT_D1:.*]] = tensor.extract %arg1[%[[C1]]] : tensor<5xindex>
// CHECK-DAG: %[[RESULT_D2:.*]] = tensor.extract %arg1[%[[C2]]] : tensor<5xindex>
// CHECK-DAG: %[[RESULT_D4:.*]] = tensor.extract %arg1[%[[C4]]] : tensor<5xindex>
- // CHECK-DAG: %[[ARG_D1:.*]] = tensor.dim %arg0, %[[C1]] : tensor<4x?x3x?xi32>
- // CHECK-DAG: %[[ARG_D3:.*]] = tensor.dim %arg0, %[[C3]] : tensor<4x?x3x?xi32>
- // CHECK-DAG: %[[RESULT:.*]] = flow.tensor.reshape %arg0 : tensor<4x?x3x?xi32>{%[[ARG_D1]], %[[ARG_D3]]} -> tensor<12x?x?x1x?xi32>{%[[RESULT_D1]], %[[RESULT_D2]], %[[RESULT_D4]]}
+ // CHECK-DAG: %[[ARG_D1:.*]] = tensor.dim %[[CAST_ARG]], %[[C1]] : tensor<4x?x3x?xi32>
+ // CHECK-DAG: %[[ARG_D3:.*]] = tensor.dim %[[CAST_ARG]], %[[C3]] : tensor<4x?x3x?xi32>
+ // CHECK-DAG: %[[RESULT:.*]] = flow.tensor.reshape %[[CAST_ARG]] : tensor<4x?x3x?xi32>{%[[ARG_D1]], %[[ARG_D3]]} -> tensor<12x?x?x1x?xi32>{%[[RESULT_D1]], %[[RESULT_D2]], %[[RESULT_D4]]}
%0 = "mhlo.dynamic_reshape"(%arg0, %arg1) : (tensor<4x?x3x?xui32>, tensor<5xindex>) -> tensor<12x?x?x1x?xui32>
- // CHECK: return %[[RESULT]]
+ // CHECK-DAG: %[[CAST_RESULT:.*]] = tensor.bitcast %[[RESULT]] : tensor<12x?x?x1x?xi32> to tensor<12x?x?x1x?xui32>
+ // CHECK: return %[[CAST_RESULT]]
return %0 : tensor<12x?x?x1x?xui32>
}
@@ -427,16 +429,18 @@
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+ // CHECK-DAG: %[[CAST_ARG:.*]] = tensor.bitcast %arg0 : tensor<4x?x3x?xui32> to tensor<4x?x3x?xi32>
// CHECK-DAG: %[[D1:.*]] = tensor.extract %arg1[%[[C1]]] : tensor<5xi32>
// CHECK-DAG: %[[D2:.*]] = tensor.extract %arg1[%[[C2]]] : tensor<5xi32>
// CHECK-DAG: %[[D4:.*]] = tensor.extract %arg1[%[[C4]]] : tensor<5xi32>
// CHECK-DAG: %[[RESULT_D1:.*]] = arith.index_cast %[[D1]] : i32 to index
// CHECK-DAG: %[[RESULT_D2:.*]] = arith.index_cast %[[D2]] : i32 to index
// CHECK-DAG: %[[RESULT_D4:.*]] = arith.index_cast %[[D4]] : i32 to index
- // CHECK-DAG: %[[ARG_D1:.*]] = tensor.dim %arg0, %[[C1]] : tensor<4x?x3x?xi32>
- // CHECK-DAG: %[[ARG_D3:.*]] = tensor.dim %arg0, %[[C3]] : tensor<4x?x3x?xi32>
- // CHECK-DAG: %[[RESULT:.*]] = flow.tensor.reshape %arg0 : tensor<4x?x3x?xi32>{%[[ARG_D1]], %[[ARG_D3]]} -> tensor<12x?x?x1x?xi32>{%[[RESULT_D1]], %[[RESULT_D2]], %[[RESULT_D4]]}
+ // CHECK-DAG: %[[ARG_D1:.*]] = tensor.dim %[[CAST_ARG]], %[[C1]] : tensor<4x?x3x?xi32>
+ // CHECK-DAG: %[[ARG_D3:.*]] = tensor.dim %[[CAST_ARG]], %[[C3]] : tensor<4x?x3x?xi32>
+ // CHECK-DAG: %[[RESULT:.*]] = flow.tensor.reshape %[[CAST_ARG]] : tensor<4x?x3x?xi32>{%[[ARG_D1]], %[[ARG_D3]]} -> tensor<12x?x?x1x?xi32>{%[[RESULT_D1]], %[[RESULT_D2]], %[[RESULT_D4]]}
%0 = "mhlo.dynamic_reshape"(%arg0, %arg1) : (tensor<4x?x3x?xui32>, tensor<5xi32>) -> tensor<12x?x?x1x?xui32>
- // CHECK: return %[[RESULT]]
+ // CHECK-DAG: %[[CAST_RESULT:.*]] = tensor.bitcast %[[RESULT]] : tensor<12x?x?x1x?xi32> to tensor<12x?x?x1x?xui32>
+ // CHECK: return %[[CAST_RESULT]]
return %0 : tensor<12x?x?x1x?xui32>
}
diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/test/convert_collective_ops.mlir b/compiler/src/iree/compiler/InputConversion/MHLO/test/convert_collective_ops.mlir
index f08df5d..b7a0936 100644
--- a/compiler/src/iree/compiler/InputConversion/MHLO/test/convert_collective_ops.mlir
+++ b/compiler/src/iree/compiler/InputConversion/MHLO/test/convert_collective_ops.mlir
@@ -6,7 +6,8 @@
// CHECK-DAG: [[RANK:%.+]] = flow.channel.rank [[CHANNEL]] : index
// CHECK-DAG: [[CAST:%.+]] = arith.index_castui [[RANK]] : index to i32
// CHECK-DAG: [[TENSOR:%.+]] = tensor.from_elements [[CAST]] : tensor<i32>
- // CHECK-DAG: return [[TENSOR]] : tensor<i32>
+ // CHECK-DAG: [[BITCAST:%.+]] = tensor.bitcast [[TENSOR]] : tensor<i32> to tensor<ui32>
+ // CHECK-DAG: return [[BITCAST]] : tensor<ui32>
%id = mhlo.replica_id : tensor<ui32>
return %id : tensor<ui32>
}
diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/test/convert_mhlo_to_linalg_ext.mlir b/compiler/src/iree/compiler/InputConversion/MHLO/test/convert_mhlo_to_linalg_ext.mlir
index 1504703..e05b990 100644
--- a/compiler/src/iree/compiler/InputConversion/MHLO/test/convert_mhlo_to_linalg_ext.mlir
+++ b/compiler/src/iree/compiler/InputConversion/MHLO/test/convert_mhlo_to_linalg_ext.mlir
@@ -34,14 +34,14 @@
// CHECK-LABEL: func.func @sort_1d_ui(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-SAME: )
-// CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : tensor<128xui32> to tensor<128xi32>
+// CHECK: %[[CAST:.+]] = tensor.bitcast %[[ARG0]] : tensor<128xui32> to tensor<128xi32>
// CHECK: %[[SORT:.+]] = iree_linalg_ext.sort
// CHECK-SAME: dimension(0)
// CHECK-SAME: outs(%[[CAST]] : tensor<128xi32>)
// CHECK: ^bb0(%[[ARG1:.+]]: i32, %[[ARG2:.+]]: i32)
// CHECK: %[[CMP:.+]] = arith.cmpi ugt, %[[ARG1]], %[[ARG2]]
// CHECK: iree_linalg_ext.yield %[[CMP]]
-// CHECK: %[[RESULT:.+]] = builtin.unrealized_conversion_cast %[[SORT]] : tensor<128xi32> to tensor<128xui32>
+// CHECK: %[[RESULT:.+]] = tensor.bitcast %[[SORT]] : tensor<128xi32> to tensor<128xui32>
// CHECK: return %[[RESULT]]
// -----
@@ -154,7 +154,7 @@
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-SAME: )
// CHECK: %[[UI32:.+]] = mhlo.constant dense<2> : tensor<ui32>
-// CHECK: %[[CONVERSION_CAST_CST:.+]] = builtin.unrealized_conversion_cast %[[UI32]] : tensor<ui32> to tensor<i32>
+// CHECK: %[[CONVERSION_CAST_CST:.+]] = tensor.bitcast %[[UI32]] : tensor<ui32> to tensor<i32>
// CHECK: %[[EXTRACT_CST:.+]] = tensor.extract %[[CONVERSION_CAST_CST]][] : tensor<i32>
// CHECK: %[[SORT:.+]] = iree_linalg_ext.sort
// CHECK-SAME: dimension(1)
diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/test/convert_structural_types.mlir b/compiler/src/iree/compiler/InputConversion/MHLO/test/convert_structural_types.mlir
index 964b5a1..73455e7 100644
--- a/compiler/src/iree/compiler/InputConversion/MHLO/test/convert_structural_types.mlir
+++ b/compiler/src/iree/compiler/InputConversion/MHLO/test/convert_structural_types.mlir
@@ -2,24 +2,33 @@
// CHECK-LABEL: @func_cfg_conversion
module @func_cfg_conversion {
- // CHECK: func.func @caller(%arg0: tensor<2xi32>, %arg1: i1) -> tensor<2xi32>
+ // CHECK: func.func @caller(%arg0: tensor<2xui32>, %arg1: i1) -> tensor<2xui32>
func.func @caller(%arg0: tensor<2xui32>, %arg1 : i1) -> tensor<2xui32> {
- // CHECK: %[[RESULT:.*]] = call @callee(%arg0, %arg1) : (tensor<2xi32>, i1) -> tensor<2xi32>
+ // CHECK: %[[RESULT:.*]] = call @callee(%arg0, %arg1) : (tensor<2xui32>, i1) -> tensor<2xui32>
%1 = call @callee(%arg0, %arg1) : (tensor<2xui32>, i1) -> tensor<2xui32>
- // CHECK: return %[[RESULT]] : tensor<2xi32>
+ // CHECK: return %[[RESULT]] : tensor<2xui32>
return %1 : tensor<2xui32>
}
- // CHECK: func.func @callee(%arg0: tensor<2xi32>, %arg1: i1) -> tensor<2xi32>
+ // CHECK: func.func @callee(%arg0: tensor<2xui32>, %arg1: i1) -> tensor<2xui32>
func.func @callee(%arg0: tensor<2xui32>, %arg1: i1) -> tensor<2xui32> {
- // CHECK: cf.cond_br %arg1, ^bb1(%arg0 : tensor<2xi32>), ^bb2(%arg0 : tensor<2xi32>)
+ // CHECK: cf.cond_br %arg1, ^bb1(%arg0 : tensor<2xui32>), ^bb2(%arg0 : tensor<2xui32>)
cf.cond_br %arg1, ^bb1(%arg0 : tensor<2xui32>), ^bb2(%arg0 : tensor<2xui32>)
- // CHECK: ^bb1(%[[BB1_PHI:.*]]: tensor<2xi32>)
+ // CHECK: ^bb1(%[[BB1_PHI:.*]]: tensor<2xui32>)
^bb1(%phi0 : tensor<2xui32>) :
- // CHECK: cf.br ^bb2(%[[BB1_PHI]] : tensor<2xi32>)
- cf.br ^bb2(%phi0 : tensor<2xui32>)
- // CHECK: ^bb2(%[[BB2_PHI:.*]]: tensor<2xi32>)
+ // CHECK: tensor.bitcast %[[BB1_PHI]] : tensor<2xui32> to tensor<2xi32>
+ // CHECK: %[[BB1_PHI_ADD:.*]] = linalg.generic
+ // CHECK: %[[BB1_PHI_ADD_CAST:.*]] = tensor.bitcast %[[BB1_PHI_ADD]] : tensor<2xi32> to tensor<2xui32>
+ // CHECK: cf.br ^bb2(%[[BB1_PHI_ADD_CAST]] : tensor<2xui32>)
+ %0 = "mhlo.add"(%phi0, %phi0) : (tensor<2xui32>, tensor<2xui32>) -> tensor<2xui32>
+ cf.br ^bb2(%0 : tensor<2xui32>)
+ // CHECK: ^bb2(%[[BB2_PHI:.*]]: tensor<2xui32>)
^bb2(%phi1 : tensor<2xui32>):
- return %phi1 : tensor<2xui32>
+ // CHECK: tensor.bitcast %[[BB2_PHI]] : tensor<2xui32> to tensor<2xi32>
+ // CHECK: %[[BB2_PHI_ADD:.*]] = linalg.generic
+ // CHECK: %[[BB2_PHI_ADD_CAST:.*]] = tensor.bitcast %[[BB2_PHI_ADD]] : tensor<2xi32> to tensor<2xui32>
+ // CHECK: return %[[BB2_PHI_ADD_CAST]] : tensor<2xui32>
+ %1 = "mhlo.add"(%phi1, %phi1) : (tensor<2xui32>, tensor<2xui32>) -> tensor<2xui32>
+ return %1 : tensor<2xui32>
}
}
diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/test/mhlo_to_linalg.mlir b/compiler/src/iree/compiler/InputConversion/MHLO/test/mhlo_to_linalg.mlir
index a135925..b2001e1 100644
--- a/compiler/src/iree/compiler/InputConversion/MHLO/test/mhlo_to_linalg.mlir
+++ b/compiler/src/iree/compiler/InputConversion/MHLO/test/mhlo_to_linalg.mlir
@@ -19,11 +19,12 @@
// CHECK: ml_program.global private mutable @variable(dense<0> : tensor<2xi32>) : tensor<2xi32>
ml_program.global private mutable @variable(dense<0> : tensor<2xui32>) : tensor<2xui32>
-// CHECK: func.func @global_types() -> tensor<2xi32>
+// CHECK: func.func @global_types() -> tensor<2xui32>
func.func @global_types() -> tensor<2xui32> {
// CHECK-NEXT: %[[VALUE:.+]] = ml_program.global_load @variable : tensor<2xi32>
%0 = ml_program.global_load @variable : tensor<2xui32>
- // CHECK: return %[[VALUE]] : tensor<2xi32>
+ // CHECK-NEXT: %[[CAST:.*]] = tensor.bitcast %[[VALUE]] : tensor<2xi32> to tensor<2xui32>
+ // CHECK: return %[[CAST]] : tensor<2xui32>
return %0 : tensor<2xui32>
}
@@ -37,3 +38,22 @@
%0, %1 = "mhlo.optimization_barrier"(%arg0, %arg1) : (tensor<3x4xf32>, tensor<4xi32>) -> (tensor<3x4xf32>, tensor<4xi32>)
return %0, %1 : tensor<3x4xf32>, tensor<4xi32>
}
+
+// -----
+
+// CHECK: @unsigned_integer_input_output(%[[ARG0:.*]]: tensor<2x2xui32>, %[[ARG1:.*]]: tensor<2x2xui32>) -> tensor<2x2xui32>
+func.func @unsigned_integer_input_output(%arg0: tensor<2x2xui32>, %arg1: tensor<2x2xui32>) -> tensor<2x2xui32> {
+ // CHECK: %[[CAST0:.*]] = tensor.bitcast %[[ARG0]] : tensor<2x2xui32> to tensor<2x2xi32>
+ // CHECK: %[[CAST1:.*]] = tensor.bitcast %[[ARG1]] : tensor<2x2xui32> to tensor<2x2xi32>
+ // CHECK: %[[INIT:.*]] = tensor.empty() : tensor<2x2xi32>
+ // CHECK: %[[LINALG:.*]] = linalg.generic
+ // CHECK-SAME: ins(%[[CAST0]], %[[CAST1]] : tensor<2x2xi32>, tensor<2x2xi32>
+ // CHECK-SAME: outs(%[[INIT]] : tensor<2x2xi32>)
+ // CHECK: ^bb0(%[[IN0:.*]]: i32, %[[IN1:.*]]: i32, %out: i32):
+ // CHECK: %[[ADD:.*]] = arith.addi %[[IN0]], %[[IN1]] : i32
+ // CHECK: linalg.yield %[[ADD:.*]] : i32
+ %0 = "mhlo.add"(%arg0, %arg1) : (tensor<2x2xui32>, tensor<2x2xui32>) -> tensor<2x2xui32>
+ // CHECK: %[[CAST_RESULT:.*]] = tensor.bitcast %[[LINALG]] : tensor<2x2xi32> to tensor<2x2xui32>
+ // CHECK: return %[[CAST_RESULT]]
+ return %0 : tensor<2x2xui32>
+}