data-tiling for `f16` and `bf16` matmuls (#14207)

diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUMaterializeEncodingPass.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUMaterializeEncodingPass.cpp
index b854beb..c2375d8 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUMaterializeEncodingPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUMaterializeEncodingPass.cpp
@@ -33,6 +33,15 @@
 chooseMatmulTileParamsAArch64(MatmulType type, ExecutableTargetAttr target) {
   switch (type) {
   case MatmulType::F32F32F32:
+  case MatmulType::F16F16F32:
+  case MatmulType::F16F16F16:
+  case MatmulType::BF16BF16F32:
+  case MatmulType::BF16BF16BF16:
+    // Note: 16-bit floating point types currently use the same tile size as
+    // f32. This makes sense when either (1) the accumulator is f32, or (2)
+    // the arithmetic will have to expand f16 to f32 in registers. We may
+    // reconsider when taking advantage of native f16/bf16 arithmetic when the
+    // accumulator itself is f16/bf16.
     return {8, 1, 8};
   case MatmulType::I8I8I32:
     if (hasFeature(target, "+i8mm")) {
@@ -54,8 +63,18 @@
 chooseMatmulTileParamsX86_64(MatmulType type, ExecutableTargetAttr target) {
   switch (type) {
   case MatmulType::F32F32F32:
-    if (hasAVX512fFeature(target))
+  case MatmulType::F16F16F32:
+  case MatmulType::F16F16F16:
+  case MatmulType::BF16BF16F32:
+  case MatmulType::BF16BF16BF16:
+    // Note: 16-bit floating point types currently use the same tile size as
+    // f32. This makes sense when either (1) the accumulator is f32, or (2)
+    // the arithmetic will have to expand f16 to f32 in registers. We may
+    // reconsider when taking advantage of native f16/bf16 arithmetic when the
+    // accumulator itself is f16/bf16.
+    if (hasFeature(target, "+avx512f")) {
       return {16, 1, 16};
+    }
     if (hasFeature(target, "+avx")) {
       // Note: for good performance, most +avx users will also want to add
       // +fma, but that's a local instruction selection detail and the tile
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/materialize_encoding.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/materialize_encoding.mlir
index 15c7430..dc7ad00 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/materialize_encoding.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/materialize_encoding.mlir
@@ -212,6 +212,66 @@
 
 // -----
 
+func.func @matmul_lowering_f16f16f16_aarch64() attributes {
+  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="aarch64-xyz-xyz"}>
+} {
+  %c0 = arith.constant 0 : index
+  %M = hal.interface.constant.load[0] : index
+  %N = hal.interface.constant.load[1] : index
+  %K = hal.interface.constant.load[2] : index
+  %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0)
+      : !flow.dispatch.tensor<readonly:tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_LHS>>>{%M, %K}
+  %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0)
+      : !flow.dispatch.tensor<readonly:tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_RHS>>>{%K, %N}
+  %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0)
+      : !flow.dispatch.tensor<readwrite:tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_RESULT>>>{%M, %N}
+  %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [%M, %K], strides = [1, 1]
+      : !flow.dispatch.tensor<readonly:tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_LHS>>>{%M, %K}
+      -> tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_LHS>>
+  %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [%K, %N], strides = [1, 1]
+      : !flow.dispatch.tensor<readonly:tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_RHS>>>{%K, %N}
+      -> tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_RHS>>
+  %5 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1]
+      : !flow.dispatch.tensor<readwrite:tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_RESULT>>>{%M, %N}
+      -> tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_RESULT>>
+  %6 = linalg.matmul
+      ins(%3, %4 : tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_LHS>>,
+                   tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_RHS>>)
+      outs(%5 : tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_RESULT>>)
+      -> tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_RESULT>>
+  flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1]
+      : tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_RESULT>>
+      -> !flow.dispatch.tensor<readwrite:tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_RESULT>>>{%M, %N}
+  return
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
+//      CHECK: func @matmul_lowering_f16f16f16_aarch64()
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[M:.+]] = hal.interface.constant.load[0]
+//  CHECK-DAG:   %[[N:.+]] = hal.interface.constant.load[1]
+//  CHECK-DAG:   %[[K:.+]] = hal.interface.constant.load[2]
+//  CHECK-DAG:   %[[TILED_M:.+]] = affine.apply #[[MAP0]]()[%[[M]]]
+//      CHECK:   %[[LHS_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(0)
+// CHECK-SAME:       !flow.dispatch.tensor<readonly:tensor<?x?x8x1xf16>>{%[[TILED_M]], %[[K]]}
+//      CHECK:   %[[TILED_N:.+]] = affine.apply #[[MAP0]]()[%[[N]]]
+//      CHECK:   %[[RHS_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(1)
+// CHECK-SAME:       !flow.dispatch.tensor<readonly:tensor<?x?x8x1xf16>>{%[[TILED_N]], %[[K]]}
+//      CHECK:   %[[OUTS_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(2)
+// CHECK-SAME:       !flow.dispatch.tensor<readwrite:tensor<?x?x8x8xf16>>{%[[TILED_M]], %[[TILED_N]]}
+//      CHECK:   %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]
+// CHECK-SAME:       offsets = [0, 0, 0, 0], sizes = [%[[TILED_M]], %[[K]], 8, 1], strides = [1, 1, 1, 1]
+//      CHECK:   %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]
+// CHECK-SAME:       offsets = [0, 0, 0, 0], sizes = [%[[TILED_N]], %[[K]], 8, 1], strides = [1, 1, 1, 1]
+//      CHECK:   %[[OUTS:.+]] = flow.dispatch.tensor.load %[[OUTS_BINDING]]
+// CHECK-SAME:       offsets = [0, 0, 0, 0], sizes = [%[[TILED_M]], %[[TILED_N]], 8, 8], strides = [1, 1, 1, 1]
+//      CHECK:   %[[MMT4D:.+]] = linalg.mmt4d
+// CHECK-SAME:       ins(%[[LHS]], %[[RHS]] :
+// CHECK-SAME:       outs(%[[OUTS]] :
+//      CHECK:   flow.dispatch.tensor.store %[[MMT4D]], %[[OUTS_BINDING]]
+// CHECK-SAME:       offsets = [0, 0, 0, 0], sizes = [%[[TILED_M]], %[[TILED_N]], 8, 8], strides = [1, 1, 1, 1]
+
+// -----
+
 func.func @matmul_lowering_f32f32f32_x86_64() attributes {
   hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz"}>
 } {
@@ -393,6 +453,246 @@
 
 // -----
 
+func.func @matmul_lowering_f16f16f32_x86_64_avx512f() attributes {
+  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512f"}>
+} {
+  %c0 = arith.constant 0 : index
+  %M = hal.interface.constant.load[0] : index
+  %N = hal.interface.constant.load[1] : index
+  %K = hal.interface.constant.load[2] : index
+  %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0)
+      : !flow.dispatch.tensor<readonly:tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F32_LHS>>>{%M, %K}
+  %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0)
+      : !flow.dispatch.tensor<readonly:tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F32_RHS>>>{%K, %N}
+  %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0)
+      : !flow.dispatch.tensor<readwrite:tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F16F16F32_RESULT>>>{%M, %N}
+  %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [%M, %K], strides = [1, 1]
+      : !flow.dispatch.tensor<readonly:tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F32_LHS>>>{%M, %K}
+      -> tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F32_LHS>>
+  %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [%K, %N], strides = [1, 1]
+      : !flow.dispatch.tensor<readonly:tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F32_RHS>>>{%K, %N}
+      -> tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F32_RHS>>
+  %5 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1]
+      : !flow.dispatch.tensor<readwrite:tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F16F16F32_RESULT>>>{%M, %N}
+      -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F16F16F32_RESULT>>
+  %6 = linalg.matmul
+      ins(%3, %4 : tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F32_LHS>>,
+                   tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F32_RHS>>)
+      outs(%5 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F16F16F32_RESULT>>)
+      -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F16F16F32_RESULT>>
+  flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1]
+      : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F16F16F32_RESULT>>
+      -> !flow.dispatch.tensor<readwrite:tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F16F16F32_RESULT>>>{%M, %N}
+  return
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 16)>
+//      CHECK: func @matmul_lowering_f16f16f32_x86_64_avx512f()
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[M:.+]] = hal.interface.constant.load[0]
+//  CHECK-DAG:   %[[N:.+]] = hal.interface.constant.load[1]
+//  CHECK-DAG:   %[[K:.+]] = hal.interface.constant.load[2]
+//  CHECK-DAG:   %[[TILED_M:.+]] = affine.apply #[[MAP0]]()[%[[M]]]
+//      CHECK:   %[[LHS_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(0)
+// CHECK-SAME:       !flow.dispatch.tensor<readonly:tensor<?x?x16x1xf16>>{%[[TILED_M]], %[[K]]}
+//      CHECK:   %[[TILED_N:.+]] = affine.apply #[[MAP0]]()[%[[N]]]
+//      CHECK:   %[[RHS_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(1)
+// CHECK-SAME:       !flow.dispatch.tensor<readonly:tensor<?x?x16x1xf16>>{%[[TILED_N]], %[[K]]}
+//      CHECK:   %[[OUTS_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(2)
+// CHECK-SAME:       !flow.dispatch.tensor<readwrite:tensor<?x?x16x16xf32>>{%[[TILED_M]], %[[TILED_N]]}
+//      CHECK:   %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]
+// CHECK-SAME:       offsets = [0, 0, 0, 0], sizes = [%[[TILED_M]], %[[K]], 16, 1], strides = [1, 1, 1, 1]
+//      CHECK:   %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]
+// CHECK-SAME:       offsets = [0, 0, 0, 0], sizes = [%[[TILED_N]], %[[K]], 16, 1], strides = [1, 1, 1, 1]
+//      CHECK:   %[[OUTS:.+]] = flow.dispatch.tensor.load %[[OUTS_BINDING]]
+// CHECK-SAME:       offsets = [0, 0, 0, 0], sizes = [%[[TILED_M]], %[[TILED_N]], 16, 16], strides = [1, 1, 1, 1]
+//      CHECK:   %[[MMT4D:.+]] = linalg.mmt4d
+// CHECK-SAME:       ins(%[[LHS]], %[[RHS]] :
+// CHECK-SAME:       outs(%[[OUTS]] :
+//      CHECK:   flow.dispatch.tensor.store %[[MMT4D]], %[[OUTS_BINDING]]
+// CHECK-SAME:       offsets = [0, 0, 0, 0], sizes = [%[[TILED_M]], %[[TILED_N]], 16, 16], strides = [1, 1, 1, 1]
+
+// -----
+
+func.func @matmul_lowering_f16f16f16_x86_64_avx512f() attributes {
+  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512f"}>
+} {
+  %c0 = arith.constant 0 : index
+  %M = hal.interface.constant.load[0] : index
+  %N = hal.interface.constant.load[1] : index
+  %K = hal.interface.constant.load[2] : index
+  %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0)
+      : !flow.dispatch.tensor<readonly:tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_LHS>>>{%M, %K}
+  %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0)
+      : !flow.dispatch.tensor<readonly:tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_RHS>>>{%K, %N}
+  %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0)
+      : !flow.dispatch.tensor<readwrite:tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_RESULT>>>{%M, %N}
+  %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [%M, %K], strides = [1, 1]
+      : !flow.dispatch.tensor<readonly:tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_LHS>>>{%M, %K}
+      -> tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_LHS>>
+  %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [%K, %N], strides = [1, 1]
+      : !flow.dispatch.tensor<readonly:tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_RHS>>>{%K, %N}
+      -> tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_RHS>>
+  %5 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1]
+      : !flow.dispatch.tensor<readwrite:tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_RESULT>>>{%M, %N}
+      -> tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_RESULT>>
+  %6 = linalg.matmul
+      ins(%3, %4 : tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_LHS>>,
+                   tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_RHS>>)
+      outs(%5 : tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_RESULT>>)
+      -> tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_RESULT>>
+  flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1]
+      : tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_RESULT>>
+      -> !flow.dispatch.tensor<readwrite:tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_RESULT>>>{%M, %N}
+  return
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 16)>
+//      CHECK: func @matmul_lowering_f16f16f16_x86_64_avx512f()
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[M:.+]] = hal.interface.constant.load[0]
+//  CHECK-DAG:   %[[N:.+]] = hal.interface.constant.load[1]
+//  CHECK-DAG:   %[[K:.+]] = hal.interface.constant.load[2]
+//  CHECK-DAG:   %[[TILED_M:.+]] = affine.apply #[[MAP0]]()[%[[M]]]
+//      CHECK:   %[[LHS_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(0)
+// CHECK-SAME:       !flow.dispatch.tensor<readonly:tensor<?x?x16x1xf16>>{%[[TILED_M]], %[[K]]}
+//      CHECK:   %[[TILED_N:.+]] = affine.apply #[[MAP0]]()[%[[N]]]
+//      CHECK:   %[[RHS_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(1)
+// CHECK-SAME:       !flow.dispatch.tensor<readonly:tensor<?x?x16x1xf16>>{%[[TILED_N]], %[[K]]}
+//      CHECK:   %[[OUTS_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(2)
+// CHECK-SAME:       !flow.dispatch.tensor<readwrite:tensor<?x?x16x16xf16>>{%[[TILED_M]], %[[TILED_N]]}
+//      CHECK:   %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]
+// CHECK-SAME:       offsets = [0, 0, 0, 0], sizes = [%[[TILED_M]], %[[K]], 16, 1], strides = [1, 1, 1, 1]
+//      CHECK:   %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]
+// CHECK-SAME:       offsets = [0, 0, 0, 0], sizes = [%[[TILED_N]], %[[K]], 16, 1], strides = [1, 1, 1, 1]
+//      CHECK:   %[[OUTS:.+]] = flow.dispatch.tensor.load %[[OUTS_BINDING]]
+// CHECK-SAME:       offsets = [0, 0, 0, 0], sizes = [%[[TILED_M]], %[[TILED_N]], 16, 16], strides = [1, 1, 1, 1]
+//      CHECK:   %[[MMT4D:.+]] = linalg.mmt4d
+// CHECK-SAME:       ins(%[[LHS]], %[[RHS]] :
+// CHECK-SAME:       outs(%[[OUTS]] :
+//      CHECK:   flow.dispatch.tensor.store %[[MMT4D]], %[[OUTS_BINDING]]
+// CHECK-SAME:       offsets = [0, 0, 0, 0], sizes = [%[[TILED_M]], %[[TILED_N]], 16, 16], strides = [1, 1, 1, 1]
+
+// -----
+
+func.func @matmul_lowering_bf16bf16f32_x86_64_avx512f() attributes {
+  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512f"}>
+} {
+  %c0 = arith.constant 0 : index
+  %M = hal.interface.constant.load[0] : index
+  %N = hal.interface.constant.load[1] : index
+  %K = hal.interface.constant.load[2] : index
+  %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0)
+      : !flow.dispatch.tensor<readonly:tensor<?x?xbf16, #iree_linalg_ext.encoding<MATMUL_BF16BF16F32_LHS>>>{%M, %K}
+  %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0)
+      : !flow.dispatch.tensor<readonly:tensor<?x?xbf16, #iree_linalg_ext.encoding<MATMUL_BF16BF16F32_RHS>>>{%K, %N}
+  %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0)
+      : !flow.dispatch.tensor<readwrite:tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_BF16BF16F32_RESULT>>>{%M, %N}
+  %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [%M, %K], strides = [1, 1]
+      : !flow.dispatch.tensor<readonly:tensor<?x?xbf16, #iree_linalg_ext.encoding<MATMUL_BF16BF16F32_LHS>>>{%M, %K}
+      -> tensor<?x?xbf16, #iree_linalg_ext.encoding<MATMUL_BF16BF16F32_LHS>>
+  %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [%K, %N], strides = [1, 1]
+      : !flow.dispatch.tensor<readonly:tensor<?x?xbf16, #iree_linalg_ext.encoding<MATMUL_BF16BF16F32_RHS>>>{%K, %N}
+      -> tensor<?x?xbf16, #iree_linalg_ext.encoding<MATMUL_BF16BF16F32_RHS>>
+  %5 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1]
+      : !flow.dispatch.tensor<readwrite:tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_BF16BF16F32_RESULT>>>{%M, %N}
+      -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_BF16BF16F32_RESULT>>
+  %6 = linalg.matmul
+      ins(%3, %4 : tensor<?x?xbf16, #iree_linalg_ext.encoding<MATMUL_BF16BF16F32_LHS>>,
+                   tensor<?x?xbf16, #iree_linalg_ext.encoding<MATMUL_BF16BF16F32_RHS>>)
+      outs(%5 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_BF16BF16F32_RESULT>>)
+      -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_BF16BF16F32_RESULT>>
+  flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1]
+      : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_BF16BF16F32_RESULT>>
+      -> !flow.dispatch.tensor<readwrite:tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_BF16BF16F32_RESULT>>>{%M, %N}
+  return
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 16)>
+//      CHECK: func @matmul_lowering_bf16bf16f32_x86_64_avx512f()
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[M:.+]] = hal.interface.constant.load[0]
+//  CHECK-DAG:   %[[N:.+]] = hal.interface.constant.load[1]
+//  CHECK-DAG:   %[[K:.+]] = hal.interface.constant.load[2]
+//  CHECK-DAG:   %[[TILED_M:.+]] = affine.apply #[[MAP0]]()[%[[M]]]
+//      CHECK:   %[[LHS_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(0)
+// CHECK-SAME:       !flow.dispatch.tensor<readonly:tensor<?x?x16x1xbf16>>{%[[TILED_M]], %[[K]]}
+//      CHECK:   %[[TILED_N:.+]] = affine.apply #[[MAP0]]()[%[[N]]]
+//      CHECK:   %[[RHS_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(1)
+// CHECK-SAME:       !flow.dispatch.tensor<readonly:tensor<?x?x16x1xbf16>>{%[[TILED_N]], %[[K]]}
+//      CHECK:   %[[OUTS_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(2)
+// CHECK-SAME:       !flow.dispatch.tensor<readwrite:tensor<?x?x16x16xf32>>{%[[TILED_M]], %[[TILED_N]]}
+//      CHECK:   %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]
+// CHECK-SAME:       offsets = [0, 0, 0, 0], sizes = [%[[TILED_M]], %[[K]], 16, 1], strides = [1, 1, 1, 1]
+//      CHECK:   %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]
+// CHECK-SAME:       offsets = [0, 0, 0, 0], sizes = [%[[TILED_N]], %[[K]], 16, 1], strides = [1, 1, 1, 1]
+//      CHECK:   %[[OUTS:.+]] = flow.dispatch.tensor.load %[[OUTS_BINDING]]
+// CHECK-SAME:       offsets = [0, 0, 0, 0], sizes = [%[[TILED_M]], %[[TILED_N]], 16, 16], strides = [1, 1, 1, 1]
+//      CHECK:   %[[MMT4D:.+]] = linalg.mmt4d
+// CHECK-SAME:       ins(%[[LHS]], %[[RHS]] :
+// CHECK-SAME:       outs(%[[OUTS]] :
+//      CHECK:   flow.dispatch.tensor.store %[[MMT4D]], %[[OUTS_BINDING]]
+// CHECK-SAME:       offsets = [0, 0, 0, 0], sizes = [%[[TILED_M]], %[[TILED_N]], 16, 16], strides = [1, 1, 1, 1]
+
+// -----
+
+func.func @matmul_lowering_bf16bf16bf16_x86_64_avx512f() attributes {
+  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512f"}>
+} {
+  %c0 = arith.constant 0 : index
+  %M = hal.interface.constant.load[0] : index
+  %N = hal.interface.constant.load[1] : index
+  %K = hal.interface.constant.load[2] : index
+  %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0)
+      : !flow.dispatch.tensor<readonly:tensor<?x?xbf16, #iree_linalg_ext.encoding<MATMUL_BF16BF16BF16_LHS>>>{%M, %K}
+  %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0)
+      : !flow.dispatch.tensor<readonly:tensor<?x?xbf16, #iree_linalg_ext.encoding<MATMUL_BF16BF16BF16_RHS>>>{%K, %N}
+  %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0)
+      : !flow.dispatch.tensor<readwrite:tensor<?x?xbf16, #iree_linalg_ext.encoding<MATMUL_BF16BF16BF16_RESULT>>>{%M, %N}
+  %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [%M, %K], strides = [1, 1]
+      : !flow.dispatch.tensor<readonly:tensor<?x?xbf16, #iree_linalg_ext.encoding<MATMUL_BF16BF16BF16_LHS>>>{%M, %K}
+      -> tensor<?x?xbf16, #iree_linalg_ext.encoding<MATMUL_BF16BF16BF16_LHS>>
+  %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [%K, %N], strides = [1, 1]
+      : !flow.dispatch.tensor<readonly:tensor<?x?xbf16, #iree_linalg_ext.encoding<MATMUL_BF16BF16BF16_RHS>>>{%K, %N}
+      -> tensor<?x?xbf16, #iree_linalg_ext.encoding<MATMUL_BF16BF16BF16_RHS>>
+  %5 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1]
+      : !flow.dispatch.tensor<readwrite:tensor<?x?xbf16, #iree_linalg_ext.encoding<MATMUL_BF16BF16BF16_RESULT>>>{%M, %N}
+      -> tensor<?x?xbf16, #iree_linalg_ext.encoding<MATMUL_BF16BF16BF16_RESULT>>
+  %6 = linalg.matmul
+      ins(%3, %4 : tensor<?x?xbf16, #iree_linalg_ext.encoding<MATMUL_BF16BF16BF16_LHS>>,
+                   tensor<?x?xbf16, #iree_linalg_ext.encoding<MATMUL_BF16BF16BF16_RHS>>)
+      outs(%5 : tensor<?x?xbf16, #iree_linalg_ext.encoding<MATMUL_BF16BF16BF16_RESULT>>)
+      -> tensor<?x?xbf16, #iree_linalg_ext.encoding<MATMUL_BF16BF16BF16_RESULT>>
+  flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1]
+      : tensor<?x?xbf16, #iree_linalg_ext.encoding<MATMUL_BF16BF16BF16_RESULT>>
+      -> !flow.dispatch.tensor<readwrite:tensor<?x?xbf16, #iree_linalg_ext.encoding<MATMUL_BF16BF16BF16_RESULT>>>{%M, %N}
+  return
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 16)>
+//      CHECK: func @matmul_lowering_bf16bf16bf16_x86_64_avx512f()
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[M:.+]] = hal.interface.constant.load[0]
+//  CHECK-DAG:   %[[N:.+]] = hal.interface.constant.load[1]
+//  CHECK-DAG:   %[[K:.+]] = hal.interface.constant.load[2]
+//  CHECK-DAG:   %[[TILED_M:.+]] = affine.apply #[[MAP0]]()[%[[M]]]
+//      CHECK:   %[[LHS_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(0)
+// CHECK-SAME:       !flow.dispatch.tensor<readonly:tensor<?x?x16x1xbf16>>{%[[TILED_M]], %[[K]]}
+//      CHECK:   %[[TILED_N:.+]] = affine.apply #[[MAP0]]()[%[[N]]]
+//      CHECK:   %[[RHS_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(1)
+// CHECK-SAME:       !flow.dispatch.tensor<readonly:tensor<?x?x16x1xbf16>>{%[[TILED_N]], %[[K]]}
+//      CHECK:   %[[OUTS_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(2)
+// CHECK-SAME:       !flow.dispatch.tensor<readwrite:tensor<?x?x16x16xbf16>>{%[[TILED_M]], %[[TILED_N]]}
+//      CHECK:   %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]
+// CHECK-SAME:       offsets = [0, 0, 0, 0], sizes = [%[[TILED_M]], %[[K]], 16, 1], strides = [1, 1, 1, 1]
+//      CHECK:   %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]
+// CHECK-SAME:       offsets = [0, 0, 0, 0], sizes = [%[[TILED_N]], %[[K]], 16, 1], strides = [1, 1, 1, 1]
+//      CHECK:   %[[OUTS:.+]] = flow.dispatch.tensor.load %[[OUTS_BINDING]]
+// CHECK-SAME:       offsets = [0, 0, 0, 0], sizes = [%[[TILED_M]], %[[TILED_N]], 16, 16], strides = [1, 1, 1, 1]
+//      CHECK:   %[[MMT4D:.+]] = linalg.mmt4d
+// CHECK-SAME:       ins(%[[LHS]], %[[RHS]] :
+// CHECK-SAME:       outs(%[[OUTS]] :
+//      CHECK:   flow.dispatch.tensor.store %[[MMT4D]], %[[OUTS_BINDING]]
+// CHECK-SAME:       offsets = [0, 0, 0, 0], sizes = [%[[TILED_M]], %[[TILED_N]], 16, 16], strides = [1, 1, 1, 1]
+
+// -----
+
 func.func @matmul_lowering_i8i8i32_aarch64() attributes {
   hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="aarch64-xyz-xyz"}>
 } {
diff --git a/compiler/src/iree/compiler/Codegen/Utils/EncodingUtils.cpp b/compiler/src/iree/compiler/Codegen/Utils/EncodingUtils.cpp
index 26ae6b4..b92223f 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/EncodingUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Utils/EncodingUtils.cpp
@@ -23,11 +23,26 @@
       resultElementType.isSignlessInteger(32)) {
     return MatmulType::I8I8I32;
   }
-
   if (lhsElementType.isF32() && rhsElementType.isF32() &&
       resultElementType.isF32()) {
     return MatmulType::F32F32F32;
   }
+  if (lhsElementType.isF16() && rhsElementType.isF16() &&
+      resultElementType.isF32()) {
+    return MatmulType::F16F16F32;
+  }
+  if (lhsElementType.isF16() && rhsElementType.isF16() &&
+      resultElementType.isF16()) {
+    return MatmulType::F16F16F16;
+  }
+  if (lhsElementType.isBF16() && rhsElementType.isBF16() &&
+      resultElementType.isF32()) {
+    return MatmulType::BF16BF16F32;
+  }
+  if (lhsElementType.isBF16() && rhsElementType.isBF16() &&
+      resultElementType.isBF16()) {
+    return MatmulType::BF16BF16BF16;
+  }
 
   return std::nullopt;
 }
@@ -40,36 +55,50 @@
   return encodingAttr.getEncoding().getValue();
 }
 
+#define IREE_GETMATMULTYPE_CASE(TYPE)                                          \
+  case TensorEncoding::MATMUL_##TYPE##_LHS:                                    \
+  case TensorEncoding::MATMUL_##TYPE##_RHS:                                    \
+  case TensorEncoding::MATMUL_##TYPE##_RESULT:                                 \
+    return MatmulType::TYPE;
+
 std::optional<MatmulType> getMatmulType(TensorEncoding encoding) {
   switch (encoding) {
-  case TensorEncoding::MATMUL_F32F32F32_LHS:
-  case TensorEncoding::MATMUL_F32F32F32_RHS:
-  case TensorEncoding::MATMUL_F32F32F32_RESULT:
-    return MatmulType::F32F32F32;
-  case TensorEncoding::MATMUL_I8I8I32_LHS:
-  case TensorEncoding::MATMUL_I8I8I32_RHS:
-  case TensorEncoding::MATMUL_I8I8I32_RESULT:
-    return MatmulType::I8I8I32;
+    IREE_GETMATMULTYPE_CASE(F32F32F32)
+    IREE_GETMATMULTYPE_CASE(I8I8I32)
+    IREE_GETMATMULTYPE_CASE(F16F16F32)
+    IREE_GETMATMULTYPE_CASE(F16F16F16)
+    IREE_GETMATMULTYPE_CASE(BF16BF16F32)
+    IREE_GETMATMULTYPE_CASE(BF16BF16BF16)
   default:
     return std::nullopt;
   }
 }
 
+#undef IREE_GETMATMULTYPE_CASE
+
+#define IREE_GETMATMULOPERANDROLE_CASE(TYPE)                                   \
+  case TensorEncoding::MATMUL_##TYPE##_LHS:                                    \
+    return MatmulOperandRole::LHS;                                             \
+  case TensorEncoding::MATMUL_##TYPE##_RHS:                                    \
+    return MatmulOperandRole::RHS;                                             \
+  case TensorEncoding::MATMUL_##TYPE##_RESULT:                                 \
+    return MatmulOperandRole::RESULT;
+
 std::optional<MatmulOperandRole> getMatmulOperandRole(TensorEncoding encoding) {
   switch (encoding) {
-  case TensorEncoding::MATMUL_F32F32F32_LHS:
-  case TensorEncoding::MATMUL_I8I8I32_LHS:
-    return MatmulOperandRole::LHS;
-  case TensorEncoding::MATMUL_F32F32F32_RHS:
-  case TensorEncoding::MATMUL_I8I8I32_RHS:
-    return MatmulOperandRole::RHS;
-  case TensorEncoding::MATMUL_F32F32F32_RESULT:
-  case TensorEncoding::MATMUL_I8I8I32_RESULT:
-    return MatmulOperandRole::RESULT;
+    IREE_GETMATMULOPERANDROLE_CASE(F32F32F32)
+    IREE_GETMATMULOPERANDROLE_CASE(I8I8I32)
+    IREE_GETMATMULOPERANDROLE_CASE(F16F16F32)
+    IREE_GETMATMULOPERANDROLE_CASE(F16F16F16)
+    IREE_GETMATMULOPERANDROLE_CASE(BF16BF16F32)
+    IREE_GETMATMULOPERANDROLE_CASE(BF16BF16BF16)
+
   default:
     return std::nullopt;
   }
 }
 
+#undef IREE_GETMATMULOPERANDROLE_CASE
+
 } // namespace iree_compiler
 } // namespace mlir
diff --git a/compiler/src/iree/compiler/Codegen/Utils/EncodingUtils.h b/compiler/src/iree/compiler/Codegen/Utils/EncodingUtils.h
index 1eb914b..22cc758 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/EncodingUtils.h
+++ b/compiler/src/iree/compiler/Codegen/Utils/EncodingUtils.h
@@ -18,6 +18,10 @@
 enum class MatmulType {
   F32F32F32,
   I8I8I32,
+  F16F16F32,
+  F16F16F16,
+  BF16BF16F32,
+  BF16BF16BF16
 };
 
 // Enumeration of the operands of a matmul-like operation such as linalg.matmul.
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/SetEncoding.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/SetEncoding.cpp
index f9808ae..fb8fc94 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/SetEncoding.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/SetEncoding.cpp
@@ -160,6 +160,26 @@
       lhsEncoding = TensorEncoding::MATMUL_F32F32F32_LHS;
       rhsEncoding = TensorEncoding::MATMUL_F32F32F32_RHS;
       outEncoding = TensorEncoding::MATMUL_F32F32F32_RESULT;
+    } else if (lhsElemType.isF16() && rhsElemType.isF16() &&
+               outElemType.isF32()) {
+      lhsEncoding = TensorEncoding::MATMUL_F16F16F32_LHS;
+      rhsEncoding = TensorEncoding::MATMUL_F16F16F32_RHS;
+      outEncoding = TensorEncoding::MATMUL_F16F16F32_RESULT;
+    } else if (lhsElemType.isF16() && rhsElemType.isF16() &&
+               outElemType.isF16()) {
+      lhsEncoding = TensorEncoding::MATMUL_F16F16F16_LHS;
+      rhsEncoding = TensorEncoding::MATMUL_F16F16F16_RHS;
+      outEncoding = TensorEncoding::MATMUL_F16F16F16_RESULT;
+    } else if (lhsElemType.isBF16() && rhsElemType.isBF16() &&
+               outElemType.isF32()) {
+      lhsEncoding = TensorEncoding::MATMUL_BF16BF16F32_LHS;
+      rhsEncoding = TensorEncoding::MATMUL_BF16BF16F32_RHS;
+      outEncoding = TensorEncoding::MATMUL_BF16BF16F32_RESULT;
+    } else if (lhsElemType.isBF16() && rhsElemType.isBF16() &&
+               outElemType.isBF16()) {
+      lhsEncoding = TensorEncoding::MATMUL_BF16BF16BF16_LHS;
+      rhsEncoding = TensorEncoding::MATMUL_BF16BF16BF16_RHS;
+      outEncoding = TensorEncoding::MATMUL_BF16BF16BF16_RESULT;
     } else if (lhsElemType.isSignlessInteger(8) &&
                rhsElemType.isSignlessInteger(8) &&
                outElemType.isSignlessInteger(32)) {
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/set_encoding.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/set_encoding.mlir
index ecde468..92de5a8 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/set_encoding.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/set_encoding.mlir
@@ -74,7 +74,6 @@
 //      PADDING:   %[[RESULT:.+]] = iree_linalg_ext.unset_encoding %[[MATMUL]]
 //      PADDING:   return %[[RESULT]]
 
-
 // -----
 
 func.func @matmul_dynamic(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
@@ -120,6 +119,126 @@
 
 // -----
 
+func.func @matmul_i8i8i32(%arg0 : tensor<128x256xi8>, %arg1 : tensor<256x512xi8>,
+    %arg2 : tensor<128x512xi32>) -> tensor<128x512xi32> {
+  %0 = linalg.matmul ins(%arg0, %arg1 : tensor<128x256xi8>, tensor<256x512xi8>)
+      outs(%arg2 : tensor<128x512xi32>) -> tensor<128x512xi32>
+  return %0 : tensor<128x512xi32>
+}
+//      CHECK: func @matmul_i8i8i32(
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<128x256xi8>
+// CHECK-SAME:     %[[ARG1:.+]]: tensor<256x512xi8>
+// CHECK-SAME:     %[[ARG2:.+]]: tensor<128x512xi32>
+//      CHECK:   %[[LHS:.+]] = iree_linalg_ext.set_encoding %[[ARG0]]
+// CHECK-SAME:       tensor<128x256xi8, #iree_linalg_ext.encoding<MATMUL_I8I8I32_LHS>>
+//      CHECK:   %[[RHS:.+]] = iree_linalg_ext.set_encoding %[[ARG1]]
+// CHECK-SAME:       tensor<256x512xi8, #iree_linalg_ext.encoding<MATMUL_I8I8I32_RHS>>
+//      CHECK:   %[[OUTS:.+]] = iree_linalg_ext.set_encoding %[[ARG2]]
+// CHECK-SAME:       tensor<128x512xi32, #iree_linalg_ext.encoding<MATMUL_I8I8I32_RESULT>>
+//      CHECK:   %[[MATMUL:.+]] = linalg.matmul
+// CHECK-SAME:       ins(%[[LHS]], %[[RHS]] :
+// CHECK-SAME:       outs(%[[OUTS]] :
+//      CHECK:   %[[RESULT:.+]] = iree_linalg_ext.unset_encoding %[[MATMUL]]
+//      CHECK:   return %[[RESULT]]
+
+// -----
+
+func.func @matmul_f16f16f32(%arg0 : tensor<128x256xf16>, %arg1 : tensor<256x512xf16>,
+    %arg2 : tensor<128x512xf32>) -> tensor<128x512xf32> {
+  %0 = linalg.matmul ins(%arg0, %arg1 : tensor<128x256xf16>, tensor<256x512xf16>)
+      outs(%arg2 : tensor<128x512xf32>) -> tensor<128x512xf32>
+  return %0 : tensor<128x512xf32>
+}
+//      CHECK: func @matmul_f16f16f32(
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<128x256xf16>
+// CHECK-SAME:     %[[ARG1:.+]]: tensor<256x512xf16>
+// CHECK-SAME:     %[[ARG2:.+]]: tensor<128x512xf32>
+//      CHECK:   %[[LHS:.+]] = iree_linalg_ext.set_encoding %[[ARG0]]
+// CHECK-SAME:       tensor<128x256xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F32_LHS>>
+//      CHECK:   %[[RHS:.+]] = iree_linalg_ext.set_encoding %[[ARG1]]
+// CHECK-SAME:       tensor<256x512xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F32_RHS>>
+//      CHECK:   %[[OUTS:.+]] = iree_linalg_ext.set_encoding %[[ARG2]]
+// CHECK-SAME:       tensor<128x512xf32, #iree_linalg_ext.encoding<MATMUL_F16F16F32_RESULT>>
+//      CHECK:   %[[MATMUL:.+]] = linalg.matmul
+// CHECK-SAME:       ins(%[[LHS]], %[[RHS]] :
+// CHECK-SAME:       outs(%[[OUTS]] :
+//      CHECK:   %[[RESULT:.+]] = iree_linalg_ext.unset_encoding %[[MATMUL]]
+//      CHECK:   return %[[RESULT]]
+
+// -----
+
+func.func @matmul_f16f16f16(%arg0 : tensor<128x256xf16>, %arg1 : tensor<256x512xf16>,
+    %arg2 : tensor<128x512xf16>) -> tensor<128x512xf16> {
+  %0 = linalg.matmul ins(%arg0, %arg1 : tensor<128x256xf16>, tensor<256x512xf16>)
+      outs(%arg2 : tensor<128x512xf16>) -> tensor<128x512xf16>
+  return %0 : tensor<128x512xf16>
+}
+//      CHECK: func @matmul_f16f16f16(
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<128x256xf16>
+// CHECK-SAME:     %[[ARG1:.+]]: tensor<256x512xf16>
+// CHECK-SAME:     %[[ARG2:.+]]: tensor<128x512xf16>
+//      CHECK:   %[[LHS:.+]] = iree_linalg_ext.set_encoding %[[ARG0]]
+// CHECK-SAME:       tensor<128x256xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_LHS>>
+//      CHECK:   %[[RHS:.+]] = iree_linalg_ext.set_encoding %[[ARG1]]
+// CHECK-SAME:       tensor<256x512xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_RHS>>
+//      CHECK:   %[[OUTS:.+]] = iree_linalg_ext.set_encoding %[[ARG2]]
+// CHECK-SAME:       tensor<128x512xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_RESULT>>
+//      CHECK:   %[[MATMUL:.+]] = linalg.matmul
+// CHECK-SAME:       ins(%[[LHS]], %[[RHS]] :
+// CHECK-SAME:       outs(%[[OUTS]] :
+//      CHECK:   %[[RESULT:.+]] = iree_linalg_ext.unset_encoding %[[MATMUL]]
+//      CHECK:   return %[[RESULT]]
+
+// -----
+
+func.func @matmul_bf16bf16f32(%arg0 : tensor<128x256xbf16>, %arg1 : tensor<256x512xbf16>,
+    %arg2 : tensor<128x512xf32>) -> tensor<128x512xf32> {
+  %0 = linalg.matmul ins(%arg0, %arg1 : tensor<128x256xbf16>, tensor<256x512xbf16>)
+      outs(%arg2 : tensor<128x512xf32>) -> tensor<128x512xf32>
+  return %0 : tensor<128x512xf32>
+}
+//      CHECK: func @matmul_bf16bf16f32(
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<128x256xbf16>
+// CHECK-SAME:     %[[ARG1:.+]]: tensor<256x512xbf16>
+// CHECK-SAME:     %[[ARG2:.+]]: tensor<128x512xf32>
+//      CHECK:   %[[LHS:.+]] = iree_linalg_ext.set_encoding %[[ARG0]]
+// CHECK-SAME:       tensor<128x256xbf16, #iree_linalg_ext.encoding<MATMUL_BF16BF16F32_LHS>>
+//      CHECK:   %[[RHS:.+]] = iree_linalg_ext.set_encoding %[[ARG1]]
+// CHECK-SAME:       tensor<256x512xbf16, #iree_linalg_ext.encoding<MATMUL_BF16BF16F32_RHS>>
+//      CHECK:   %[[OUTS:.+]] = iree_linalg_ext.set_encoding %[[ARG2]]
+// CHECK-SAME:       tensor<128x512xf32, #iree_linalg_ext.encoding<MATMUL_BF16BF16F32_RESULT>>
+//      CHECK:   %[[MATMUL:.+]] = linalg.matmul
+// CHECK-SAME:       ins(%[[LHS]], %[[RHS]] :
+// CHECK-SAME:       outs(%[[OUTS]] :
+//      CHECK:   %[[RESULT:.+]] = iree_linalg_ext.unset_encoding %[[MATMUL]]
+//      CHECK:   return %[[RESULT]]
+
+// -----
+
+func.func @matmul_bf16bf16bf16(%arg0 : tensor<128x256xbf16>, %arg1 : tensor<256x512xbf16>,
+    %arg2 : tensor<128x512xbf16>) -> tensor<128x512xbf16> {
+  %0 = linalg.matmul ins(%arg0, %arg1 : tensor<128x256xbf16>, tensor<256x512xbf16>)
+      outs(%arg2 : tensor<128x512xbf16>) -> tensor<128x512xbf16>
+  return %0 : tensor<128x512xbf16>
+}
+//      CHECK: func @matmul_bf16bf16bf16(
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<128x256xbf16>
+// CHECK-SAME:     %[[ARG1:.+]]: tensor<256x512xbf16>
+// CHECK-SAME:     %[[ARG2:.+]]: tensor<128x512xbf16>
+//      CHECK:   %[[LHS:.+]] = iree_linalg_ext.set_encoding %[[ARG0]]
+// CHECK-SAME:       tensor<128x256xbf16, #iree_linalg_ext.encoding<MATMUL_BF16BF16BF16_LHS>>
+//      CHECK:   %[[RHS:.+]] = iree_linalg_ext.set_encoding %[[ARG1]]
+// CHECK-SAME:       tensor<256x512xbf16, #iree_linalg_ext.encoding<MATMUL_BF16BF16BF16_RHS>>
+//      CHECK:   %[[OUTS:.+]] = iree_linalg_ext.set_encoding %[[ARG2]]
+// CHECK-SAME:       tensor<128x512xbf16, #iree_linalg_ext.encoding<MATMUL_BF16BF16BF16_RESULT>>
+//      CHECK:   %[[MATMUL:.+]] = linalg.matmul
+// CHECK-SAME:       ins(%[[LHS]], %[[RHS]] :
+// CHECK-SAME:       outs(%[[OUTS]] :
+//      CHECK:   %[[RESULT:.+]] = iree_linalg_ext.unset_encoding %[[MATMUL]]
+//      CHECK:   return %[[RESULT]]
+
+// -----
+
 func.func @fold_fill_with_set_encoding(%arg0 : index, %arg1 : index)
   -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>> {
   %cst = arith.constant 0.0 : f32
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtBase.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtBase.td
index fc228c1..68f1aec 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtBase.td
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtBase.td
@@ -57,12 +57,40 @@
     : I32EnumAttrCase<"MATMUL_I8I8I32_RHS", 4>;
 def MATMUL_I8I8I32_RESULT
     : I32EnumAttrCase<"MATMUL_I8I8I32_RESULT", 5>;
+def MATMUL_F16F16F32_LHS
+    : I32EnumAttrCase<"MATMUL_F16F16F32_LHS", 6>;
+def MATMUL_F16F16F32_RHS
+    : I32EnumAttrCase<"MATMUL_F16F16F32_RHS", 7>;
+def MATMUL_F16F16F32_RESULT
+    : I32EnumAttrCase<"MATMUL_F16F16F32_RESULT", 8>;
+def MATMUL_F16F16F16_LHS
+    : I32EnumAttrCase<"MATMUL_F16F16F16_LHS", 9>;
+def MATMUL_F16F16F16_RHS
+    : I32EnumAttrCase<"MATMUL_F16F16F16_RHS", 10>;
+def MATMUL_F16F16F16_RESULT
+    : I32EnumAttrCase<"MATMUL_F16F16F16_RESULT", 11>;
+def MATMUL_BF16BF16F32_LHS
+    : I32EnumAttrCase<"MATMUL_BF16BF16F32_LHS", 12>;
+def MATMUL_BF16BF16F32_RHS
+    : I32EnumAttrCase<"MATMUL_BF16BF16F32_RHS", 13>;
+def MATMUL_BF16BF16F32_RESULT
+    : I32EnumAttrCase<"MATMUL_BF16BF16F32_RESULT", 14>;
+def MATMUL_BF16BF16BF16_LHS
+    : I32EnumAttrCase<"MATMUL_BF16BF16BF16_LHS", 15>;
+def MATMUL_BF16BF16BF16_RHS
+    : I32EnumAttrCase<"MATMUL_BF16BF16BF16_RHS", 16>;
+def MATMUL_BF16BF16BF16_RESULT
+    : I32EnumAttrCase<"MATMUL_BF16BF16BF16_RESULT", 17>;
 
 def TensorEncodingEnum
     : I32EnumAttr<"TensorEncoding",
                   "identifier for encoding used for the tensor",[
                     MATMUL_F32F32F32_LHS, MATMUL_F32F32F32_RHS, MATMUL_F32F32F32_RESULT,
                     MATMUL_I8I8I32_LHS, MATMUL_I8I8I32_RHS, MATMUL_I8I8I32_RESULT,
+                    MATMUL_F16F16F32_LHS, MATMUL_F16F16F32_RHS, MATMUL_F16F16F32_RESULT,
+                    MATMUL_F16F16F16_LHS, MATMUL_F16F16F16_RHS, MATMUL_F16F16F16_RESULT,
+                    MATMUL_BF16BF16F32_LHS, MATMUL_BF16BF16F32_RHS, MATMUL_BF16BF16F32_RESULT,
+                    MATMUL_BF16BF16BF16_LHS, MATMUL_BF16BF16BF16_RHS, MATMUL_BF16BF16BF16_RESULT,
                   ]> {
   let cppNamespace = "::mlir::iree_compiler::IREE::LinalgExt";
   let genSpecializedAttr = 0;
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp
index 86ec963..53db781 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp
@@ -203,15 +203,7 @@
       getEncoding(inputs[1]->get().getType().cast<RankedTensorType>());
   std::optional<TensorEncoding> resultEncoding =
       getEncoding(outputs[0]->get().getType().cast<RankedTensorType>());
-  if (!lhsEncoding ||
-      (lhsEncoding.value() != TensorEncoding::MATMUL_F32F32F32_LHS &&
-       lhsEncoding.value() != TensorEncoding::MATMUL_I8I8I32_LHS) ||
-      !rhsEncoding ||
-      (rhsEncoding.value() != TensorEncoding::MATMUL_F32F32F32_RHS &&
-       rhsEncoding.value() != TensorEncoding::MATMUL_I8I8I32_RHS) ||
-      !resultEncoding ||
-      (resultEncoding.value() != TensorEncoding::MATMUL_F32F32F32_RESULT &&
-       resultEncoding.value() != TensorEncoding::MATMUL_I8I8I32_RESULT)) {
+  if (!lhsEncoding || !rhsEncoding || !resultEncoding) {
     return failure();
   }
   Operation *mmt4DOp = rewriter.create<linalg::Mmt4DOp>(