data-tiling for `f16` and `bf16` matmuls (#14207)
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUMaterializeEncodingPass.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUMaterializeEncodingPass.cpp
index b854beb..c2375d8 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUMaterializeEncodingPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUMaterializeEncodingPass.cpp
@@ -33,6 +33,15 @@
chooseMatmulTileParamsAArch64(MatmulType type, ExecutableTargetAttr target) {
switch (type) {
case MatmulType::F32F32F32:
+ case MatmulType::F16F16F32:
+ case MatmulType::F16F16F16:
+ case MatmulType::BF16BF16F32:
+ case MatmulType::BF16BF16BF16:
+ // Note: 16-bit floating point types currently use the same tile size as
+ // f32. This makes sense when either (1) the accumulator is f32, or (2)
+ // the arithmetic will have to expand f16 to f32 in registers. We may
+ // reconsider when taking advantage of native f16/bf16 arithmetic when the
+ // accumulator itself is f16/bf16.
return {8, 1, 8};
case MatmulType::I8I8I32:
if (hasFeature(target, "+i8mm")) {
@@ -54,8 +63,18 @@
chooseMatmulTileParamsX86_64(MatmulType type, ExecutableTargetAttr target) {
switch (type) {
case MatmulType::F32F32F32:
- if (hasAVX512fFeature(target))
+ case MatmulType::F16F16F32:
+ case MatmulType::F16F16F16:
+ case MatmulType::BF16BF16F32:
+ case MatmulType::BF16BF16BF16:
+ // Note: 16-bit floating point types currently use the same tile size as
+ // f32. This makes sense when either (1) the accumulator is f32, or (2)
+ // the arithmetic will have to expand f16 to f32 in registers. We may
+ // reconsider when taking advantage of native f16/bf16 arithmetic when the
+ // accumulator itself is f16/bf16.
+ if (hasFeature(target, "+avx512f")) {
return {16, 1, 16};
+ }
if (hasFeature(target, "+avx")) {
// Note: for good performance, most +avx users will also want to add
// +fma, but that's a local instruction selection detail and the tile
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/materialize_encoding.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/materialize_encoding.mlir
index 15c7430..dc7ad00 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/materialize_encoding.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/materialize_encoding.mlir
@@ -212,6 +212,66 @@
// -----
+func.func @matmul_lowering_f16f16f16_aarch64() attributes {
+ hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="aarch64-xyz-xyz"}>
+} {
+ %c0 = arith.constant 0 : index
+ %M = hal.interface.constant.load[0] : index
+ %N = hal.interface.constant.load[1] : index
+ %K = hal.interface.constant.load[2] : index
+ %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0)
+ : !flow.dispatch.tensor<readonly:tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_LHS>>>{%M, %K}
+ %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0)
+ : !flow.dispatch.tensor<readonly:tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_RHS>>>{%K, %N}
+ %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0)
+ : !flow.dispatch.tensor<readwrite:tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_RESULT>>>{%M, %N}
+ %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [%M, %K], strides = [1, 1]
+ : !flow.dispatch.tensor<readonly:tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_LHS>>>{%M, %K}
+ -> tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_LHS>>
+ %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [%K, %N], strides = [1, 1]
+ : !flow.dispatch.tensor<readonly:tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_RHS>>>{%K, %N}
+ -> tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_RHS>>
+ %5 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1]
+ : !flow.dispatch.tensor<readwrite:tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_RESULT>>>{%M, %N}
+ -> tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_RESULT>>
+ %6 = linalg.matmul
+ ins(%3, %4 : tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_LHS>>,
+ tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_RHS>>)
+ outs(%5 : tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_RESULT>>)
+ -> tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_RESULT>>
+ flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1]
+ : tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_RESULT>>
+ -> !flow.dispatch.tensor<readwrite:tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_RESULT>>>{%M, %N}
+ return
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
+// CHECK: func @matmul_lowering_f16f16f16_aarch64()
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[M:.+]] = hal.interface.constant.load[0]
+// CHECK-DAG: %[[N:.+]] = hal.interface.constant.load[1]
+// CHECK-DAG: %[[K:.+]] = hal.interface.constant.load[2]
+// CHECK-DAG: %[[TILED_M:.+]] = affine.apply #[[MAP0]]()[%[[M]]]
+// CHECK: %[[LHS_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(0)
+// CHECK-SAME: !flow.dispatch.tensor<readonly:tensor<?x?x8x1xf16>>{%[[TILED_M]], %[[K]]}
+// CHECK: %[[TILED_N:.+]] = affine.apply #[[MAP0]]()[%[[N]]]
+// CHECK: %[[RHS_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(1)
+// CHECK-SAME: !flow.dispatch.tensor<readonly:tensor<?x?x8x1xf16>>{%[[TILED_N]], %[[K]]}
+// CHECK: %[[OUTS_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(2)
+// CHECK-SAME: !flow.dispatch.tensor<readwrite:tensor<?x?x8x8xf16>>{%[[TILED_M]], %[[TILED_N]]}
+// CHECK: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]
+// CHECK-SAME: offsets = [0, 0, 0, 0], sizes = [%[[TILED_M]], %[[K]], 8, 1], strides = [1, 1, 1, 1]
+// CHECK: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]
+// CHECK-SAME: offsets = [0, 0, 0, 0], sizes = [%[[TILED_N]], %[[K]], 8, 1], strides = [1, 1, 1, 1]
+// CHECK: %[[OUTS:.+]] = flow.dispatch.tensor.load %[[OUTS_BINDING]]
+// CHECK-SAME: offsets = [0, 0, 0, 0], sizes = [%[[TILED_M]], %[[TILED_N]], 8, 8], strides = [1, 1, 1, 1]
+// CHECK: %[[MMT4D:.+]] = linalg.mmt4d
+// CHECK-SAME: ins(%[[LHS]], %[[RHS]] :
+// CHECK-SAME: outs(%[[OUTS]] :
+// CHECK: flow.dispatch.tensor.store %[[MMT4D]], %[[OUTS_BINDING]]
+// CHECK-SAME: offsets = [0, 0, 0, 0], sizes = [%[[TILED_M]], %[[TILED_N]], 8, 8], strides = [1, 1, 1, 1]
+
+// -----
+
func.func @matmul_lowering_f32f32f32_x86_64() attributes {
hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz"}>
} {
@@ -393,6 +453,246 @@
// -----
+func.func @matmul_lowering_f16f16f32_x86_64_avx512f() attributes {
+ hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512f"}>
+} {
+ %c0 = arith.constant 0 : index
+ %M = hal.interface.constant.load[0] : index
+ %N = hal.interface.constant.load[1] : index
+ %K = hal.interface.constant.load[2] : index
+ %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0)
+ : !flow.dispatch.tensor<readonly:tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F32_LHS>>>{%M, %K}
+ %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0)
+ : !flow.dispatch.tensor<readonly:tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F32_RHS>>>{%K, %N}
+ %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0)
+ : !flow.dispatch.tensor<readwrite:tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F16F16F32_RESULT>>>{%M, %N}
+ %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [%M, %K], strides = [1, 1]
+ : !flow.dispatch.tensor<readonly:tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F32_LHS>>>{%M, %K}
+ -> tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F32_LHS>>
+ %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [%K, %N], strides = [1, 1]
+ : !flow.dispatch.tensor<readonly:tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F32_RHS>>>{%K, %N}
+ -> tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F32_RHS>>
+ %5 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1]
+ : !flow.dispatch.tensor<readwrite:tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F16F16F32_RESULT>>>{%M, %N}
+ -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F16F16F32_RESULT>>
+ %6 = linalg.matmul
+ ins(%3, %4 : tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F32_LHS>>,
+ tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F32_RHS>>)
+ outs(%5 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F16F16F32_RESULT>>)
+ -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F16F16F32_RESULT>>
+ flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1]
+ : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F16F16F32_RESULT>>
+ -> !flow.dispatch.tensor<readwrite:tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F16F16F32_RESULT>>>{%M, %N}
+ return
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 16)>
+// CHECK: func @matmul_lowering_f16f16f32_x86_64_avx512f()
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[M:.+]] = hal.interface.constant.load[0]
+// CHECK-DAG: %[[N:.+]] = hal.interface.constant.load[1]
+// CHECK-DAG: %[[K:.+]] = hal.interface.constant.load[2]
+// CHECK-DAG: %[[TILED_M:.+]] = affine.apply #[[MAP0]]()[%[[M]]]
+// CHECK: %[[LHS_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(0)
+// CHECK-SAME: !flow.dispatch.tensor<readonly:tensor<?x?x16x1xf16>>{%[[TILED_M]], %[[K]]}
+// CHECK: %[[TILED_N:.+]] = affine.apply #[[MAP0]]()[%[[N]]]
+// CHECK: %[[RHS_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(1)
+// CHECK-SAME: !flow.dispatch.tensor<readonly:tensor<?x?x16x1xf16>>{%[[TILED_N]], %[[K]]}
+// CHECK: %[[OUTS_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(2)
+// CHECK-SAME: !flow.dispatch.tensor<readwrite:tensor<?x?x16x16xf32>>{%[[TILED_M]], %[[TILED_N]]}
+// CHECK: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]
+// CHECK-SAME: offsets = [0, 0, 0, 0], sizes = [%[[TILED_M]], %[[K]], 16, 1], strides = [1, 1, 1, 1]
+// CHECK: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]
+// CHECK-SAME: offsets = [0, 0, 0, 0], sizes = [%[[TILED_N]], %[[K]], 16, 1], strides = [1, 1, 1, 1]
+// CHECK: %[[OUTS:.+]] = flow.dispatch.tensor.load %[[OUTS_BINDING]]
+// CHECK-SAME: offsets = [0, 0, 0, 0], sizes = [%[[TILED_M]], %[[TILED_N]], 16, 16], strides = [1, 1, 1, 1]
+// CHECK: %[[MMT4D:.+]] = linalg.mmt4d
+// CHECK-SAME: ins(%[[LHS]], %[[RHS]] :
+// CHECK-SAME: outs(%[[OUTS]] :
+// CHECK: flow.dispatch.tensor.store %[[MMT4D]], %[[OUTS_BINDING]]
+// CHECK-SAME: offsets = [0, 0, 0, 0], sizes = [%[[TILED_M]], %[[TILED_N]], 16, 16], strides = [1, 1, 1, 1]
+
+// -----
+
+func.func @matmul_lowering_f16f16f16_x86_64_avx512f() attributes {
+ hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512f"}>
+} {
+ %c0 = arith.constant 0 : index
+ %M = hal.interface.constant.load[0] : index
+ %N = hal.interface.constant.load[1] : index
+ %K = hal.interface.constant.load[2] : index
+ %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0)
+ : !flow.dispatch.tensor<readonly:tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_LHS>>>{%M, %K}
+ %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0)
+ : !flow.dispatch.tensor<readonly:tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_RHS>>>{%K, %N}
+ %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0)
+ : !flow.dispatch.tensor<readwrite:tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_RESULT>>>{%M, %N}
+ %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [%M, %K], strides = [1, 1]
+ : !flow.dispatch.tensor<readonly:tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_LHS>>>{%M, %K}
+ -> tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_LHS>>
+ %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [%K, %N], strides = [1, 1]
+ : !flow.dispatch.tensor<readonly:tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_RHS>>>{%K, %N}
+ -> tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_RHS>>
+ %5 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1]
+ : !flow.dispatch.tensor<readwrite:tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_RESULT>>>{%M, %N}
+ -> tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_RESULT>>
+ %6 = linalg.matmul
+ ins(%3, %4 : tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_LHS>>,
+ tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_RHS>>)
+ outs(%5 : tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_RESULT>>)
+ -> tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_RESULT>>
+ flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1]
+ : tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_RESULT>>
+ -> !flow.dispatch.tensor<readwrite:tensor<?x?xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_RESULT>>>{%M, %N}
+ return
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 16)>
+// CHECK: func @matmul_lowering_f16f16f16_x86_64_avx512f()
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[M:.+]] = hal.interface.constant.load[0]
+// CHECK-DAG: %[[N:.+]] = hal.interface.constant.load[1]
+// CHECK-DAG: %[[K:.+]] = hal.interface.constant.load[2]
+// CHECK-DAG: %[[TILED_M:.+]] = affine.apply #[[MAP0]]()[%[[M]]]
+// CHECK: %[[LHS_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(0)
+// CHECK-SAME: !flow.dispatch.tensor<readonly:tensor<?x?x16x1xf16>>{%[[TILED_M]], %[[K]]}
+// CHECK: %[[TILED_N:.+]] = affine.apply #[[MAP0]]()[%[[N]]]
+// CHECK: %[[RHS_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(1)
+// CHECK-SAME: !flow.dispatch.tensor<readonly:tensor<?x?x16x1xf16>>{%[[TILED_N]], %[[K]]}
+// CHECK: %[[OUTS_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(2)
+// CHECK-SAME: !flow.dispatch.tensor<readwrite:tensor<?x?x16x16xf16>>{%[[TILED_M]], %[[TILED_N]]}
+// CHECK: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]
+// CHECK-SAME: offsets = [0, 0, 0, 0], sizes = [%[[TILED_M]], %[[K]], 16, 1], strides = [1, 1, 1, 1]
+// CHECK: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]
+// CHECK-SAME: offsets = [0, 0, 0, 0], sizes = [%[[TILED_N]], %[[K]], 16, 1], strides = [1, 1, 1, 1]
+// CHECK: %[[OUTS:.+]] = flow.dispatch.tensor.load %[[OUTS_BINDING]]
+// CHECK-SAME: offsets = [0, 0, 0, 0], sizes = [%[[TILED_M]], %[[TILED_N]], 16, 16], strides = [1, 1, 1, 1]
+// CHECK: %[[MMT4D:.+]] = linalg.mmt4d
+// CHECK-SAME: ins(%[[LHS]], %[[RHS]] :
+// CHECK-SAME: outs(%[[OUTS]] :
+// CHECK: flow.dispatch.tensor.store %[[MMT4D]], %[[OUTS_BINDING]]
+// CHECK-SAME: offsets = [0, 0, 0, 0], sizes = [%[[TILED_M]], %[[TILED_N]], 16, 16], strides = [1, 1, 1, 1]
+
+// -----
+
+func.func @matmul_lowering_bf16bf16f32_x86_64_avx512f() attributes {
+ hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512f"}>
+} {
+ %c0 = arith.constant 0 : index
+ %M = hal.interface.constant.load[0] : index
+ %N = hal.interface.constant.load[1] : index
+ %K = hal.interface.constant.load[2] : index
+ %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0)
+ : !flow.dispatch.tensor<readonly:tensor<?x?xbf16, #iree_linalg_ext.encoding<MATMUL_BF16BF16F32_LHS>>>{%M, %K}
+ %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0)
+ : !flow.dispatch.tensor<readonly:tensor<?x?xbf16, #iree_linalg_ext.encoding<MATMUL_BF16BF16F32_RHS>>>{%K, %N}
+ %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0)
+ : !flow.dispatch.tensor<readwrite:tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_BF16BF16F32_RESULT>>>{%M, %N}
+ %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [%M, %K], strides = [1, 1]
+ : !flow.dispatch.tensor<readonly:tensor<?x?xbf16, #iree_linalg_ext.encoding<MATMUL_BF16BF16F32_LHS>>>{%M, %K}
+ -> tensor<?x?xbf16, #iree_linalg_ext.encoding<MATMUL_BF16BF16F32_LHS>>
+ %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [%K, %N], strides = [1, 1]
+ : !flow.dispatch.tensor<readonly:tensor<?x?xbf16, #iree_linalg_ext.encoding<MATMUL_BF16BF16F32_RHS>>>{%K, %N}
+ -> tensor<?x?xbf16, #iree_linalg_ext.encoding<MATMUL_BF16BF16F32_RHS>>
+ %5 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1]
+ : !flow.dispatch.tensor<readwrite:tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_BF16BF16F32_RESULT>>>{%M, %N}
+ -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_BF16BF16F32_RESULT>>
+ %6 = linalg.matmul
+ ins(%3, %4 : tensor<?x?xbf16, #iree_linalg_ext.encoding<MATMUL_BF16BF16F32_LHS>>,
+ tensor<?x?xbf16, #iree_linalg_ext.encoding<MATMUL_BF16BF16F32_RHS>>)
+ outs(%5 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_BF16BF16F32_RESULT>>)
+ -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_BF16BF16F32_RESULT>>
+ flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1]
+ : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_BF16BF16F32_RESULT>>
+ -> !flow.dispatch.tensor<readwrite:tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_BF16BF16F32_RESULT>>>{%M, %N}
+ return
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 16)>
+// CHECK: func @matmul_lowering_bf16bf16f32_x86_64_avx512f()
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[M:.+]] = hal.interface.constant.load[0]
+// CHECK-DAG: %[[N:.+]] = hal.interface.constant.load[1]
+// CHECK-DAG: %[[K:.+]] = hal.interface.constant.load[2]
+// CHECK-DAG: %[[TILED_M:.+]] = affine.apply #[[MAP0]]()[%[[M]]]
+// CHECK: %[[LHS_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(0)
+// CHECK-SAME: !flow.dispatch.tensor<readonly:tensor<?x?x16x1xbf16>>{%[[TILED_M]], %[[K]]}
+// CHECK: %[[TILED_N:.+]] = affine.apply #[[MAP0]]()[%[[N]]]
+// CHECK: %[[RHS_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(1)
+// CHECK-SAME: !flow.dispatch.tensor<readonly:tensor<?x?x16x1xbf16>>{%[[TILED_N]], %[[K]]}
+// CHECK: %[[OUTS_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(2)
+// CHECK-SAME: !flow.dispatch.tensor<readwrite:tensor<?x?x16x16xf32>>{%[[TILED_M]], %[[TILED_N]]}
+// CHECK: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]
+// CHECK-SAME: offsets = [0, 0, 0, 0], sizes = [%[[TILED_M]], %[[K]], 16, 1], strides = [1, 1, 1, 1]
+// CHECK: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]
+// CHECK-SAME: offsets = [0, 0, 0, 0], sizes = [%[[TILED_N]], %[[K]], 16, 1], strides = [1, 1, 1, 1]
+// CHECK: %[[OUTS:.+]] = flow.dispatch.tensor.load %[[OUTS_BINDING]]
+// CHECK-SAME: offsets = [0, 0, 0, 0], sizes = [%[[TILED_M]], %[[TILED_N]], 16, 16], strides = [1, 1, 1, 1]
+// CHECK: %[[MMT4D:.+]] = linalg.mmt4d
+// CHECK-SAME: ins(%[[LHS]], %[[RHS]] :
+// CHECK-SAME: outs(%[[OUTS]] :
+// CHECK: flow.dispatch.tensor.store %[[MMT4D]], %[[OUTS_BINDING]]
+// CHECK-SAME: offsets = [0, 0, 0, 0], sizes = [%[[TILED_M]], %[[TILED_N]], 16, 16], strides = [1, 1, 1, 1]
+
+// -----
+
+func.func @matmul_lowering_bf16bf16bf16_x86_64_avx512f() attributes {
+ hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512f"}>
+} {
+ %c0 = arith.constant 0 : index
+ %M = hal.interface.constant.load[0] : index
+ %N = hal.interface.constant.load[1] : index
+ %K = hal.interface.constant.load[2] : index
+ %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0)
+ : !flow.dispatch.tensor<readonly:tensor<?x?xbf16, #iree_linalg_ext.encoding<MATMUL_BF16BF16BF16_LHS>>>{%M, %K}
+ %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0)
+ : !flow.dispatch.tensor<readonly:tensor<?x?xbf16, #iree_linalg_ext.encoding<MATMUL_BF16BF16BF16_RHS>>>{%K, %N}
+ %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0)
+ : !flow.dispatch.tensor<readwrite:tensor<?x?xbf16, #iree_linalg_ext.encoding<MATMUL_BF16BF16BF16_RESULT>>>{%M, %N}
+ %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [%M, %K], strides = [1, 1]
+ : !flow.dispatch.tensor<readonly:tensor<?x?xbf16, #iree_linalg_ext.encoding<MATMUL_BF16BF16BF16_LHS>>>{%M, %K}
+ -> tensor<?x?xbf16, #iree_linalg_ext.encoding<MATMUL_BF16BF16BF16_LHS>>
+ %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [%K, %N], strides = [1, 1]
+ : !flow.dispatch.tensor<readonly:tensor<?x?xbf16, #iree_linalg_ext.encoding<MATMUL_BF16BF16BF16_RHS>>>{%K, %N}
+ -> tensor<?x?xbf16, #iree_linalg_ext.encoding<MATMUL_BF16BF16BF16_RHS>>
+ %5 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1]
+ : !flow.dispatch.tensor<readwrite:tensor<?x?xbf16, #iree_linalg_ext.encoding<MATMUL_BF16BF16BF16_RESULT>>>{%M, %N}
+ -> tensor<?x?xbf16, #iree_linalg_ext.encoding<MATMUL_BF16BF16BF16_RESULT>>
+ %6 = linalg.matmul
+ ins(%3, %4 : tensor<?x?xbf16, #iree_linalg_ext.encoding<MATMUL_BF16BF16BF16_LHS>>,
+ tensor<?x?xbf16, #iree_linalg_ext.encoding<MATMUL_BF16BF16BF16_RHS>>)
+ outs(%5 : tensor<?x?xbf16, #iree_linalg_ext.encoding<MATMUL_BF16BF16BF16_RESULT>>)
+ -> tensor<?x?xbf16, #iree_linalg_ext.encoding<MATMUL_BF16BF16BF16_RESULT>>
+ flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1]
+ : tensor<?x?xbf16, #iree_linalg_ext.encoding<MATMUL_BF16BF16BF16_RESULT>>
+ -> !flow.dispatch.tensor<readwrite:tensor<?x?xbf16, #iree_linalg_ext.encoding<MATMUL_BF16BF16BF16_RESULT>>>{%M, %N}
+ return
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 16)>
+// CHECK: func @matmul_lowering_bf16bf16bf16_x86_64_avx512f()
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[M:.+]] = hal.interface.constant.load[0]
+// CHECK-DAG: %[[N:.+]] = hal.interface.constant.load[1]
+// CHECK-DAG: %[[K:.+]] = hal.interface.constant.load[2]
+// CHECK-DAG: %[[TILED_M:.+]] = affine.apply #[[MAP0]]()[%[[M]]]
+// CHECK: %[[LHS_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(0)
+// CHECK-SAME: !flow.dispatch.tensor<readonly:tensor<?x?x16x1xbf16>>{%[[TILED_M]], %[[K]]}
+// CHECK: %[[TILED_N:.+]] = affine.apply #[[MAP0]]()[%[[N]]]
+// CHECK: %[[RHS_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(1)
+// CHECK-SAME: !flow.dispatch.tensor<readonly:tensor<?x?x16x1xbf16>>{%[[TILED_N]], %[[K]]}
+// CHECK: %[[OUTS_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(2)
+// CHECK-SAME: !flow.dispatch.tensor<readwrite:tensor<?x?x16x16xbf16>>{%[[TILED_M]], %[[TILED_N]]}
+// CHECK: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]
+// CHECK-SAME: offsets = [0, 0, 0, 0], sizes = [%[[TILED_M]], %[[K]], 16, 1], strides = [1, 1, 1, 1]
+// CHECK: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]
+// CHECK-SAME: offsets = [0, 0, 0, 0], sizes = [%[[TILED_N]], %[[K]], 16, 1], strides = [1, 1, 1, 1]
+// CHECK: %[[OUTS:.+]] = flow.dispatch.tensor.load %[[OUTS_BINDING]]
+// CHECK-SAME: offsets = [0, 0, 0, 0], sizes = [%[[TILED_M]], %[[TILED_N]], 16, 16], strides = [1, 1, 1, 1]
+// CHECK: %[[MMT4D:.+]] = linalg.mmt4d
+// CHECK-SAME: ins(%[[LHS]], %[[RHS]] :
+// CHECK-SAME: outs(%[[OUTS]] :
+// CHECK: flow.dispatch.tensor.store %[[MMT4D]], %[[OUTS_BINDING]]
+// CHECK-SAME: offsets = [0, 0, 0, 0], sizes = [%[[TILED_M]], %[[TILED_N]], 16, 16], strides = [1, 1, 1, 1]
+
+// -----
+
func.func @matmul_lowering_i8i8i32_aarch64() attributes {
hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="aarch64-xyz-xyz"}>
} {
diff --git a/compiler/src/iree/compiler/Codegen/Utils/EncodingUtils.cpp b/compiler/src/iree/compiler/Codegen/Utils/EncodingUtils.cpp
index 26ae6b4..b92223f 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/EncodingUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Utils/EncodingUtils.cpp
@@ -23,11 +23,26 @@
resultElementType.isSignlessInteger(32)) {
return MatmulType::I8I8I32;
}
-
if (lhsElementType.isF32() && rhsElementType.isF32() &&
resultElementType.isF32()) {
return MatmulType::F32F32F32;
}
+ if (lhsElementType.isF16() && rhsElementType.isF16() &&
+ resultElementType.isF32()) {
+ return MatmulType::F16F16F32;
+ }
+ if (lhsElementType.isF16() && rhsElementType.isF16() &&
+ resultElementType.isF16()) {
+ return MatmulType::F16F16F16;
+ }
+ if (lhsElementType.isBF16() && rhsElementType.isBF16() &&
+ resultElementType.isF32()) {
+ return MatmulType::BF16BF16F32;
+ }
+ if (lhsElementType.isBF16() && rhsElementType.isBF16() &&
+ resultElementType.isBF16()) {
+ return MatmulType::BF16BF16BF16;
+ }
return std::nullopt;
}
@@ -40,36 +55,50 @@
return encodingAttr.getEncoding().getValue();
}
+#define IREE_GETMATMULTYPE_CASE(TYPE) \
+ case TensorEncoding::MATMUL_##TYPE##_LHS: \
+ case TensorEncoding::MATMUL_##TYPE##_RHS: \
+ case TensorEncoding::MATMUL_##TYPE##_RESULT: \
+ return MatmulType::TYPE;
+
std::optional<MatmulType> getMatmulType(TensorEncoding encoding) {
switch (encoding) {
- case TensorEncoding::MATMUL_F32F32F32_LHS:
- case TensorEncoding::MATMUL_F32F32F32_RHS:
- case TensorEncoding::MATMUL_F32F32F32_RESULT:
- return MatmulType::F32F32F32;
- case TensorEncoding::MATMUL_I8I8I32_LHS:
- case TensorEncoding::MATMUL_I8I8I32_RHS:
- case TensorEncoding::MATMUL_I8I8I32_RESULT:
- return MatmulType::I8I8I32;
+ IREE_GETMATMULTYPE_CASE(F32F32F32)
+ IREE_GETMATMULTYPE_CASE(I8I8I32)
+ IREE_GETMATMULTYPE_CASE(F16F16F32)
+ IREE_GETMATMULTYPE_CASE(F16F16F16)
+ IREE_GETMATMULTYPE_CASE(BF16BF16F32)
+ IREE_GETMATMULTYPE_CASE(BF16BF16BF16)
default:
return std::nullopt;
}
}
+#undef IREE_GETMATMULTYPE_CASE
+
+#define IREE_GETMATMULOPERANDROLE_CASE(TYPE) \
+ case TensorEncoding::MATMUL_##TYPE##_LHS: \
+ return MatmulOperandRole::LHS; \
+ case TensorEncoding::MATMUL_##TYPE##_RHS: \
+ return MatmulOperandRole::RHS; \
+ case TensorEncoding::MATMUL_##TYPE##_RESULT: \
+ return MatmulOperandRole::RESULT;
+
std::optional<MatmulOperandRole> getMatmulOperandRole(TensorEncoding encoding) {
switch (encoding) {
- case TensorEncoding::MATMUL_F32F32F32_LHS:
- case TensorEncoding::MATMUL_I8I8I32_LHS:
- return MatmulOperandRole::LHS;
- case TensorEncoding::MATMUL_F32F32F32_RHS:
- case TensorEncoding::MATMUL_I8I8I32_RHS:
- return MatmulOperandRole::RHS;
- case TensorEncoding::MATMUL_F32F32F32_RESULT:
- case TensorEncoding::MATMUL_I8I8I32_RESULT:
- return MatmulOperandRole::RESULT;
+ IREE_GETMATMULOPERANDROLE_CASE(F32F32F32)
+ IREE_GETMATMULOPERANDROLE_CASE(I8I8I32)
+ IREE_GETMATMULOPERANDROLE_CASE(F16F16F32)
+ IREE_GETMATMULOPERANDROLE_CASE(F16F16F16)
+ IREE_GETMATMULOPERANDROLE_CASE(BF16BF16F32)
+ IREE_GETMATMULOPERANDROLE_CASE(BF16BF16BF16)
+
default:
return std::nullopt;
}
}
+#undef IREE_GETMATMULOPERANDROLE_CASE
+
} // namespace iree_compiler
} // namespace mlir
diff --git a/compiler/src/iree/compiler/Codegen/Utils/EncodingUtils.h b/compiler/src/iree/compiler/Codegen/Utils/EncodingUtils.h
index 1eb914b..22cc758 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/EncodingUtils.h
+++ b/compiler/src/iree/compiler/Codegen/Utils/EncodingUtils.h
@@ -18,6 +18,10 @@
enum class MatmulType {
F32F32F32,
I8I8I32,
+ F16F16F32,
+ F16F16F16,
+ BF16BF16F32,
+ BF16BF16BF16
};
// Enumeration of the operands of a matmul-like operation such as linalg.matmul.
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/SetEncoding.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/SetEncoding.cpp
index f9808ae..fb8fc94 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/SetEncoding.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/SetEncoding.cpp
@@ -160,6 +160,26 @@
lhsEncoding = TensorEncoding::MATMUL_F32F32F32_LHS;
rhsEncoding = TensorEncoding::MATMUL_F32F32F32_RHS;
outEncoding = TensorEncoding::MATMUL_F32F32F32_RESULT;
+ } else if (lhsElemType.isF16() && rhsElemType.isF16() &&
+ outElemType.isF32()) {
+ lhsEncoding = TensorEncoding::MATMUL_F16F16F32_LHS;
+ rhsEncoding = TensorEncoding::MATMUL_F16F16F32_RHS;
+ outEncoding = TensorEncoding::MATMUL_F16F16F32_RESULT;
+ } else if (lhsElemType.isF16() && rhsElemType.isF16() &&
+ outElemType.isF16()) {
+ lhsEncoding = TensorEncoding::MATMUL_F16F16F16_LHS;
+ rhsEncoding = TensorEncoding::MATMUL_F16F16F16_RHS;
+ outEncoding = TensorEncoding::MATMUL_F16F16F16_RESULT;
+ } else if (lhsElemType.isBF16() && rhsElemType.isBF16() &&
+ outElemType.isF32()) {
+ lhsEncoding = TensorEncoding::MATMUL_BF16BF16F32_LHS;
+ rhsEncoding = TensorEncoding::MATMUL_BF16BF16F32_RHS;
+ outEncoding = TensorEncoding::MATMUL_BF16BF16F32_RESULT;
+ } else if (lhsElemType.isBF16() && rhsElemType.isBF16() &&
+ outElemType.isBF16()) {
+ lhsEncoding = TensorEncoding::MATMUL_BF16BF16BF16_LHS;
+ rhsEncoding = TensorEncoding::MATMUL_BF16BF16BF16_RHS;
+ outEncoding = TensorEncoding::MATMUL_BF16BF16BF16_RESULT;
} else if (lhsElemType.isSignlessInteger(8) &&
rhsElemType.isSignlessInteger(8) &&
outElemType.isSignlessInteger(32)) {
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/set_encoding.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/set_encoding.mlir
index ecde468..92de5a8 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/set_encoding.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/set_encoding.mlir
@@ -74,7 +74,6 @@
// PADDING: %[[RESULT:.+]] = iree_linalg_ext.unset_encoding %[[MATMUL]]
// PADDING: return %[[RESULT]]
-
// -----
func.func @matmul_dynamic(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
@@ -120,6 +119,126 @@
// -----
+func.func @matmul_i8i8i32(%arg0 : tensor<128x256xi8>, %arg1 : tensor<256x512xi8>,
+ %arg2 : tensor<128x512xi32>) -> tensor<128x512xi32> {
+ %0 = linalg.matmul ins(%arg0, %arg1 : tensor<128x256xi8>, tensor<256x512xi8>)
+ outs(%arg2 : tensor<128x512xi32>) -> tensor<128x512xi32>
+ return %0 : tensor<128x512xi32>
+}
+// CHECK: func @matmul_i8i8i32(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<128x256xi8>
+// CHECK-SAME: %[[ARG1:.+]]: tensor<256x512xi8>
+// CHECK-SAME: %[[ARG2:.+]]: tensor<128x512xi32>
+// CHECK: %[[LHS:.+]] = iree_linalg_ext.set_encoding %[[ARG0]]
+// CHECK-SAME: tensor<128x256xi8, #iree_linalg_ext.encoding<MATMUL_I8I8I32_LHS>>
+// CHECK: %[[RHS:.+]] = iree_linalg_ext.set_encoding %[[ARG1]]
+// CHECK-SAME: tensor<256x512xi8, #iree_linalg_ext.encoding<MATMUL_I8I8I32_RHS>>
+// CHECK: %[[OUTS:.+]] = iree_linalg_ext.set_encoding %[[ARG2]]
+// CHECK-SAME: tensor<128x512xi32, #iree_linalg_ext.encoding<MATMUL_I8I8I32_RESULT>>
+// CHECK: %[[MATMUL:.+]] = linalg.matmul
+// CHECK-SAME: ins(%[[LHS]], %[[RHS]] :
+// CHECK-SAME: outs(%[[OUTS]] :
+// CHECK: %[[RESULT:.+]] = iree_linalg_ext.unset_encoding %[[MATMUL]]
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @matmul_f16f16f32(%arg0 : tensor<128x256xf16>, %arg1 : tensor<256x512xf16>,
+ %arg2 : tensor<128x512xf32>) -> tensor<128x512xf32> {
+ %0 = linalg.matmul ins(%arg0, %arg1 : tensor<128x256xf16>, tensor<256x512xf16>)
+ outs(%arg2 : tensor<128x512xf32>) -> tensor<128x512xf32>
+ return %0 : tensor<128x512xf32>
+}
+// CHECK: func @matmul_f16f16f32(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<128x256xf16>
+// CHECK-SAME: %[[ARG1:.+]]: tensor<256x512xf16>
+// CHECK-SAME: %[[ARG2:.+]]: tensor<128x512xf32>
+// CHECK: %[[LHS:.+]] = iree_linalg_ext.set_encoding %[[ARG0]]
+// CHECK-SAME: tensor<128x256xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F32_LHS>>
+// CHECK: %[[RHS:.+]] = iree_linalg_ext.set_encoding %[[ARG1]]
+// CHECK-SAME: tensor<256x512xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F32_RHS>>
+// CHECK: %[[OUTS:.+]] = iree_linalg_ext.set_encoding %[[ARG2]]
+// CHECK-SAME: tensor<128x512xf32, #iree_linalg_ext.encoding<MATMUL_F16F16F32_RESULT>>
+// CHECK: %[[MATMUL:.+]] = linalg.matmul
+// CHECK-SAME: ins(%[[LHS]], %[[RHS]] :
+// CHECK-SAME: outs(%[[OUTS]] :
+// CHECK: %[[RESULT:.+]] = iree_linalg_ext.unset_encoding %[[MATMUL]]
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @matmul_f16f16f16(%arg0 : tensor<128x256xf16>, %arg1 : tensor<256x512xf16>,
+ %arg2 : tensor<128x512xf16>) -> tensor<128x512xf16> {
+ %0 = linalg.matmul ins(%arg0, %arg1 : tensor<128x256xf16>, tensor<256x512xf16>)
+ outs(%arg2 : tensor<128x512xf16>) -> tensor<128x512xf16>
+ return %0 : tensor<128x512xf16>
+}
+// CHECK: func @matmul_f16f16f16(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<128x256xf16>
+// CHECK-SAME: %[[ARG1:.+]]: tensor<256x512xf16>
+// CHECK-SAME: %[[ARG2:.+]]: tensor<128x512xf16>
+// CHECK: %[[LHS:.+]] = iree_linalg_ext.set_encoding %[[ARG0]]
+// CHECK-SAME: tensor<128x256xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_LHS>>
+// CHECK: %[[RHS:.+]] = iree_linalg_ext.set_encoding %[[ARG1]]
+// CHECK-SAME: tensor<256x512xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_RHS>>
+// CHECK: %[[OUTS:.+]] = iree_linalg_ext.set_encoding %[[ARG2]]
+// CHECK-SAME: tensor<128x512xf16, #iree_linalg_ext.encoding<MATMUL_F16F16F16_RESULT>>
+// CHECK: %[[MATMUL:.+]] = linalg.matmul
+// CHECK-SAME: ins(%[[LHS]], %[[RHS]] :
+// CHECK-SAME: outs(%[[OUTS]] :
+// CHECK: %[[RESULT:.+]] = iree_linalg_ext.unset_encoding %[[MATMUL]]
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @matmul_bf16bf16f32(%arg0 : tensor<128x256xbf16>, %arg1 : tensor<256x512xbf16>,
+ %arg2 : tensor<128x512xf32>) -> tensor<128x512xf32> {
+ %0 = linalg.matmul ins(%arg0, %arg1 : tensor<128x256xbf16>, tensor<256x512xbf16>)
+ outs(%arg2 : tensor<128x512xf32>) -> tensor<128x512xf32>
+ return %0 : tensor<128x512xf32>
+}
+// CHECK: func @matmul_bf16bf16f32(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<128x256xbf16>
+// CHECK-SAME: %[[ARG1:.+]]: tensor<256x512xbf16>
+// CHECK-SAME: %[[ARG2:.+]]: tensor<128x512xf32>
+// CHECK: %[[LHS:.+]] = iree_linalg_ext.set_encoding %[[ARG0]]
+// CHECK-SAME: tensor<128x256xbf16, #iree_linalg_ext.encoding<MATMUL_BF16BF16F32_LHS>>
+// CHECK: %[[RHS:.+]] = iree_linalg_ext.set_encoding %[[ARG1]]
+// CHECK-SAME: tensor<256x512xbf16, #iree_linalg_ext.encoding<MATMUL_BF16BF16F32_RHS>>
+// CHECK: %[[OUTS:.+]] = iree_linalg_ext.set_encoding %[[ARG2]]
+// CHECK-SAME: tensor<128x512xf32, #iree_linalg_ext.encoding<MATMUL_BF16BF16F32_RESULT>>
+// CHECK: %[[MATMUL:.+]] = linalg.matmul
+// CHECK-SAME: ins(%[[LHS]], %[[RHS]] :
+// CHECK-SAME: outs(%[[OUTS]] :
+// CHECK: %[[RESULT:.+]] = iree_linalg_ext.unset_encoding %[[MATMUL]]
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @matmul_bf16bf16bf16(%arg0 : tensor<128x256xbf16>, %arg1 : tensor<256x512xbf16>,
+ %arg2 : tensor<128x512xbf16>) -> tensor<128x512xbf16> {
+ %0 = linalg.matmul ins(%arg0, %arg1 : tensor<128x256xbf16>, tensor<256x512xbf16>)
+ outs(%arg2 : tensor<128x512xbf16>) -> tensor<128x512xbf16>
+ return %0 : tensor<128x512xbf16>
+}
+// CHECK: func @matmul_bf16bf16bf16(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<128x256xbf16>
+// CHECK-SAME: %[[ARG1:.+]]: tensor<256x512xbf16>
+// CHECK-SAME: %[[ARG2:.+]]: tensor<128x512xbf16>
+// CHECK: %[[LHS:.+]] = iree_linalg_ext.set_encoding %[[ARG0]]
+// CHECK-SAME: tensor<128x256xbf16, #iree_linalg_ext.encoding<MATMUL_BF16BF16BF16_LHS>>
+// CHECK: %[[RHS:.+]] = iree_linalg_ext.set_encoding %[[ARG1]]
+// CHECK-SAME: tensor<256x512xbf16, #iree_linalg_ext.encoding<MATMUL_BF16BF16BF16_RHS>>
+// CHECK: %[[OUTS:.+]] = iree_linalg_ext.set_encoding %[[ARG2]]
+// CHECK-SAME: tensor<128x512xbf16, #iree_linalg_ext.encoding<MATMUL_BF16BF16BF16_RESULT>>
+// CHECK: %[[MATMUL:.+]] = linalg.matmul
+// CHECK-SAME: ins(%[[LHS]], %[[RHS]] :
+// CHECK-SAME: outs(%[[OUTS]] :
+// CHECK: %[[RESULT:.+]] = iree_linalg_ext.unset_encoding %[[MATMUL]]
+// CHECK: return %[[RESULT]]
+
+// -----
+
func.func @fold_fill_with_set_encoding(%arg0 : index, %arg1 : index)
-> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>> {
%cst = arith.constant 0.0 : f32
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtBase.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtBase.td
index fc228c1..68f1aec 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtBase.td
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtBase.td
@@ -57,12 +57,40 @@
: I32EnumAttrCase<"MATMUL_I8I8I32_RHS", 4>;
def MATMUL_I8I8I32_RESULT
: I32EnumAttrCase<"MATMUL_I8I8I32_RESULT", 5>;
+def MATMUL_F16F16F32_LHS
+ : I32EnumAttrCase<"MATMUL_F16F16F32_LHS", 6>;
+def MATMUL_F16F16F32_RHS
+ : I32EnumAttrCase<"MATMUL_F16F16F32_RHS", 7>;
+def MATMUL_F16F16F32_RESULT
+ : I32EnumAttrCase<"MATMUL_F16F16F32_RESULT", 8>;
+def MATMUL_F16F16F16_LHS
+ : I32EnumAttrCase<"MATMUL_F16F16F16_LHS", 9>;
+def MATMUL_F16F16F16_RHS
+ : I32EnumAttrCase<"MATMUL_F16F16F16_RHS", 10>;
+def MATMUL_F16F16F16_RESULT
+ : I32EnumAttrCase<"MATMUL_F16F16F16_RESULT", 11>;
+def MATMUL_BF16BF16F32_LHS
+ : I32EnumAttrCase<"MATMUL_BF16BF16F32_LHS", 12>;
+def MATMUL_BF16BF16F32_RHS
+ : I32EnumAttrCase<"MATMUL_BF16BF16F32_RHS", 13>;
+def MATMUL_BF16BF16F32_RESULT
+ : I32EnumAttrCase<"MATMUL_BF16BF16F32_RESULT", 14>;
+def MATMUL_BF16BF16BF16_LHS
+ : I32EnumAttrCase<"MATMUL_BF16BF16BF16_LHS", 15>;
+def MATMUL_BF16BF16BF16_RHS
+ : I32EnumAttrCase<"MATMUL_BF16BF16BF16_RHS", 16>;
+def MATMUL_BF16BF16BF16_RESULT
+ : I32EnumAttrCase<"MATMUL_BF16BF16BF16_RESULT", 17>;
def TensorEncodingEnum
: I32EnumAttr<"TensorEncoding",
"identifier for encoding used for the tensor",[
MATMUL_F32F32F32_LHS, MATMUL_F32F32F32_RHS, MATMUL_F32F32F32_RESULT,
MATMUL_I8I8I32_LHS, MATMUL_I8I8I32_RHS, MATMUL_I8I8I32_RESULT,
+ MATMUL_F16F16F32_LHS, MATMUL_F16F16F32_RHS, MATMUL_F16F16F32_RESULT,
+ MATMUL_F16F16F16_LHS, MATMUL_F16F16F16_RHS, MATMUL_F16F16F16_RESULT,
+ MATMUL_BF16BF16F32_LHS, MATMUL_BF16BF16F32_RHS, MATMUL_BF16BF16F32_RESULT,
+ MATMUL_BF16BF16BF16_LHS, MATMUL_BF16BF16BF16_RHS, MATMUL_BF16BF16BF16_RESULT,
]> {
let cppNamespace = "::mlir::iree_compiler::IREE::LinalgExt";
let genSpecializedAttr = 0;
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp
index 86ec963..53db781 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp
@@ -203,15 +203,7 @@
getEncoding(inputs[1]->get().getType().cast<RankedTensorType>());
std::optional<TensorEncoding> resultEncoding =
getEncoding(outputs[0]->get().getType().cast<RankedTensorType>());
- if (!lhsEncoding ||
- (lhsEncoding.value() != TensorEncoding::MATMUL_F32F32F32_LHS &&
- lhsEncoding.value() != TensorEncoding::MATMUL_I8I8I32_LHS) ||
- !rhsEncoding ||
- (rhsEncoding.value() != TensorEncoding::MATMUL_F32F32F32_RHS &&
- rhsEncoding.value() != TensorEncoding::MATMUL_I8I8I32_RHS) ||
- !resultEncoding ||
- (resultEncoding.value() != TensorEncoding::MATMUL_F32F32F32_RESULT &&
- resultEncoding.value() != TensorEncoding::MATMUL_I8I8I32_RESULT)) {
+ if (!lhsEncoding || !rhsEncoding || !resultEncoding) {
return failure();
}
Operation *mmt4DOp = rewriter.create<linalg::Mmt4DOp>(