| module attributes { transform.with_named_sequence } { |
| |
| // Codegen. |
| transform.named_sequence @codegen( |
| %variant_op: !transform.any_op {transform.consumed}) { |
| |
| // Get attention op |
| // ========================================== |
| %attention = transform.structured.match ops{["iree_linalg_ext.attention"]} in %variant_op : (!transform.any_op) -> !transform.any_op |
| |
| // Tile and distribute to workgroups |
| // ========================================== |
| %tiled_attention, %forall_grid = |
| transform.structured.tile_using_forall %attention num_threads [1] |
| ( mapping = [#gpu.block<x>] ) |
| : (!transform.any_op) -> (!transform.any_op, !transform.any_op) |
| transform.iree.populate_workgroup_count_region_using_num_threads_slice %forall_grid |
| : (!transform.any_op) -> () |
| |
| // Tile and decompose attention |
| // ========================================== |
| %attention4 = transform.structured.match ops{["iree_linalg_ext.attention"]} in %variant_op : (!transform.any_op) -> !transform.any_op |
| %acc_fill, %max_fill, %sum_fill, %inner_loop, %final_scaling, %blocked_attention = transform.tile_attention %attention4 : |
| (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) |
| %fill_op, %first_matmul, %reduce_max, %partial_softmax, %scale_factor, %update, %reduce_sum, %scale_acc, %second_matmul |
| = transform.decompose_tiled_attention %blocked_attention : |
| (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) |
| |
| // Vectorize function |
| // ========================================== |
| %func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op |
| transform.apply_patterns to %func { |
| transform.apply_patterns.iree.fold_reshape_into_tensor_hal_interface |
| transform.apply_patterns.linalg.fold_unit_extent_dims_via_slices |
| transform.apply_patterns.vector.cast_away_vector_leading_one_dim |
| } : !transform.any_op |
| %func_3 = transform.structured.vectorize_children_and_apply_patterns %func : (!transform.any_op) -> !transform.any_op |
| transform.apply_patterns to %func_3 { |
| transform.apply_patterns.iree.fold_fill_into_pad |
| transform.apply_patterns.linalg.tiling_canonicalization |
| transform.apply_patterns.scf.for_loop_canonicalization |
| transform.apply_patterns.canonicalization |
| } : !transform.any_op |
| transform.iree.apply_licm %variant_op : !transform.any_op |
| transform.apply_cse to %variant_op : !transform.any_op |
| |
| // Bufferization |
| // ========================================== |
| transform.iree.eliminate_empty_tensors %variant_op : (!transform.any_op) -> () |
| transform.apply_patterns to %func_3 { |
| transform.apply_patterns.linalg.erase_unnecessary_inputs |
| } : !transform.any_op |
| %variant_op_3 = transform.iree.bufferize %variant_op : (!transform.any_op) -> (!transform.any_op) |
| |
| // Step 6. Post-bufferization vector distribution |
| // =========================================================================== |
| %func_7 = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!transform.any_op) -> !transform.any_op |
| transform.iree.forall_to_workgroup %func_7 : (!transform.any_op) -> () |
| %func_8 = transform.structured.hoist_redundant_vector_transfers %func_7 |
| : (!transform.any_op) -> !transform.any_op |
| transform.apply_patterns to %func_8 { |
| transform.apply_patterns.canonicalization |
| } : !transform.any_op |
| transform.apply_cse to %func_8 : !transform.any_op |
| transform.memref.erase_dead_alloc_and_stores %func_8 : (!transform.any_op) -> () |
| transform.yield |
| } // codegen |
| |
| // Find `hal.executable.variant`. |
| transform.named_sequence @match_variant_for_codegen(%root: !transform.any_op {transform.readonly}) |
| -> !transform.any_op { |
| transform.match.operation_name %root ["hal.executable.variant"] : !transform.any_op |
| transform.yield %root : !transform.any_op |
| } |
| |
| // Transform entry-point |
| transform.named_sequence @__transform_main(%root: !transform.any_op {transform.consumed}) { |
| transform.foreach_match in %root |
| @match_variant_for_codegen -> @codegen |
| : (!transform.any_op) -> (!transform.any_op) |
| transform.yield |
| } |
| } // module |
| |