[Codegen][VectorExt] Fix VectorExt ops for 0-d vectors (#18915)
Upstream "AnyVector" does not actually allow 0d vectors. Instead, the
upstream macro, AnyVectorOfAnyRank allows them instead.
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.td
index bf31ba6..446ff77 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.td
@@ -256,25 +256,24 @@
}];
let parameters = (ins
- ArrayRefParameter<"int64_t", "subgroup_tile">:$subgroupTile,
- ArrayRefParameter<"int64_t", "batch_tile">:$batchTile,
- ArrayRefParameter<"int64_t", "outer_tile">:$outerTile,
- ArrayRefParameter<"int64_t", "thread_tile">:$threadTile,
- ArrayRefParameter<"int64_t", "element_tile">:$elementTile,
+ OptionalArrayRefParameter<"int64_t", "subgroup_tile">:$subgroupTile,
+ OptionalArrayRefParameter<"int64_t", "batch_tile">:$batchTile,
+ OptionalArrayRefParameter<"int64_t", "outer_tile">:$outerTile,
+ OptionalArrayRefParameter<"int64_t", "thread_tile">:$threadTile,
+ OptionalArrayRefParameter<"int64_t", "element_tile">:$elementTile,
- ArrayRefParameter<"int64_t", "subgroup_strides">:$subgroupStrides,
- ArrayRefParameter<"int64_t", "thread_strides">:$threadStrides
+ OptionalArrayRefParameter<"int64_t", "subgroup_strides">:$subgroupStrides,
+ OptionalArrayRefParameter<"int64_t", "thread_strides">:$threadStrides
);
let assemblyFormat = [{
- `<` `subgroup_tile` `=` `[` $subgroupTile `]` `,`
- `batch_tile` `=` `[` $batchTile `]` `,`
- `outer_tile` `=` `[` $outerTile `]` `,`
- `thread_tile` `=` `[` $threadTile `]` `,`
- `element_tile` `=` `[` $elementTile `]` `,`
-
- `subgroup_strides` `=` `[` $subgroupStrides `]` `,`
- `thread_strides` `=` `[` $threadStrides `]`
+ `<` `subgroup_tile` `=` `[` (`]`) : ($subgroupTile^ `]`)? `,`
+ `batch_tile` `=` `[` (`]`) : ($batchTile^ `]`)? `,`
+ `outer_tile` `=` `[` (`]`) : ($outerTile^ `]`)? `,`
+ `thread_tile` `=` `[` (`]`) : ($threadTile^ `]`)? `,`
+ `element_tile` `=` `[` (`]`) : ($elementTile^ `]`)? `,`
+ `subgroup_strides` `=` `[` (`]`) : ($subgroupStrides^ `]`)? `,`
+ `thread_strides` `=` `[` (`]`) : ($threadStrides^ `]`)?
`>`
}];
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.td b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.td
index 4e40cd8..04055bf 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.td
@@ -84,10 +84,10 @@
distributed vectors.
}];
let arguments = (ins
- AnyVector:$input
+ AnyVectorOfAnyRank:$input
);
let results = (outs
- AnyVector:$output
+ AnyVectorOfAnyRank:$output
);
let extraClassDeclaration = [{}];
let assemblyFormat = "$input attr-dict `:` type($input) `->` type($output)";
@@ -103,10 +103,10 @@
distributed vectors.
}];
let arguments = (ins
- AnyVector:$input
+ AnyVectorOfAnyRank:$input
);
let results = (outs
- AnyVector:$output
+ AnyVectorOfAnyRank:$output
);
let extraClassDeclaration = [{}];
let assemblyFormat = "$input attr-dict `:` type($input) `->` type($output)";
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/test/roundtrip.mlir b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/test/roundtrip.mlir
index f320654..fc14c3b 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/test/roundtrip.mlir
+++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/test/roundtrip.mlir
@@ -88,6 +88,37 @@
// -----
+#nested_0 = #iree_vector_ext.nested_layout<
+ subgroup_tile = [],
+ batch_tile = [],
+ outer_tile = [],
+ thread_tile = [],
+ element_tile = [],
+
+ subgroup_strides = [],
+ thread_strides = []
+>
+
+func.func @specify_nested_0d(%lhs: vector<f16>) -> vector<f16> {
+ %result = iree_vector_ext.to_layout %lhs to layout(#nested_0) : vector<f16>
+ func.return %result : vector<f16>
+}
+
+// CHECK: #[[$LAYOUT0:.+]] = #iree_vector_ext.nested_layout<
+// CHECK-SAME: subgroup_tile = [],
+// CHECK-SAME: batch_tile = [],
+// CHECK-SAME: outer_tile = [],
+// CHECK-SAME: thread_tile = [],
+// CHECK-SAME: element_tile = [],
+// CHECK-SAME: subgroup_strides = [],
+// CHECK-SAME: thread_strides = []>
+
+// CHECK-LABEL: func.func @specify_nested_0d
+// CHECK: to_layout
+// CHECK-SAME: layout(#[[$LAYOUT0]])
+
+// -----
+
func.func @to_simd_op(%simt: vector<4x4x4xf16>) -> vector<64x64xf16> {
%simd = iree_vector_ext.to_simd %simt : vector<4x4x4xf16> -> vector<64x64xf16>
func.return %simd : vector<64x64xf16>
@@ -103,3 +134,21 @@
}
// CHECK-LABEL: func.func @to_simt_op
// CHECK: iree_vector_ext.to_simd
+
+// -----
+
+func.func @to_simd_op_0d(%simt: vector<f16>) -> vector<f16> {
+ %simd = iree_vector_ext.to_simd %simt : vector<f16> -> vector<f16>
+ func.return %simd : vector<f16>
+}
+// CHECK-LABEL: func.func @to_simd_op
+// CHECK: iree_vector_ext.to_simd
+
+// -----
+
+func.func @to_simt_op_0d(%simd: vector<f32>) -> vector<f32> {
+ %simt = iree_vector_ext.to_simd %simd : vector<f32> -> vector<f32>
+ func.return %simt : vector<f32>
+}
+// CHECK-LABEL: func.func @to_simt_op
+// CHECK: iree_vector_ext.to_simd