Use tensor of i32 type instead of tensor index type for constant value in fft lowering. (#6630)
diff --git a/iree/compiler/InputConversion/MHLO/ConvertMHLOToLinalgExt.cpp b/iree/compiler/InputConversion/MHLO/ConvertMHLOToLinalgExt.cpp index f14b8e5..360f71b 100644 --- a/iree/compiler/InputConversion/MHLO/ConvertMHLOToLinalgExt.cpp +++ b/iree/compiler/InputConversion/MHLO/ConvertMHLOToLinalgExt.cpp
@@ -257,9 +257,9 @@ for (int j = 0; j < logn; ++j) { r |= ((i >> j) & 1) << (logn - j - 1); } - values.push_back(b.getIndexAttr(r)); + values.push_back(b.getI32IntegerAttr(r)); } - auto type = RankedTensorType::get({fftLength}, b.getIndexType()); + auto type = RankedTensorType::get({fftLength}, b.getI32Type()); return b.create<ConstantOp>(type, DenseIntElementsAttr::get(type, values)); } @@ -291,7 +291,7 @@ for (auto i : llvm::seq<unsigned>(0, rank - 1)) { ivs.push_back(b.create<linalg::IndexOp>(loc, i)); } - ivs.push_back(args[0]); + ivs.push_back(b.create<IndexCastOp>(loc, args[0], b.getIndexType())); b.create<linalg::YieldOp>( loc, b.create<tensor::ExtractOp>(loc, real, ivs).getResult()); });
diff --git a/iree/compiler/InputConversion/MHLO/test/convert_mhlo_to_linalg_ext.mlir b/iree/compiler/InputConversion/MHLO/test/convert_mhlo_to_linalg_ext.mlir index 8c80edc..e6eee1b 100644 --- a/iree/compiler/InputConversion/MHLO/test/convert_mhlo_to_linalg_ext.mlir +++ b/iree/compiler/InputConversion/MHLO/test/convert_mhlo_to_linalg_ext.mlir
@@ -262,15 +262,16 @@ // CHECK-DAG: #[[MAP:.+]] = affine_map<(d0) -> (d0)> // CHECK: func @rfft_1d // CHECK-SAME: %[[REAL:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[INDICES:.+]] = constant dense<[0, 4, 2, 6, 1, 5, 3, 7]> : tensor<8xindex> +// CHECK-DAG: %[[INDICES:.+]] = constant dense<[0, 4, 2, 6, 1, 5, 3, 7]> : tensor<8xi32> // CHECK-DAG: %[[INIT_TENSOR:.+]] = linalg.init_tensor [8] : tensor<8xf32> // CHECK: %[[REORDERED:.+]] = linalg.generic // CHECK-SAME: {indexing_maps = [#[[MAP]], #[[MAP]]] // CHECK-SAME: iterator_types = ["parallel"] // CHECK-SAME: ins(%[[INDICES]] // CHECK-SAME: outs(%[[INIT_TENSOR]] -// CHECK: ^bb0(%[[IDX:.+]]: index, %{{.+}}: f32): -// CHECK: %[[LOAD:.+]] = tensor.extract %[[REAL]][%[[IDX]]] : tensor<8xf32> +// CHECK: ^bb0(%[[IDX:.+]]: i32, %{{.+}}: f32): +// CHECK: %[[IDXVAL:.+]] = index_cast %[[IDX]] : i32 to index +// CHECK: %[[LOAD:.+]] = tensor.extract %[[REAL]][%[[IDXVAL]]] : tensor<8xf32> // CHECK: linalg.yield %[[LOAD]] : f32 // CHECK: %[[IMAG:.+]] = constant dense<0.000000e+00> : tensor<8xf32> // CHECK: %[[C1:.+]] = constant 1 : index @@ -303,16 +304,17 @@ // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK: func @rfft_2d // CHECK-SAME: %[[REAL:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[INDICES:.+]] = constant dense<[0, 4, 2, 6, 1, 5, 3, 7]> : tensor<8xindex> +// CHECK-DAG: %[[INDICES:.+]] = constant dense<[0, 4, 2, 6, 1, 5, 3, 7]> : tensor<8xi32> // CHECK-DAG: %[[INIT_TENSOR:.+]] = linalg.init_tensor [4, 8] : tensor<4x8xf32> // CHECK: %[[REORDERED:.+]] = linalg.generic // CHECK-SAME: {indexing_maps = [#[[MAP0]], #[[MAP1]]] // CHECK-SAME: iterator_types = ["parallel", "parallel"] // CHECK-SAME: ins(%[[INDICES]] // CHECK-SAME: outs(%[[INIT_TENSOR]] -// CHECK: ^bb0(%[[IDX:.+]]: index, %{{.+}}: f32): +// CHECK: ^bb0(%[[IDX:.+]]: i32, %{{.+}}: f32): // CHECK: %[[I:.+]] = linalg.index 0 -// CHECK: %[[LOAD:.+]] = tensor.extract %[[REAL]][%[[I]], %[[IDX]]] : tensor<4x8xf32> +// CHECK: %[[IDXVAL:.+]] = index_cast %[[IDX]] : i32 to index +// CHECK: %[[LOAD:.+]] = tensor.extract %[[REAL]][%[[I]], %[[IDXVAL]]] : tensor<4x8xf32> // CHECK: linalg.yield %[[LOAD]] : f32 // CHECK: %[[IMAG:.+]] = constant dense<0.000000e+00> : tensor<4x8xf32> // CHECK: %[[C1:.+]] = constant 1 : index