[VectorExt] Add custom parser/printer to elide identity orderings (#16522)
The attribute as is is quite unwieldy. Add a custom parser/printer for
the orders at least to allow eliding identity ones, which is the most
common case. We can revisit the parser in the future to see if something
similar would be worth doing for the sizes as well.
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_contract_amdgpu.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_contract_amdgpu.mlir
index 8ee8923..e3f963c 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_contract_amdgpu.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_contract_amdgpu.mlir
@@ -14,11 +14,7 @@
threads_per_outer = [32, 2],
elements_per_thread = [1, 4],
- subgroup_order = [0, 1],
- batch_order = [0, 1],
- outer_order = [0, 1],
thread_order = [1, 0],
- element_order = [0, 1],
subgroup_basis = [1, 1],
thread_basis = [32, 2]
@@ -32,10 +28,6 @@
threads_per_outer = [2, 32],
elements_per_thread = [4, 1],
- subgroup_order = [0, 1],
- batch_order = [0, 1],
- outer_order = [0, 1],
- thread_order = [0, 1],
element_order = [1, 0],
subgroup_basis = [1, 1],
@@ -50,10 +42,6 @@
threads_per_outer = [2, 32],
elements_per_thread = [4, 1],
- subgroup_order = [0, 1],
- batch_order = [0, 1],
- outer_order = [0, 1],
- thread_order = [0, 1],
element_order = [1, 0],
subgroup_basis = [1, 1],
@@ -118,11 +106,7 @@
threads_per_outer = [16, 4],
elements_per_thread = [1, 4],
- subgroup_order = [0, 1],
- batch_order = [0, 1],
- outer_order = [0, 1],
thread_order = [1, 0],
- element_order = [0, 1],
subgroup_basis = [1, 1],
thread_basis = [16, 4]
@@ -136,10 +120,6 @@
threads_per_outer = [4, 16],
elements_per_thread = [4, 1],
- subgroup_order = [0, 1],
- batch_order = [0, 1],
- outer_order = [0, 1],
- thread_order = [0, 1],
element_order = [1, 0],
subgroup_basis = [1, 1],
@@ -206,11 +186,7 @@
threads_per_outer = [32, 2],
elements_per_thread = [1, 4],
- subgroup_order = [0, 1],
- batch_order = [0, 1],
- outer_order = [0, 1],
thread_order = [1, 0],
- element_order = [0, 1],
subgroup_basis = [1, 1],
thread_basis = [32, 2]
@@ -224,10 +200,6 @@
threads_per_outer = [2, 32],
elements_per_thread = [4, 1],
- subgroup_order = [0, 1],
- batch_order = [0, 1],
- outer_order = [0, 1],
- thread_order = [0, 1],
element_order = [1, 0],
subgroup_basis = [1, 1],
@@ -242,10 +214,6 @@
threads_per_outer = [2, 32],
elements_per_thread = [4, 1],
- subgroup_order = [0, 1],
- batch_order = [0, 1],
- outer_order = [0, 1],
- thread_order = [0, 1],
element_order = [1, 0],
subgroup_basis = [1, 1],
@@ -311,11 +279,7 @@
threads_per_outer = [32, 2],
elements_per_thread = [1, 4],
- subgroup_order = [0, 1],
- batch_order = [0, 1],
- outer_order = [0, 1],
thread_order = [1, 0],
- element_order = [0, 1],
subgroup_basis = [1, 1],
thread_basis = [32, 2]
@@ -329,10 +293,6 @@
threads_per_outer = [2, 32],
elements_per_thread = [4, 1],
- subgroup_order = [0, 1],
- batch_order = [0, 1],
- outer_order = [0, 1],
- thread_order = [0, 1],
element_order = [1, 0],
subgroup_basis = [1, 1],
@@ -347,10 +307,6 @@
threads_per_outer = [2, 32],
elements_per_thread = [4, 1],
- subgroup_order = [0, 1],
- batch_order = [0, 1],
- outer_order = [0, 1],
- thread_order = [0, 1],
element_order = [1, 0],
subgroup_basis = [1, 1],
@@ -413,11 +369,7 @@
threads_per_outer = [32, 2],
elements_per_thread = [1, 4],
- subgroup_order = [0, 1],
- batch_order = [0, 1],
- outer_order = [0, 1],
thread_order = [1, 0],
- element_order = [0, 1],
subgroup_basis = [1, 1],
thread_basis = [32, 2]
@@ -431,10 +383,6 @@
threads_per_outer = [2, 32],
elements_per_thread = [4, 1],
- subgroup_order = [0, 1],
- batch_order = [0, 1],
- outer_order = [0, 1],
- thread_order = [0, 1],
element_order = [1, 0],
subgroup_basis = [1, 1],
@@ -449,10 +397,7 @@
threads_per_outer = [2, 32],
elements_per_thread = [4, 1],
- subgroup_order = [0, 1],
batch_order = [1, 0],
- outer_order = [0, 1],
- thread_order = [0, 1],
element_order = [1, 0],
subgroup_basis = [1, 1],
@@ -520,11 +465,7 @@
threads_per_outer = [32, 2],
elements_per_thread = [1, 4],
- subgroup_order = [0, 1],
- batch_order = [0, 1],
- outer_order = [0, 1],
thread_order = [1, 0],
- element_order = [0, 1],
subgroup_basis = [1, 1],
thread_basis = [32, 2]
@@ -538,11 +479,7 @@
threads_per_outer = [32, 2],
elements_per_thread = [1, 4],
- subgroup_order = [0, 1],
- batch_order = [0, 1],
- outer_order = [0, 1],
thread_order = [1, 0],
- element_order = [0, 1],
subgroup_basis = [1, 1],
thread_basis = [32, 2]
@@ -556,10 +493,6 @@
threads_per_outer = [2, 32],
elements_per_thread = [4, 1],
- subgroup_order = [0, 1],
- batch_order = [0, 1],
- outer_order = [0, 1],
- thread_order = [0, 1],
element_order = [1, 0],
subgroup_basis = [1, 1],
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir
index c155e4a..2497666 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir
@@ -7,11 +7,7 @@
threads_per_outer = [8, 1],
elements_per_thread = [1, 8],
- subgroup_order = [0, 1],
batch_order = [1, 0],
- outer_order = [0, 1],
- thread_order = [0, 1],
- element_order = [0, 1],
subgroup_basis = [1, 1],
thread_basis = [8, 1]
@@ -60,10 +56,7 @@
threads_per_outer = [4, 8],
elements_per_thread = [4, 1],
- subgroup_order = [0, 1],
batch_order = [1, 0],
- outer_order = [0, 1],
- thread_order = [0, 1],
element_order = [1, 0],
subgroup_basis = [1, 1],
@@ -112,11 +105,7 @@
threads_per_outer = [8, 1],
elements_per_thread = [1, 8],
- subgroup_order = [0, 1],
batch_order = [1, 0],
- outer_order = [0, 1],
- thread_order = [0, 1],
- element_order = [0, 1],
subgroup_basis = [1, 1],
thread_basis = [8, 1]
@@ -165,10 +154,7 @@
threads_per_outer = [4, 8],
elements_per_thread = [4, 1],
- subgroup_order = [0, 1],
batch_order = [1, 0],
- outer_order = [0, 1],
- thread_order = [0, 1],
element_order = [1, 0],
subgroup_basis = [1, 1],
@@ -213,11 +199,7 @@
threads_per_outer = [8, 1],
elements_per_thread = [1, 8],
- subgroup_order = [0, 1],
batch_order = [1, 0],
- outer_order = [0, 1],
- thread_order = [0, 1],
- element_order = [0, 1],
subgroup_basis = [1, 1],
thread_basis = [8, 1]
@@ -268,12 +250,6 @@
threads_per_outer = [4, 8],
elements_per_thread = [4, 1],
- subgroup_order = [0, 1],
- batch_order = [0, 1],
- outer_order = [0, 1],
- thread_order = [0, 1],
- element_order = [0, 1],
-
subgroup_basis = [1, 1],
thread_basis = [4, 8]
>
@@ -316,7 +292,6 @@
batch_order = [1, 2, 3, 0],
outer_order = [0, 3, 1, 2],
thread_order = [0, 1, 3, 2],
- element_order = [0, 1, 2, 3],
subgroup_basis = [7, 3, 1, 1],
thread_basis = [1, 1, 2, 2]
@@ -356,11 +331,7 @@
threads_per_outer = [8, 1],
elements_per_thread = [1, 8],
- subgroup_order = [0, 1],
batch_order = [1, 0],
- outer_order = [0, 1],
- thread_order = [0, 1],
- element_order = [0, 1],
subgroup_basis = [1, 1],
thread_basis = [8, 1]
@@ -406,10 +377,7 @@
threads_per_outer = [4, 8],
elements_per_thread = [4, 1],
- subgroup_order = [0, 1],
batch_order = [1, 0],
- outer_order = [0, 1],
- thread_order = [0, 1],
element_order = [1, 0],
subgroup_basis = [1, 1],
@@ -456,11 +424,7 @@
threads_per_outer = [8, 1],
elements_per_thread = [1, 8],
- subgroup_order = [0, 1],
batch_order = [1, 0],
- outer_order = [0, 1],
- thread_order = [0, 1],
- element_order = [0, 1],
subgroup_basis = [1, 1],
thread_basis = [8, 1]
@@ -514,11 +478,7 @@
threads_per_outer = [8, 1],
elements_per_thread = [1, 8],
- subgroup_order = [0, 1],
batch_order = [1, 0],
- outer_order = [0, 1],
- thread_order = [0, 1],
- element_order = [0, 1],
subgroup_basis = [1, 1],
thread_basis = [8, 1]
@@ -569,11 +529,7 @@
threads_per_outer = [32, 2],
elements_per_thread = [1, 4],
- subgroup_order = [0, 1],
- batch_order = [0, 1],
- outer_order = [0, 1],
thread_order = [1, 0],
- element_order = [0, 1],
subgroup_basis = [4, 2],
thread_basis = [2, 32]
@@ -587,10 +543,6 @@
threads_per_outer = [2, 32],
elements_per_thread = [4, 1],
- subgroup_order = [0, 1],
- batch_order = [0, 1],
- outer_order = [0, 1],
- thread_order = [0, 1],
element_order = [1, 0],
subgroup_basis = [4, 2],
@@ -605,10 +557,6 @@
threads_per_outer = [2, 32],
elements_per_thread = [4, 1],
- subgroup_order = [0, 1],
- batch_order = [0, 1],
- outer_order = [0, 1],
- thread_order = [0, 1],
element_order = [1, 0],
subgroup_basis = [4, 2],
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_layout.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_layout.mlir
index 543f22e..c3b7fe9 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_layout.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_layout.mlir
@@ -14,15 +14,15 @@
// CHECK: contract A vector layout: #iree_vector_ext.nested_layout<
// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [1, 1], threads_per_outer = [32, 2], elements_per_thread = [1, 4],
-// CHECK-SAME: subgroup_order = [0, 1], batch_order = [0, 1], outer_order = [0, 1], thread_order = [1, 0], element_order = [0, 1],
+// CHECK-SAME: thread_order = [1, 0]
// CHECK-SAME: subgroup_basis = [1, 1], thread_basis = [2, 32]>
// CHECK: contract B vector layout: #iree_vector_ext.nested_layout<
// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 2], outers_per_batch = [1, 1], threads_per_outer = [2, 32], elements_per_thread = [4, 1],
-// CHECK-SAME: subgroup_order = [0, 1], batch_order = [0, 1], outer_order = [0, 1], thread_order = [0, 1], element_order = [1, 0],
+// CHECK-SAME: element_order = [1, 0],
// CHECK-SAME: subgroup_basis = [1, 1], thread_basis = [2, 32]>
// CHECK: contract C vector layout: #iree_vector_ext.nested_layout<
// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [4, 1], threads_per_outer = [2, 32], elements_per_thread = [4, 1],
-// CHECK-SAME: subgroup_order = [0, 1], batch_order = [0, 1], outer_order = [0, 1], thread_order = [0, 1], element_order = [1, 0],
+// CHECK-SAME: element_order = [1, 0],
// CHECK-SAME: subgroup_basis = [1, 1], thread_basis = [2, 32]>
// -----
@@ -41,15 +41,15 @@
// CHECK: contract A vector layout: #iree_vector_ext.nested_layout<
// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [1, 1], threads_per_outer = [32, 2], elements_per_thread = [1, 4],
-// CHECK-SAME: subgroup_order = [0, 1], batch_order = [0, 1], outer_order = [0, 1], thread_order = [1, 0], element_order = [0, 1],
+// CHECK-SAME: thread_order = [1, 0]
// CHECK-SAME: subgroup_basis = [1, 1], thread_basis = [2, 32]>
// CHECK: contract B vector layout: #iree_vector_ext.nested_layout<
// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 2], outers_per_batch = [1, 1], threads_per_outer = [32, 2], elements_per_thread = [1, 4],
-// CHECK-SAME: subgroup_order = [1, 0], batch_order = [1, 0], outer_order = [1, 0], thread_order = [1, 0], element_order = [0, 1],
+// CHECK-SAME: subgroup_order = [1, 0], batch_order = [1, 0], outer_order = [1, 0], thread_order = [1, 0]
// CHECK-SAME: subgroup_basis = [1, 1], thread_basis = [2, 32]>
// CHECK: contract C vector layout: #iree_vector_ext.nested_layout<
// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [4, 1], threads_per_outer = [2, 32], elements_per_thread = [4, 1],
-// CHECK-SAME: subgroup_order = [0, 1], batch_order = [0, 1], outer_order = [0, 1], thread_order = [0, 1], element_order = [1, 0],
+// CHECK-SAME: element_order = [1, 0],
// CHECK-SAME: subgroup_basis = [1, 1], thread_basis = [2, 32]>
// -----
@@ -109,24 +109,22 @@
// CHECK: transfer '{{.+}} memref<16x256xf16{{.+}}<storage_buffer>>, vector<16x32xf16>' vector layout: #iree_vector_ext.nested_layout<
// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [1, 1], outers_per_batch = [1, 1], threads_per_outer = [16, 4], elements_per_thread = [1, 8],
-// CHECK-SAME: subgroup_order = [0, 1], batch_order = [0, 1], outer_order = [0, 1], thread_order = [0, 1], element_order = [0, 1],
// CHECK-SAME: subgroup_basis = [1, 1], thread_basis = [16, 4]>
// CHECK: transfer '{{.+}} memref<256x16xf16{{.+}}<storage_buffer>>, vector<32x16xf16>' vector layout: #iree_vector_ext.nested_layout<
// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [1, 1], outers_per_batch = [1, 1], threads_per_outer = [32, 2], elements_per_thread = [1, 8],
-// CHECK-SAME: subgroup_order = [0, 1], batch_order = [0, 1], outer_order = [0, 1], thread_order = [0, 1], element_order = [0, 1],
// CHECK-SAME: subgroup_basis = [1, 1], thread_basis = [32, 2]>
// CHECK: contract A vector layout: #iree_vector_ext.nested_layout<
// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [1, 2], outers_per_batch = [1, 1], threads_per_outer = [16, 4], elements_per_thread = [1, 4],
-// CHECK-SAME: subgroup_order = [0, 1], batch_order = [0, 1], outer_order = [0, 1], thread_order = [1, 0], element_order = [0, 1],
+// CHECK-SAME: thread_order = [1, 0]
// CHECK-SAME: subgroup_basis = [1, 1], thread_basis = [4, 16]>
// CHECK: contract B vector layout: #iree_vector_ext.nested_layout<
// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 1], outers_per_batch = [1, 1], threads_per_outer = [4, 16], elements_per_thread = [4, 1],
-// CHECK-SAME: subgroup_order = [0, 1], batch_order = [0, 1], outer_order = [0, 1], thread_order = [0, 1], element_order = [1, 0],
+// CHECK-SAME: element_order = [1, 0],
// CHECK-SAME: subgroup_basis = [1, 1], thread_basis = [4, 16]>
// CHECK: contract C vector layout: #iree_vector_ext.nested_layout<
// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [1, 1], outers_per_batch = [1, 1], threads_per_outer = [4, 16], elements_per_thread = [4, 1],
-// CHECK-SAME: subgroup_order = [0, 1], batch_order = [0, 1], outer_order = [0, 1], thread_order = [0, 1], element_order = [1, 0],
+// CHECK-SAME: element_order = [1, 0],
// CHECK-SAME: subgroup_basis = [1, 1], thread_basis = [4, 16]>
// -----
@@ -171,7 +169,6 @@
// CHECK-NOT: transfer '{{.+}} memref<16x16xf16{{.+}}<storage_buffer>>, vector<16x16xf16>' vector layout
// CHECK: transfer '{{.+}} memref<16x256xf16{{.+}}<storage_buffer>>, vector<16x32xf16>' vector layout: #iree_vector_ext.nested_layout<
// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [1, 1], outers_per_batch = [1, 1], threads_per_outer = [16, 4], elements_per_thread = [1, 8],
-// CHECK-SAME: subgroup_order = [0, 1], batch_order = [0, 1], outer_order = [0, 1], thread_order = [0, 1], element_order = [0, 1],
// CHECK-SAME: subgroup_basis = [1, 1], thread_basis = [16, 4]>
// CHECK: transfer '{{.+}} memref<16x256xf16{{.+}}storage_buffer>>, vector<32x16xf16>' vector layout: #iree_vector_ext.nested_layout<
// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [1, 1], outers_per_batch = [1, 1], threads_per_outer = [4, 16], elements_per_thread = [8, 1],
@@ -180,14 +177,14 @@
// CHECK: contract A vector layout: #iree_vector_ext.nested_layout<
// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [1, 2], outers_per_batch = [1, 1], threads_per_outer = [16, 4], elements_per_thread = [1, 4],
-// CHECK-SAME: subgroup_order = [0, 1], batch_order = [0, 1], outer_order = [0, 1], thread_order = [1, 0], element_order = [0, 1],
+// CHECK-SAME: thread_order = [1, 0]
// CHECK-SAME: subgroup_basis = [1, 1], thread_basis = [4, 16]>
// CHECK: contract B vector layout: #iree_vector_ext.nested_layout<
// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 1], outers_per_batch = [1, 1], threads_per_outer = [4, 16], elements_per_thread = [4, 1],
-// CHECK-SAME: subgroup_order = [0, 1], batch_order = [0, 1], outer_order = [0, 1], thread_order = [0, 1], element_order = [1, 0],
+// CHECK-SAME: element_order = [1, 0],
// CHECK-SAME: subgroup_basis = [1, 1], thread_basis = [4, 16]>
// CHECK: contract C vector layout: #iree_vector_ext.nested_layout<
// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [1, 1], outers_per_batch = [1, 1], threads_per_outer = [4, 16], elements_per_thread = [4, 1],
-// CHECK-SAME: subgroup_order = [0, 1], batch_order = [0, 1], outer_order = [0, 1], thread_order = [0, 1], element_order = [1, 0],
+// CHECK-SAME: element_order = [1, 0],
// CHECK-SAME: subgroup_basis = [1, 1], thread_basis = [4, 16]>
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.td
index cb340d6..bba41a0 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.td
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.td
@@ -285,8 +285,8 @@
ArrayRefParameter<"int64_t", "thread_basis">:$threadBasis
);
- // TODO: add custom parser/printer and builder to elide default value array
- // refs.
+ // By default, identity orderings are elided when parsing/printing.
+ // TODO: Improve custom parsing/printing for sizes/basis elements.
let assemblyFormat = [{
`<` `subgroups_per_workgroup` `=` `[` $subgroupsPerWorkgroup `]` `,`
`batches_per_subgroup` `=` `[` $batchesPerSubgroup `]` `,`
@@ -294,11 +294,11 @@
`threads_per_outer` `=` `[` $threadsPerOuter `]` `,`
`elements_per_thread` `=` `[` $elementsPerThread `]` `,`
- `subgroup_order` `=` `[` $subgroupOrder `]` `,`
- `batch_order` `=` `[` $batchOrder `]` `,`
- `outer_order` `=` `[` $outerOrder `]` `,`
- `thread_order` `=` `[` $threadOrder `]` `,`
- `element_order` `=` `[` $elementOrder `]` `,`
+ custom<Permutation>("\"subgroup_order\"", ref($subgroupsPerWorkgroup), "true", $subgroupOrder) ``
+ custom<Permutation>("\"batch_order\"", ref($batchesPerSubgroup), "true", $batchOrder) ``
+ custom<Permutation>("\"outer_order\"", ref($outersPerBatch), "true", $outerOrder) ``
+ custom<Permutation>("\"thread_order\"", ref($threadsPerOuter), "true", $threadOrder) ``
+ custom<Permutation>("\"element_order\"", ref($elementsPerThread), "true", $elementOrder) ``
`subgroup_basis` `=` `[` $subgroupBasis `]` `,`
`thread_basis` `=` `[` $threadBasis `]`
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp
index d6a8768..ef56793 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp
@@ -13,6 +13,7 @@
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
@@ -341,6 +342,63 @@
return delinearized;
}
+//===----------------------------------------------------------------------===//
+// Custom Parsers/Printers
+//===----------------------------------------------------------------------===//
+
+// Custom parser/printer to construct the permutation based on the rank of the
+// sizes corresponding to this order.
+static ParseResult parsePermutation(AsmParser &parser, StringRef baseName,
+ ArrayRef<int64_t> sizes, bool parseComma,
+ SmallVector<int64_t> &permutation) {
+ if (failed(parser.parseOptionalKeyword(baseName))) {
+ permutation = llvm::to_vector(llvm::seq<int64_t>(0, sizes.size()));
+ return success();
+ }
+ if (failed(parser.parseEqual())) {
+ return failure();
+ }
+ if (parser.parseLSquare()) {
+ return failure();
+ }
+ auto arrayParser = FieldParser<SmallVector<int64_t>>::parse(parser);
+ if (failed(arrayParser)) {
+ parser.emitError(parser.getCurrentLocation(),
+ "failed to parse permutation parameter '")
+ << baseName << "' which is to be a `::llvm::ArrayRef<int64_t>`";
+ }
+ if (parser.parseRSquare()) {
+ return failure();
+ }
+ if (parseComma) {
+ if (parser.parseComma()) {
+ return failure();
+ }
+ }
+ permutation = *arrayParser;
+ return success();
+}
+
+static void printPermutation(AsmPrinter &p, StringRef baseName,
+ ArrayRef<int64_t> sizes, bool printComma,
+ ArrayRef<int64_t> permutation) {
+ if (isIdentityPermutation(permutation)) {
+ return;
+ }
+ p << baseName;
+ // This is called without whitespace inserted by default for optionality.
+ // Insert it explicitly instead.
+ p << ' ';
+ p << '=';
+ p << ' ';
+ p << '[';
+ llvm::interleaveComma(permutation, p);
+ p << ']';
+ if (printComma) {
+ p << ',' << ' ';
+ }
+}
+
} // namespace mlir::iree_compiler::IREE::VectorExt
using namespace mlir::iree_compiler::IREE::VectorExt;
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/roundtrip.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/roundtrip.mlir
index 3c92b66..844000c 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/roundtrip.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/roundtrip.mlir
@@ -47,7 +47,6 @@
subgroup_order = [1, 0],
batch_order = [1, 0],
- outer_order = [0, 1],
thread_order = [1, 0],
element_order = [1, 0],
@@ -63,8 +62,8 @@
return %2 : vector<32x32xf16>
}
-// CHECK-DAG: #[[LAYOUT0:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [1, 1], batches_per_subgroup = [4, 2], outers_per_batch = [1, 4], threads_per_outer = [2, 4], elements_per_thread = [4, 1], subgroup_order = [1, 0], batch_order = [1, 0], outer_order = [0, 1], thread_order = [1, 0], element_order = [1, 0], subgroup_basis = [1, 1], thread_basis = [2, 4]>
-// CHECK-DAG: #[[LAYOUT1:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 4], outers_per_batch = [4, 1], threads_per_outer = [4, 2], elements_per_thread = [1, 4], subgroup_order = [0, 1], batch_order = [0, 1], outer_order = [1, 0], thread_order = [0, 1], element_order = [0, 1], subgroup_basis = [1, 1], thread_basis = [4, 2]>
+// CHECK-DAG: #[[LAYOUT0:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [1, 1], batches_per_subgroup = [4, 2], outers_per_batch = [1, 4], threads_per_outer = [2, 4], elements_per_thread = [4, 1], subgroup_order = [1, 0], batch_order = [1, 0], thread_order = [1, 0], element_order = [1, 0], subgroup_basis = [1, 1], thread_basis = [2, 4]>
+// CHECK-DAG: #[[LAYOUT1:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 4], outers_per_batch = [4, 1], threads_per_outer = [4, 2], elements_per_thread = [1, 4], outer_order = [1, 0], subgroup_basis = [1, 1], thread_basis = [4, 2]>
// CHECK-LABEL: func.func @specify_nested
// CHECK: iree_vector_ext.layout_conflict_resolution
// CHECK-SAME: desiredLayout = #[[LAYOUT0]]