Add a test for xla_hlo.transpose op.
This operation is already supported on all the backends. Add test cases match
the examples presented at
https://www.tensorflow.org/api_docs/python/tf/transpose
PiperOrigin-RevId: 306999939
diff --git a/iree/test/e2e/xla_ops/BUILD b/iree/test/e2e/xla_ops/BUILD
index d363e25..9c8bc15 100644
--- a/iree/test/e2e/xla_ops/BUILD
+++ b/iree/test/e2e/xla_ops/BUILD
@@ -67,6 +67,7 @@
"sine.mlir",
"slice.mlir",
"sqrt.mlir",
+ "transpose.mlir",
"while.mlir",
],
driver = "vulkan",
@@ -92,6 +93,7 @@
"rsqrt.mlir",
"select.mlir",
"sqrt.mlir",
+ "transpose.mlir",
],
compiler_flags = ["-iree-use-linalg-to-spirv-path"],
driver = "vulkan",
@@ -125,6 +127,7 @@
"rsqrt.mlir",
"select.mlir",
"sqrt.mlir",
+ "transpose.mlir",
"while.mlir",
],
driver = "llvm",
diff --git a/iree/test/e2e/xla_ops/CMakeLists.txt b/iree/test/e2e/xla_ops/CMakeLists.txt
index fde903e..e4aecf2 100644
--- a/iree/test/e2e/xla_ops/CMakeLists.txt
+++ b/iree/test/e2e/xla_ops/CMakeLists.txt
@@ -60,6 +60,7 @@
"sine.mlir"
"slice.mlir"
"sqrt.mlir"
+ "transpose.mlir"
"while.mlir"
TARGET_BACKEND
vulkan-spirv
@@ -87,6 +88,7 @@
"rsqrt.mlir"
"select.mlir"
"sqrt.mlir"
+ "transpose.mlir"
TARGET_BACKEND
vulkan-spirv
DRIVER
@@ -123,6 +125,7 @@
"rsqrt.mlir"
"select.mlir"
"sqrt.mlir"
+ "transpose.mlir"
"while.mlir"
TARGET_BACKEND
llvm-ir
diff --git a/iree/test/e2e/xla_ops/transpose.mlir b/iree/test/e2e/xla_ops/transpose.mlir
new file mode 100644
index 0000000..678b1dc
--- /dev/null
+++ b/iree/test/e2e/xla_ops/transpose.mlir
@@ -0,0 +1,29 @@
+func @transpose_2d() attributes { iree.module.export } {
+ %input = iree.unfoldable_constant dense<[[1, 2, 3],
+ [4, 5, 6]]> : tensor<2x3xi32>
+ %0 = "xla_hlo.transpose"(%input) {
+ permutation = dense<[1, 0]> : tensor<2xi64>
+ } : (tensor<2x3xi32>) -> tensor<3x2xi32>
+ check.expect_eq_const(%0, dense<[[1, 4],
+ [2, 5],
+ [3, 6]]> : tensor<3x2xi32>) : tensor<3x2xi32>
+ return
+}
+
+func @transpose_3d() attributes { iree.module.export } {
+ %input = iree.unfoldable_constant dense<[[[ 1, 2, 3],
+ [ 4, 5, 6]],
+ [[ 7, 8, 9],
+ [10, 11, 12]]]> : tensor<2x2x3xi32>
+ %0 = "xla_hlo.transpose"(%input) {
+ permutation = dense<[0, 2, 1]> : tensor<3xi64>
+ } : (tensor<2x2x3xi32>) -> tensor<2x3x2xi32>
+ check.expect_eq_const(%0, dense<[
+ [[ 1, 4],
+ [ 2, 5],
+ [ 3, 6]],
+ [[ 7, 10],
+ [ 8, 11],
+ [ 9, 12]]]> : tensor<2x3x2xi32>) : tensor<2x3x2xi32>
+ return
+}