Encode the matmul type triple in `TensorEncoding` (#11355)
`TensorEncoding` is used to determine tile sizes in
`MaterializeEncoding`. These depend on which SIMD instructions we want
to use to perform a matmul, which in turn depends on the (LHS, RHS,
RESULT) element type triple of the matmul --- not just one tensor's
element type in isolation.
While updating `SetEncoding`, I found some opportunity to tidy up. The
pattern was performing some rewrites (padding) before it was done
potentially returning failure, so I have reordered that. And the local
variables are renamed to be more explanatory.
diff --git a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingPass.cpp b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingPass.cpp
index e0d5c42..6780903 100644
--- a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingPass.cpp
@@ -22,6 +22,9 @@
namespace mlir {
namespace iree_compiler {
+using IREE::LinalgExt::MaterializeEncodingInfo;
+using IREE::LinalgExt::TensorEncoding;
+
/// For `dispatchTensorType` that bind a `RankedTensorType` with encoding,
/// returns the materialized shape of the `dispatchTensorType`. The
/// dynamic dimensions of the `dispatchTensorType` are provided in
@@ -44,7 +47,7 @@
IREE::LinalgExt::MaterializeEncodingFn materializeEncodingFn =
typeConverter.getMaterializeEncodingFn();
- FailureOr<IREE::LinalgExt::MaterializeEncodingInfo> encodingInfo =
+ FailureOr<MaterializeEncodingInfo> encodingInfo =
materializeEncodingFn(boundTensorType);
if (failed(encodingInfo)) {
return failure();
@@ -86,8 +89,7 @@
namespace {
/// Extract encoding from the `tensorType` if specified.
-static Optional<IREE::LinalgExt::TensorEncoding> getEncoding(
- RankedTensorType tensorType) {
+static Optional<TensorEncoding> getEncoding(RankedTensorType tensorType) {
auto encodingAttr = tensorType.getEncoding()
.dyn_cast_or_null<IREE::LinalgExt::EncodingAttr>();
if (!encodingAttr) return llvm::None;
@@ -98,25 +100,29 @@
/// materializing the pack op.
// TODO(bjacob): This is in the process of being actually implemented in a way
// that actually uses target information.
-static FailureOr<IREE::LinalgExt::MaterializeEncodingInfo> chooseEncodingInfo(
+static FailureOr<MaterializeEncodingInfo> chooseEncodingInfo(
RankedTensorType tensorType, Operation *op) {
auto target = IREE::HAL::ExecutableTargetAttr::lookup(op);
// TODO: actually use `target`.
(void)target;
- Optional<IREE::LinalgExt::TensorEncoding> encoding = getEncoding(tensorType);
+ Optional<TensorEncoding> encoding = getEncoding(tensorType);
if (!encoding) return failure();
switch (*encoding) {
- case IREE::LinalgExt::TensorEncoding::GEMM_LHS:
- return IREE::LinalgExt::MaterializeEncodingInfo{{0, 1}, {8, 4}, {}};
+ case TensorEncoding::MATMUL_F32F32F32_LHS:
+ case TensorEncoding::MATMUL_I8I8I32_LHS:
+ return MaterializeEncodingInfo{{0, 1}, {8, 4}, {}};
break;
- case IREE::LinalgExt::TensorEncoding::GEMM_RHS:
- return IREE::LinalgExt::MaterializeEncodingInfo{{0, 1}, {4, 8}, {}};
+ case TensorEncoding::MATMUL_F32F32F32_RHS:
+ case TensorEncoding::MATMUL_I8I8I32_RHS:
+ return MaterializeEncodingInfo{{0, 1}, {4, 8}, {}};
break;
- case IREE::LinalgExt::TensorEncoding::GEMM_RESULT:
- return IREE::LinalgExt::MaterializeEncodingInfo{{0, 1}, {8, 8}, {}};
+ case TensorEncoding::MATMUL_F32F32F32_RHS_TRANSPOSE:
+ case TensorEncoding::MATMUL_I8I8I32_RHS_TRANSPOSE:
+ return MaterializeEncodingInfo{{1, 0}, {8, 4}, {1, 0}};
break;
- case IREE::LinalgExt::TensorEncoding::GEMM_RHS_TRANSPOSE:
- return IREE::LinalgExt::MaterializeEncodingInfo{{1, 0}, {8, 4}, {1, 0}};
+ case TensorEncoding::MATMUL_F32F32F32_RESULT:
+ case TensorEncoding::MATMUL_I8I8I32_RESULT:
+ return MaterializeEncodingInfo{{0, 1}, {8, 8}, {}};
break;
default:
return failure();
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/materialize_encoding.mlir b/compiler/src/iree/compiler/Codegen/Common/test/materialize_encoding.mlir
index 6b5a159..84888e4 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/materialize_encoding.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/materialize_encoding.mlir
@@ -10,7 +10,7 @@
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64)
: !flow.dispatch.tensor<readonly:tensor<?x?xf32>>{%d0, %d1}
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64)
- : !flow.dispatch.tensor<writeonly:tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>>{%outd0, %outd1}
+ : !flow.dispatch.tensor<writeonly:tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>>{%outd0, %outd1}
%2 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [%d0, %d1], strides = [1, 1]
: !flow.dispatch.tensor<readonly:tensor<?x?xf32>>{%d0, %d1} -> tensor<?x?xf32>
%p0 = affine.apply affine_map<()[s0, s1] -> (-s0 + s1)>()[%d0, %outd0]
@@ -19,10 +19,10 @@
^bb0(%arg0: index, %arg1: index):
tensor.yield %cst : f32
} : tensor<?x?xf32> to tensor<?x?xf32>
- %3 = iree_linalg_ext.set_encoding %padded : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
+ %3 = iree_linalg_ext.set_encoding %padded : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
flow.dispatch.tensor.store %3, %1, offsets = [0, 0], sizes = [%outd0, %outd1], strides = [1, 1]
- : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
- -> !flow.dispatch.tensor<writeonly:tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>>{%outd0, %outd1}
+ : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
+ -> !flow.dispatch.tensor<writeonly:tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>>{%outd0, %outd1}
return
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
@@ -56,14 +56,14 @@
%outd0 = hal.interface.constant.load [2] : index
%outd1 = hal.interface.constant.load [3] : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64)
- : !flow.dispatch.tensor<readonly:tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>>{%d0, %d1}
+ : !flow.dispatch.tensor<readonly:tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>>{%d0, %d1}
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64)
: !flow.dispatch.tensor<writeonly:tensor<?x?xf32>>{%outd0, %outd1}
%2 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [%d0, %d1], strides = [1, 1]
- : !flow.dispatch.tensor<readonly:tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>>{%d0, %d1}
- -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
+ : !flow.dispatch.tensor<readonly:tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>>{%d0, %d1}
+ -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
%3 = iree_linalg_ext.unset_encoding %2
- : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>> -> tensor<?x?xf32>
+ : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>> -> tensor<?x?xf32>
%4 = tensor.extract_slice %3[0, 0] [%outd0, %outd1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
flow.dispatch.tensor.store %4, %1, offsets = [0, 0], sizes = [%outd0, %outd1], strides = [1, 1]
: tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:tensor<?x?xf32>>{%outd0, %outd1}
@@ -97,28 +97,28 @@
%N = hal.interface.constant.load[1] : index
%K = hal.interface.constant.load[2] : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64)
- : !flow.dispatch.tensor<readonly:tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>>{%M, %K}
+ : !flow.dispatch.tensor<readonly:tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>>{%M, %K}
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64)
- : !flow.dispatch.tensor<readonly:tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS_TRANSPOSE>>>{%K, %N}
+ : !flow.dispatch.tensor<readonly:tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS_TRANSPOSE>>>{%K, %N}
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64)
- : !flow.dispatch.tensor<readwrite:tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>>{%M, %N}
+ : !flow.dispatch.tensor<readwrite:tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>>{%M, %N}
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [%M, %K], strides = [1, 1]
- : !flow.dispatch.tensor<readonly:tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>>{%M, %K}
- -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
+ : !flow.dispatch.tensor<readonly:tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>>{%M, %K}
+ -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [%K, %N], strides = [1, 1]
- : !flow.dispatch.tensor<readonly:tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS_TRANSPOSE>>>{%K, %N}
- -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS_TRANSPOSE>>
+ : !flow.dispatch.tensor<readonly:tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS_TRANSPOSE>>>{%K, %N}
+ -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS_TRANSPOSE>>
%5 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1]
- : !flow.dispatch.tensor<readwrite:tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>>{%M, %N}
- -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>
+ : !flow.dispatch.tensor<readwrite:tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>>{%M, %N}
+ -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>
%6 = linalg.matmul
- ins(%3, %4 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>,
- tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS_TRANSPOSE>>)
- outs(%5 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>)
- -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>
+ ins(%3, %4 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>,
+ tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS_TRANSPOSE>>)
+ outs(%5 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>)
+ -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>
flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1]
- : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>
- -> !flow.dispatch.tensor<readwrite:tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>>{%M, %N}
+ : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>
+ -> !flow.dispatch.tensor<readwrite:tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>>{%M, %N}
return
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/SetEncoding.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/SetEncoding.cpp
index 50332cf..27cb387 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/SetEncoding.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/SetEncoding.cpp
@@ -29,6 +29,8 @@
namespace IREE {
namespace Flow {
+using IREE::LinalgExt::TensorEncoding;
+
//===---------------------------------------------------------------------===//
// Utility functions
//===---------------------------------------------------------------------===//
@@ -128,44 +130,81 @@
llvm::any_of(outputs, hasEncoding)) {
return failure();
}
+
+ Value origLhs = inputs[0]->get();
+ Value origRhs = inputs[1]->get();
+ Value origOut = outputs[0]->get();
+
+ auto getElemType = [](Value v) -> Type {
+ if (auto tensorType = v.getType().dyn_cast<RankedTensorType>()) {
+ return tensorType.getElementType();
+ }
+ return {};
+ };
+ Type lhsElemType = getElemType(origLhs);
+ Type rhsElemType = getElemType(origRhs);
+ Type outElemType = getElemType(origOut);
+
+ if (!lhsElemType || !rhsElemType || !outElemType) {
+ return failure();
+ }
+
+ TensorEncoding lhsEncoding;
+ TensorEncoding rhsEncoding;
+ TensorEncoding outEncoding;
+
+ if (lhsElemType.isF32() && rhsElemType.isF32() && outElemType.isF32()) {
+ lhsEncoding = TensorEncoding::MATMUL_F32F32F32_LHS;
+ rhsEncoding = TensorEncoding::MATMUL_F32F32F32_RHS_TRANSPOSE;
+ outEncoding = TensorEncoding::MATMUL_F32F32F32_RESULT;
+ } else if (lhsElemType.isSignlessInteger(8) &&
+ rhsElemType.isSignlessInteger(8) &&
+ outElemType.isSignlessInteger(32)) {
+ lhsEncoding = TensorEncoding::MATMUL_I8I8I32_LHS;
+ rhsEncoding = TensorEncoding::MATMUL_I8I8I32_RHS_TRANSPOSE;
+ outEncoding = TensorEncoding::MATMUL_I8I8I32_RESULT;
+ } else {
+ return rewriter.notifyMatchFailure(
+ matmulOp,
+ "unhandled combination of (lhs, rhs, result) element types");
+ }
+
Location loc = matmulOp.getLoc();
// Set encoding for LHS (pad if necessary)
- FailureOr<Value> lhs =
- padIfNeeded(rewriter, loc, inputs[0]->get(), padding);
- if (failed(lhs)) {
+ FailureOr<Value> paddedLhs = padIfNeeded(rewriter, loc, origLhs, padding);
+ if (failed(paddedLhs)) {
return rewriter.notifyMatchFailure(matmulOp, "failed to pad lhs");
}
- Value lhsEncoding = rewriter.create<IREE::LinalgExt::SetEncodingOp>(
- loc, lhs.value(), IREE::LinalgExt::TensorEncoding::GEMM_LHS);
// Set encoding for RHS (pad if necessary)
- FailureOr<Value> rhs =
- padIfNeeded(rewriter, loc, inputs[1]->get(), padding);
- if (failed(rhs)) {
+ FailureOr<Value> paddedRhs = padIfNeeded(rewriter, loc, origRhs, padding);
+ if (failed(paddedRhs)) {
return rewriter.notifyMatchFailure(matmulOp, "failed to pad rhs");
}
- Value rhsEncoding = rewriter.create<IREE::LinalgExt::SetEncodingOp>(
- loc, rhs.value(), IREE::LinalgExt::TensorEncoding::GEMM_RHS_TRANSPOSE);
// Set encoding for OUTS (pad if necessary)
- FailureOr<Value> output =
- padIfNeeded(rewriter, loc, outputs[0]->get(), padding);
- if (failed(output)) {
+ FailureOr<Value> paddedOut = padIfNeeded(rewriter, loc, origOut, padding);
+ if (failed(paddedOut)) {
return rewriter.notifyMatchFailure(matmulOp, "failed to pad output");
}
- Value outsEncoding = rewriter.create<IREE::LinalgExt::SetEncodingOp>(
- loc, output.value(), IREE::LinalgExt::TensorEncoding::GEMM_RESULT);
+
+ Value encodedLhs = rewriter.create<IREE::LinalgExt::SetEncodingOp>(
+ loc, paddedLhs.value(), lhsEncoding);
+ Value encodedRhs = rewriter.create<IREE::LinalgExt::SetEncodingOp>(
+ loc, paddedRhs.value(), rhsEncoding);
+ Value encodedOut = rewriter.create<IREE::LinalgExt::SetEncodingOp>(
+ loc, paddedOut.value(), outEncoding);
auto matmulTiled = rewriter.create<linalg::MatmulOp>(
- loc, outsEncoding.getType(), ValueRange{lhsEncoding, rhsEncoding},
- outsEncoding);
+ loc, encodedOut.getType(), ValueRange{encodedLhs, encodedRhs},
+ encodedOut);
auto unsetEncoding = rewriter.create<IREE::LinalgExt::UnsetEncodingOp>(
loc, matmulTiled.getResult(0));
Value replacement = unsetEncoding.getResult();
// If the output was padded, extract the actual output.
- if (output.value() != outputs[0]->get()) {
+ if (paddedOut.value() != origOut) {
auto replacementRank =
replacement.getType().cast<RankedTensorType>().getRank();
// Offsets are all 0.
@@ -177,7 +216,7 @@
// Sizes are computed by original output size.
FailureOr<SmallVector<OpFoldResult>> sizes =
- LinalgExt::getDims(rewriter, loc, outputs[0]->get());
+ LinalgExt::getDims(rewriter, loc, origOut);
if (failed(sizes)) {
return rewriter.notifyMatchFailure(matmulOp,
"failed to get shape of result");
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
index 8d27060..1c12fca 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
@@ -1816,10 +1816,10 @@
// -----
func.func @set_encoding_op(%arg0 : tensor<?x?xf32>)
- -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>> {
+ -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>> {
%0 = iree_linalg_ext.set_encoding %arg0
- : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
- return %0 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
+ : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
+ return %0 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
}
// CHECK: func @set_encoding_op
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>
@@ -1831,13 +1831,13 @@
// CHECK-NEXT: %[[INARG:.+]]: !flow.dispatch.tensor<readonly:tensor<?x?xf32>>
// CHECK-SAME: %[[INDEXARG0:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[INDEXARG1:[a-zA-Z0-9]+]]: index
-// CHECK-SAME: %[[OUTARG:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<writeonly:tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>>
+// CHECK-SAME: %[[OUTARG:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<writeonly:tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>>
// CHECK: %[[LOAD:.+]] = flow.dispatch.tensor.load %[[INARG]]
// CHECK-SAME: !flow.dispatch.tensor<readonly:tensor<?x?xf32>>{%[[INDEXARG0]], %[[INDEXARG1]]}
// CHECK: %[[ENCODING:.+]] = iree_linalg_ext.set_encoding %[[LOAD]]
// CHECK: flow.dispatch.tensor.store %[[ENCODING]], %[[OUTARG]]
// CHECK-SAME: sizes = [%[[INDEXARG0]], %[[INDEXARG1]]]
-// CHECK-SAME: !flow.dispatch.tensor<writeonly:tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>>{%[[INDEXARG0]], %[[INDEXARG1]]}
+// CHECK-SAME: !flow.dispatch.tensor<writeonly:tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>>{%[[INDEXARG0]], %[[INDEXARG1]]}
// CHECK: flow.return
// CHECK: count(%[[WL0:[a-zA-Z0-9]+]]: index, %[[WL1:.+]]: index)
// CHECK: %[[X:[a-zA-Z0-9]+]], %[[Y:[a-zA-Z0-9]+]], %[[Z:.+]] = flow.dispatch.workgroup_count_from_set_encoding_op %[[WL0]], %[[WL1]]
@@ -1846,26 +1846,26 @@
// -----
-func.func @unset_encoding_op(%arg0 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>)
+func.func @unset_encoding_op(%arg0 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>)
-> tensor<?x?xf32> {
%0 = iree_linalg_ext.unset_encoding %arg0
- : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>> -> tensor<?x?xf32>
+ : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>> -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
// CHECK: func @unset_encoding_op
-// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
// CHECK: %[[DISPATCH:.+]] = flow.dispatch.workgroups[%[[D0]], %[[D1]]](%[[ARG0]], %[[D0]], %[[D1]])
-// CHECK-NEXT: %[[INARG:.+]]: !flow.dispatch.tensor<readonly:tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>>
+// CHECK-NEXT: %[[INARG:.+]]: !flow.dispatch.tensor<readonly:tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>>
// CHECK-SAME: %[[INDEXARG0:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[INDEXARG1:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[OUTARG:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<writeonly:tensor<?x?xf32>>
// CHECK: %[[LOAD:.+]] = flow.dispatch.tensor.load %[[INARG]]
// CHECK-SAME: sizes = [%[[INDEXARG0]], %[[INDEXARG1]]]
-// CHECK-SAME: !flow.dispatch.tensor<readonly:tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>>{%[[INDEXARG0]], %[[INDEXARG1]]}
+// CHECK-SAME: !flow.dispatch.tensor<readonly:tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>>{%[[INDEXARG0]], %[[INDEXARG1]]}
// CHECK: %[[ENCODING:.+]] = iree_linalg_ext.unset_encoding %[[LOAD]]
// CHECK: flow.dispatch.tensor.store %[[ENCODING]], %[[OUTARG]]
// CHECK-SAME: !flow.dispatch.tensor<writeonly:tensor<?x?xf32>>{%[[INDEXARG0]], %[[INDEXARG1]]}
@@ -1879,7 +1879,7 @@
#map = affine_map<()[s0] -> (-s0 + (s0 ceildiv 16) * 16)>
func.func @pad_and_set_encoding_op(%arg0 : tensor<?x?xf32>)
- -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>> {
+ -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%cst = arith.constant 0.0 : f32
@@ -1892,8 +1892,8 @@
tensor.yield %cst : f32
} : tensor<?x?xf32> to tensor<?x?xf32>
%encoding = iree_linalg_ext.set_encoding %pad
- : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
- return %encoding : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
+ : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
+ return %encoding : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> ((s0 ceildiv 16) * 16)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (-s0 + (s0 ceildiv 16) * 16)>
@@ -1911,7 +1911,7 @@
// CHECK-SAME: %[[INARG:.+]]: !flow.dispatch.tensor<readonly:tensor<?x?xf32>>
// CHECK-SAME: %[[PADDED_D0:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[PADDED_D1:[a-zA-Z0-9]+]]: index
-// CHECK-SAME: %[[OUTARG:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<writeonly:tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>>
+// CHECK-SAME: %[[OUTARG:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<writeonly:tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>>
// CHECK: %[[LOAD:.+]] = flow.dispatch.tensor.load %[[INARG]]
// CHECK-SAME: !flow.dispatch.tensor<readonly:tensor<?x?xf32>>{%[[INDEXARG0]], %[[INDEXARG1]]}
// CHECK: %[[HIGHPAD1:.+]] = affine.apply #[[MAP1]]()[%[[INDEXARG1]]]
@@ -1920,7 +1920,7 @@
// CHECK: %[[SET_ENCODING:.+]] = iree_linalg_ext.set_encoding %[[PADDED]]
// CHECK: flow.dispatch.tensor.store %[[SET_ENCODING]], %[[OUTARG]]
// CHECK-SAME: sizes = [%[[PADDED_D0]], %[[PADDED_D1]]]
-// CHECK-SAME: !flow.dispatch.tensor<writeonly:tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>>{%[[PADDED_D0]], %[[PADDED_D1]]}
+// CHECK-SAME: !flow.dispatch.tensor<writeonly:tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>>{%[[PADDED_D0]], %[[PADDED_D1]]}
// CHECK: flow.return
// CHECK: count(%[[WL0:[a-zA-Z0-9]+]]: index, %[[WL1:.+]]: index)
// CHECK: %[[X:[a-zA-Z0-9]+]], %[[Y:[a-zA-Z0-9]+]], %[[Z:.+]] = flow.dispatch.workgroup_count_from_set_encoding_op %[[WL0]], %[[WL1]]
@@ -1930,16 +1930,16 @@
// -----
func.func @unset_encoding_and_slice(
- %arg0: tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>,
+ %arg0: tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>,
%arg1 : index, %arg2 : index) -> tensor<?x?xf32> {
%0 = iree_linalg_ext.unset_encoding %arg0
- : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>> -> tensor<?x?xf32>
+ : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>> -> tensor<?x?xf32>
%1 = tensor.extract_slice %0[0, 0] [%arg1, %arg2] [1, 1]
: tensor<?x?xf32> to tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
// CHECK: func @unset_encoding_and_slice
-// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
@@ -1947,7 +1947,7 @@
// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
// CHECK: %[[DISPATCH:.+]] = flow.dispatch.workgroups[%[[ARG1]], %[[ARG2]]](%[[ARG0]], %[[D0]], %[[D1]], %[[ARG1]], %[[ARG2]])
-// CHECK-NEXT: %[[INARG:.+]]: !flow.dispatch.tensor<readonly:tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>>
+// CHECK-NEXT: %[[INARG:.+]]: !flow.dispatch.tensor<readonly:tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>>
// CHECK-SAME: %[[INDEXARG0:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[INDEXARG1:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[INDEXARG2:[a-zA-Z0-9]+]]: index
@@ -1955,7 +1955,7 @@
// CHECK-SAME: %[[OUTARG:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<writeonly:tensor<?x?xf32>>
// CHECK: %[[LOAD:.+]] = flow.dispatch.tensor.load %[[INARG]]
// CHECK-SAME: sizes = [%[[INDEXARG0]], %[[INDEXARG1]]]
-// CHECK-SAME: !flow.dispatch.tensor<readonly:tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>>{%[[INDEXARG0]], %[[INDEXARG1]]}
+// CHECK-SAME: !flow.dispatch.tensor<readonly:tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>>{%[[INDEXARG0]], %[[INDEXARG1]]}
// CHECK: %[[ENCODING:.+]] = iree_linalg_ext.unset_encoding %[[LOAD]]
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ENCODING]][0, 0] [%[[INDEXARG2]], %[[INDEXARG3]]]
// CHECK: flow.dispatch.tensor.store %[[SLICE]], %[[OUTARG]]
@@ -1966,26 +1966,26 @@
// -----
func.func @gemm_encoded(
- %arg0 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>,
- %arg1 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS_TRANSPOSE>>,
- %arg2 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>)
- -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>> {
+ %arg0 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>,
+ %arg1 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS_TRANSPOSE>>,
+ %arg2 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>)
+ -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>> {
%0 = linalg.matmul
ins(%arg0, %arg1
- : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>,
- tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS_TRANSPOSE>>)
- outs(%arg2 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>)
- -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>
- return %0 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>
+ : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>,
+ tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS_TRANSPOSE>>)
+ outs(%arg2 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>)
+ -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>
+ return %0 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>
}
// CHECK: func.func @gemm_encoded
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS_TRANSPOSE>>
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS_TRANSPOSE>>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>
// CHECK: %[[DISPATCH:.+]] = flow.dispatch.workgroups
-// CHECK-NEXT: %[[LHS_IN:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<readonly:tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>>
-// CHECK-SAME: %[[RHS_IN:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<readonly:tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS_TRANSPOSE>>>
-// CHECK-SAME: %[[INIT_IN:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<readwrite:tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>>
+// CHECK-NEXT: %[[LHS_IN:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<readonly:tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>>
+// CHECK-SAME: %[[RHS_IN:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<readonly:tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS_TRANSPOSE>>>
+// CHECK-SAME: %[[INIT_IN:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<readwrite:tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>>
// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_IN]]
// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_IN]]
// CHECK-DAG: %[[INIT:.+]] = flow.dispatch.tensor.load %[[INIT_IN]]
@@ -1997,32 +1997,32 @@
// -----
func.func @gemm_fill_encoded(
- %arg0 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>,
- %arg1 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS_TRANSPOSE>>)
- -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>> {
+ %arg0 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>,
+ %arg1 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS_TRANSPOSE>>)
+ -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%cst = arith.constant 0.0 : f32
- %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
- %d1 = tensor.dim %arg1, %c1 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS_TRANSPOSE>>
- %empty = tensor.empty(%d0, %d1) : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>
- %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>)
- -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>
+ %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
+ %d1 = tensor.dim %arg1, %c1 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS_TRANSPOSE>>
+ %empty = tensor.empty(%d0, %d1) : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>
+ %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>)
+ -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>
%0 = linalg.matmul
ins(%arg0, %arg1
- : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>,
- tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS_TRANSPOSE>>)
- outs(%fill : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>)
- -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>
- return %0 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>
+ : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>,
+ tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS_TRANSPOSE>>)
+ outs(%fill : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>)
+ -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>
+ return %0 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>
}
// CHECK: func.func @gemm_fill_encoded
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS_TRANSPOSE>>
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS_TRANSPOSE>>
// CHECK: %[[DISPATCH:.+]] = flow.dispatch.workgroups
-// CHECK-NEXT: %[[LHS_IN:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<readonly:tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>>
-// CHECK-SAME: %[[RHS_IN:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<readonly:tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS_TRANSPOSE>>>
-// CHECK-SAME: %[[RESULT:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<writeonly:tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>>
+// CHECK-NEXT: %[[LHS_IN:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<readonly:tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>>
+// CHECK-SAME: %[[RHS_IN:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<readonly:tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS_TRANSPOSE>>>
+// CHECK-SAME: %[[RESULT:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<writeonly:tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>>
// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_IN]]
// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_IN]]
// CHECK: %[[EMPTY:.+]] = tensor.empty
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/set_encoding.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/set_encoding.mlir
index 2d96ea9..3093d36 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/set_encoding.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/set_encoding.mlir
@@ -12,11 +12,11 @@
// CHECK-SAME: %[[ARG1:.+]]: tensor<256x512xf32>
// CHECK-SAME: %[[ARG2:.+]]: tensor<128x512xf32>
// CHECK: %[[LHS:.+]] = iree_linalg_ext.set_encoding %[[ARG0]]
-// CHECK-SAME: tensor<128x256xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
+// CHECK-SAME: tensor<128x256xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
// CHECK: %[[RHS:.+]] = iree_linalg_ext.set_encoding %[[ARG1]]
-// CHECK-SAME: tensor<256x512xf32, #iree_linalg_ext.encoding<GEMM_RHS_TRANSPOSE>>
+// CHECK-SAME: tensor<256x512xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS_TRANSPOSE>>
// CHECK: %[[OUTS:.+]] = iree_linalg_ext.set_encoding %[[ARG2]]
-// CHECK-SAME: tensor<128x512xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>
+// CHECK-SAME: tensor<128x512xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>
// CHECK: %[[MATMUL:.+]] = linalg.matmul
// CHECK-SAME: ins(%[[LHS]], %[[RHS]] :
// CHECK-SAME: outs(%[[OUTS]] :
@@ -37,16 +37,16 @@
// CHECK-SAME: %[[ARG2:.+]]: tensor<100x500xf32>
// CHECK: %[[LHS_PAD:.+]] = tensor.pad %[[ARG0]] low[0, 0] high[12, 6]
// CHECK: tensor<100x250xf32> to tensor<112x256xf32>
-// CHECK: %[[LHS:.+]] = iree_linalg_ext.set_encoding %[[LHS_PAD]]
-// CHECK-SAME: tensor<112x256xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
// CHECK: %[[RHS_PAD:.+]] = tensor.pad %[[ARG1]] low[0, 0] high[6, 12]
// CHECK: tensor<250x500xf32> to tensor<256x512xf32>
-// CHECK: %[[RHS:.+]] = iree_linalg_ext.set_encoding %[[RHS_PAD]]
-// CHECK-SAME: tensor<256x512xf32, #iree_linalg_ext.encoding<GEMM_RHS_TRANSPOSE>>
// CHECK: %[[OUTS_PAD:.+]] = tensor.pad %[[ARG2]] low[0, 0] high[12, 12]
// CHECK: tensor<100x500xf32> to tensor<112x512xf32>
+// CHECK: %[[LHS:.+]] = iree_linalg_ext.set_encoding %[[LHS_PAD]]
+// CHECK-SAME: tensor<112x256xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
+// CHECK: %[[RHS:.+]] = iree_linalg_ext.set_encoding %[[RHS_PAD]]
+// CHECK-SAME: tensor<256x512xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS_TRANSPOSE>>
// CHECK: %[[OUTS:.+]] = iree_linalg_ext.set_encoding %[[OUTS_PAD]]
-// CHECK-SAME: tensor<112x512xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>
+// CHECK-SAME: tensor<112x512xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>
// CHECK: %[[MATMUL:.+]] = linalg.matmul
// CHECK-SAME: ins(%[[LHS]], %[[RHS]] :
// CHECK-SAME: outs(%[[OUTS]] :
@@ -60,14 +60,14 @@
// PADDING-SAME: %[[ARG2:.+]]: tensor<100x500xf32>
// PADDING: %[[LHS_PAD:.+]] = tensor.pad %[[ARG0]] low[0, 0] high[0, 2]
// PADDING: tensor<100x250xf32> to tensor<100x252xf32>
-// PADDING: %[[LHS:.+]] = iree_linalg_ext.set_encoding %[[LHS_PAD]]
-// PADDING-SAME: tensor<100x252xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
// PADDING: %[[RHS_PAD:.+]] = tensor.pad %[[ARG1]] low[0, 0] high[2, 0]
// PADDING: tensor<250x500xf32> to tensor<252x500xf32>
+// PADDING: %[[LHS:.+]] = iree_linalg_ext.set_encoding %[[LHS_PAD]]
+// PADDING-SAME: tensor<100x252xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
// PADDING: %[[RHS:.+]] = iree_linalg_ext.set_encoding %[[RHS_PAD]]
-// PADDING-SAME: tensor<252x500xf32, #iree_linalg_ext.encoding<GEMM_RHS_TRANSPOSE>>
+// PADDING-SAME: tensor<252x500xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS_TRANSPOSE>>
// PADDING: %[[OUTS:.+]] = iree_linalg_ext.set_encoding %[[ARG2]]
-// PADDING-SAME: tensor<100x500xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>
+// PADDING-SAME: tensor<100x500xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>
// PADDING: %[[MATMUL:.+]] = linalg.matmul
// PADDING-SAME: ins(%[[LHS]], %[[RHS]] :
// PADDING-SAME: outs(%[[OUTS]] :
@@ -95,22 +95,22 @@
// CHECK-DAG: %[[HIGHPAD_LHS_0:.+]] = affine.apply #[[MAP]]()[%[[LHS_D0]]]
// CHECK-DAG: %[[HIGHPAD_LHS_1:.+]] = affine.apply #[[MAP]]()[%[[LHS_D1]]]
// CHECK: %[[LHS_PAD:.+]] = tensor.pad %[[ARG0]] low[0, 0] high[%[[HIGHPAD_LHS_0]], %[[HIGHPAD_LHS_1]]]
-// CHECK: %[[LHS:.+]] = iree_linalg_ext.set_encoding %[[LHS_PAD]]
-// CHECK-SAME: tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
// CHECK-DAG: %[[RHS_D0:.+]] = tensor.dim %[[ARG1]], %[[C0]]
// CHECK-DAG: %[[RHS_D1:.+]] = tensor.dim %[[ARG1]], %[[C1]]
// CHECK-DAG: %[[HIGHPAD_RHS_0:.+]] = affine.apply #[[MAP]]()[%[[RHS_D0]]]
// CHECK-DAG: %[[HIGHPAD_RHS_1:.+]] = affine.apply #[[MAP]]()[%[[RHS_D1]]]
// CHECK: %[[RHS_PAD:.+]] = tensor.pad %[[ARG1]] low[0, 0] high[%[[HIGHPAD_RHS_0]], %[[HIGHPAD_RHS_1]]]
-// CHECK: %[[RHS:.+]] = iree_linalg_ext.set_encoding %[[RHS_PAD]]
-// CHECK-SAME: tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS_TRANSPOSE>>
// CHECK-DAG: %[[OUTS_D0:.+]] = tensor.dim %[[ARG2]], %[[C0]]
// CHECK-DAG: %[[OUTS_D1:.+]] = tensor.dim %[[ARG2]], %[[C1]]
// CHECK-DAG: %[[HIGHPAD_OUTS_0:.+]] = affine.apply #[[MAP]]()[%[[OUTS_D0]]]
// CHECK-DAG: %[[HIGHPAD_OUTS_1:.+]] = affine.apply #[[MAP]]()[%[[OUTS_D1]]]
// CHECK: %[[OUTS_PAD:.+]] = tensor.pad %[[ARG2]] low[0, 0] high[%[[HIGHPAD_OUTS_0]], %[[HIGHPAD_OUTS_1]]]
+// CHECK: %[[LHS:.+]] = iree_linalg_ext.set_encoding %[[LHS_PAD]]
+// CHECK-SAME: tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
+// CHECK: %[[RHS:.+]] = iree_linalg_ext.set_encoding %[[RHS_PAD]]
+// CHECK-SAME: tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS_TRANSPOSE>>
// CHECK: %[[OUTS:.+]] = iree_linalg_ext.set_encoding %[[OUTS_PAD]]
-// CHECK-SAME: tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>
+// CHECK-SAME: tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>
// CHECK: %[[MATMUL:.+]] = linalg.matmul
// CHECK-SAME: ins(%[[LHS]], %[[RHS]] :
// CHECK-SAME: outs(%[[OUTS]] :
@@ -121,24 +121,24 @@
// -----
func.func @fold_fill_with_set_encoding(%arg0 : index, %arg1 : index)
- -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>> {
+ -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>> {
%cst = arith.constant 0.0 : f32
%0 = tensor.empty(%arg0, %arg1) : tensor<?x?xf32>
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<?x?xf32>) -> tensor<?x?xf32>
%2 = iree_linalg_ext.set_encoding %1 : tensor<?x?xf32>
- -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
- return %2 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
+ -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
+ return %2 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
}
// CHECK: func @fold_fill_with_set_encoding(
-// CHECK: %[[EMPTY:.+]] = tensor.empty(%{{.+}}, %{{.+}}) : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
+// CHECK: %[[EMPTY:.+]] = tensor.empty(%{{.+}}, %{{.+}}) : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
// CHECK: %[[FILL:.+]] = linalg.fill
-// CHECK-SAME: outs(%[[EMPTY]] : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>)
+// CHECK-SAME: outs(%[[EMPTY]] : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>)
// CHECK: return %[[FILL]]
// -----
func.func @fold_fill_with_tensor_pad(%arg0 : index, %arg1 : index, %arg2 : index, %arg3 : index)
- -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>> {
+ -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>> {
%cst = arith.constant 0.0 : f32
%0 = tensor.empty(%arg0, %arg1) : tensor<?x?xf32>
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<?x?xf32>) -> tensor<?x?xf32>
@@ -147,12 +147,12 @@
tensor.yield %cst : f32
} : tensor<?x?xf32> to tensor<?x?xf32>
%3 = iree_linalg_ext.set_encoding %2 : tensor<?x?xf32>
- -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>
- return %3 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>
+ -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>
+ return %3 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>
}
// CHECK: func @fold_fill_with_tensor_pad(
// CHECK: %[[EMPTY:.+]] = tensor.empty(
-// CHECK-SAME: tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>
+// CHECK-SAME: tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>
// CHECK: %[[FILL:.+]] = linalg.fill
// CHECK-SAME: outs(%[[EMPTY]] :
// CHECK: return %[[FILL]]
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtBase.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtBase.td
index afb6c3a..c0530fe 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtBase.td
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtBase.td
@@ -45,19 +45,28 @@
: AttrDef<IREELinalgExt_Dialect, name, traits>;
// List of pre-defined data layout encoding attributes.
-def GEMM_LHS
- : I32EnumAttrCase<"GEMM_LHS", 0>;
-def GEMM_RESULT
- : I32EnumAttrCase<"GEMM_RESULT", 1>;
-def GEMM_RHS
- : I32EnumAttrCase<"GEMM_RHS", 2>;
-def GEMM_RHS_TRANSPOSE
- : I32EnumAttrCase<"GEMM_RHS_TRANSPOSE", 3>;
+def MATMUL_F32F32F32_LHS
+ : I32EnumAttrCase<"MATMUL_F32F32F32_LHS", 0>;
+def MATMUL_F32F32F32_RHS
+ : I32EnumAttrCase<"MATMUL_F32F32F32_RHS", 1>;
+def MATMUL_F32F32F32_RHS_TRANSPOSE
+ : I32EnumAttrCase<"MATMUL_F32F32F32_RHS_TRANSPOSE", 2>;
+def MATMUL_F32F32F32_RESULT
+ : I32EnumAttrCase<"MATMUL_F32F32F32_RESULT", 3>;
+def MATMUL_I8I8I32_LHS
+ : I32EnumAttrCase<"MATMUL_I8I8I32_LHS", 4>;
+def MATMUL_I8I8I32_RHS
+ : I32EnumAttrCase<"MATMUL_I8I8I32_RHS", 5>;
+def MATMUL_I8I8I32_RHS_TRANSPOSE
+ : I32EnumAttrCase<"MATMUL_I8I8I32_RHS_TRANSPOSE", 6>;
+def MATMUL_I8I8I32_RESULT
+ : I32EnumAttrCase<"MATMUL_I8I8I32_RESULT", 7>;
def TensorEncodingEnum
: I32EnumAttr<"TensorEncoding",
"identifier for encoding used for the tensor",[
- GEMM_LHS, GEMM_RESULT, GEMM_RHS, GEMM_RHS_TRANSPOSE
+ MATMUL_F32F32F32_LHS, MATMUL_F32F32F32_RHS, MATMUL_F32F32F32_RHS_TRANSPOSE, MATMUL_F32F32F32_RESULT,
+ MATMUL_I8I8I32_LHS, MATMUL_I8I8I32_RHS, MATMUL_I8I8I32_RHS_TRANSPOSE, MATMUL_I8I8I32_RESULT,
]> {
let cppNamespace = "::mlir::iree_compiler::IREE::LinalgExt";
let genSpecializedAttr = 0;
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp
index cdf0258..e29223d 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp
@@ -85,18 +85,22 @@
if (!encoding)
return failure();
switch (*encoding) {
- case TensorEncoding::GEMM_LHS:
+ case TensorEncoding::MATMUL_F32F32F32_LHS:
+ case TensorEncoding::MATMUL_I8I8I32_LHS:
return MaterializeEncodingInfo{{0, 1}, {8, 4}, {}};
break;
- case TensorEncoding::GEMM_RHS:
+ case TensorEncoding::MATMUL_F32F32F32_RHS:
+ case TensorEncoding::MATMUL_I8I8I32_RHS:
return MaterializeEncodingInfo{{0, 1}, {4, 8}, {}};
break;
- case TensorEncoding::GEMM_RESULT:
- return MaterializeEncodingInfo{{0, 1}, {8, 8}, {}};
- break;
- case TensorEncoding::GEMM_RHS_TRANSPOSE:
+ case TensorEncoding::MATMUL_F32F32F32_RHS_TRANSPOSE:
+ case TensorEncoding::MATMUL_I8I8I32_RHS_TRANSPOSE:
return MaterializeEncodingInfo{{1, 0}, {8, 4}, {1, 0}};
break;
+ case TensorEncoding::MATMUL_F32F32F32_RESULT:
+ case TensorEncoding::MATMUL_I8I8I32_RESULT:
+ return MaterializeEncodingInfo{{0, 1}, {8, 8}, {}};
+ break;
default:
return failure();
}
@@ -182,9 +186,9 @@
}
/// Utility method to convert from `linalg.matmul` with
-/// - lhs encoding of GEMM_LHS
-/// - rhs encoding of GEMM_RHS_TRANSPOSE
-/// - result encoding of GEMM_RESULT
+/// - lhs encoding of MATMUL_F32F32F32_LHS
+/// - rhs encoding of MATMUL_F32F32F32_RHS_TRANSPOSE
+/// - result encoding of MATMUL_F32F32F32_RESULT
/// to linalg.mmt4d op.
static FailureOr<Operation *>
lowerOpWithEncoding(RewriterBase &rewriter, linalg::MatmulOp matmulOp,
@@ -201,11 +205,12 @@
getEncoding(inputs[1]->get().getType().cast<RankedTensorType>());
Optional<TensorEncoding> resultEncoding =
getEncoding(outputs[0]->get().getType().cast<RankedTensorType>());
- if (!lhsEncoding || lhsEncoding.value() != TensorEncoding::GEMM_LHS ||
+ if (!lhsEncoding ||
+ lhsEncoding.value() != TensorEncoding::MATMUL_F32F32F32_LHS ||
!rhsEncoding ||
- rhsEncoding.value() != TensorEncoding::GEMM_RHS_TRANSPOSE ||
+ rhsEncoding.value() != TensorEncoding::MATMUL_F32F32F32_RHS_TRANSPOSE ||
!resultEncoding ||
- resultEncoding.value() != TensorEncoding::GEMM_RESULT) {
+ resultEncoding.value() != TensorEncoding::MATMUL_F32F32F32_RESULT) {
return failure();
}
Operation *mmt4DOp = rewriter.create<linalg::Mmt4DOp>(
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/invalid.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/invalid.mlir
index 82be355..aee10ed 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/invalid.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/invalid.mlir
@@ -590,9 +590,9 @@
// -----
-func.func @illegal_set_encoding_op_with_source_encoding(%arg0 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>) -> tensor<?x?xf32> {
+func.func @illegal_set_encoding_op_with_source_encoding(%arg0 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>) -> tensor<?x?xf32> {
// expected-error @+1 {{source of set_encoding op cannot have a tensor encoding}}
- %0 = iree_linalg_ext.set_encoding %arg0: tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>> -> tensor<?x?xf32>
+ %0 = iree_linalg_ext.set_encoding %arg0: tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>> -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
@@ -606,18 +606,18 @@
// -----
-func.func @illegal_set_encoding_op_with_rank_change(%arg0 : tensor<?x?xf32>) -> tensor<?xf32, #iree_linalg_ext.encoding<GEMM_LHS>> {
+func.func @illegal_set_encoding_op_with_rank_change(%arg0 : tensor<?x?xf32>) -> tensor<?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>> {
// expected-error @+1 {{cannot change the rank of the tensor}}
- %0 = iree_linalg_ext.set_encoding %arg0: tensor<?x?xf32> -> tensor<?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
- return %0 : tensor<?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
+ %0 = iree_linalg_ext.set_encoding %arg0: tensor<?x?xf32> -> tensor<?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
+ return %0 : tensor<?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
}
// -----
-func.func @illegal_set_encoding_op_with_shape_change(%arg0 : tensor<10x20xf32>) -> tensor<20x30xf32, #iree_linalg_ext.encoding<GEMM_LHS>> {
+func.func @illegal_set_encoding_op_with_shape_change(%arg0 : tensor<10x20xf32>) -> tensor<20x30xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>> {
// expected-error @+1 {{expected to preserve the logical shape of the tensor}}
- %0 = iree_linalg_ext.set_encoding %arg0: tensor<10x20xf32> -> tensor<20x30xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
- return %0 : tensor<20x30xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
+ %0 = iree_linalg_ext.set_encoding %arg0: tensor<10x20xf32> -> tensor<20x30xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
+ return %0 : tensor<20x30xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
}
// -----
@@ -630,10 +630,10 @@
// -----
-func.func @illegal_unset_encoding_op_with_result_encoding(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>> {
+func.func @illegal_unset_encoding_op_with_result_encoding(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>> {
// expected-error @+1 {{result of unset_encoding op cannot have a tensor encoding}}
- %0 = iree_linalg_ext.unset_encoding %arg0: tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
- return %0 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
+ %0 = iree_linalg_ext.unset_encoding %arg0: tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
+ return %0 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
}
// -----
@@ -646,17 +646,17 @@
// -----
-func.func @illegal_unset_encoding_op_with_rank_change(%arg0 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>) -> tensor<?xf32> {
+func.func @illegal_unset_encoding_op_with_rank_change(%arg0 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>) -> tensor<?xf32> {
// expected-error @+1 {{cannot change the rank of the tensor}}
- %0 = iree_linalg_ext.unset_encoding %arg0: tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>> -> tensor<?xf32>
+ %0 = iree_linalg_ext.unset_encoding %arg0: tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>> -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// -----
-func.func @illegal_unset_encoding_op_with_shape_change(%arg0 : tensor<20x30xf32, #iree_linalg_ext.encoding<GEMM_LHS>>) -> tensor<10x20xf32> {
+func.func @illegal_unset_encoding_op_with_shape_change(%arg0 : tensor<20x30xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>) -> tensor<10x20xf32> {
// expected-error @+1 {{expected to preserve the logical shape of the tensor}}
- %0 = iree_linalg_ext.unset_encoding %arg0: tensor<20x30xf32, #iree_linalg_ext.encoding<GEMM_LHS>> -> tensor<10x20xf32>
+ %0 = iree_linalg_ext.unset_encoding %arg0: tensor<20x30xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>> -> tensor<10x20xf32>
return %0 : tensor<10x20xf32>
}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir
index fe91dc5..35e2053 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir
@@ -1,8 +1,8 @@
// RUN: iree-dialects-opt --iree-linalg-ext-materialize-encoding -cse -split-input-file %s | FileCheck %s
func.func @pack_unpack_gemm_lhs(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
- %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
- %1 = iree_linalg_ext.unset_encoding %0 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>> -> tensor<?x?xf32>
+ %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
+ %1 = iree_linalg_ext.unset_encoding %0 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>> -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
@@ -24,8 +24,8 @@
// -----
func.func @pack_unpack_gemm_rhs(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
- %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS>>
- %1 = iree_linalg_ext.unset_encoding %0 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS>> -> tensor<?x?xf32>
+ %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS>>
+ %1 = iree_linalg_ext.unset_encoding %0 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS>> -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
// CHECK-LABEL: func @pack_unpack_gemm_rhs(
@@ -35,8 +35,8 @@
// -----
func.func @pack_unpack_gemm_rhs_transpose(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
- %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS_TRANSPOSE>>
- %1 = iree_linalg_ext.unset_encoding %0 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS_TRANSPOSE>> -> tensor<?x?xf32>
+ %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS_TRANSPOSE>>
+ %1 = iree_linalg_ext.unset_encoding %0 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS_TRANSPOSE>> -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
// CHECK-LABEL: func @pack_unpack_gemm_rhs_transpose(
@@ -46,8 +46,8 @@
// -----
func.func @pack_unpack_gemm_result(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
- %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>
- %1 = iree_linalg_ext.unset_encoding %0 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>> -> tensor<?x?xf32>
+ %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>
+ %1 = iree_linalg_ext.unset_encoding %0 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>> -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
// CHECK-LABEL: func @pack_unpack_gemm_result(
@@ -62,20 +62,20 @@
^bb0(%b0: index, %b1 : index):
tensor.yield %pad_value : f32
} : tensor<100x250xf32> to tensor<104x252xf32>
- %lhs = iree_linalg_ext.set_encoding %pad_lhs : tensor<104x252xf32> -> tensor<104x252xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
+ %lhs = iree_linalg_ext.set_encoding %pad_lhs : tensor<104x252xf32> -> tensor<104x252xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
%pad_rhs = tensor.pad %arg1 low[0, 0] high[2, 4] {
^bb0(%b0: index, %b1 : index):
tensor.yield %pad_value : f32
} : tensor<250x500xf32> to tensor<252x504xf32>
- %rhs = iree_linalg_ext.set_encoding %pad_rhs : tensor<252x504xf32> -> tensor<252x504xf32, #iree_linalg_ext.encoding<GEMM_RHS_TRANSPOSE>>
+ %rhs = iree_linalg_ext.set_encoding %pad_rhs : tensor<252x504xf32> -> tensor<252x504xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS_TRANSPOSE>>
%pad_output = tensor.pad %arg2 low[0, 0] high[4, 4] {
^bb0(%b0: index, %b1 : index):
tensor.yield %pad_value : f32
} : tensor<100x500xf32> to tensor<104x504xf32>
- %output = iree_linalg_ext.set_encoding %pad_output : tensor<104x504xf32> -> tensor<104x504xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>
- %gemm_packed = linalg.matmul ins(%lhs, %rhs : tensor<104x252xf32, #iree_linalg_ext.encoding<GEMM_LHS>>, tensor<252x504xf32, #iree_linalg_ext.encoding<GEMM_RHS_TRANSPOSE>>)
- outs(%output : tensor<104x504xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>) -> tensor<104x504xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>
- %gemm = iree_linalg_ext.unset_encoding %gemm_packed : tensor<104x504xf32, #iree_linalg_ext.encoding<GEMM_RESULT>> -> tensor<104x504xf32>
+ %output = iree_linalg_ext.set_encoding %pad_output : tensor<104x504xf32> -> tensor<104x504xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>
+ %gemm_packed = linalg.matmul ins(%lhs, %rhs : tensor<104x252xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>, tensor<252x504xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS_TRANSPOSE>>)
+ outs(%output : tensor<104x504xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>) -> tensor<104x504xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>
+ %gemm = iree_linalg_ext.unset_encoding %gemm_packed : tensor<104x504xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>> -> tensor<104x504xf32>
%result = tensor.extract_slice %gemm[0, 0] [100, 500] [1, 1] : tensor<104x504xf32> to tensor<100x500xf32>
return %result : tensor<100x500xf32>
}
@@ -102,12 +102,12 @@
// -----
func.func @pack_gemm_dynamic(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
- %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
- %1 = iree_linalg_ext.set_encoding %arg1 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS_TRANSPOSE>>
- %2 = iree_linalg_ext.set_encoding %arg2 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>
- %3 = linalg.matmul ins(%0, %1 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>, tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS_TRANSPOSE>>)
- outs(%2 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>) -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>
- %4 = iree_linalg_ext.unset_encoding %3 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>> -> tensor<?x?xf32>
+ %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
+ %1 = iree_linalg_ext.set_encoding %arg1 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS_TRANSPOSE>>
+ %2 = iree_linalg_ext.set_encoding %arg2 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>
+ %3 = linalg.matmul ins(%0, %1 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>, tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS_TRANSPOSE>>)
+ outs(%2 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>) -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>
+ %4 = iree_linalg_ext.unset_encoding %3 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>> -> tensor<?x?xf32>
return %4 : tensor<?x?xf32>
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
@@ -133,14 +133,14 @@
%cst = arith.constant 0.0 : f32
%d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
%d1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
- %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
- %1 = iree_linalg_ext.set_encoding %arg1 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS_TRANSPOSE>>
- %2 = tensor.empty(%d0, %d1) : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>
- %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>)
- -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>
- %4 = linalg.matmul ins(%0, %1 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>, tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS_TRANSPOSE>>)
- outs(%3 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>) -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>
- %5 = iree_linalg_ext.unset_encoding %4 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>> -> tensor<?x?xf32>
+ %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
+ %1 = iree_linalg_ext.set_encoding %arg1 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS_TRANSPOSE>>
+ %2 = tensor.empty(%d0, %d1) : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>
+ %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>)
+ -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>
+ %4 = linalg.matmul ins(%0, %1 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>, tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS_TRANSPOSE>>)
+ outs(%3 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>) -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>
+ %5 = iree_linalg_ext.unset_encoding %4 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>> -> tensor<?x?xf32>
return %5 : tensor<?x?xf32>
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/resolve-shaped-type-result-dims.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/resolve-shaped-type-result-dims.mlir
index f5619f7..41bd0f8 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/resolve-shaped-type-result-dims.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/resolve-shaped-type-result-dims.mlir
@@ -3,9 +3,9 @@
func.func @pack_static(%arg0 : tensor<100x250xf32>) -> (index, index) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
- %0 = iree_linalg_ext.set_encoding %arg0 : tensor<100x250xf32> -> tensor<100x250xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
- %1 = tensor.dim %0, %c0 : tensor<100x250xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
- %2 = tensor.dim %0, %c1 : tensor<100x250xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
+ %0 = iree_linalg_ext.set_encoding %arg0 : tensor<100x250xf32> -> tensor<100x250xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
+ %1 = tensor.dim %0, %c0 : tensor<100x250xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
+ %2 = tensor.dim %0, %c1 : tensor<100x250xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
return %1, %2 : index, index
}
// CHECK-LABEL: func @pack_static(
@@ -18,9 +18,9 @@
func.func @pack_dynamic(%arg0 : tensor<?x?xf32>) -> (index, index) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
- %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
- %1 = tensor.dim %0, %c0 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
- %2 = tensor.dim %0, %c1 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
+ %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
+ %1 = tensor.dim %0, %c0 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
+ %2 = tensor.dim %0, %c1 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
return %1, %2 : index, index
}
// CHECK: func @pack_dynamic(%[[ARG0:.+]]: tensor<?x?xf32>)
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/roundtrip.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/roundtrip.mlir
index 71222ab..5cd5f74 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/roundtrip.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/roundtrip.mlir
@@ -869,36 +869,36 @@
// -----
// CHECK: @set_encoding_ops(%[[ARG0:.+]]: tensor<?x?xf32>)
-func.func @set_encoding_ops(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>> {
- // CHECK: iree_linalg_ext.set_encoding %[[ARG0]] : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
- %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
- return %0 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
+func.func @set_encoding_ops(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>> {
+ // CHECK: iree_linalg_ext.set_encoding %[[ARG0]] : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
+ %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
+ return %0 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
}
// -----
// CHECK: @set_encoding_ops_mixed_dynamic_static(%[[ARG0:.+]]: tensor<?x10xf32>)
-func.func @set_encoding_ops_mixed_dynamic_static(%arg0: tensor<?x10xf32>) -> tensor<20x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>> {
- // CHECK: iree_linalg_ext.set_encoding %[[ARG0]] : tensor<?x10xf32> -> tensor<20x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
- %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x10xf32> -> tensor<20x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
- return %0 : tensor<20x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
+func.func @set_encoding_ops_mixed_dynamic_static(%arg0: tensor<?x10xf32>) -> tensor<20x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>> {
+ // CHECK: iree_linalg_ext.set_encoding %[[ARG0]] : tensor<?x10xf32> -> tensor<20x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
+ %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x10xf32> -> tensor<20x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
+ return %0 : tensor<20x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
}
// -----
-// CHECK: @unset_encoding_ops(%[[ARG0:.+]]: tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS>>)
-func.func @unset_encoding_ops(%arg0: tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS>>) -> tensor<?x?xf32> {
- // CHECK: iree_linalg_ext.unset_encoding %[[ARG0]] : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS>> -> tensor<?x?xf32>
- %0 = iree_linalg_ext.unset_encoding %arg0 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS>> -> tensor<?x?xf32>
+// CHECK: @unset_encoding_ops(%[[ARG0:.+]]: tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS>>)
+func.func @unset_encoding_ops(%arg0: tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS>>) -> tensor<?x?xf32> {
+ // CHECK: iree_linalg_ext.unset_encoding %[[ARG0]] : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS>> -> tensor<?x?xf32>
+ %0 = iree_linalg_ext.unset_encoding %arg0 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS>> -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
// -----
-// CHECK: @unset_encoding_ops_mixed_dynamic_static(%[[ARG0:.+]]: tensor<10x?xf32, #iree_linalg_ext.encoding<GEMM_RHS>>)
-func.func @unset_encoding_ops_mixed_dynamic_static(%arg0: tensor<10x?xf32, #iree_linalg_ext.encoding<GEMM_RHS>>) -> tensor<?x20xf32> {
- // CHECK: iree_linalg_ext.unset_encoding %[[ARG0]] : tensor<10x?xf32, #iree_linalg_ext.encoding<GEMM_RHS>>
- %0 = iree_linalg_ext.unset_encoding %arg0 : tensor<10x?xf32, #iree_linalg_ext.encoding<GEMM_RHS>> -> tensor<?x20xf32>
+// CHECK: @unset_encoding_ops_mixed_dynamic_static(%[[ARG0:.+]]: tensor<10x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS>>)
+func.func @unset_encoding_ops_mixed_dynamic_static(%arg0: tensor<10x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS>>) -> tensor<?x20xf32> {
+ // CHECK: iree_linalg_ext.unset_encoding %[[ARG0]] : tensor<10x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS>>
+ %0 = iree_linalg_ext.unset_encoding %arg0 : tensor<10x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS>> -> tensor<?x20xf32>
return %0 : tensor<?x20xf32>
}
@@ -906,14 +906,14 @@
func.func @encoding_tensors_with_ops(%arg0 : tensor<?x?xf32>,
%arg1 : tensor<?x?xf32>, %arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
- %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
- %1 = iree_linalg_ext.set_encoding %arg1 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS>>
- %2 = iree_linalg_ext.set_encoding %arg2 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>
+ %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
+ %1 = iree_linalg_ext.set_encoding %arg1 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS>>
+ %2 = iree_linalg_ext.set_encoding %arg2 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>
%3 = linalg.matmul
- ins(%0, %1 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>, tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS>>)
- outs(%2 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>)
- -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>
- %4 = iree_linalg_ext.unset_encoding %3 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>> -> tensor<?x?xf32>
+ ins(%0, %1 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>, tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS>>)
+ outs(%2 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>)
+ -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>
+ %4 = iree_linalg_ext.unset_encoding %3 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>> -> tensor<?x?xf32>
return %4 : tensor<?x?xf32>
}
// CHECK-LABEL: func.func @encoding_tensors_with_ops
@@ -921,11 +921,11 @@
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<?x?xf32>
// CHECK: %[[LHS:.+]] = iree_linalg_ext.set_encoding %[[ARG0]]
-// CHECK-SAME: tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
+// CHECK-SAME: tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
// CHECK: %[[RHS:.+]] = iree_linalg_ext.set_encoding %[[ARG1]]
-// CHECK-SAME: tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS>>
+// CHECK-SAME: tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS>>
// CHECK: %[[OUT:.+]] = iree_linalg_ext.set_encoding %[[ARG2]]
-// CHECK-SAME: tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>
+// CHECK-SAME: tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>
// CHECK: %[[GEMM:.+]] = linalg.matmul
// CHECK-SAME: ins(%[[LHS]], %[[RHS]] :
// CHECK-SAME: outs(%[[OUT]] :