blob: d5324388867639e931358e70e09730c74450459f [file]
// Test RDNA3 WMMA f16/bf16 accumulator correctness for matmul.
//
// Regression test for a bug where WMMAR3_F16_16x16x16_F16 accumulator layout
// on gfx1100 (RDNA3) was incorrect: outer={16,1} claimed 16 valid elements per
// lane but v_wmma_f16_16x16x16_f16 only produces 8 valid results at even
// indices. The fix changes the layout to outer={8,1}, thread={2,16}.
// See amdgpu.wmma op docs in AMDGPUOps.td for the gfx11 subwordOffset layout.
//
// Uses LHS * I = LHS (matmul with identity) so each output element has a
// unique expected value. LHS values are unique per element (row and column
// dependent) to catch both row-swap and column-permutation bugs.
//
// The optimization_barrier prevents fusion of data generation with the matmul,
// ensuring the matmul gets its own dispatch and triggers the WMMA codegen path
// via vector distribution on gfx1100.
func.func @wmma_matmul_f16_identity() {
// Build LHS with unique values: element [i,j] = (i * 128 + j) / 16384.
// Range [0, ~1.0] stays within f16 precision.
%empty_lhs = tensor.empty() : tensor<128x128xf16>
%c128 = arith.constant 128.0 : f32
%c16384 = arith.constant 16384.0 : f32
%lhs_gen = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]
} outs(%empty_lhs : tensor<128x128xf16>) {
^bb0(%out: f16):
%i = linalg.index 0 : index
%j = linalg.index 1 : index
%i_i32 = arith.index_cast %i : index to i32
%j_i32 = arith.index_cast %j : index to i32
%i_f32 = arith.sitofp %i_i32 : i32 to f32
%j_f32 = arith.sitofp %j_i32 : i32 to f32
%row = arith.mulf %i_f32, %c128 : f32
%linear = arith.addf %row, %j_f32 : f32
%val_f32 = arith.divf %linear, %c16384 : f32
%val = arith.truncf %val_f32 : f32 to f16
linalg.yield %val : f16
} -> tensor<128x128xf16>
// Build 128x128 identity matrix.
%empty_rhs = tensor.empty() : tensor<128x128xf16>
%zero_f16 = arith.constant 0.0 : f16
%one_f16 = arith.constant 1.0 : f16
%rhs_gen = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]
} outs(%empty_rhs : tensor<128x128xf16>) {
^bb0(%out: f16):
%r = linalg.index 0 : index
%c = linalg.index 1 : index
%eq = arith.cmpi eq, %r, %c : index
%val = arith.select %eq, %one_f16, %zero_f16 : f16
linalg.yield %val : f16
} -> tensor<128x128xf16>
// Prevent fusion of data generation with the matmul so the matmul gets
// its own dispatch and the WMMA codegen path is selected.
%lhs = util.optimization_barrier %lhs_gen : tensor<128x128xf16>
%rhs = util.optimization_barrier %rhs_gen : tensor<128x128xf16>
// Matmul: LHS * I = LHS
%cst = arith.constant 0.0 : f16
%empty_out = tensor.empty() : tensor<128x128xf16>
%fill = linalg.fill ins(%cst : f16) outs(%empty_out : tensor<128x128xf16>) -> tensor<128x128xf16>
%result = linalg.matmul ins(%lhs, %rhs : tensor<128x128xf16>, tensor<128x128xf16>)
outs(%fill : tensor<128x128xf16>) -> tensor<128x128xf16>
check.expect_almost_eq(%result, %lhs) : tensor<128x128xf16>
return
}
func.func @wmma_matmul_bf16_identity() {
// Same test for bf16 to cover WMMAR3_BF16_16x16x16_BF16.
%empty_lhs = tensor.empty() : tensor<128x128xbf16>
%c128 = arith.constant 128.0 : f32
%c16384 = arith.constant 16384.0 : f32
%lhs_gen = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]
} outs(%empty_lhs : tensor<128x128xbf16>) {
^bb0(%out: bf16):
%i = linalg.index 0 : index
%j = linalg.index 1 : index
%i_i32 = arith.index_cast %i : index to i32
%j_i32 = arith.index_cast %j : index to i32
%i_f32 = arith.sitofp %i_i32 : i32 to f32
%j_f32 = arith.sitofp %j_i32 : i32 to f32
%row = arith.mulf %i_f32, %c128 : f32
%linear = arith.addf %row, %j_f32 : f32
%val_f32 = arith.divf %linear, %c16384 : f32
%val = arith.truncf %val_f32 : f32 to bf16
linalg.yield %val : bf16
} -> tensor<128x128xbf16>
%empty_rhs = tensor.empty() : tensor<128x128xbf16>
%zero_bf16 = arith.constant 0.0 : bf16
%one_bf16 = arith.constant 1.0 : bf16
%rhs_gen = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]
} outs(%empty_rhs : tensor<128x128xbf16>) {
^bb0(%out: bf16):
%r = linalg.index 0 : index
%c = linalg.index 1 : index
%eq = arith.cmpi eq, %r, %c : index
%val = arith.select %eq, %one_bf16, %zero_bf16 : bf16
linalg.yield %val : bf16
} -> tensor<128x128xbf16>
%lhs = util.optimization_barrier %lhs_gen : tensor<128x128xbf16>
%rhs = util.optimization_barrier %rhs_gen : tensor<128x128xbf16>
%cst = arith.constant 0.0 : bf16
%empty_out = tensor.empty() : tensor<128x128xbf16>
%fill = linalg.fill ins(%cst : bf16) outs(%empty_out : tensor<128x128xbf16>) -> tensor<128x128xbf16>
%result = linalg.matmul ins(%lhs, %rhs : tensor<128x128xbf16>, tensor<128x128xbf16>)
outs(%fill : tensor<128x128xbf16>) -> tensor<128x128xbf16>
check.expect_almost_eq(%result, %lhs) : tensor<128x128xbf16>
return
}