[Encoding] Support SetEncoding on scaled contraction ops (#21825)
Adds a new op type enum case for the generic `EncodingAttr` called
`SCALED_MATMUL`, and implements the logic to set encodings on scaled
contraction ops. A new flag is added for testing because the
materialization for the new encoding is not implemented yet. The flag
will be removed once materialization is implemented.
---------
Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.td b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.td
index bd24e90..3766dec 100644
--- a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.td
+++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.td
@@ -30,11 +30,13 @@
// Enums for tagging operand operation in an EncodingAttr
def MATMUL : I32EnumAttrCase<"matmul", 0>;
-def CONV : I32EnumAttrCase<"conv", 1>;
+def SCALED_MATMUL : I32EnumAttrCase<"scaled_matmul", 1>;
+def CONV : I32EnumAttrCase<"conv", 2>;
def EncodingOpType : IREEEncoding_I32EnumAttr<"EncodingOpType",
"Tracks the type of operation of the operand.", [
MATMUL,
+ SCALED_MATMUL,
CONV,
]>;
diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingTypes.h b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingTypes.h
index 7b892c2..b58c828 100644
--- a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingTypes.h
+++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingTypes.h
@@ -25,6 +25,12 @@
const int64_t MATMUL_LHS = 0;
const int64_t MATMUL_RHS = 1;
const int64_t MATMUL_RESULT = 2;
+/// Scaled matmul
+const int64_t SCALED_MATMUL_LHS = 0;
+const int64_t SCALED_MATMUL_RHS = 1;
+const int64_t SCALED_MATMUL_LHS_SCALES = 2;
+const int64_t SCALED_MATMUL_RHS_SCALES = 3;
+const int64_t SCALED_MATMUL_RESULT = 4;
/// Convert operand index to strings for printing
std::string stringifyOperandIndex(IntegerAttr);
diff --git a/compiler/src/iree/compiler/DispatchCreation/AnnotateDataTilingHints.cpp b/compiler/src/iree/compiler/DispatchCreation/AnnotateDataTilingHints.cpp
index ef67d2b..4bcc7ad 100644
--- a/compiler/src/iree/compiler/DispatchCreation/AnnotateDataTilingHints.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/AnnotateDataTilingHints.cpp
@@ -7,7 +7,9 @@
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Dialect/Encoding/Utils/Utils.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "iree/compiler/Dialect/LinalgExt/Utils/MatchUtils.h"
#include "iree/compiler/DispatchCreation/Passes.h"
+#include "llvm/Support/CommandLine.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
@@ -21,6 +23,15 @@
#define GEN_PASS_DEF_ANNOTATEDATATILINGHINTSPASS
#include "iree/compiler/DispatchCreation/Passes.h.inc"
+/// Command line options used purely for development purposes. Not to be relied
+/// on in any way.
+/// TODO(Max191): Remove this once materialization for scaled matmul encodings
+/// is implemented.
+static llvm::cl::opt<bool> clTestSetScaledMatmulEncodings(
+ "iree-dispatch-creation-test-set-scaled-matmul-encodings",
+ llvm::cl::desc("Set encodings on scaled matmul ops"), llvm::cl::init(true),
+ llvm::cl::Hidden);
+
namespace {
struct AnnotateDataTilingHintsPass final
: impl::AnnotateDataTilingHintsPassBase<AnnotateDataTilingHintsPass> {
@@ -70,23 +81,13 @@
return true;
}
-/// Not all contractions are supported by data tiling, so return true if:
+/// Common pre-conditions for data tiling. Return true if:
/// 1) linalgOp has pure tensor semantics.
/// 2) linalgOp does not have a preset compilation info.
/// 3) The workgroup count is not present if linalgOp is wrapped within
/// Flow::DispatchRegionOp.
/// 4) All the operands do not have encodings.
-/// 5) linalgOp has contraction indexingMaps.
-/// 6) There are not more than one of each contraction dimension.
-/// 7) There is an M or N dimension, and there is a K dimension.
-/// 8) linalgOp has the same body as an ordinary int or float matmul.
-///
-/// These restrictions are required because data tiling currently creates
-/// an Mmt4DOp or BatchMmt4DOp on the packed inputs.
-///
-/// TODO(#16176): Loosen restrictions on contraction ops once data tiling
-/// can support more cases.
-static bool isSupportedContractionOp(linalg::LinalgOp linalgOp) {
+static bool dataTilablePreCondition(linalg::LinalgOp linalgOp) {
if (!linalgOp.hasPureTensorSemantics()) {
return false;
}
@@ -108,7 +109,26 @@
llvm::any_of(linalgOp.getDpsInits(), hasEncoding)) {
return false;
}
+ return true;
+}
+/// Not all contractions are supported by data tiling, so return true if:
+/// 1) `linalgOp` meets the pre-conditions for data tiling defined in
+/// `dataTilablePreCondition`.
+/// 2) `linalgOp` has contraction indexingMaps.
+/// 3) There are not more than one of each contraction dimension.
+/// 4) There is an M or N dimension, and there is a K dimension.
+/// 5) `linalgOp` has the same body as an ordinary int or float matmul.
+///
+/// These restrictions are required because data tiling currently creates
+/// an Mmt4DOp or BatchMmt4DOp on the packed inputs.
+///
+/// TODO(#16176): Loosen restrictions on contraction ops once data tiling
+/// can support more cases.
+static bool isSupportedContractionOp(linalg::LinalgOp linalgOp) {
+ if (!dataTilablePreCondition(linalgOp)) {
+ return false;
+ }
if (!linalg::isaContractionOpInterface(linalgOp)) {
return false;
}
@@ -126,6 +146,37 @@
return true;
}
+/// Not all scaled contractions are supported by data tiling, so return true if:
+/// 1) `linalgOp` meets the pre-conditions for data tiling defined in
+/// `dataTilablePreCondition`.
+/// 2) `linalgOp` is a scaled contraction op, as defined by
+/// `IREE::LinalgExt::isaScaledContractionOpInterface`.
+/// 3) There are exactly one K and one Kb scaled contraction dimension.
+///
+/// These restrictions are required because the current data tiling
+/// implementation required a single K and Kb dimension.
+///
+/// TODO(Max191): Loosen restrictions on scaled contraction ops once data tiling
+/// can support more cases.
+static bool isSupportedScaledContractionOp(linalg::LinalgOp linalgOp) {
+ if (!clTestSetScaledMatmulEncodings) {
+ return false;
+ }
+ if (!dataTilablePreCondition(linalgOp)) {
+ return false;
+ }
+ if (!IREE::LinalgExt::isaScaledContractionOpInterface(linalgOp)) {
+ return false;
+ }
+ FailureOr<IREE::LinalgExt::ScaledContractionDimensions> cDims =
+ IREE::LinalgExt::inferScaledContractionDims(
+ linalgOp.getIndexingMapsArray());
+ if (failed(cDims) || cDims->k.size() != 1 || cDims->kB.size() != 1) {
+ return false;
+ }
+ return true;
+}
+
void AnnotateDataTilingHintsPass::runOnOperation() {
FunctionOpInterface funcOp = getOperation();
SmallVector<Operation *> candidates;
@@ -134,7 +185,8 @@
return WalkResult::interrupt();
}
auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
- if (linalgOp && isSupportedContractionOp(linalgOp)) {
+ if (linalgOp && (isSupportedContractionOp(linalgOp) ||
+ isSupportedScaledContractionOp(linalgOp))) {
candidates.push_back(op);
return WalkResult::advance();
}
diff --git a/compiler/src/iree/compiler/DispatchCreation/SetEncoding.cpp b/compiler/src/iree/compiler/DispatchCreation/SetEncoding.cpp
index 1c44973..8da216b 100644
--- a/compiler/src/iree/compiler/DispatchCreation/SetEncoding.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/SetEncoding.cpp
@@ -11,6 +11,7 @@
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "iree/compiler/Dialect/LinalgExt/Utils/MatchUtils.h"
#include "iree/compiler/DispatchCreation/FusionUtils.h"
#include "iree/compiler/DispatchCreation/Passes.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -86,55 +87,102 @@
return result;
}
+/// Contains the invariant information across operands for the
+/// iree_encoding.encoding. The operand number is not included because
+/// it is not invariant across operands.
+struct GenericEncodingCommonInfo {
+ IREE::Encoding::EncodingOpType opType;
+ SmallVector<Type> elemTypes;
+ SmallVector<AffineMap> maps;
+ SmallVector<int64_t> iterationSizes;
+};
+
+/// Get the `GenericEncodingCommonInfo` for the `linalgOp` or return failure
+/// if op is not supported. Supported ops are contraction ops and scaled
+/// contraction ops.
+static FailureOr<GenericEncodingCommonInfo>
+getGenericEncodingCommonInfo(RewriterBase &rewriter,
+ linalg::LinalgOp linalgOp) {
+ // Case 1: ContractionOpInterface
+ if (linalg::isaContractionOpInterface(linalgOp)) {
+ Type lhsElemType = getContractionInputTypeWithSignedness(
+ rewriter, linalgOp, linalgOp.getDpsInputOperand(0));
+ Type rhsElemType = getContractionInputTypeWithSignedness(
+ rewriter, linalgOp, linalgOp.getDpsInputOperand(1));
+ Type outElemType = getContractionInputTypeWithSignedness(
+ rewriter, linalgOp, linalgOp.getDpsInitOperand(0));
+ if (!lhsElemType || !rhsElemType || !outElemType) {
+ return failure();
+ }
+ return GenericEncodingCommonInfo(
+ {/*opType=*/IREE::Encoding::EncodingOpType::matmul,
+ /*elemTypes=*/{lhsElemType, rhsElemType, outElemType},
+ /*map=*/linalgOp.getIndexingMapsArray(),
+ /*iterationSizes=*/linalgOp.getStaticLoopRanges()});
+ }
+ // Case 2: Scaled ContractionOpInterface
+ if (!IREE::LinalgExt::isaScaledContractionOpInterface(linalgOp)) {
+ return failure();
+ }
+ FailureOr<IREE::LinalgExt::ScaledContractionDimensions> cDims =
+ IREE::LinalgExt::inferScaledContractionDims(
+ linalgOp.getIndexingMapsArray());
+ Type lhsElemType =
+ getElementTypeOrSelf(linalgOp.getDpsInputOperand(0)->get().getType());
+ Type rhsElemType =
+ getElementTypeOrSelf(linalgOp.getDpsInputOperand(1)->get().getType());
+ Type lhsScalesElemType =
+ getElementTypeOrSelf(linalgOp.getDpsInputOperand(2)->get().getType());
+ Type rhsScalesElemType =
+ getElementTypeOrSelf(linalgOp.getDpsInputOperand(3)->get().getType());
+ Type outElemType =
+ getElementTypeOrSelf(linalgOp.getDpsInitOperand(0)->get().getType());
+ return GenericEncodingCommonInfo(
+ {/*opType=*/IREE::Encoding::EncodingOpType::scaled_matmul,
+ /*elemTypes=*/
+ {lhsElemType, rhsElemType, lhsScalesElemType, rhsScalesElemType,
+ outElemType},
+ /*map=*/linalgOp.getIndexingMapsArray(),
+ /*iterationSizes=*/linalgOp.getStaticLoopRanges()});
+}
+
static LogicalResult setDataTilingEncodings(RewriterBase &rewriter,
linalg::LinalgOp linalgOp,
EncodingOptions encodingOption) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(linalgOp);
-
- Value lhs = linalgOp.getDpsInputOperand(0)->get();
- Value rhs = linalgOp.getDpsInputOperand(1)->get();
- Value out = linalgOp.getDpsInitOperand(0)->get();
- Type lhsElemType = getContractionInputTypeWithSignedness(
- rewriter, linalgOp, linalgOp.getDpsInputOperand(0));
- Type rhsElemType = getContractionInputTypeWithSignedness(
- rewriter, linalgOp, linalgOp.getDpsInputOperand(1));
- Type outElemType = getContractionInputTypeWithSignedness(
- rewriter, linalgOp, linalgOp.getDpsInitOperand(0));
- if (!lhsElemType || !rhsElemType || !outElemType) {
- return failure();
- }
- SmallVector<Type> elemTypes = {lhsElemType, rhsElemType, outElemType};
-
- // The `iteration_sizes` are the linalg op's static loop ranges. From the
- // combination of `iteration_sizes` and `user_indexing_maps`, we can later
- // derive information such as the iteration size of the M/N dimensions of a
- // matmul-like operation for example.
- FailureOr<SmallVector<int64_t>> maybeIterationSizes =
- linalgOp.getStaticLoopRanges();
- if (failed(maybeIterationSizes)) {
- return failure();
- }
- SmallVector<int64_t> iterationSizes = std::move(maybeIterationSizes.value());
-
Location loc = linalgOp.getLoc();
- SmallVector<AffineMap> maps = linalgOp.getIndexingMapsArray();
- auto opType = IREE::Encoding::EncodingOpType::matmul;
+ FailureOr<GenericEncodingCommonInfo> encodingInfo =
+ getGenericEncodingCommonInfo(rewriter, linalgOp);
+ if (failed(encodingInfo)) {
+ return failure();
+ }
auto setEncodingWrapper = [&](Value src, int64_t operandIndex) -> Value {
MLIRContext *ctx = linalgOp.getContext();
Attribute encoding;
switch (encodingOption) {
case EncodingOptions::Generic: {
- encoding = EncodingAttr::get(ctx, operandIndex, opType, elemTypes, maps,
- iterationSizes);
+ encoding = EncodingAttr::get(ctx, operandIndex, encodingInfo->opType,
+ encodingInfo->elemTypes, encodingInfo->maps,
+ encodingInfo->iterationSizes);
break;
}
case EncodingOptions::MatmulK: {
SmallVector<int32_t> kDims;
- AffineMap indexingMap = maps[operandIndex];
+ AffineMap indexingMap = encodingInfo->maps[operandIndex];
+ SmallVector<int64_t> kCDims;
auto cDims = linalg::inferContractionDims(linalgOp);
- for (auto k : cDims->k) {
+ if (!failed(cDims)) {
+ kCDims.append(cDims->k.begin(), cDims->k.end());
+ }
+ FailureOr<IREE::LinalgExt::ScaledContractionDimensions> scaledCDims =
+ IREE::LinalgExt::inferScaledContractionDims(linalgOp);
+ if (!failed(scaledCDims)) {
+ kCDims.append(scaledCDims->k.begin(), scaledCDims->k.end());
+ kCDims.append(scaledCDims->kB.begin(), scaledCDims->kB.end());
+ }
+ for (auto k : kCDims) {
std::optional<unsigned> dimIdx =
indexingMap.getResultPosition(rewriter.getAffineDimExpr(k));
if (!dimIdx) {
@@ -152,16 +200,37 @@
}
return setEncoding(rewriter, loc, src, encoding);
};
- auto encodedLhs = setEncodingWrapper(lhs, IREE::Encoding::MATMUL_LHS);
- auto encodedRhs = setEncodingWrapper(rhs, IREE::Encoding::MATMUL_RHS);
- auto encodedOut = setEncodingWrapper(out, IREE::Encoding::MATMUL_RESULT);
- Value opTiled = clone(rewriter, linalgOp, encodedOut.getType(),
- ValueRange{encodedLhs, encodedRhs, encodedOut})
- ->getResult(0);
+
+ SmallVector<Value> encodedInputOperands;
+ Value encodedInitOperand;
+ if (linalg::isaContractionOpInterface(linalgOp)) {
+ encodedInputOperands.push_back(setEncodingWrapper(
+ linalgOp.getDpsInputs()[0], IREE::Encoding::MATMUL_LHS));
+ encodedInputOperands.push_back(setEncodingWrapper(
+ linalgOp.getDpsInputs()[1], IREE::Encoding::MATMUL_RHS));
+ encodedInitOperand = setEncodingWrapper(linalgOp.getDpsInits()[0],
+ IREE::Encoding::MATMUL_RESULT);
+ } else {
+ encodedInputOperands.push_back(setEncodingWrapper(
+ linalgOp.getDpsInputs()[0], IREE::Encoding::SCALED_MATMUL_LHS));
+ encodedInputOperands.push_back(setEncodingWrapper(
+ linalgOp.getDpsInputs()[1], IREE::Encoding::SCALED_MATMUL_RHS));
+ encodedInputOperands.push_back(setEncodingWrapper(
+ linalgOp.getDpsInputs()[2], IREE::Encoding::SCALED_MATMUL_LHS_SCALES));
+ encodedInputOperands.push_back(setEncodingWrapper(
+ linalgOp.getDpsInputs()[3], IREE::Encoding::SCALED_MATMUL_RHS_SCALES));
+ encodedInitOperand = setEncodingWrapper(
+ linalgOp.getDpsInits()[0], IREE::Encoding::SCALED_MATMUL_RESULT);
+ }
+ SmallVector<Value> encodedOperands(encodedInputOperands);
+ encodedOperands.push_back(encodedInitOperand);
+ Value opTiled =
+ clone(rewriter, linalgOp, encodedInitOperand.getType(), encodedOperands)
+ ->getResult(0);
// Sizes are computed by original output size.
SmallVector<OpFoldResult> outSizes =
- tensor::getMixedSizes(rewriter, loc, out);
+ tensor::getMixedSizes(rewriter, loc, linalgOp.getDpsInits()[0]);
Value result = unsetEncoding(rewriter, loc, opTiled, outSizes);
rewriter.replaceOp(linalgOp, result);
diff --git a/compiler/src/iree/compiler/DispatchCreation/test/set_encoding.mlir b/compiler/src/iree/compiler/DispatchCreation/test/set_encoding.mlir
index 68d2b07..7839175 100644
--- a/compiler/src/iree/compiler/DispatchCreation/test/set_encoding.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/set_encoding.mlir
@@ -1,5 +1,7 @@
-// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-annotate-data-tiling-hints,iree-dispatch-creation-set-encoding))" %s | FileCheck %s --check-prefixes=CHECK-ALL,CHECK
-// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-annotate-data-tiling-hints,iree-dispatch-creation-set-encoding{encoding-option=matmulk}))" %s | FileCheck %s --check-prefixes=CHECK-ALL,MATMULK
+// RUN: iree-opt --split-input-file --iree-dispatch-creation-test-set-scaled-matmul-encodings \
+// RUN: --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-annotate-data-tiling-hints,iree-dispatch-creation-set-encoding))" %s | FileCheck %s --check-prefixes=CHECK-ALL,CHECK
+// RUN: iree-opt --split-input-file --iree-dispatch-creation-test-set-scaled-matmul-encodings \
+// RUN: --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-annotate-data-tiling-hints,iree-dispatch-creation-set-encoding{encoding-option=matmulk}))" %s | FileCheck %s --check-prefixes=CHECK-ALL,MATMULK
// Test with `iree-dispatch-creation-annotate-data-tiling-hints` that adds the
// data-tiling hint to all the available gemm ops. Otherwise, we have to add the
// hint to all the gemm ops in the file.
@@ -1203,6 +1205,130 @@
util.return %4 : !hal.buffer_view
}
-// CHECK-LABEL: util.func public @region_with_workgroup_count
-// CHECK-NOT: iree_encoding.set_encoding
-// CHECK-NOT: iree_encoding.unset_encoding
+// CHECK-ALL-LABEL: util.func public @region_with_workgroup_count
+// CHECK-NOT: iree_encoding.set_encoding
+// CHECK-NOT: iree_encoding.unset_encoding
+
+// -----
+
+util.func public @scaled_contraction_f4_f4_f8_f8_f32(
+ %a : tensor<256x128x32xf4E2M1FN>, %b : tensor<512x128x32xf4E2M1FN>,
+ %a_scales : tensor<256x128xf8E8M0FNU>, %b_scales : tensor<512x128xf8E8M0FNU>,
+ %c : tensor<256x512xf32>) -> tensor<256x512xf32> {
+ %0 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+ ins(%a, %b, %a_scales, %b_scales : tensor<256x128x32xf4E2M1FN>, tensor<512x128x32xf4E2M1FN>, tensor<256x128xf8E8M0FNU>, tensor<512x128xf8E8M0FNU>)
+ outs(%c : tensor<256x512xf32>) {
+ ^bb0(%in: f4E2M1FN, %in_0: f4E2M1FN, %in_1: f8E8M0FNU, %in_2: f8E8M0FNU, %out: f32):
+ %12 = arith.scaling_extf %in, %in_1 : f4E2M1FN, f8E8M0FNU to f32
+ %13 = arith.scaling_extf %in_0, %in_2 : f4E2M1FN, f8E8M0FNU to f32
+ %14 = arith.mulf %12, %13 : f32
+ %15 = arith.addf %out, %14 : f32
+ linalg.yield %15 : f32
+ } -> tensor<256x512xf32>
+ util.return %0 : tensor<256x512xf32>
+}
+
+// CHECK-ALL-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+// CHECK-ALL-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
+// CHECK-ALL-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2)>
+// CHECK-ALL-DAG: #[[$MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+// CHECK-ALL-DAG: #[[$MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+
+// CHECK-DAG: #[[$ENC0:.+]] = #iree_encoding.encoding<operand_index = 0 : index, op_type = scaled_matmul, element_types = [f4E2M1FN, f4E2M1FN, f8E8M0FNU, f8E8M0FNU, f32], user_indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]], #[[$MAP3]], #[[$MAP4]]], iteration_sizes = [256, 512, 128, 32]>
+// CHECK-DAG: #[[$ENC1:.+]] = #iree_encoding.encoding<operand_index = 1 : index, op_type = scaled_matmul, element_types = [f4E2M1FN, f4E2M1FN, f8E8M0FNU, f8E8M0FNU, f32], user_indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]], #[[$MAP3]], #[[$MAP4]]], iteration_sizes = [256, 512, 128, 32]>
+// CHECK-DAG: #[[$ENC2:.+]] = #iree_encoding.encoding<operand_index = 2 : index, op_type = scaled_matmul, element_types = [f4E2M1FN, f4E2M1FN, f8E8M0FNU, f8E8M0FNU, f32], user_indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]], #[[$MAP3]], #[[$MAP4]]], iteration_sizes = [256, 512, 128, 32]>
+// CHECK-DAG: #[[$ENC3:.+]] = #iree_encoding.encoding<operand_index = 3 : index, op_type = scaled_matmul, element_types = [f4E2M1FN, f4E2M1FN, f8E8M0FNU, f8E8M0FNU, f32], user_indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]], #[[$MAP3]], #[[$MAP4]]], iteration_sizes = [256, 512, 128, 32]>
+// CHECK-DAG: #[[$ENC4:.+]] = #iree_encoding.encoding<operand_index = 4 : index, op_type = scaled_matmul, element_types = [f4E2M1FN, f4E2M1FN, f8E8M0FNU, f8E8M0FNU, f32], user_indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]], #[[$MAP3]], #[[$MAP4]]], iteration_sizes = [256, 512, 128, 32]>
+
+// MATMULK-DAG: #[[$ENC0:.+]] = #iree_encoding.matmul_k<k_dims = [1, 2]>
+// MATMULK-DAG: #[[$ENC1:.+]] = #iree_encoding.matmul_k<k_dims = [1]>
+// MATMULK-DAG: #[[$ENC2:.+]] = #iree_encoding.matmul_k<k_dims = []>
+
+// CHECK-ALL: util.func public @scaled_contraction_f4_f4_f8_f8_f32
+// CHECK-ALL-SAME: %[[A:.*]]: tensor<256x128x32xf4E2M1FN>
+// CHECK-ALL-SAME: %[[B:.*]]: tensor<512x128x32xf4E2M1FN>
+// CHECK-ALL-SAME: %[[AS:.*]]: tensor<256x128xf8E8M0FNU>
+// CHECK-ALL-SAME: %[[BS:.*]]: tensor<512x128xf8E8M0FNU>
+// CHECK-ALL-SAME: %[[C:.*]]: tensor<256x512xf32>
+
+// CHECK-DAG: %[[A_ENC:.*]] = iree_encoding.set_encoding %[[A]] : tensor<256x128x32xf4E2M1FN> -> tensor<256x128x32xf4E2M1FN, #[[$ENC0]]>
+// CHECK-DAG: %[[B_ENC:.*]] = iree_encoding.set_encoding %[[B]] : tensor<512x128x32xf4E2M1FN> -> tensor<512x128x32xf4E2M1FN, #[[$ENC1]]>
+// CHECK-DAG: %[[AS_ENC:.*]] = iree_encoding.set_encoding %[[AS]] : tensor<256x128xf8E8M0FNU> -> tensor<256x128xf8E8M0FNU, #[[$ENC2]]>
+// CHECK-DAG: %[[BS_ENC:.*]] = iree_encoding.set_encoding %[[BS]] : tensor<512x128xf8E8M0FNU> -> tensor<512x128xf8E8M0FNU, #[[$ENC3]]>
+// CHECK-DAG: %[[C_ENC:.*]] = iree_encoding.set_encoding %[[C]] : tensor<256x512xf32> -> tensor<256x512xf32, #[[$ENC4]]>
+
+// MATMULK-DAG: %[[A_ENC:.*]] = iree_encoding.set_encoding %[[A]] : tensor<256x128x32xf4E2M1FN> -> tensor<256x128x32xf4E2M1FN, #[[$ENC0]]>
+// MATMULK-DAG: %[[B_ENC:.*]] = iree_encoding.set_encoding %[[B]] : tensor<512x128x32xf4E2M1FN> -> tensor<512x128x32xf4E2M1FN, #[[$ENC0]]>
+// MATMULK-DAG: %[[AS_ENC:.*]] = iree_encoding.set_encoding %[[AS]] : tensor<256x128xf8E8M0FNU> -> tensor<256x128xf8E8M0FNU, #[[$ENC1]]
+// MATMULK-DAG: %[[BS_ENC:.*]] = iree_encoding.set_encoding %[[BS]] : tensor<512x128xf8E8M0FNU> -> tensor<512x128xf8E8M0FNU, #[[$ENC1]]
+// MATMULK-DAG: %[[C_ENC:.*]] = iree_encoding.set_encoding %[[C]] : tensor<256x512xf32> -> tensor<256x512xf32, #[[$ENC2]]>
+
+// CHECK-ALL: %[[GENERIC:.*]] = linalg.generic
+// CHECK-ALL-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]], #[[$MAP3]], #[[$MAP4]]]
+// CHECK-ALL-SAME: iterator_types = ["parallel", "parallel", "reduction", "reduction"]
+// CHECK-ALL-SAME: ins(%[[A_ENC]], %[[B_ENC]], %[[AS_ENC]], %[[BS_ENC]]
+// CHECK-ALL-SAME: outs(%[[C_ENC]]
+// CHECK: %[[RESULT:.*]] = iree_encoding.unset_encoding %[[GENERIC]] : tensor<256x512xf32, #[[$ENC4]]> -> tensor<256x512xf32>
+// MATMULK: %[[RESULT:.*]] = iree_encoding.unset_encoding %[[GENERIC]] : tensor<256x512xf32, #[[$ENC2]]> -> tensor<256x512xf32>
+// CHECK-ALL: util.return %[[RESULT]] : tensor<256x512xf32>
+
+// -----
+
+util.func public @scaled_contraction_multi_k_f4_f4_f8_f8_f32(
+ %a : tensor<256x2x64x32xf4E2M1FN>, %b : tensor<512x2x64x32xf4E2M1FN>,
+ %a_scales : tensor<256x2x64xf8E8M0FNU>, %b_scales : tensor<512x2x64xf8E8M0FNU>,
+ %c : tensor<256x512xf32>) -> tensor<256x512xf32> {
+ %0 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d3, d4)>,
+ affine_map<(d0, d1, d2, d3, d4) -> (d1, d2, d3, d4)>,
+ affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d3)>,
+ affine_map<(d0, d1, d2, d3, d4) -> (d1, d2, d3)>,
+ affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction", "reduction"]}
+ ins(%a, %b, %a_scales, %b_scales : tensor<256x2x64x32xf4E2M1FN>, tensor<512x2x64x32xf4E2M1FN>, tensor<256x2x64xf8E8M0FNU>, tensor<512x2x64xf8E8M0FNU>)
+ outs(%c : tensor<256x512xf32>) {
+ ^bb0(%in: f4E2M1FN, %in_0: f4E2M1FN, %in_1: f8E8M0FNU, %in_2: f8E8M0FNU, %out: f32):
+ %12 = arith.scaling_extf %in, %in_1 : f4E2M1FN, f8E8M0FNU to f32
+ %13 = arith.scaling_extf %in_0, %in_2 : f4E2M1FN, f8E8M0FNU to f32
+ %14 = arith.mulf %12, %13 : f32
+ %15 = arith.addf %out, %14 : f32
+ linalg.yield %15 : f32
+ } -> tensor<256x512xf32>
+ util.return %0 : tensor<256x512xf32>
+}
+
+// CHECK-ALL: util.func public @scaled_contraction_multi_k_f4_f4_f8_f8_f32
+// CHECK-ALL-NOT: iree_encoding.set_encoding
+
+// -----
+util.func public @scaled_contraction_multi_kb_f4_f4_f8_f8_f32(
+ %a : tensor<256x128x2x32xf4E2M1FN>, %b : tensor<512x128x2x32xf4E2M1FN>,
+ %a_scales : tensor<256x128xf8E8M0FNU>, %b_scales : tensor<512x128xf8E8M0FNU>,
+ %c : tensor<256x512xf32>) -> tensor<256x512xf32> {
+ %0 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d3, d4)>,
+ affine_map<(d0, d1, d2, d3, d4) -> (d1, d2, d3, d4)>,
+ affine_map<(d0, d1, d2, d3, d4) -> (d0, d2)>,
+ affine_map<(d0, d1, d2, d3, d4) -> (d1, d2)>,
+ affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction", "reduction"]}
+ ins(%a, %b, %a_scales, %b_scales : tensor<256x128x2x32xf4E2M1FN>, tensor<512x128x2x32xf4E2M1FN>, tensor<256x128xf8E8M0FNU>, tensor<512x128xf8E8M0FNU>)
+ outs(%c : tensor<256x512xf32>) {
+ ^bb0(%in: f4E2M1FN, %in_0: f4E2M1FN, %in_1: f8E8M0FNU, %in_2: f8E8M0FNU, %out: f32):
+ %12 = arith.scaling_extf %in, %in_1 : f4E2M1FN, f8E8M0FNU to f32
+ %13 = arith.scaling_extf %in_0, %in_2 : f4E2M1FN, f8E8M0FNU to f32
+ %14 = arith.mulf %12, %13 : f32
+ %15 = arith.addf %out, %14 : f32
+ linalg.yield %15 : f32
+ } -> tensor<256x512xf32>
+ util.return %0 : tensor<256x512xf32>
+}
+
+// CHECK-ALL: util.func public @scaled_contraction_multi_kb_f4_f4_f8_f8_f32
+// CHECK-ALL-NOT: iree_encoding.set_encoding