blob: 92d5cf85b8a52d707cd9ab89f49bbe0d4be26b1b [file] [log] [blame]
// RUN: iree-opt %s --split-input-file --iree-transform-dialect-interpreter | FileCheck %s
#matmat_accesses = [
affine_map<(m, n, k) -> (m, k)>,
affine_map<(m, n, k) -> (k, n)>,
affine_map<(m, n, k) -> (m, n)>
]
#matmat_trait = {
indexing_maps = #matmat_accesses,
iterator_types = ["parallel", "parallel", "reduction"]
}
func.func @wmma(%a: memref<16x16xf32>, %b: memref<16x16xf32>, %c: memref<16x16xf32>) {
%c0 = arith.constant 0: index
%cst = arith.constant 0.0: f32
%va = vector.transfer_read %a[%c0, %c0], %cst: memref<16x16xf32>, vector<16x16xf32>
%vb = vector.transfer_read %b[%c0, %c0], %cst: memref<16x16xf32>, vector<16x16xf32>
%vc = vector.transfer_read %c[%c0, %c0], %cst: memref<16x16xf32>, vector<16x16xf32>
// CHECK-NOT: vector.contract
// CHECK: gpu.subgroup_mma_compute
%vres = vector.contract #matmat_trait %va, %vb, %vc
: vector<16x16xf32>, vector<16x16xf32> into vector<16x16xf32>
vector.transfer_write %vres, %c[%c0, %c0]: vector<16x16xf32>, memref<16x16xf32>
return
}
transform.sequence failures(propagate) {
^bb1(%module: !transform.any_op):
%func = transform.structured.match ops{["func.func"]} in %module
: (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.iree.unroll_vectors_gpu_wmma_sync
} : !transform.any_op
transform.iree.vector.vector_to_mma_conversion %func { use_wmma } : (!transform.any_op) -> ()
// Apply canonicalization post-hoc to trigger DCE and pass the test
// (i.e. all vector.contract are dead).
// TODO: consider having the vector_to_mma_conversion do the DCE automatically.
transform.apply_patterns to %func {
transform.apply_patterns.canonicalization
} : !transform.any_op
}
// -----
#matmat_accesses = [
affine_map<(m, n, k) -> (m, k)>,
affine_map<(m, n, k) -> (k, n)>,
affine_map<(m, n, k) -> (m, n)>
]
#matmat_trait = {
indexing_maps = #matmat_accesses,
iterator_types = ["parallel", "parallel", "reduction"]
}
func.func @mma_sync(%a: memref<16x16xf32>, %b: memref<16x16xf32>, %c: memref<16x16xf32>) {
%c0 = arith.constant 0: index
%cst = arith.constant 0.0: f32
%va = vector.transfer_read %a[%c0, %c0], %cst: memref<16x16xf32>, vector<16x16xf32>
%vb = vector.transfer_read %b[%c0, %c0], %cst: memref<16x16xf32>, vector<16x16xf32>
%vc = vector.transfer_read %c[%c0, %c0], %cst: memref<16x16xf32>, vector<16x16xf32>
// CHECK-NOT: vector.contract
// CHECK: nvgpu.mma.sync{{.*}} tf32Enabled}
%vres = vector.contract #matmat_trait %va, %vb, %vc
: vector<16x16xf32>, vector<16x16xf32> into vector<16x16xf32>
vector.transfer_write %vres, %c[%c0, %c0]: vector<16x16xf32>, memref<16x16xf32>
return
}
transform.sequence failures(propagate) {
^bb1(%module: !transform.any_op):
%func = transform.structured.match ops{["func.func"]} in %module
: (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.iree.unroll_vectors_gpu_mma_sync
} : !transform.any_op
transform.iree.vector.vector_to_mma_conversion %func { use_mma_sync } : (!transform.any_op) -> ()
// Apply canonicalization post-hoc to trigger DCE and pass the test
// (i.e. all vector.contract are dead).
// TODO: consider having the vector_to_mma_conversion do the DCE automatically.
transform.apply_patterns to %func {
transform.apply_patterns.canonicalization
} : !transform.any_op
}