Update semantics of linalg_ext.reverse. (#7155)
1. Allow the op takes multiple reverse dimensions.
2. Update the loop bound of reverse dim with the whole size.
Re (2), it is required because of tiling interface. The current tiling
interface expects the implmenetation returned a tiled op. However, two
tiled ops are required in the previous definition. E.g., say that we
have `M` tiles in total.
[T_1], [T_2], ... , [T_M]
The result of `T_1` will be stored to the position of `T_M`, and vise
versa. If we iterate only half of the tensor, we need to create reversed
`T_1` and reversed `T_M`, then store them to corresponding offsets. In
this context, two tiled op are created, but only one can be returned.
The updated semantic also meets the current lowering, instead of
emitting the affine maps into the op itself.
This PR also adds e2e tests for the op.
It is a step toward https://github.com/google/iree/issues/5045
diff --git a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
index 4ae4556..7497397 100644
--- a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
+++ b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -8,6 +8,7 @@
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/SMLoc.h"
@@ -764,17 +765,27 @@
//===----------------------------------------------------------------------===//
static LogicalResult verifyReverseOp(ReverseOp op) {
- if (op.getNumInputs()) {
- return op.emitOpError("expected no inputs");
+ if (op.getNumInputs() != 1) {
+ return op.emitOpError("expected exactly one input");
}
if (op.getNumOutputs() != 1) {
return op.emitOpError("expected exactly one output");
}
+ if (op.input().getType() != op.output().getType()) {
+ return op.emitOpError("expected input/output types are identical");
+ }
int64_t rank = op.getOperandRank();
- int dimension = op.dimension();
- if (dimension < 0 || dimension >= rank) {
- return op.emitOpError("dimension must be within (0, ") << rank << "]";
+ llvm::SmallSetVector<int64_t, 4> s;
+ for (auto dim : op.dims()) {
+ if (dim < 0 || dim >= rank) {
+ return op.emitOpError("all the dimensions must be within [0, ")
+ << rank << ")";
+ }
+ if (s.contains(dim)) {
+ return op.emitOpError("expected dimensions numbers are all unique");
+ }
+ s.insert(dim);
}
return success();
@@ -783,8 +794,9 @@
bool ReverseOp::payloadUsesValueFromOperand(OpOperand *) { return false; }
SmallVector<StringRef> ReverseOp::getLoopIteratorTypes() {
+ // TODO(hanchung): Mark them parallel after tiling method is implemented.
SmallVector<StringRef> iteratorTypes(getOperandRank(),
- getParallelIteratorTypeName());
+ getReductionIteratorTypeName());
return iteratorTypes;
}
@@ -794,12 +806,9 @@
Value one = builder.create<ConstantIndexOp>(loc, 1);
SmallVector<Range> ranges;
for (auto dim : llvm::seq<int64_t>(0, getOperandRank())) {
- Value ub = getDimValue(builder, loc, operand(), dim);
+ Value ub = getDimValue(builder, loc, input(), dim);
ranges.emplace_back(Range{zero, ub, one});
}
- auto dim = dimension();
- ranges[dim].size = builder.create<SignedDivIOp>(
- loc, ranges[dim].size, builder.create<ConstantIndexOp>(loc, 2));
return ranges;
}
@@ -807,18 +816,13 @@
Location loc,
ValueRange ivs) {
SmallVector<Value> mirrorIndices(ivs.begin(), ivs.end());
- auto dim = dimension();
- auto size = getDimValue(b, loc, operand(), dim);
- size = b.create<SubIOp>(loc, size, b.create<ConstantIndexOp>(loc, 1));
- mirrorIndices[dim] = b.create<SubIOp>(loc, size, mirrorIndices[dim]);
-
- // for (int i = 0; i < n / 2; ++i) {
- // swap(array[i], array[n - 1 - i]);
- // }
- Value v1 = b.create<memref::LoadOp>(loc, operand(), ivs);
- Value v2 = b.create<memref::LoadOp>(loc, operand(), mirrorIndices);
- b.create<memref::StoreOp>(loc, v1, operand(), mirrorIndices);
- b.create<memref::StoreOp>(loc, v2, operand(), ivs);
+ for (auto dim : dims()) {
+ auto size = getDimValue(b, loc, input(), dim);
+ size = b.create<SubIOp>(loc, size, b.create<ConstantIndexOp>(loc, 1));
+ mirrorIndices[dim] = b.create<SubIOp>(loc, size, mirrorIndices[dim]);
+ }
+ Value val = b.create<memref::LoadOp>(loc, input(), ivs);
+ b.create<memref::StoreOp>(loc, val, output(), mirrorIndices);
return success();
}
diff --git a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
index 9841da1..9894831 100644
--- a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
+++ b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
@@ -253,38 +253,45 @@
["payloadUsesValueFromOperand"]>]> {
let summary = "Reverse operator";
let description = [{
- A temporary solution of a reverse op. The loop bound of the reverse
- dimension is half of the shape because we can simply swap elements. E.g.,
-
- for (int i = 0; i < n / 2; ++i) {
- std::swap(a[i], a[n - 1 - i]);
+ A temporary solution for lowering reverse ops into IREE, allowing IREE to
+ tile and distribute them.
}
}];
let arguments = (ins Variadic<AnyShaped>:$inputs,
Variadic<AnyShaped>:$outputs,
- I64Attr:$dimension
+ I64ElementsAttr:$dimensions
);
let results = (outs Variadic<AnyRankedTensor>:$results);
let assemblyFormat = [{
- `dimension` `(` $dimension `)`
+ `dimensions` `(` $dimensions `)`
attr-dict (`ins` `(` $inputs^ `:` type($inputs) `)`)?
- `outs` `(` $outputs `:` type($outputs) `)`
+ (`outs` `(` $outputs^ `:` type($outputs) `)`)?
(`:` type($results)^)?
}];
let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{
- Value operand() {
+ Value input() {
+ return getInputOperand(0)->get();
+ }
+ Value output() {
return getOutputOperand(0)->get();
-}
+ }
ShapedType getOperandType() {
- return operand().getType().cast<ShapedType>();
+ return input().getType().cast<ShapedType>();
}
int64_t getOperandRank() {
return getOperandType().getRank();
}
- ArrayRef<int64_t> getOperandShape() {
+ ArrayRef<int64_t> getOprerandShape() {
return getOperandType().getShape();
}
+ SmallVector<int64_t> dims() {
+ SmallVector<int64_t> ret;
+ for (const APInt& elem : dimensions()) {
+ ret.push_back(elem.getLimitedValue());
+ }
+ return ret;
+ }
}];
}
diff --git a/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir b/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir
index 2cea9ce..f4bb586 100644
--- a/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir
+++ b/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir
@@ -380,3 +380,27 @@
} -> tensor<?x?xi64>
return %0 : tensor<?x?xi64>
}
+
+// -----
+
+func @reverse_diff_types(%arg0: tensor<3x5xi32>) -> tensor<3x6xi32> {
+ %init = linalg.init_tensor [3, 6] : tensor<3x6xi32>
+ // expected-error @+1 {{expected input/output types are identical}}
+ %0 = linalg_ext.reverse
+ dimensions(dense<0> : tensor<1xi64>)
+ ins(%arg0 : tensor<3x5xi32>)
+ outs(%init : tensor<3x6xi32>) : tensor<3x6xi32>
+ return %0 : tensor<3x6xi32>
+}
+
+// -----
+
+func @reverse_dup_dims(%arg0: tensor<3x5xi32>) -> tensor<3x5xi32> {
+ %init = linalg.init_tensor [3, 5] : tensor<3x5xi32>
+ // expected-error @+1 {{expected dimensions numbers are all unique}}
+ %0 = linalg_ext.reverse
+ dimensions(dense<[0, 0]> : tensor<2xi64>)
+ ins(%arg0 : tensor<3x5xi32>)
+ outs(%init : tensor<3x5xi32>) : tensor<3x5xi32>
+ return %0 : tensor<3x5xi32>
+}
diff --git a/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir b/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir
index e488bc1..51509f5 100644
--- a/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir
+++ b/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir
@@ -391,38 +391,78 @@
// -----
func @reverse_tensor(%arg0: tensor<3x5xi32>) -> tensor<3x5xi32> {
+ %init = linalg.init_tensor [3, 5] : tensor<3x5xi32>
%0 = linalg_ext.reverse
- dimension(0)
- outs(%arg0 : tensor<3x5xi32>) : tensor<3x5xi32>
+ dimensions(dense<0> : tensor<1xi64>)
+ ins(%arg0 : tensor<3x5xi32>)
+ outs(%init : tensor<3x5xi32>) : tensor<3x5xi32>
return %0 : tensor<3x5xi32>
}
// CHECK-LABEL: func @reverse_tensor
-// CHECK-SAME: %[[ARG0:.+]]: tensor<3x5xi32>
-// CHECK: %[[RESULT:.+]] = linalg_ext.reverse dimension(0)
-// CHECK-SAME: outs(%[[ARG0]]
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<3x5xi32>
+// CHECK: %[[INIT:.+]] = linalg.init_tensor [3, 5]
+// CHECK: %[[RESULT:.+]] = linalg_ext.reverse
+// CHECK-SAME: dimensions(dense<0> : tensor<1xi64>)
+// CHECK-SAME: ins(%[[ARG0]]
+// CHECK-SAME: outs(%[[INIT]]
// -----
-func @reverse_memref(%arg0: memref<3x5xi32>) {
+func @reverse_memref(%arg0: memref<3x5xi32>, %arg1: memref<3x5xi32>) {
linalg_ext.reverse
- dimension(0)
- outs(%arg0 : memref<3x5xi32>)
+ dimensions(dense<0> : tensor<1xi64>)
+ ins(%arg0 : memref<3x5xi32>)
+ outs(%arg1 : memref<3x5xi32>)
return
}
// CHECK-LABEL: func @reverse_memref
-// CHECK-SAME: %[[ARG0:.+]]: memref<3x5xi32>
-// CHECK: linalg_ext.reverse dimension(0)
-// CHECK-SAME: outs(%[[ARG0]]
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<3x5xi32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<3x5xi32>
+// CHECK: linalg_ext.reverse
+// CHECK-SAME: dimensions(dense<0> : tensor<1xi64>)
+// CHECK-SAME: ins(%[[ARG0]]
+// CHECK-SAME: outs(%[[ARG1]]
// -----
func @reverse_dynamic_tensor(%arg0: tensor<?x?xi32>) -> tensor<?x?xi32> {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %d0 = tensor.dim %arg0, %c0 : tensor<?x?xi32>
+ %d1 = tensor.dim %arg0, %c1 : tensor<?x?xi32>
+ %init = linalg.init_tensor [%d0, %d1] : tensor<?x?xi32>
%0 = linalg_ext.reverse
- dimension(1)
- outs(%arg0 : tensor<?x?xi32>) : tensor<?x?xi32>
+ dimensions(dense<1> : tensor<1xi64>)
+ ins(%arg0 : tensor<?x?xi32>)
+ outs(%init : tensor<?x?xi32>) : tensor<?x?xi32>
return %0 : tensor<?x?xi32>
}
// CHECK-LABEL: func @reverse_dynamic_tensor
-// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xi32>
-// CHECK: %[[RESULT:.+]] = linalg_ext.reverse dimension(1)
-// CHECK-SAME: outs(%[[ARG0]]
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xi32>
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[D0]], %[[D1]]]
+// CHECK: %[[RESULT:.+]] = linalg_ext.reverse
+// CHECK-SAME: dimensions(dense<1> : tensor<1xi64>)
+// CHECK-SAME: ins(%[[ARG0]]
+// CHECK-SAME: outs(%[[INIT]]
+
+// -----
+
+func @reverse_multi_dims(%arg0: tensor<3x5xi32>) -> tensor<3x5xi32> {
+ %init = linalg.init_tensor [3, 5] : tensor<3x5xi32>
+ %0 = linalg_ext.reverse
+ dimensions(dense<[0, 1]> : tensor<2xi64>)
+ ins(%arg0 : tensor<3x5xi32>)
+ outs(%init : tensor<3x5xi32>) : tensor<3x5xi32>
+ return %0 : tensor<3x5xi32>
+}
+// CHECK-LABEL: func @reverse_multi_dims
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<3x5xi32>
+// CHECK: %[[INIT:.+]] = linalg.init_tensor [3, 5]
+// CHECK: %[[RESULT:.+]] = linalg_ext.reverse
+// CHECK-SAME: dimensions(dense<[0, 1]> : tensor<2xi64>)
+// CHECK-SAME: ins(%[[ARG0]]
+// CHECK-SAME: outs(%[[INIT]]
diff --git a/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir b/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir
index 84883b9..d3fc5a4 100644
--- a/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir
+++ b/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir
@@ -484,26 +484,24 @@
// -----
-func @reverse_dim_0(%arg0: memref<?x?xi32>) {
+func @reverse_dim_0(%arg0: memref<?x?xi32>, %arg1: memref<?x?xi32>) {
linalg_ext.reverse
- dimension(0)
- outs(%arg0 : memref<?x?xi32>)
+ dimensions(dense<0> : tensor<1xi64>)
+ ins(%arg0 : memref<?x?xi32>)
+ outs(%arg1 : memref<?x?xi32>)
return
}
// CHECK-LABEL: func @reverse_dim_0
-// CHECK-SAME: %[[BUF:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[OUT:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
-// CHECK-DAG: %[[C2:.+]] = constant 2 : index
// CHECK-DAG: %[[D0:.+]] = memref.dim %arg0, %c0 : memref<?x?xi32>
// CHECK-DAG: %[[D1:.+]] = memref.dim %arg0, %c1 : memref<?x?xi32>
-// CHECK-DAG: %[[REV_UB:.+]] = divi_signed %[[D0]], %[[C2]] : index
-// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[REV_UB]] step %[[C1]]
+// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[D0]] step %[[C1]]
// CHECK: scf.for %[[J:.+]] = %[[C0]] to %[[D1]] step %[[C1]]
-// CHECK: %[[T0:.+]] = memref.dim %[[BUF]], %[[C0]]
+// CHECK: %[[T0:.+]] = memref.dim %[[IN]], %[[C0]]
// CHECK: %[[T1:.+]] = subi %[[T0]], %[[C1]] : index
// CHECK: %[[T2:.+]] = subi %[[T1]], %[[I]] : index
-// CHECK: %[[V0:.+]] = memref.load %[[BUF]][%[[I]], %[[J]]]
-// CHECK: %[[V1:.+]] = memref.load %[[BUF]][%[[T2]], %[[J]]]
-// CHECK: memref.store %[[V0]], %[[BUF]][%[[T2]], %[[J]]] : memref<?x?xi32>
-// CHECK: memref.store %[[V1]], %[[BUF]][%[[I]], %[[J]]] : memref<?x?xi32>
+// CHECK: %[[V0:.+]] = memref.load %[[IN]][%[[I]], %[[J]]]
+// CHECK: memref.store %[[V0]], %[[OUT]][%[[T2]], %[[J]]] : memref<?x?xi32>
diff --git a/iree/test/e2e/linalg_ext_ops/BUILD b/iree/test/e2e/linalg_ext_ops/BUILD
new file mode 100644
index 0000000..9752101
--- /dev/null
+++ b/iree/test/e2e/linalg_ext_ops/BUILD
@@ -0,0 +1,109 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob")
+load("//build_tools/bazel:iree_check_test.bzl", "iree_check_single_backend_test_suite")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_check_single_backend_test_suite(
+ name = "check_cuda",
+ srcs = enforce_glob(
+ # keep sorted
+ [
+ "reverse.mlir",
+ ],
+ include = ["*.mlir"],
+ exclude = [
+ ],
+ ),
+ driver = "cuda",
+ tags = [
+ # CUDA cuInit fails with sanitizer on.
+ "noasan",
+ "nomsan",
+ "notsan",
+ "noubsan",
+ "requires-gpu-nvidia",
+ ],
+ target_backend = "cuda",
+)
+
+iree_check_single_backend_test_suite(
+ name = "check_dylib_embedded-llvm-aot_dylib",
+ srcs = enforce_glob(
+ # keep sorted
+ [
+ "reverse.mlir",
+ ],
+ include = ["*.mlir"],
+ exclude = [
+ ],
+ ),
+ compiler_flags = [
+ "-iree-llvm-link-embedded=true",
+ ],
+ driver = "dylib",
+ target_backend = "dylib-llvm-aot",
+)
+
+iree_check_single_backend_test_suite(
+ name = "check_dylib-llvm-aot_dylib",
+ srcs = enforce_glob(
+ # keep sorted
+ [
+ "reverse.mlir",
+ ],
+ include = ["*.mlir"],
+ exclude = [
+ ],
+ ),
+ driver = "dylib",
+ target_backend = "dylib-llvm-aot",
+)
+
+iree_check_single_backend_test_suite(
+ name = "check_vmvx_vmvx",
+ srcs = enforce_glob(
+ # keep sorted
+ [
+ "reverse.mlir",
+ ],
+ include = ["*.mlir"],
+ exclude = [
+ ],
+ ),
+ driver = "vmvx",
+ target_backend = "vmvx",
+)
+
+iree_check_single_backend_test_suite(
+ name = "check_vulkan-spirv_vulkan",
+ srcs = enforce_glob(
+ # keep sorted
+ [
+ "reverse.mlir",
+ ],
+ include = ["*.mlir"],
+ exclude = [
+ ],
+ ),
+ driver = "vulkan",
+ target_backend = "vulkan-spirv",
+)
+
+test_suite(
+ name = "check",
+ tests = [
+ ":check_dylib-llvm-aot_dylib",
+ ":check_vmvx_vmvx",
+ ":check_vulkan-spirv_vulkan",
+ ],
+)
diff --git a/iree/test/e2e/linalg_ext_ops/CMakeLists.txt b/iree/test/e2e/linalg_ext_ops/CMakeLists.txt
new file mode 100644
index 0000000..78b49f4
--- /dev/null
+++ b/iree/test/e2e/linalg_ext_ops/CMakeLists.txt
@@ -0,0 +1,76 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# iree/test/e2e/linalg_ext_ops/BUILD #
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_check_single_backend_test_suite(
+ NAME
+ check_cuda
+ SRCS
+ "reverse.mlir"
+ TARGET_BACKEND
+ "cuda"
+ DRIVER
+ "cuda"
+ LABELS
+ "noasan"
+ "nomsan"
+ "notsan"
+ "noubsan"
+ "requires-gpu-nvidia"
+)
+
+iree_check_single_backend_test_suite(
+ NAME
+ check_dylib_embedded-llvm-aot_dylib
+ SRCS
+ "reverse.mlir"
+ TARGET_BACKEND
+ "dylib-llvm-aot"
+ DRIVER
+ "dylib"
+ COMPILER_FLAGS
+ "-iree-llvm-link-embedded=true"
+)
+
+iree_check_single_backend_test_suite(
+ NAME
+ check_dylib-llvm-aot_dylib
+ SRCS
+ "reverse.mlir"
+ TARGET_BACKEND
+ "dylib-llvm-aot"
+ DRIVER
+ "dylib"
+)
+
+iree_check_single_backend_test_suite(
+ NAME
+ check_vmvx_vmvx
+ SRCS
+ "reverse.mlir"
+ TARGET_BACKEND
+ "vmvx"
+ DRIVER
+ "vmvx"
+)
+
+iree_check_single_backend_test_suite(
+ NAME
+ check_vulkan-spirv_vulkan
+ SRCS
+ "reverse.mlir"
+ TARGET_BACKEND
+ "vulkan-spirv"
+ DRIVER
+ "vulkan"
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/iree/test/e2e/linalg_ext_ops/reverse.mlir b/iree/test/e2e/linalg_ext_ops/reverse.mlir
new file mode 100644
index 0000000..95a69e0
--- /dev/null
+++ b/iree/test/e2e/linalg_ext_ops/reverse.mlir
@@ -0,0 +1,53 @@
+func @reverse_dim0() {
+ %input = util.unfoldable_constant dense<[[1.0, 2.0, 3.0],
+ [4.0, 5.0, 6.0]]> : tensor<2x3xf32>
+
+ %init = linalg.init_tensor [2, 3] : tensor<2x3xf32>
+ %0 = linalg_ext.reverse
+ dimensions(dense<0> : tensor<1xi64>)
+ ins(%input : tensor<2x3xf32>)
+ outs(%init : tensor<2x3xf32>) : tensor<2x3xf32>
+
+ check.expect_almost_eq_const(
+ %0,
+ dense<[[4.0, 5.0, 6.0], [1.0, 2.0, 3.0]]> : tensor<2x3xf32>
+ ) : tensor<2x3xf32>
+
+ return
+}
+
+func @reverse_dim1() {
+ %input = util.unfoldable_constant dense<[[1, 2, 3],
+ [4, 5, 6]]> : tensor<2x3xi32>
+
+ %init = linalg.init_tensor [2, 3] : tensor<2x3xi32>
+ %0 = linalg_ext.reverse
+ dimensions(dense<1> : tensor<1xi64>)
+ ins(%input : tensor<2x3xi32>)
+ outs(%init : tensor<2x3xi32>) : tensor<2x3xi32>
+
+ check.expect_eq_const(
+ %0,
+ dense<[[3, 2, 1], [6, 5, 4]]> : tensor<2x3xi32>
+ ) : tensor<2x3xi32>
+
+ return
+}
+
+func @reverse_multi_dims() {
+ %input = util.unfoldable_constant dense<[[1, 2, 3],
+ [4, 5, 6]]> : tensor<2x3xi32>
+
+ %init = linalg.init_tensor [2, 3] : tensor<2x3xi32>
+ %0 = linalg_ext.reverse
+ dimensions(dense<[0, 1]> : tensor<2xi64>)
+ ins(%input : tensor<2x3xi32>)
+ outs(%init : tensor<2x3xi32>) : tensor<2x3xi32>
+
+ check.expect_eq_const(
+ %0,
+ dense<[[6, 5, 4], [3, 2, 1]]> : tensor<2x3xi32>
+ ) : tensor<2x3xi32>
+
+ return
+}