blob: e10a92ad6c78318bc4849af2241472ba1e03cdc5 [file] [log] [blame]
// RUN: iree-tf-opt -iree-xla-legalize-tf -split-input-file <%s | IreeFileCheck %s
// CHECK-LABEL: func @f
func @f(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<3xf32>) {
// CHECK: [[VAL0:%.+]] = mhlo.constant dense<2.000000e+00>
// CHECK: [[VAL1:%.+]] = mhlo.constant dense<1.000000e+00>
%0 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
%1 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
%2 = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
%3 = "tf.Const"() {value = dense<1> : tensor<i64>} : () -> tensor<i64>
%4 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
%5 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
%6 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
%7 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
%8 = "tf.GreaterEqual"(%2, %4) {device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i1>
%9 = "tf.StridedSlice"(%5, %7, %5, %5) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
%10 = "tf.SelectV2"(%0, %4, %9) {device = ""} : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
%11 = "tf.Range"(%4, %9, %6) {device = ""} : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<1xi32>
%12 = "tf.Equal"(%10, %11) {device = "", incompatible_shape_error = true} : (tensor<i32>, tensor<1xi32>) -> tensor<1xi1>
%13 = "tf.SelectV2"(%12, %2, %5) {device = ""} : (tensor<1xi1>, tensor<i32>, tensor<1xi32>) -> tensor<1xi32>
%14 = "tf.Sub"(%2, %6) {device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
%15 = "tf.Maximum"(%14, %6) {device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
%16 = "tf.Cast"(%15) {Truncate = false, device = ""} : (tensor<i32>) -> tensor<f32>
%17 = "tf.SelectV2"(%8, %15, %1) {device = ""} : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
%18 = "tf.Cast"(%17) {Truncate = false, device = ""} : (tensor<i32>) -> tensor<i64>
%19 = "tf.Range"(%3, %18, %3) {device = ""} : (tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<1xi64>
%20 = "tf.Cast"(%19) {Truncate = false, device = ""} : (tensor<1xi64>) -> tensor<1xf32>
// CHECK: [[VAL2:%.+]] = "mhlo.reshape"(%arg0)
%21 = "tf.ExpandDims"(%arg0, %4) {device = ""} : (tensor<f32>, tensor<i32>) -> tensor<1xf32>
// CHECK: [[VAL3:%.+]] = "mhlo.reshape"(%arg1)
%22 = "tf.ExpandDims"(%arg1, %4) {device = ""} : (tensor<f32>, tensor<i32>) -> tensor<1xf32>
// CHECK: [[VAL4:%.+]] = mhlo.subtract [[VAL3]], [[VAL2]]
%23 = "tf.Sub"(%22, %21) {device = ""} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
// CHECK: [[VAL5:%.+]] = mhlo.divide [[VAL4]], [[VAL0]]
%24 = "tf.RealDiv"(%23, %16) {device = ""} : (tensor<1xf32>, tensor<f32>) -> tensor<1xf32>
// CHECK: [[VAL6:%.+]] = mhlo.multiply [[VAL5]], [[VAL1]]
%25 = "tf.Mul"(%24, %20) {device = ""} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
// CHECK: [[VAL7:%.+]] = mhlo.add [[VAL2]], [[VAL6]]
%26 = "tf.AddV2"(%21, %25) {device = ""} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
// CHECK: [[VAL8:%.+]] = "mhlo.concatenate"([[VAL2]], [[VAL7]], [[VAL3]]) {dimension = 0 : i64}
%27 = "tf.ConcatV2"(%21, %26, %22, %10) {device = ""} : (tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<i32>) -> tensor<3xf32>
%28 = "tf.Slice"(%27, %7, %13) {device = ""} : (tensor<3xf32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xf32>
%29 = "tf.Identity"(%28) {device = ""} : (tensor<3xf32>) -> tensor<3xf32>
// CHECK: return [[VAL8]]
return %29 : tensor<3xf32>
}