[Encoding] Support SetEncoding on scaled contraction ops (#21825)

Adds a new op type enum case for the generic `EncodingAttr` called
`SCALED_MATMUL`, and implements the logic to set encodings on scaled
contraction ops. A new flag is added for testing because the
materialization for the new encoding is not implemented yet. The flag
will be removed once materialization is implemented.

---------

Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.td b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.td
index bd24e90..3766dec 100644
--- a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.td
+++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.td
@@ -30,11 +30,13 @@
 
 // Enums for tagging operand operation in an EncodingAttr
 def MATMUL : I32EnumAttrCase<"matmul", 0>;
-def CONV   : I32EnumAttrCase<"conv", 1>;
+def SCALED_MATMUL : I32EnumAttrCase<"scaled_matmul", 1>;
+def CONV   : I32EnumAttrCase<"conv", 2>;
 
 def EncodingOpType : IREEEncoding_I32EnumAttr<"EncodingOpType",
     "Tracks the type of operation of the operand.", [
       MATMUL,
+      SCALED_MATMUL,
       CONV,
     ]>;
 
diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingTypes.h b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingTypes.h
index 7b892c2..b58c828 100644
--- a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingTypes.h
+++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingTypes.h
@@ -25,6 +25,12 @@
 const int64_t MATMUL_LHS = 0;
 const int64_t MATMUL_RHS = 1;
 const int64_t MATMUL_RESULT = 2;
+/// Scaled matmul
+const int64_t SCALED_MATMUL_LHS = 0;
+const int64_t SCALED_MATMUL_RHS = 1;
+const int64_t SCALED_MATMUL_LHS_SCALES = 2;
+const int64_t SCALED_MATMUL_RHS_SCALES = 3;
+const int64_t SCALED_MATMUL_RESULT = 4;
 
 /// Convert operand index to strings for printing
 std::string stringifyOperandIndex(IntegerAttr);
diff --git a/compiler/src/iree/compiler/DispatchCreation/AnnotateDataTilingHints.cpp b/compiler/src/iree/compiler/DispatchCreation/AnnotateDataTilingHints.cpp
index ef67d2b..4bcc7ad 100644
--- a/compiler/src/iree/compiler/DispatchCreation/AnnotateDataTilingHints.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/AnnotateDataTilingHints.cpp
@@ -7,7 +7,9 @@
 #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
 #include "iree/compiler/Dialect/Encoding/Utils/Utils.h"
 #include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "iree/compiler/Dialect/LinalgExt/Utils/MatchUtils.h"
 #include "iree/compiler/DispatchCreation/Passes.h"
+#include "llvm/Support/CommandLine.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
@@ -21,6 +23,15 @@
 #define GEN_PASS_DEF_ANNOTATEDATATILINGHINTSPASS
 #include "iree/compiler/DispatchCreation/Passes.h.inc"
 
+/// Command line options used purely for development purposes. Not to be relied
+/// on in any way.
+/// TODO(Max191): Remove this once materialization for scaled matmul encodings
+/// is implemented.
+static llvm::cl::opt<bool> clTestSetScaledMatmulEncodings(
+    "iree-dispatch-creation-test-set-scaled-matmul-encodings",
+    llvm::cl::desc("Set encodings on scaled matmul ops"), llvm::cl::init(true),
+    llvm::cl::Hidden);
+
 namespace {
 struct AnnotateDataTilingHintsPass final
     : impl::AnnotateDataTilingHintsPassBase<AnnotateDataTilingHintsPass> {
@@ -70,23 +81,13 @@
   return true;
 }
 
-/// Not all contractions are supported by data tiling, so return true if:
+/// Common pre-conditions for data tiling. Return true if:
 ///   1) linalgOp has pure tensor semantics.
 ///   2) linalgOp does not have a preset compilation info.
 ///   3) The workgroup count is not present if linalgOp is wrapped within
 ///      Flow::DispatchRegionOp.
 ///   4) All the operands do not have encodings.
-///   5) linalgOp has contraction indexingMaps.
-///   6) There are not more than one of each contraction dimension.
-///   7) There is an M or N dimension, and there is a K dimension.
-///   8) linalgOp has the same body as an ordinary int or float matmul.
-///
-/// These restrictions are required because data tiling currently creates
-/// an Mmt4DOp or BatchMmt4DOp on the packed inputs.
-///
-/// TODO(#16176): Loosen restrictions on contraction ops once data tiling
-/// can support more cases.
-static bool isSupportedContractionOp(linalg::LinalgOp linalgOp) {
+static bool dataTilablePreCondition(linalg::LinalgOp linalgOp) {
   if (!linalgOp.hasPureTensorSemantics()) {
     return false;
   }
@@ -108,7 +109,26 @@
       llvm::any_of(linalgOp.getDpsInits(), hasEncoding)) {
     return false;
   }
+  return true;
+}
 
+/// Not all contractions are supported by data tiling, so return true if:
+///   1) `linalgOp` meets the pre-conditions for data tiling defined in
+///      `dataTilablePreCondition`.
+///   2) `linalgOp` has contraction indexingMaps.
+///   3) There are not more than one of each contraction dimension.
+///   4) There is an M or N dimension, and there is a K dimension.
+///   5) `linalgOp` has the same body as an ordinary int or float matmul.
+///
+/// These restrictions are required because data tiling currently creates
+/// an Mmt4DOp or BatchMmt4DOp on the packed inputs.
+///
+/// TODO(#16176): Loosen restrictions on contraction ops once data tiling
+/// can support more cases.
+static bool isSupportedContractionOp(linalg::LinalgOp linalgOp) {
+  if (!dataTilablePreCondition(linalgOp)) {
+    return false;
+  }
   if (!linalg::isaContractionOpInterface(linalgOp)) {
     return false;
   }
@@ -126,6 +146,37 @@
   return true;
 }
 
+/// Not all scaled contractions are supported by data tiling, so return true if:
+///   1) `linalgOp` meets the pre-conditions for data tiling defined in
+///      `dataTilablePreCondition`.
+///   2) `linalgOp` is a scaled contraction op, as defined by
+///      `IREE::LinalgExt::isaScaledContractionOpInterface`.
+///   3) There are exactly one K and one Kb scaled contraction dimension.
+///
+/// These restrictions are required because the current data tiling
+/// implementation required a single K and Kb dimension.
+///
+/// TODO(Max191): Loosen restrictions on scaled contraction ops once data tiling
+/// can support more cases.
+static bool isSupportedScaledContractionOp(linalg::LinalgOp linalgOp) {
+  if (!clTestSetScaledMatmulEncodings) {
+    return false;
+  }
+  if (!dataTilablePreCondition(linalgOp)) {
+    return false;
+  }
+  if (!IREE::LinalgExt::isaScaledContractionOpInterface(linalgOp)) {
+    return false;
+  }
+  FailureOr<IREE::LinalgExt::ScaledContractionDimensions> cDims =
+      IREE::LinalgExt::inferScaledContractionDims(
+          linalgOp.getIndexingMapsArray());
+  if (failed(cDims) || cDims->k.size() != 1 || cDims->kB.size() != 1) {
+    return false;
+  }
+  return true;
+}
+
 void AnnotateDataTilingHintsPass::runOnOperation() {
   FunctionOpInterface funcOp = getOperation();
   SmallVector<Operation *> candidates;
@@ -134,7 +185,8 @@
       return WalkResult::interrupt();
     }
     auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
-    if (linalgOp && isSupportedContractionOp(linalgOp)) {
+    if (linalgOp && (isSupportedContractionOp(linalgOp) ||
+                     isSupportedScaledContractionOp(linalgOp))) {
       candidates.push_back(op);
       return WalkResult::advance();
     }
diff --git a/compiler/src/iree/compiler/DispatchCreation/SetEncoding.cpp b/compiler/src/iree/compiler/DispatchCreation/SetEncoding.cpp
index 1c44973..8da216b 100644
--- a/compiler/src/iree/compiler/DispatchCreation/SetEncoding.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/SetEncoding.cpp
@@ -11,6 +11,7 @@
 #include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
 #include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
 #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "iree/compiler/Dialect/LinalgExt/Utils/MatchUtils.h"
 #include "iree/compiler/DispatchCreation/FusionUtils.h"
 #include "iree/compiler/DispatchCreation/Passes.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -86,55 +87,102 @@
   return result;
 }
 
+/// Contains the invariant information across operands for the
+/// iree_encoding.encoding. The operand number is not included because
+/// it is not invariant across operands.
+struct GenericEncodingCommonInfo {
+  IREE::Encoding::EncodingOpType opType;
+  SmallVector<Type> elemTypes;
+  SmallVector<AffineMap> maps;
+  SmallVector<int64_t> iterationSizes;
+};
+
+/// Get the `GenericEncodingCommonInfo` for the `linalgOp` or return failure
+/// if op is not supported. Supported ops are contraction ops and scaled
+/// contraction ops.
+static FailureOr<GenericEncodingCommonInfo>
+getGenericEncodingCommonInfo(RewriterBase &rewriter,
+                             linalg::LinalgOp linalgOp) {
+  // Case 1: ContractionOpInterface
+  if (linalg::isaContractionOpInterface(linalgOp)) {
+    Type lhsElemType = getContractionInputTypeWithSignedness(
+        rewriter, linalgOp, linalgOp.getDpsInputOperand(0));
+    Type rhsElemType = getContractionInputTypeWithSignedness(
+        rewriter, linalgOp, linalgOp.getDpsInputOperand(1));
+    Type outElemType = getContractionInputTypeWithSignedness(
+        rewriter, linalgOp, linalgOp.getDpsInitOperand(0));
+    if (!lhsElemType || !rhsElemType || !outElemType) {
+      return failure();
+    }
+    return GenericEncodingCommonInfo(
+        {/*opType=*/IREE::Encoding::EncodingOpType::matmul,
+         /*elemTypes=*/{lhsElemType, rhsElemType, outElemType},
+         /*map=*/linalgOp.getIndexingMapsArray(),
+         /*iterationSizes=*/linalgOp.getStaticLoopRanges()});
+  }
+  // Case 2: Scaled ContractionOpInterface
+  if (!IREE::LinalgExt::isaScaledContractionOpInterface(linalgOp)) {
+    return failure();
+  }
+  FailureOr<IREE::LinalgExt::ScaledContractionDimensions> cDims =
+      IREE::LinalgExt::inferScaledContractionDims(
+          linalgOp.getIndexingMapsArray());
+  Type lhsElemType =
+      getElementTypeOrSelf(linalgOp.getDpsInputOperand(0)->get().getType());
+  Type rhsElemType =
+      getElementTypeOrSelf(linalgOp.getDpsInputOperand(1)->get().getType());
+  Type lhsScalesElemType =
+      getElementTypeOrSelf(linalgOp.getDpsInputOperand(2)->get().getType());
+  Type rhsScalesElemType =
+      getElementTypeOrSelf(linalgOp.getDpsInputOperand(3)->get().getType());
+  Type outElemType =
+      getElementTypeOrSelf(linalgOp.getDpsInitOperand(0)->get().getType());
+  return GenericEncodingCommonInfo(
+      {/*opType=*/IREE::Encoding::EncodingOpType::scaled_matmul,
+       /*elemTypes=*/
+       {lhsElemType, rhsElemType, lhsScalesElemType, rhsScalesElemType,
+        outElemType},
+       /*map=*/linalgOp.getIndexingMapsArray(),
+       /*iterationSizes=*/linalgOp.getStaticLoopRanges()});
+}
+
 static LogicalResult setDataTilingEncodings(RewriterBase &rewriter,
                                             linalg::LinalgOp linalgOp,
                                             EncodingOptions encodingOption) {
   OpBuilder::InsertionGuard guard(rewriter);
   rewriter.setInsertionPoint(linalgOp);
-
-  Value lhs = linalgOp.getDpsInputOperand(0)->get();
-  Value rhs = linalgOp.getDpsInputOperand(1)->get();
-  Value out = linalgOp.getDpsInitOperand(0)->get();
-  Type lhsElemType = getContractionInputTypeWithSignedness(
-      rewriter, linalgOp, linalgOp.getDpsInputOperand(0));
-  Type rhsElemType = getContractionInputTypeWithSignedness(
-      rewriter, linalgOp, linalgOp.getDpsInputOperand(1));
-  Type outElemType = getContractionInputTypeWithSignedness(
-      rewriter, linalgOp, linalgOp.getDpsInitOperand(0));
-  if (!lhsElemType || !rhsElemType || !outElemType) {
-    return failure();
-  }
-  SmallVector<Type> elemTypes = {lhsElemType, rhsElemType, outElemType};
-
-  // The `iteration_sizes` are the linalg op's static loop ranges. From the
-  // combination of `iteration_sizes` and `user_indexing_maps`, we can later
-  // derive information such as the iteration size of the M/N dimensions of a
-  // matmul-like operation for example.
-  FailureOr<SmallVector<int64_t>> maybeIterationSizes =
-      linalgOp.getStaticLoopRanges();
-  if (failed(maybeIterationSizes)) {
-    return failure();
-  }
-  SmallVector<int64_t> iterationSizes = std::move(maybeIterationSizes.value());
-
   Location loc = linalgOp.getLoc();
-  SmallVector<AffineMap> maps = linalgOp.getIndexingMapsArray();
 
-  auto opType = IREE::Encoding::EncodingOpType::matmul;
+  FailureOr<GenericEncodingCommonInfo> encodingInfo =
+      getGenericEncodingCommonInfo(rewriter, linalgOp);
+  if (failed(encodingInfo)) {
+    return failure();
+  }
   auto setEncodingWrapper = [&](Value src, int64_t operandIndex) -> Value {
     MLIRContext *ctx = linalgOp.getContext();
     Attribute encoding;
     switch (encodingOption) {
     case EncodingOptions::Generic: {
-      encoding = EncodingAttr::get(ctx, operandIndex, opType, elemTypes, maps,
-                                   iterationSizes);
+      encoding = EncodingAttr::get(ctx, operandIndex, encodingInfo->opType,
+                                   encodingInfo->elemTypes, encodingInfo->maps,
+                                   encodingInfo->iterationSizes);
       break;
     }
     case EncodingOptions::MatmulK: {
       SmallVector<int32_t> kDims;
-      AffineMap indexingMap = maps[operandIndex];
+      AffineMap indexingMap = encodingInfo->maps[operandIndex];
+      SmallVector<int64_t> kCDims;
       auto cDims = linalg::inferContractionDims(linalgOp);
-      for (auto k : cDims->k) {
+      if (!failed(cDims)) {
+        kCDims.append(cDims->k.begin(), cDims->k.end());
+      }
+      FailureOr<IREE::LinalgExt::ScaledContractionDimensions> scaledCDims =
+          IREE::LinalgExt::inferScaledContractionDims(linalgOp);
+      if (!failed(scaledCDims)) {
+        kCDims.append(scaledCDims->k.begin(), scaledCDims->k.end());
+        kCDims.append(scaledCDims->kB.begin(), scaledCDims->kB.end());
+      }
+      for (auto k : kCDims) {
         std::optional<unsigned> dimIdx =
             indexingMap.getResultPosition(rewriter.getAffineDimExpr(k));
         if (!dimIdx) {
@@ -152,16 +200,37 @@
     }
     return setEncoding(rewriter, loc, src, encoding);
   };
-  auto encodedLhs = setEncodingWrapper(lhs, IREE::Encoding::MATMUL_LHS);
-  auto encodedRhs = setEncodingWrapper(rhs, IREE::Encoding::MATMUL_RHS);
-  auto encodedOut = setEncodingWrapper(out, IREE::Encoding::MATMUL_RESULT);
-  Value opTiled = clone(rewriter, linalgOp, encodedOut.getType(),
-                        ValueRange{encodedLhs, encodedRhs, encodedOut})
-                      ->getResult(0);
+
+  SmallVector<Value> encodedInputOperands;
+  Value encodedInitOperand;
+  if (linalg::isaContractionOpInterface(linalgOp)) {
+    encodedInputOperands.push_back(setEncodingWrapper(
+        linalgOp.getDpsInputs()[0], IREE::Encoding::MATMUL_LHS));
+    encodedInputOperands.push_back(setEncodingWrapper(
+        linalgOp.getDpsInputs()[1], IREE::Encoding::MATMUL_RHS));
+    encodedInitOperand = setEncodingWrapper(linalgOp.getDpsInits()[0],
+                                            IREE::Encoding::MATMUL_RESULT);
+  } else {
+    encodedInputOperands.push_back(setEncodingWrapper(
+        linalgOp.getDpsInputs()[0], IREE::Encoding::SCALED_MATMUL_LHS));
+    encodedInputOperands.push_back(setEncodingWrapper(
+        linalgOp.getDpsInputs()[1], IREE::Encoding::SCALED_MATMUL_RHS));
+    encodedInputOperands.push_back(setEncodingWrapper(
+        linalgOp.getDpsInputs()[2], IREE::Encoding::SCALED_MATMUL_LHS_SCALES));
+    encodedInputOperands.push_back(setEncodingWrapper(
+        linalgOp.getDpsInputs()[3], IREE::Encoding::SCALED_MATMUL_RHS_SCALES));
+    encodedInitOperand = setEncodingWrapper(
+        linalgOp.getDpsInits()[0], IREE::Encoding::SCALED_MATMUL_RESULT);
+  }
+  SmallVector<Value> encodedOperands(encodedInputOperands);
+  encodedOperands.push_back(encodedInitOperand);
+  Value opTiled =
+      clone(rewriter, linalgOp, encodedInitOperand.getType(), encodedOperands)
+          ->getResult(0);
 
   // Sizes are computed by original output size.
   SmallVector<OpFoldResult> outSizes =
-      tensor::getMixedSizes(rewriter, loc, out);
+      tensor::getMixedSizes(rewriter, loc, linalgOp.getDpsInits()[0]);
   Value result = unsetEncoding(rewriter, loc, opTiled, outSizes);
 
   rewriter.replaceOp(linalgOp, result);
diff --git a/compiler/src/iree/compiler/DispatchCreation/test/set_encoding.mlir b/compiler/src/iree/compiler/DispatchCreation/test/set_encoding.mlir
index 68d2b07..7839175 100644
--- a/compiler/src/iree/compiler/DispatchCreation/test/set_encoding.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/set_encoding.mlir
@@ -1,5 +1,7 @@
-// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-annotate-data-tiling-hints,iree-dispatch-creation-set-encoding))" %s | FileCheck %s --check-prefixes=CHECK-ALL,CHECK
-// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-annotate-data-tiling-hints,iree-dispatch-creation-set-encoding{encoding-option=matmulk}))" %s | FileCheck %s --check-prefixes=CHECK-ALL,MATMULK
+// RUN: iree-opt --split-input-file --iree-dispatch-creation-test-set-scaled-matmul-encodings \
+// RUN:   --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-annotate-data-tiling-hints,iree-dispatch-creation-set-encoding))" %s | FileCheck %s --check-prefixes=CHECK-ALL,CHECK
+// RUN: iree-opt --split-input-file --iree-dispatch-creation-test-set-scaled-matmul-encodings \
+// RUN:   --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-annotate-data-tiling-hints,iree-dispatch-creation-set-encoding{encoding-option=matmulk}))" %s | FileCheck %s --check-prefixes=CHECK-ALL,MATMULK
 // Test with `iree-dispatch-creation-annotate-data-tiling-hints` that adds the
 // data-tiling hint to all the available gemm ops. Otherwise, we have to add the
 // hint to all the gemm ops in the file.
@@ -1203,6 +1205,130 @@
   util.return %4 : !hal.buffer_view
 }
 
-// CHECK-LABEL:  util.func public @region_with_workgroup_count
-// CHECK-NOT:      iree_encoding.set_encoding
-// CHECK-NOT:      iree_encoding.unset_encoding
+// CHECK-ALL-LABEL:  util.func public @region_with_workgroup_count
+// CHECK-NOT:          iree_encoding.set_encoding
+// CHECK-NOT:          iree_encoding.unset_encoding
+
+// -----
+
+util.func public @scaled_contraction_f4_f4_f8_f8_f32(
+    %a : tensor<256x128x32xf4E2M1FN>, %b : tensor<512x128x32xf4E2M1FN>,
+    %a_scales : tensor<256x128xf8E8M0FNU>, %b_scales : tensor<512x128xf8E8M0FNU>,
+    %c : tensor<256x512xf32>) -> tensor<256x512xf32> {
+  %0 = linalg.generic {
+      indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d1, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1)>],
+      iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+      ins(%a, %b, %a_scales, %b_scales : tensor<256x128x32xf4E2M1FN>, tensor<512x128x32xf4E2M1FN>, tensor<256x128xf8E8M0FNU>, tensor<512x128xf8E8M0FNU>)
+      outs(%c : tensor<256x512xf32>) {
+  ^bb0(%in: f4E2M1FN, %in_0: f4E2M1FN, %in_1: f8E8M0FNU, %in_2: f8E8M0FNU, %out: f32):
+    %12 = arith.scaling_extf %in, %in_1 : f4E2M1FN, f8E8M0FNU to f32
+    %13 = arith.scaling_extf %in_0, %in_2 : f4E2M1FN, f8E8M0FNU to f32
+    %14 = arith.mulf %12, %13 : f32
+    %15 = arith.addf %out, %14 : f32
+    linalg.yield %15 : f32
+  } -> tensor<256x512xf32>
+  util.return %0 : tensor<256x512xf32>
+}
+
+// CHECK-ALL-DAG:   #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+// CHECK-ALL-DAG:   #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
+// CHECK-ALL-DAG:   #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2)>
+// CHECK-ALL-DAG:   #[[$MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+// CHECK-ALL-DAG:   #[[$MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+
+// CHECK-DAG:       #[[$ENC0:.+]] = #iree_encoding.encoding<operand_index = 0 : index, op_type =  scaled_matmul, element_types = [f4E2M1FN, f4E2M1FN, f8E8M0FNU, f8E8M0FNU, f32], user_indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]], #[[$MAP3]], #[[$MAP4]]], iteration_sizes = [256, 512, 128, 32]>
+// CHECK-DAG:       #[[$ENC1:.+]] = #iree_encoding.encoding<operand_index = 1 : index, op_type =  scaled_matmul, element_types = [f4E2M1FN, f4E2M1FN, f8E8M0FNU, f8E8M0FNU, f32], user_indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]], #[[$MAP3]], #[[$MAP4]]], iteration_sizes = [256, 512, 128, 32]>
+// CHECK-DAG:       #[[$ENC2:.+]] = #iree_encoding.encoding<operand_index = 2 : index, op_type =  scaled_matmul, element_types = [f4E2M1FN, f4E2M1FN, f8E8M0FNU, f8E8M0FNU, f32], user_indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]], #[[$MAP3]], #[[$MAP4]]], iteration_sizes = [256, 512, 128, 32]>
+// CHECK-DAG:       #[[$ENC3:.+]] = #iree_encoding.encoding<operand_index = 3 : index, op_type =  scaled_matmul, element_types = [f4E2M1FN, f4E2M1FN, f8E8M0FNU, f8E8M0FNU, f32], user_indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]], #[[$MAP3]], #[[$MAP4]]], iteration_sizes = [256, 512, 128, 32]>
+// CHECK-DAG:       #[[$ENC4:.+]] = #iree_encoding.encoding<operand_index = 4 : index, op_type =  scaled_matmul, element_types = [f4E2M1FN, f4E2M1FN, f8E8M0FNU, f8E8M0FNU, f32], user_indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]], #[[$MAP3]], #[[$MAP4]]], iteration_sizes = [256, 512, 128, 32]>
+
+// MATMULK-DAG:     #[[$ENC0:.+]] = #iree_encoding.matmul_k<k_dims = [1, 2]>
+// MATMULK-DAG:     #[[$ENC1:.+]] = #iree_encoding.matmul_k<k_dims = [1]>
+// MATMULK-DAG:     #[[$ENC2:.+]] = #iree_encoding.matmul_k<k_dims = []>
+
+// CHECK-ALL:       util.func public @scaled_contraction_f4_f4_f8_f8_f32
+// CHECK-ALL-SAME:      %[[A:.*]]: tensor<256x128x32xf4E2M1FN>
+// CHECK-ALL-SAME:      %[[B:.*]]: tensor<512x128x32xf4E2M1FN>
+// CHECK-ALL-SAME:      %[[AS:.*]]: tensor<256x128xf8E8M0FNU>
+// CHECK-ALL-SAME:      %[[BS:.*]]: tensor<512x128xf8E8M0FNU>
+// CHECK-ALL-SAME:      %[[C:.*]]: tensor<256x512xf32>
+
+// CHECK-DAG:         %[[A_ENC:.*]] = iree_encoding.set_encoding %[[A]] : tensor<256x128x32xf4E2M1FN> -> tensor<256x128x32xf4E2M1FN, #[[$ENC0]]>
+// CHECK-DAG:         %[[B_ENC:.*]] = iree_encoding.set_encoding %[[B]] : tensor<512x128x32xf4E2M1FN> -> tensor<512x128x32xf4E2M1FN, #[[$ENC1]]>
+// CHECK-DAG:         %[[AS_ENC:.*]] = iree_encoding.set_encoding %[[AS]] : tensor<256x128xf8E8M0FNU> -> tensor<256x128xf8E8M0FNU, #[[$ENC2]]>
+// CHECK-DAG:         %[[BS_ENC:.*]] = iree_encoding.set_encoding %[[BS]] : tensor<512x128xf8E8M0FNU> -> tensor<512x128xf8E8M0FNU, #[[$ENC3]]>
+// CHECK-DAG:         %[[C_ENC:.*]] = iree_encoding.set_encoding %[[C]] : tensor<256x512xf32> -> tensor<256x512xf32, #[[$ENC4]]>
+
+// MATMULK-DAG:       %[[A_ENC:.*]] = iree_encoding.set_encoding %[[A]] : tensor<256x128x32xf4E2M1FN> -> tensor<256x128x32xf4E2M1FN, #[[$ENC0]]>
+// MATMULK-DAG:       %[[B_ENC:.*]] = iree_encoding.set_encoding %[[B]] : tensor<512x128x32xf4E2M1FN> -> tensor<512x128x32xf4E2M1FN, #[[$ENC0]]>
+// MATMULK-DAG:       %[[AS_ENC:.*]] = iree_encoding.set_encoding %[[AS]] : tensor<256x128xf8E8M0FNU> -> tensor<256x128xf8E8M0FNU, #[[$ENC1]]
+// MATMULK-DAG:       %[[BS_ENC:.*]] = iree_encoding.set_encoding %[[BS]] : tensor<512x128xf8E8M0FNU> -> tensor<512x128xf8E8M0FNU, #[[$ENC1]]
+// MATMULK-DAG:       %[[C_ENC:.*]] = iree_encoding.set_encoding %[[C]] : tensor<256x512xf32> -> tensor<256x512xf32, #[[$ENC2]]>
+
+// CHECK-ALL:         %[[GENERIC:.*]] = linalg.generic
+// CHECK-ALL-SAME:      indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]], #[[$MAP3]], #[[$MAP4]]]
+// CHECK-ALL-SAME:      iterator_types = ["parallel", "parallel", "reduction", "reduction"]
+// CHECK-ALL-SAME:      ins(%[[A_ENC]], %[[B_ENC]], %[[AS_ENC]], %[[BS_ENC]]
+// CHECK-ALL-SAME:      outs(%[[C_ENC]]
+// CHECK:             %[[RESULT:.*]] = iree_encoding.unset_encoding %[[GENERIC]] : tensor<256x512xf32, #[[$ENC4]]> -> tensor<256x512xf32>
+// MATMULK:           %[[RESULT:.*]] = iree_encoding.unset_encoding %[[GENERIC]] : tensor<256x512xf32, #[[$ENC2]]> -> tensor<256x512xf32>
+// CHECK-ALL:         util.return %[[RESULT]] : tensor<256x512xf32>
+
+// -----
+
+util.func public @scaled_contraction_multi_k_f4_f4_f8_f8_f32(
+    %a : tensor<256x2x64x32xf4E2M1FN>, %b : tensor<512x2x64x32xf4E2M1FN>,
+    %a_scales : tensor<256x2x64xf8E8M0FNU>, %b_scales : tensor<512x2x64xf8E8M0FNU>,
+    %c : tensor<256x512xf32>) -> tensor<256x512xf32> {
+  %0 = linalg.generic {
+      indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d3, d4)>,
+                       affine_map<(d0, d1, d2, d3, d4) -> (d1, d2, d3, d4)>,
+                       affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d3)>,
+                       affine_map<(d0, d1, d2, d3, d4) -> (d1, d2, d3)>,
+                       affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>],
+      iterator_types = ["parallel", "parallel", "reduction", "reduction", "reduction"]}
+      ins(%a, %b, %a_scales, %b_scales : tensor<256x2x64x32xf4E2M1FN>, tensor<512x2x64x32xf4E2M1FN>, tensor<256x2x64xf8E8M0FNU>, tensor<512x2x64xf8E8M0FNU>)
+      outs(%c : tensor<256x512xf32>) {
+  ^bb0(%in: f4E2M1FN, %in_0: f4E2M1FN, %in_1: f8E8M0FNU, %in_2: f8E8M0FNU, %out: f32):
+    %12 = arith.scaling_extf %in, %in_1 : f4E2M1FN, f8E8M0FNU to f32
+    %13 = arith.scaling_extf %in_0, %in_2 : f4E2M1FN, f8E8M0FNU to f32
+    %14 = arith.mulf %12, %13 : f32
+    %15 = arith.addf %out, %14 : f32
+    linalg.yield %15 : f32
+  } -> tensor<256x512xf32>
+  util.return %0 : tensor<256x512xf32>
+}
+
+//  CHECK-ALL:     util.func public @scaled_contraction_multi_k_f4_f4_f8_f8_f32
+//  CHECK-ALL-NOT:   iree_encoding.set_encoding
+
+// -----
+util.func public @scaled_contraction_multi_kb_f4_f4_f8_f8_f32(
+    %a : tensor<256x128x2x32xf4E2M1FN>, %b : tensor<512x128x2x32xf4E2M1FN>,
+    %a_scales : tensor<256x128xf8E8M0FNU>, %b_scales : tensor<512x128xf8E8M0FNU>,
+    %c : tensor<256x512xf32>) -> tensor<256x512xf32> {
+  %0 = linalg.generic {
+      indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d3, d4)>,
+                       affine_map<(d0, d1, d2, d3, d4) -> (d1, d2, d3, d4)>,
+                       affine_map<(d0, d1, d2, d3, d4) -> (d0, d2)>,
+                       affine_map<(d0, d1, d2, d3, d4) -> (d1, d2)>,
+                       affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>],
+      iterator_types = ["parallel", "parallel", "reduction", "reduction", "reduction"]}
+      ins(%a, %b, %a_scales, %b_scales : tensor<256x128x2x32xf4E2M1FN>, tensor<512x128x2x32xf4E2M1FN>, tensor<256x128xf8E8M0FNU>, tensor<512x128xf8E8M0FNU>)
+      outs(%c : tensor<256x512xf32>) {
+  ^bb0(%in: f4E2M1FN, %in_0: f4E2M1FN, %in_1: f8E8M0FNU, %in_2: f8E8M0FNU, %out: f32):
+    %12 = arith.scaling_extf %in, %in_1 : f4E2M1FN, f8E8M0FNU to f32
+    %13 = arith.scaling_extf %in_0, %in_2 : f4E2M1FN, f8E8M0FNU to f32
+    %14 = arith.mulf %12, %13 : f32
+    %15 = arith.addf %out, %14 : f32
+    linalg.yield %15 : f32
+  } -> tensor<256x512xf32>
+  util.return %0 : tensor<256x512xf32>
+}
+
+//  CHECK-ALL:     util.func public @scaled_contraction_multi_kb_f4_f4_f8_f8_f32
+//  CHECK-ALL-NOT:   iree_encoding.set_encoding