Add support for converting bf16 to uint16 on func ops. (#15231)

This enables mmt4d ukernels on bf16xbf16->f32.
diff --git a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
index 2a44e1d..f53f616 100644
--- a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
@@ -238,6 +238,7 @@
         "@llvm-project//mlir:BufferizationTransforms",
         "@llvm-project//mlir:DialectUtils",
         "@llvm-project//mlir:FuncDialect",
+        "@llvm-project//mlir:FuncTransforms",
         "@llvm-project//mlir:GPUDialect",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:LLVMCommonConversion",
diff --git a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
index eb519d2..f047fe1 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
@@ -188,6 +188,7 @@
     MLIRBufferizationDialect
     MLIRBufferizationTransforms
     MLIRFuncDialect
+    MLIRFuncTransforms
     MLIRGPUDialect
     MLIRIR
     MLIRLLVMCommonConversion
diff --git a/compiler/src/iree/compiler/Codegen/Common/ConvertBf16ToUInt16Buffers.cpp b/compiler/src/iree/compiler/Codegen/Common/ConvertBf16ToUInt16Buffers.cpp
index 1df9f53..0fed462 100644
--- a/compiler/src/iree/compiler/Codegen/Common/ConvertBf16ToUInt16Buffers.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/ConvertBf16ToUInt16Buffers.cpp
@@ -22,6 +22,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Transforms/Passes.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
 #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -54,6 +55,18 @@
     addConversion([this](ShapedType ty) -> std::optional<Type> {
       return ty.clone(convertType(ty.getElementType()));
     });
+
+    addConversion([this](FunctionType ty) -> std::optional<Type> {
+      SmallVector<Type> inputs;
+      if (failed(convertTypes(ty.getInputs(), inputs)))
+        return std::nullopt;
+
+      SmallVector<Type> results;
+      if (failed(convertTypes(ty.getResults(), results)))
+        return std::nullopt;
+
+      return FunctionType::get(ty.getContext(), inputs, results);
+    });
   }
 };
 
@@ -217,6 +230,10 @@
 
 static void populateIreeBf16EmulationPatterns(RewritePatternSet &patterns,
                                               TypeConverter &typeConverter) {
+  populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
+                                                                 typeConverter);
+  populateCallOpTypeConversionPattern(patterns, typeConverter);
+  populateReturnOpTypeConversionPattern(patterns, typeConverter);
   patterns.add<GenericTypeConversionPattern, ConvertHalInterfaceBindingSubspan,
                ConvertMemRefAlloc, ConvertMemRefLoad, ConvertMemRefStore>(
       typeConverter, patterns.getContext());
@@ -244,7 +261,6 @@
     // Run the main emulation pass.
     {
       ConversionTarget target(*ctx);
-      target.addLegalOp<func::ReturnOp>();
       target.addDynamicallyLegalOp<func::FuncOp>([&typeConverter](
                                                      Operation *op) {
         return typeConverter.isLegal(cast<func::FuncOp>(op).getFunctionType());
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/convert_bf16_to_uint16_buffers.mlir b/compiler/src/iree/compiler/Codegen/Common/test/convert_bf16_to_uint16_buffers.mlir
index 9df3a75..8b640bb 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/convert_bf16_to_uint16_buffers.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/convert_bf16_to_uint16_buffers.mlir
@@ -34,8 +34,52 @@
 // CHECK-LABEL: @bf16_constant
 func.func @bf16_constant(%arg0 : bf16) -> bf16 {
   // CHECK: %[[CNST:.+]] = arith.constant 16256 : i16
-  // CHECK: %[[CAST:.+]] = arith.bitcast %[[CNST]]
   %c0 = arith.constant 1.0 : bf16
-  // CHECK: return %[[CAST]]
+  // CHECK: return %[[CNST]]
   return %c0 : bf16
 }
+
+// -----
+
+// CHECK-LABEL: @iree_uk_mmt4d
+// CHECK-SAME:    memref<i16>
+// CHECK-SAME:    memref<i16>
+// CHECK-SAME:    memref<f32>
+func.func private @iree_uk_mmt4d(memref<bf16>, index, index, memref<bf16>, index, index, memref<f32>, index, index, index, index, index, i32, i32, i32, i32) attributes {hal.import.bitcode = true, hal.import.cconv = 1 : i32, hal.import.fields = ["processor_data"], llvm.bareptr = true}
+
+// CHECK-LABEL: @mmt4d_bf16xbf16xf32
+// CHECK:         func.call
+// CHECK-SAME:    memref<i16>
+// CHECK-SAME:    memref<i16>
+// CHECK-SAME:    memref<f32>
+func.func @mmt4d_bf16xbf16xf32() {
+  %c32 = arith.constant 32 : index
+  %c24 = arith.constant 24 : index
+  %c3 = arith.constant 3 : index
+  %c8_i32 = arith.constant 8 : i32
+  %c1_i32 = arith.constant 1 : i32
+  %c1029_i32 = arith.constant 1029 : i32
+  %c1 = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %c64 = arith.constant 64 : index
+  %c128 = arith.constant 128 : index
+  %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<1x3x8x1xbf16>
+  memref.assume_alignment %0, 64 : memref<1x3x8x1xbf16>
+  %1 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c64) flags(ReadOnly) : memref<1x3x8x1xbf16, strided<[24, 8, 1, 1], offset: 32>>
+  memref.assume_alignment %1, 64 : memref<1x3x8x1xbf16, strided<[24, 8, 1, 1], offset: 32>>
+  %2 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c128) : memref<1x1x8x8xf32, strided<[64, 64, 8, 1], offset: 32>>
+  memref.assume_alignment %2, 64 : memref<1x1x8x8xf32, strided<[64, 64, 8, 1], offset: 32>>
+  %workgroup_id_x = hal.interface.workgroup.id[0] : index
+  %workgroup_count_x = hal.interface.workgroup.count[0] : index
+  %workgroup_id_y = hal.interface.workgroup.id[1] : index
+  %workgroup_count_y = hal.interface.workgroup.count[1] : index
+  scf.for %arg0 = %workgroup_id_y to %c1 step %workgroup_count_y {
+    scf.for %arg1 = %workgroup_id_x to %c1 step %workgroup_count_x {
+      %base_buffer, %offset, %sizes:4, %strides:4 = memref.extract_strided_metadata %0 : memref<1x3x8x1xbf16> -> memref<bf16>, index, index, index, index, index, index, index, index, index
+      %base_buffer_0, %offset_1, %sizes_2:4, %strides_3:4 = memref.extract_strided_metadata %1 : memref<1x3x8x1xbf16, strided<[24, 8, 1, 1], offset: 32>> -> memref<bf16>, index, index, index, index, index, index, index, index, index
+      %base_buffer_4, %offset_5, %sizes_6:4, %strides_7:4 = memref.extract_strided_metadata %2 : memref<1x1x8x8xf32, strided<[64, 64, 8, 1], offset: 32>> -> memref<f32>, index, index, index, index, index, index, index, index, index
+      func.call @iree_uk_mmt4d(%base_buffer, %c0, %c24, %base_buffer_0, %c32, %c24, %base_buffer_4, %c32, %c64, %c1, %c1, %c3, %c8_i32, %c8_i32, %c1_i32, %c1029_i32) : (memref<bf16>, index, index, memref<bf16>, index, index, memref<f32>, index, index, index, index, index, i32, i32, i32, i32) -> ()
+    }
+  }
+  return
+}