Cast tensor.empty type to TypeConverter's type during materialization. (#15375)
It is a bug introduced by
https://github.com/openxla/iree/commit/03d655a99408f43bd5d960172b15fb0d5823e633
VMVX has a different assumption because it generally supports dynamic
cases. I.e., it treats all the matmul as dynamic shapes even they are
static inputs. This could make inferred shape and TypeConverter's target
shape different in terms of dynamism. Casting it to the target type to
make dialect conversion happy.
diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/test/vmvx_materialize_encoding.mlir b/compiler/src/iree/compiler/Codegen/Common/CPU/test/vmvx_materialize_encoding.mlir
index 66001ba..04afc11 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CPU/test/vmvx_materialize_encoding.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/CPU/test/vmvx_materialize_encoding.mlir
@@ -65,3 +65,48 @@
// CHECK-SAME: outs(%[[OUTS]] :
// CHECK: flow.dispatch.tensor.store %[[MMT4D]], %[[OUTS_BINDING]]
// CHECK-SAME: offsets = [0, 0, 0, 0], sizes = [%[[RESULT_OUTER_SIZE0]], %[[RESULT_OUTER_SIZE1]], %[[RESULT_TILE_SIZES]]#0, %[[RESULT_TILE_SIZES]]#1], strides = [1, 1, 1, 1]
+
+// -----
+
+#map = affine_map<()[s0] -> ((3 ceildiv s0) * s0)>
+#map1 = affine_map<()[s0] -> ((1 ceildiv s0) * s0)>
+func.func @fill_matmul(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index, %arg7: index) attributes {
+ hal.executable.target = #hal.executable.target<"vmvx", "vmvx-bytecode-fb">
+} {
+ %c32_i64 = arith.constant 32 : i64
+ %cst = arith.constant 0.000000e+00 : f32
+ %c0 = arith.constant 0 : index
+ %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<?x?xf32, #iree_linalg_ext.encoding<user = MATMUL, role = LHS, element_types = [f32, f32, f32], original_type = tensor<1x2xf32>>>>{%arg0, %arg1}
+ %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<?x?xf32, #iree_linalg_ext.encoding<user = MATMUL, role = RHS, element_types = [f32, f32, f32], original_type = tensor<2x3xf32>>>>{%arg2, %arg3}
+ %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<?x?xf32, #iree_linalg_ext.encoding<user = MATMUL, role = RESULT, element_types = [f32, f32, f32], original_type = tensor<1x3xf32>>>>{%arg4, %arg5}
+ %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [%arg0, %arg1], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<?x?xf32, #iree_linalg_ext.encoding<user = MATMUL, role = LHS, element_types = [f32, f32, f32], original_type = tensor<1x2xf32>>>>{%arg0, %arg1} -> tensor<?x?xf32, #iree_linalg_ext.encoding<user = MATMUL, role = LHS, element_types = [f32, f32, f32], original_type = tensor<1x2xf32>>>
+ %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [%arg2, %arg3], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<?x?xf32, #iree_linalg_ext.encoding<user = MATMUL, role = RHS, element_types = [f32, f32, f32], original_type = tensor<2x3xf32>>>>{%arg2, %arg3} -> tensor<?x?xf32, #iree_linalg_ext.encoding<user = MATMUL, role = RHS, element_types = [f32, f32, f32], original_type = tensor<2x3xf32>>>
+ %5 = affine.apply #map()[%arg6]
+ %6 = affine.apply #map1()[%arg7]
+ %7 = tensor.empty(%6, %5) : tensor<?x?xf32, #iree_linalg_ext.encoding<user = MATMUL, role = RESULT, element_types = [f32, f32, f32], original_type = tensor<1x3xf32>>>
+ %8 = linalg.fill ins(%cst : f32) outs(%7 : tensor<?x?xf32, #iree_linalg_ext.encoding<user = MATMUL, role = RESULT, element_types = [f32, f32, f32], original_type = tensor<1x3xf32>>>) -> tensor<?x?xf32, #iree_linalg_ext.encoding<user = MATMUL, role = RESULT, element_types = [f32, f32, f32], original_type = tensor<1x3xf32>>>
+ %9 = linalg.matmul ins(%3, %4 : tensor<?x?xf32, #iree_linalg_ext.encoding<user = MATMUL, role = LHS, element_types = [f32, f32, f32], original_type = tensor<1x2xf32>>>, tensor<?x?xf32, #iree_linalg_ext.encoding<user = MATMUL, role = RHS, element_types = [f32, f32, f32], original_type = tensor<2x3xf32>>>) outs(%8 : tensor<?x?xf32, #iree_linalg_ext.encoding<user = MATMUL, role = RESULT, element_types = [f32, f32, f32], original_type = tensor<1x3xf32>>>) -> tensor<?x?xf32, #iree_linalg_ext.encoding<user = MATMUL, role = RESULT, element_types = [f32, f32, f32], original_type = tensor<1x3xf32>>>
+ flow.dispatch.tensor.store %9, %2, offsets = [0, 0], sizes = [%arg4, %arg5], strides = [1, 1] : tensor<?x?xf32, #iree_linalg_ext.encoding<user = MATMUL, role = RESULT, element_types = [f32, f32, f32], original_type = tensor<1x3xf32>>> -> !flow.dispatch.tensor<writeonly:tensor<?x?xf32, #iree_linalg_ext.encoding<user = MATMUL, role = RESULT, element_types = [f32, f32, f32], original_type = tensor<1x3xf32>>>>{%arg4, %arg5}
+ return
+}
+// CHECK: func.func @fill_matmul
+// CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[LHS_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(0)
+// CHECK-SAME: !flow.dispatch.tensor<readonly:tensor<1x1x8x4xf32>>
+// CHECK: %[[RHS_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(1)
+// CHECK-SAME: !flow.dispatch.tensor<readonly:tensor<1x1x8x4xf32>>
+// CHECK: %[[OUTS_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(2)
+// CHECK-SAME: !flow.dispatch.tensor<writeonly:tensor<1x1x8x8xf32>>
+// CHECK: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]
+// CHECK-SAME: offsets = [0, 0, 0, 0], sizes = [1, 1, 8, 4], strides = [1, 1, 1, 1]
+// CHECK: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]
+// CHECK-SAME: offsets = [0, 0, 0, 0], sizes = [1, 1, 8, 4], strides = [1, 1, 1, 1]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x1x8x8xf32>
+// CHECK: %[[FILL:.+]] = linalg.fill
+// CHECK-SAME: ins(%[[ZERO]]
+// CHECK-SAME: outs(%[[EMPTY]]
+// CHECK: %[[MMT4D:.+]] = linalg.mmt4d
+// CHECK-SAME: ins(%[[LHS]], %[[RHS]]
+// CHECK-SAME: outs(%[[FILL]]
+// CHECK: flow.dispatch.tensor.store %[[MMT4D]], %[[OUTS_BINDING]]
+// CHECK-SAME: offsets = [0, 0, 0, 0], sizes = [1, 1, 8, 8], strides = [1, 1, 1, 1]
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp
index 0dd823b..780b9a4 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp
@@ -530,7 +530,19 @@
this->materializeEncodingValueFn);
if (failed(convertedOp))
return failure();
- rewriter.replaceOp(op, convertedOp.value()->getResults());
+
+ SmallVector<Value> replacements;
+ for (auto [type, res] : llvm::zip_equal(
+ op->getResultTypes(), convertedOp.value()->getResults())) {
+ Type targetType = this->getTypeConverter()->convertType(type);
+ if (targetType == res.getType()) {
+ replacements.push_back(res);
+ } else {
+ replacements.push_back(
+ rewriter.create<tensor::CastOp>(op.getLoc(), targetType, res));
+ }
+ }
+ rewriter.replaceOp(op, replacements);
return success();
}
};