NFC: refresh external dialects under integrations/tensorflow (#11426)

diff --git a/integrations/tensorflow/iree-dialects/include/iree-dialects-c/Utils.h b/integrations/tensorflow/iree-dialects/include/iree-dialects-c/Utils.h
deleted file mode 100644
index 6132ee9..0000000
--- a/integrations/tensorflow/iree-dialects/include/iree-dialects-c/Utils.h
+++ /dev/null
@@ -1,26 +0,0 @@
-// Copyright 2021 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#ifndef IREE_DIALECTS_C_UTILS_H
-#define IREE_DIALECTS_C_UTILS_H
-
-#include "mlir-c/IR.h"
-
-#ifdef __cplusplus
-extern "C" {
-#endif
-
-// TODO: Upstream C/Python APIs for symbol table.
-// Looks up the referrent operation with the given flat symbol, starting from
-// a specific op.
-MLIR_CAPI_EXPORTED MlirOperation
-ireeLookupNearestSymbolFrom(MlirOperation fromOp, MlirAttribute symbolRefAttr);
-
-#ifdef __cplusplus
-}
-#endif
-
-#endif // IREE_DIALECTS_C_UTILS_H
diff --git a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtBase.td b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtBase.td
index afb6c3a..c0530fe 100644
--- a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtBase.td
+++ b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtBase.td
@@ -45,19 +45,28 @@
   : AttrDef<IREELinalgExt_Dialect, name, traits>;
 
 // List of pre-defined data layout encoding attributes.
-def GEMM_LHS
-    : I32EnumAttrCase<"GEMM_LHS", 0>;
-def GEMM_RESULT
-    : I32EnumAttrCase<"GEMM_RESULT", 1>;
-def GEMM_RHS
-    : I32EnumAttrCase<"GEMM_RHS", 2>;
-def GEMM_RHS_TRANSPOSE
-    : I32EnumAttrCase<"GEMM_RHS_TRANSPOSE", 3>;
+def MATMUL_F32F32F32_LHS
+    : I32EnumAttrCase<"MATMUL_F32F32F32_LHS", 0>;
+def MATMUL_F32F32F32_RHS
+    : I32EnumAttrCase<"MATMUL_F32F32F32_RHS", 1>;
+def MATMUL_F32F32F32_RHS_TRANSPOSE
+    : I32EnumAttrCase<"MATMUL_F32F32F32_RHS_TRANSPOSE", 2>;
+def MATMUL_F32F32F32_RESULT
+    : I32EnumAttrCase<"MATMUL_F32F32F32_RESULT", 3>;
+def MATMUL_I8I8I32_LHS
+    : I32EnumAttrCase<"MATMUL_I8I8I32_LHS", 4>;
+def MATMUL_I8I8I32_RHS
+    : I32EnumAttrCase<"MATMUL_I8I8I32_RHS", 5>;
+def MATMUL_I8I8I32_RHS_TRANSPOSE
+    : I32EnumAttrCase<"MATMUL_I8I8I32_RHS_TRANSPOSE", 6>;
+def MATMUL_I8I8I32_RESULT
+    : I32EnumAttrCase<"MATMUL_I8I8I32_RESULT", 7>;
 
 def TensorEncodingEnum
     : I32EnumAttr<"TensorEncoding",
                   "identifier for encoding used for the tensor",[
-                    GEMM_LHS, GEMM_RESULT, GEMM_RHS, GEMM_RHS_TRANSPOSE
+                    MATMUL_F32F32F32_LHS, MATMUL_F32F32F32_RHS, MATMUL_F32F32F32_RHS_TRANSPOSE, MATMUL_F32F32F32_RESULT,
+                    MATMUL_I8I8I32_LHS, MATMUL_I8I8I32_RHS, MATMUL_I8I8I32_RHS_TRANSPOSE, MATMUL_I8I8I32_RESULT,
                   ]> {
   let cppNamespace = "::mlir::iree_compiler::IREE::LinalgExt";
   let genSpecializedAttr = 0;
diff --git a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.td b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.td
index 502d207..25733c7 100644
--- a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.td
+++ b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.td
@@ -588,8 +588,7 @@
     (`outer_dims_perm` `=` $outer_dims_perm^)?
     `inner_dims_pos` `=` $inner_dims_pos
     `inner_tiles` `=`
-    custom<DynamicIndexList>($inner_tiles, $static_inner_tiles,
-                             "ShapedType::kDynamic")
+    custom<DynamicIndexList>($inner_tiles, $static_inner_tiles)
     `into` $outputs `:` `(` type($inputs) type($outputs) `)`
      (`->` type($results)^)?
   }];
@@ -693,7 +692,8 @@
      "getLoopIteratorTypes",
      "generateScalarImplementation",
      "getResultTilePosition",
-     "getTiledImplementation"]>,
+     "getTiledImplementation",
+     "generateResultTileValue"]>,
   DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
 ]>{
   let summary = "unpack operation";
@@ -740,8 +740,7 @@
     (`outer_dims_perm` `=` $outer_dims_perm^)?
     `inner_dims_pos` `=` $inner_dims_pos
     `inner_tiles` `=`
-    custom<DynamicIndexList>($inner_tiles, $static_inner_tiles,
-                             "ShapedType::kDynamic")
+    custom<DynamicIndexList>($inner_tiles, $static_inner_tiles)
     `into` $outputs `:` `(` type($inputs) type($outputs) `)`
      (`->` type($results)^)?
   }];
@@ -887,6 +886,188 @@
 }
 
 //===----------------------------------------------------------------------===//
+// Winograd ops
+//===----------------------------------------------------------------------===//
+
+def IREELinalgExt_WinogradInputTransformOp : IREELinalgExt_Op<"winograd.input_transform",
+    [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+     DeclareOpInterfaceMethods<TilingInterface,
+      ["getIterationDomain",
+       "getLoopIteratorTypes",
+       "getResultTilePosition",
+       "getTiledImplementation"]>]> {
+  let summary = "Winograd Input Transform operator";
+  let description = [{
+    This operator is the first step in converting a convolution to
+    its Winograd equivalent. Given a tile of an input image (I),
+    this operator computes matmul(tranpose(B), matmul(I, B)).
+    The input tile is assumed to be square with each side of size m + r - 1,
+    where the convolutional kernel is m x m and the output tile size is r x r.
+    B is a constant 2-d square matrix of the same shape as the input tile I.
+    The input to the operator is an image of shape (N, H, W, C) and the
+    output is an operator of shape (m + r - 1, m + r - 1, N, H', W', C)
+    where H' = ceil((H - m + 1)/r) and W' = ceil((W - m + 1)/r). The result
+    of this operator is first collapsed and then fed to a batch matmul op.
+  }];
+
+  let arguments = (ins Variadic<AnyShaped>:$inputs,
+                       Variadic<AnyShaped>:$outputs,
+                       I64Attr:$output_tile_size,
+                       I64Attr:$kernel_size,
+                       DenseI64ArrayAttr:$image_dimensions
+  );
+
+  let builders = [
+    OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs,
+      CArg<"int64_t", "8">:$output_tile_size, CArg<"int64_t", "3">:$kernel_size,
+      CArg<"ArrayRef<int64_t>", "{1, 2}">:$image_dimensions)>
+  ];
+
+  let results = (outs Variadic<AnyRankedTensor>:$result);
+  let hasFolder = 1;
+  let assemblyFormat = [{
+    attr-dict
+    `output_tile_size` `(` $output_tile_size `)`
+    `kernel_size` `(` $kernel_size `)`
+    `image_dimensions` `(` $image_dimensions `)`
+    `ins` `(` $inputs `:` type($inputs) `)`
+    `outs` `(` $outputs `:` type($outputs) `)`
+    (`->` type($result)^)?
+  }];
+
+  let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{
+    Value input() {
+      return getInputOperand(0)->get();
+    }
+    Value output() {
+      return getOutputOperand(0)->get();
+    }
+    ShapedType getInputOperandType() {
+      return input().getType().cast<ShapedType>();
+    }
+    ShapedType getOutputOperandType() {
+      return output().getType().cast<ShapedType>();
+    }
+    int64_t getInputOperandRank() {
+      return getInputOperandType().getRank();
+    }
+    int64_t getOutputOperandRank() {
+      return getOutputOperandType().getRank();
+    }
+    int64_t getInputTileSize() {
+      return getOutputTileSize() + getKernelSize() - 1;
+    }
+    SmallVector<int64_t> imageDimensions() {
+      return llvm::to_vector(getImageDimensions());
+    }
+    int64_t getIterationDomainRank() {
+      SmallVector<int64_t> imageDims = imageDimensions();
+      return getInputOperandRank() - imageDims.size();
+    }
+    // Method to implement for specifying output range for
+    // DestinationStyleOpInterface
+    std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
+      std::pair<unsigned, unsigned> outputsIndexAndLength =
+        getODSOperandIndexAndLength(1);
+      return std::make_pair<int64_t, int64_t>(
+          outputsIndexAndLength.first,
+          outputsIndexAndLength.first + outputsIndexAndLength.second);
+    }
+  }];
+}
+
+def IREELinalgExt_WinogradOutputTransformOp : IREELinalgExt_Op<"winograd.output_transform",
+    [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+     DeclareOpInterfaceMethods<TilingInterface,
+      ["getIterationDomain",
+       "getLoopIteratorTypes",
+       "getResultTilePosition",
+       "getTiledImplementation"]>]> {
+  let summary = "Winograd Output Transform operator";
+  let description = [{
+    This operator is the last transform in converting a convolution to
+    its Winograd equivalent. After convolution in the Winograd domain
+    (which turns into an elementwise product for a single channel and
+    batch matrix multiplication for many channels), this operator converts
+    the output back into the original domain. Given a tile of the
+    output (O) in the Winograd domain, this operator computes
+    matmul(transpose(A), matmul(O, A)). The output tile is square with
+    each side of size m + r - 1, where the convolutional kernel is m x m
+    and the output tile size is r x r. A is a constant 2-d matrix of
+    shape (m + r - 1) x r. The input to the operator is a tensor of
+    shape (m + r - 1, m + r - 1, N, H', W', C) and the output is a
+    tensor of shape (N, H, W, C) where H = r H' and W = r W'. This operator
+    is followed by a tensor.extract_slice which extracts only the non-padded
+    part of the output.
+  }];
+
+  let arguments = (ins Variadic<AnyShaped>:$inputs,
+                       Variadic<AnyShaped>:$outputs,
+                       I64Attr:$output_tile_size,
+                       I64Attr:$kernel_size,
+                       DenseI64ArrayAttr:$image_dimensions
+  );
+
+  let builders = [
+    OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs,
+      CArg<"int64_t", "8">:$output_tile_size, CArg<"int64_t", "3">:$kernel_size,
+      CArg<"ArrayRef<int64_t>", "{1, 2}">:$image_dimensions)>
+  ];
+
+  let results = (outs Variadic<AnyRankedTensor>:$result);
+  let hasFolder = 1;
+  let assemblyFormat = [{
+    attr-dict
+    `output_tile_size` `(` $output_tile_size `)`
+    `kernel_size` `(` $kernel_size `)`
+    `image_dimensions` `(` $image_dimensions `)`
+    `ins` `(` $inputs `:` type($inputs) `)`
+    `outs` `(` $outputs `:` type($outputs) `)`
+    (`->` type($result)^)?
+  }];
+
+  let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{
+    Value input() {
+      return getInputOperand(0)->get();
+    }
+    Value output() {
+      return getOutputOperand(0)->get();
+    }
+    ShapedType getInputOperandType() {
+      return input().getType().cast<ShapedType>();
+    }
+    ShapedType getOutputOperandType() {
+      return output().getType().cast<ShapedType>();
+    }
+    SmallVector<int64_t> imageDimensions() {
+      return llvm::to_vector(getImageDimensions());
+    }
+    int64_t getInputOperandRank() {
+      return getInputOperandType().getRank();
+    }
+    int64_t getOutputOperandRank() {
+      return getOutputOperandType().getRank();
+    }
+    int64_t getIterationDomainRank() {
+      SmallVector<int64_t> imageDims = imageDimensions();
+      return getOutputOperandRank() - imageDims.size();
+    }
+    int64_t getInputTileSize() {
+      return getOutputTileSize() + getKernelSize() - 1;
+    }
+    // Method to implement for specifying output range for
+    // DestinationStyleOpInterface
+    std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
+      std::pair<unsigned, unsigned> outputsIndexAndLength =
+        getODSOperandIndexAndLength(1);
+      return std::make_pair<int64_t, int64_t>(
+          outputsIndexAndLength.first,
+          outputsIndexAndLength.first + outputsIndexAndLength.second);
+    }
+  }];
+}
+
+//===----------------------------------------------------------------------===//
 // Pure ops
 //===----------------------------------------------------------------------===//
 
diff --git a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h
index fd3664a..b715a4c 100644
--- a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h
+++ b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h
@@ -10,6 +10,7 @@
 #include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
 #include "mlir/Pass/Pass.h"
 
 namespace mlir {
@@ -93,7 +94,7 @@
   SmallVector<int64_t> outerDimsPerm;
 };
 using MaterializeEncodingFn =
-    std::function<FailureOr<MaterializeEncodingInfo>(TensorEncoding)>;
+    std::function<FailureOr<MaterializeEncodingInfo>(RankedTensorType)>;
 
 /// TypeConverter to use for materializing the encoding.
 struct MaterializeEncodingTypeConverter : public TypeConverter {
@@ -161,6 +162,16 @@
 
 std::unique_ptr<OperationPass<func::FuncOp>> createLinalgExtVectorizationPass();
 
+/// Tile and decompose the winograd transform ops into a sequence
+/// of linalg ops.
+std::unique_ptr<OperationPass<func::FuncOp>>
+createTileAndDecomposeWinogradTransformPass();
+
+// Creates a pass to convert linalg convolution ops into a sequence of
+// linalg_ext.winograd.* ops and linalg.batch_matmul ops using the winograd
+// tranformation.
+std::unique_ptr<Pass> createConvertConv2DToWinogradPass();
+
 // Marker used as attribute the depth of the split reduction transformations.
 const StringLiteral kSplitReductionDepthMarker = "__split_reduction_depth__";
 
@@ -195,14 +206,14 @@
 /// Create a LinalgStrategyTileAndFusePass.
 std::unique_ptr<OperationPass<func::FuncOp>>
 createLinalgStrategyTileAndFusePass(
-    StringRef opName = "", const linalg::LinalgTilingAndFusionOptions &opt = {},
+    StringRef opName = "", const scf::SCFTileAndFuseOptions &options = {},
     const LinalgExt::LinalgTransformationFilter &filter =
         LinalgExt::LinalgTransformationFilter());
 
 /// Create a LinalgStrategyTilePass.
 std::unique_ptr<OperationPass<func::FuncOp>> createLinalgStrategyTilePass(
     StringRef opName = "",
-    const linalg::LinalgTilingOptions &opt = linalg::LinalgTilingOptions(),
+    const scf::SCFTilingOptions &options = scf::SCFTilingOptions(),
     const LinalgExt::LinalgTransformationFilter &filter =
         LinalgExt::LinalgTransformationFilter());
 
diff --git a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.td b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.td
index 51d9eef..badb658 100644
--- a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.td
+++ b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.td
@@ -82,6 +82,21 @@
                     "createLinalgExtVectorizationPass()";
 }
 
+def TileAndDecomposeWinogradTransform :
+    Pass<"iree-linalg-ext-tile-and-decompose-winograd", "func::FuncOp"> {
+  let summary =
+      "Tiles and decomposes winograd transform ops into linalg ops";
+  let constructor = "mlir::iree_compiler::IREE::LinalgExt::"
+                    "createTileAndDecomposeWinogradTransformPass()";
+}
+
+def ConvertConv2DToWinograd :
+    Pass<"iree-linalg-ext-convert-conv2d-to-winograd", ""> {
+  let summary = "Convert linalg convolution ops to winograd based implementation";
+  let constructor = "mlir::iree_compiler::IREE::LinalgExt::createConvertConv2DToWinogradPass()";
+}
+
+
 //===---------------------------------------------------------------------====//
 // Codegen Strategy passes moved into IREE
 // TODO: Deprecate all this.
diff --git a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/CodegenStrategy.h b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/CodegenStrategy.h
index 7664332..d803588 100644
--- a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/CodegenStrategy.h
+++ b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/CodegenStrategy.h
@@ -8,6 +8,7 @@
 #define IREE_DIALECTS_DIALECT_LINALGEXT_TRANSFORMS_CODEGENSTRATEGY_H_
 
 #include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h"
+#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
 #include "mlir/Pass/PassManager.h"
 
 #include <utility>
@@ -38,7 +39,7 @@
 
 /// Represent one application of LinalgStrategyTileAndFusePass.
 struct TileAndFuse : public Transformation {
-  TileAndFuse(StringRef name, linalg::LinalgTilingAndFusionOptions options,
+  TileAndFuse(StringRef name, scf::SCFTileAndFuseOptions options,
               LinalgExt::LinalgTransformationFilter::FilterFunction f = nullptr)
       : Transformation(std::move(f)), opName(name),
         options(std::move(options)) {}
@@ -51,12 +52,12 @@
 
 private:
   std::string opName;
-  linalg::LinalgTilingAndFusionOptions options;
+  scf::SCFTileAndFuseOptions options;
 };
 
 /// Represent one application of LinalgStrategyTilePass.
 struct Tile : public Transformation {
-  Tile(StringRef name, linalg::LinalgTilingOptions options,
+  Tile(StringRef name, scf::SCFTilingOptions options,
        LinalgExt::LinalgTransformationFilter::FilterFunction f = nullptr)
       : Transformation(std::move(f)), opName(name),
         options(std::move(options)) {}
@@ -69,7 +70,7 @@
 
 private:
   std::string opName;
-  linalg::LinalgTilingOptions options;
+  scf::SCFTilingOptions options;
 };
 
 /// Represent one application of LinalgStrategyPadPass.
@@ -171,8 +172,7 @@
   /// Append a pattern to tile the Op `opName` and fuse its producers with
   /// tiling and fusion `options`.
   CodegenStrategy &
-  tileAndFuse(StringRef opName,
-              const linalg::LinalgTilingAndFusionOptions &options,
+  tileAndFuse(StringRef opName, const scf::SCFTileAndFuseOptions &options,
               const LinalgExt::LinalgTransformationFilter::FilterFunction &f =
                   nullptr) {
     transformationSequence.emplace_back(
@@ -182,14 +182,14 @@
   /// Conditionally append a pattern to tile the Op `opName` and fuse its
   /// producers with tiling and fusion `options`.
   CodegenStrategy &tileAndFuseIf(
-      bool b, StringRef opName, linalg::LinalgTilingAndFusionOptions options,
+      bool b, StringRef opName, scf::SCFTileAndFuseOptions options,
       LinalgExt::LinalgTransformationFilter::FilterFunction f = nullptr) {
     return b ? tileAndFuse(opName, std::move(options), std::move(f)) : *this;
   }
   /// Append a pattern to add a level of tiling for Op `opName` with tiling
   /// `options`.
   CodegenStrategy &
-  tile(StringRef opName, const linalg::LinalgTilingOptions &options,
+  tile(StringRef opName, const scf::SCFTilingOptions &options,
        const LinalgExt::LinalgTransformationFilter::FilterFunction &f =
            nullptr) {
     transformationSequence.emplace_back(
@@ -199,7 +199,7 @@
   /// Conditionally append a pattern to add a level of tiling for
   /// `LinalgOpType` with tiling `options`.
   CodegenStrategy &
-  tileIf(bool b, StringRef opName, linalg::LinalgTilingOptions options,
+  tileIf(bool b, StringRef opName, scf::SCFTilingOptions options,
          LinalgExt::LinalgTransformationFilter::FilterFunction f = nullptr) {
     return b ? tile(opName, std::move(options), std::move(f)) : *this;
   }
diff --git a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h
index 5243cfe..34db971 100644
--- a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h
+++ b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h
@@ -10,6 +10,7 @@
 #include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
 #include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
 #include "mlir/IR/PatternMatch.h"
 
 namespace mlir {
@@ -161,6 +162,43 @@
   linalg::LinalgTilingOptions options;
 };
 
+///
+/// Linalg SCF tiling pattern.
+///
+/// Apply the `tiling` transformation as a pattern.
+/// `filter` controls LinalgTransformMarker matching and update when specified.
+/// See `tiling` for more details.
+struct LinalgSCFTilingPattern
+    : public OpInterfaceRewritePattern<TilingInterface> {
+  /// Construct a generic pattern applied to all LinalgOp that verify `filter`.
+  LinalgSCFTilingPattern(
+      MLIRContext *context, scf::SCFTilingOptions options,
+      LinalgTransformationFilter f = LinalgTransformationFilter(),
+      PatternBenefit benefit = 1);
+
+  /// Construct a pattern specifically applied to `opName`.
+  LinalgSCFTilingPattern(
+      StringRef opName, MLIRContext *context, scf::SCFTilingOptions options,
+      LinalgTransformationFilter f = LinalgTransformationFilter(),
+      PatternBenefit benefit = 1);
+
+  /// `matchAndRewrite` implementation that returns the significant transformed
+  /// pieces of IR.
+  LogicalResult returningMatchAndRewrite(TilingInterface op,
+                                         PatternRewriter &rewriter) const;
+
+  LogicalResult matchAndRewrite(TilingInterface op,
+                                PatternRewriter &rewriter) const override {
+    return returningMatchAndRewrite(op, rewriter);
+  }
+
+private:
+  /// LinalgTransformMarker handles special attribute manipulations.
+  LinalgTransformationFilter filter;
+  /// Options to control tiling;
+  scf::SCFTilingOptions options;
+};
+
 template <typename... OpTypes>
 class TilingPatterns;
 
@@ -185,6 +223,36 @@
 };
 
 ///
+/// Linalg SCF tile and fuse patterns.
+///
+/// `filter` controls LinalgTransformMarker matching and update when specified.
+struct LinalgSCFTileAndFusePattern
+    : public OpInterfaceRewritePattern<TilingInterface> {
+  /// Construct a generic pattern applied to all LinalgOp that verify `filter`.
+  LinalgSCFTileAndFusePattern(
+      MLIRContext *context,
+      scf::SCFTileAndFuseOptions options = scf::SCFTileAndFuseOptions(),
+      LinalgTransformationFilter f = LinalgTransformationFilter(),
+      PatternBenefit benefit = 1);
+
+  /// Construct a pattern specifically applied to `opName`.
+  LinalgSCFTileAndFusePattern(
+      StringRef opName, MLIRContext *context,
+      scf::SCFTileAndFuseOptions options = scf::SCFTileAndFuseOptions(),
+      LinalgTransformationFilter f = LinalgTransformationFilter(),
+      PatternBenefit benefit = 1);
+
+  LogicalResult matchAndRewrite(TilingInterface op,
+                                PatternRewriter &rewriter) const override;
+
+private:
+  /// LinalgTransformMarker handles special attribute manipulations.
+  LinalgTransformationFilter filter;
+
+  scf::SCFTileAndFuseOptions options;
+};
+
+///
 /// Linalg vectorization patterns.
 ///
 /// `filter` controls LinalgTransformMarker matching and update when specified.
diff --git a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/Utils.h b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/Utils.h
index 3ad536d..722e1a9 100644
--- a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/Utils.h
+++ b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/Utils.h
@@ -11,6 +11,7 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
 
 namespace mlir {
 namespace iree_compiler {
@@ -53,6 +54,11 @@
 SmallVector<int64_t> computeInterchangeFromDimPos(ArrayRef<int64_t> dimsPos,
                                                   int64_t rank);
 
+/// Converts a 2D float array to a constant value. The 2D array is stored as
+/// a 1D row-major array in `val` and has shape `rows` x `cols`.
+Value createValueFrom2DConstant(const float *val, int64_t rows, int64_t cols,
+                                Location loc, PatternRewriter &rewriter);
+
 } // namespace LinalgExt
 } // namespace IREE
 } // namespace iree_compiler
diff --git a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/WinogradConstants.h b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/WinogradConstants.h
new file mode 100644
index 0000000..77dbb09
--- /dev/null
+++ b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/WinogradConstants.h
@@ -0,0 +1,90 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_DIALECTS_DIALECT_LINALGEXT_UTILS_WINOGRAD_CONSTANTS_H_
+#define IREE_DIALECTS_DIALECT_LINALGEXT_UTILS_WINOGRAD_CONSTANTS_H_
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace LinalgExt {
+namespace Winograd {
+
+// This file contains the Winograd constant matrices for different
+// output tile sizes
+
+//===----------------------------------------------------------------------===//
+// Output tile size = 6, Kernel size = 3
+//===----------------------------------------------------------------------===//
+// These constants were obtained from this paper:
+//
+// Liu, J. et al (2021) Optimizing Winograd-Based Convolution with Tensor Cores.
+// https://dl.acm.org/doi/abs/10.1145/3472456.3472473
+//
+
+// clang-format off
+
+const float BT_6x6_3x3[] = {
+  1,      0, -21./4.,        0,  21./4.,       0, -1, 0,
+  0,      1,       1,  -17./4., -17./4.,       1,  1, 0,
+  0,     -1,       1,   17./4., -17./4.,      -1,  1, 0,
+  0,   1./2,   1./4.,   -5./2.,  -5./4.,       2,  1, 0,
+  0,  -1./2,   1./4.,    5./2.,  -5./4.,      -2,  1, 0,
+  0,      2,       4,   -5./2.,      -5,   1./2.,  1, 0,
+  0,     -2,       4,    5./2.,      -5,  -1./2.,  1, 0,
+  0,     -1,       0,   21./4.,       0, -21./4.,  0, 1
+};
+
+const float B_6x6_3x3[] = {
+        1,       0,       0,      0,      0,      0,      0,       0,
+        0,       1,      -1,   1./2,  -1./2,      2,     -2,      -1,
+  -21./4.,       1,       1,  1./4.,  1./4.,      4,      4,       0,
+        0, -17./4.,  17./4., -5./2.,  5./2., -5./2.,  5./2.,  21./4.,
+   21./4., -17./4., -17./4., -5./4., -5./4.,     -5,     -5,       0,
+        0,       1,      -1,      2,     -2,  1./2., -1./2., -21./4.,
+       -1,       1,       1,      1,      1,      1,      1,       0,
+        0,       0,       0,      0,      0,      0,      0,       1
+};
+
+const float G_6x6_3x3[] = {
+       1,       0,      0,
+  -2./9.,  -2./9., -2./9.,
+  -2./9.,   2./9., -2./9.,
+   1./90,   1./45,  2./45,
+   1./90,  -1./45,  2./45,
+  32./45,  16./45,  8./45,
+  32./45, -16./45,  8./45,
+       0,       0,      1
+};
+
+const float AT_6x6_3x3[] = {
+  1,  1,   1,   1,    1,     1,      1,  0,
+  0,  1,  -1,   2,   -2,  1./2,  -1./2,  0,
+  0,  1,   1,   4,    4,  1./4,   1./4,  0,
+  0,  1,  -1,   8,   -8,  1./8,  -1./8,  0,
+  0,  1,   1,  16,   16, 1./16,  1./16,  0,
+  0,  1,  -1,  32,  -32, 1./32, -1./32,  1
+};
+
+const float A_6x6_3x3[] = {
+  1,     0,    0,     0,     0,      0,
+  1,     1,    1,     1,     1,      1,
+  1,    -1,    1,    -1,     1,     -1,
+  1,     2,    4,     8,    16,     32,
+  1,    -2,    4,    -8,    16,    -32,
+  1,  1./2, 1./4,  1./8, 1./16,  1./32,
+  1, -1./2, 1./4, -1./8, 1./16, -1./32,
+  0,     0,    0,     0,     0,      1
+};
+
+// clang-format on
+
+} // namespace Winograd
+} // namespace LinalgExt
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
+#endif // IREE_DIALECTS_DIALECT_LINALGEXT_UTILS_WINOGRAD_CONSTANTS_H_
diff --git a/integrations/tensorflow/iree-dialects/lib/CAPI/CMakeLists.txt b/integrations/tensorflow/iree-dialects/lib/CAPI/CMakeLists.txt
index cba823b..17561a6 100644
--- a/integrations/tensorflow/iree-dialects/lib/CAPI/CMakeLists.txt
+++ b/integrations/tensorflow/iree-dialects/lib/CAPI/CMakeLists.txt
@@ -1,6 +1,5 @@
 add_mlir_public_c_api_library(IREEDialectsCAPI
   Dialects.cpp
-  Utils.cpp
   LINK_LIBS PUBLIC
   MLIRIR
   MLIRTransformDialect
diff --git a/integrations/tensorflow/iree-dialects/lib/CAPI/Utils.cpp b/integrations/tensorflow/iree-dialects/lib/CAPI/Utils.cpp
deleted file mode 100644
index d704f2b..0000000
--- a/integrations/tensorflow/iree-dialects/lib/CAPI/Utils.cpp
+++ /dev/null
@@ -1,20 +0,0 @@
-// Copyright 2021 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree-dialects-c/Utils.h"
-
-#include "mlir/CAPI/IR.h"
-#include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/IR/SymbolTable.h"
-
-using namespace mlir;
-
-MlirOperation ireeLookupNearestSymbolFrom(MlirOperation fromOp,
-                                          MlirAttribute symbolRefAttr) {
-  auto symbolRefAttrCpp = unwrap(symbolRefAttr).cast<SymbolRefAttr>();
-  return wrap(
-      SymbolTable::lookupNearestSymbolFrom(unwrap(fromOp), symbolRefAttrCpp));
-}
diff --git a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
index 6e0a36c..fc0db72 100644
--- a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
+++ b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -1706,8 +1706,11 @@
   SmallVector<Value> dynamicTileSizes;
   dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes,
                              ShapedType::kDynamic);
-  build(builder, state, output.getType(), source, output, outerDimsPerm,
-        innerDimsPos, dynamicTileSizes, staticTileSizes,
+  build(builder, state, output.getType(), source, output,
+        outerDimsPerm.empty() ? nullptr
+                              : builder.getDenseI64ArrayAttr(outerDimsPerm),
+        builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
+        builder.getDenseI64ArrayAttr(staticTileSizes),
         (paddingValue ? paddingValue.value() : nullptr));
 }
 
@@ -2096,8 +2099,11 @@
   SmallVector<Value> dynamicTileSizes;
   dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes,
                              ShapedType::kDynamic);
-  build(builder, state, output.getType(), source, output, outerDimsPerm,
-        innerDimsPos, dynamicTileSizes, staticTileSizes);
+  build(builder, state, output.getType(), source, output,
+        outerDimsPerm.empty() ? nullptr
+                              : builder.getDenseI64ArrayAttr(outerDimsPerm),
+        builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
+        builder.getDenseI64ArrayAttr(staticTileSizes));
 }
 
 SmallVector<OpFoldResult> UnPackOp::getMixedTiles() {
@@ -2363,6 +2369,15 @@
   return iteratorTypes;
 }
 
+FailureOr<Value>
+UnPackOp::generateResultTileValue(OpBuilder &b, unsigned resultNumber,
+                                  ArrayRef<OpFoldResult> offsets,
+                                  ArrayRef<OpFoldResult> sizes) {
+  return getTiledImplementation(b, offsets, sizes)
+      .back()
+      ->getResult(resultNumber);
+}
+
 //===----------------------------------------------------------------------===//
 // WinogradInputTransformOp
 //===----------------------------------------------------------------------===//
@@ -2531,6 +2546,165 @@
       .reifyResultShapes(b, reifiedReturnShapes);
 }
 
+//===----------------------------------------------------------------------===//
+// WinogradOutputTransformOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult WinogradOutputTransformOp::verify() {
+  Operation *op = getOperation();
+  if (getNumInputs() != 1) {
+    return op->emitOpError("expected one input operand");
+  }
+  if (getNumOutputs() != 1) {
+    return op->emitOpError("expected one output operand");
+  }
+  auto inputType = input().getType().cast<ShapedType>();
+  auto outputType = output().getType().cast<ShapedType>();
+  ArrayRef<int64_t> inputShape = inputType.getShape();
+  if (inputShape.size() != 6) {
+    return op->emitOpError("expected input operand to have rank 6");
+  }
+  ArrayRef<int64_t> outputShape = outputType.getShape();
+  if (outputType.getElementType() != inputType.getElementType()) {
+    return op->emitOpError(
+        "expected input/output element types to be identical");
+  }
+  if (getOutputOperandRank() != getInputOperandRank() - 2) {
+    return op->emitOpError(
+        "expected output rank to be equal to input rank - 2");
+  }
+  const SmallVector<int64_t> imageDims = imageDimensions();
+  const size_t numImageDims = imageDims.size();
+  llvm::SmallSetVector<int64_t, 2> imageDimsSet(imageDims.begin(),
+                                                imageDims.end());
+  if (imageDims.size() != 2) {
+    return op->emitOpError("expected only 2 image dimensions");
+  }
+  for (auto dim : imageDims) {
+    if ((dim < 0) || (dim > 3)) {
+      return op->emitOpError(
+          "expect image dimensions to be in the range: [0, 3]");
+    }
+  }
+  const int64_t outputTileSize = getOutputTileSize();
+  SmallVector<int64_t> expectedOutputShape(getOutputOperandRank(), 1);
+  int outputIndex;
+  for (int i = numImageDims; i < inputShape.size(); i++) {
+    outputIndex = i - numImageDims;
+    if (!imageDimsSet.contains(outputIndex)) {
+      expectedOutputShape[outputIndex] = inputShape[i];
+    } else {
+      expectedOutputShape[outputIndex] = outputTileSize * inputShape[i];
+    }
+  }
+  if (!areShapesCompatible(expectedOutputShape, outputShape)) {
+    return op->emitOpError("incompatible output shape");
+  }
+  return success();
+}
+
+SmallVector<Range>
+WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
+  Location loc = getLoc();
+  Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+  Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
+  Value source = output();
+  SmallVector<int64_t> imageDims = imageDimensions();
+  llvm::SmallSetVector<int64_t, 2> imageDimsSet(imageDims.begin(),
+                                                imageDims.end());
+  SmallVector<Range> loopBounds(imageDims.size());
+  int count = 0;
+  for (auto dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
+    if (!imageDimsSet.contains(dim)) {
+      loopBounds[count].offset = zero;
+      loopBounds[count].size = getDimValue(builder, loc, source, dim);
+      loopBounds[count].stride = one;
+      count++;
+    }
+  }
+  return loopBounds;
+}
+
+SmallVector<utils::IteratorType>
+WinogradOutputTransformOp::getLoopIteratorTypes() {
+  SmallVector<utils::IteratorType> iteratorTypes(getIterationDomainRank(),
+                                                 utils::IteratorType::parallel);
+  return iteratorTypes;
+}
+
+SmallVector<Operation *> WinogradOutputTransformOp::getTiledImplementation(
+    OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
+    ArrayRef<OpFoldResult> sizes) {
+
+  Location loc = getLoc();
+  auto one = builder.getIndexAttr(1);
+  auto zero = builder.getIndexAttr(0);
+
+  assert(offsets.size() == 2);
+  SmallVector<OpFoldResult> inputOffsets(getInputOperandRank(), zero);
+  SmallVector<OpFoldResult> outputOffsets(getOutputOperandRank(), zero);
+  inputOffsets[2] = outputOffsets[0] = offsets[0];
+  inputOffsets[5] = outputOffsets[3] = offsets[1];
+
+  SmallVector<OpFoldResult> inputStrides(getInputOperandRank(), one);
+  SmallVector<OpFoldResult> outputStrides(getOutputOperandRank(), one);
+
+  assert(sizes.size() == 2);
+  auto inputShape = input().getType().cast<ShapedType>().getShape();
+  auto outputShape = output().getType().cast<ShapedType>().getShape();
+  SmallVector<OpFoldResult> inputSizes =
+      getAsOpFoldResult(builder.getIndexArrayAttr(inputShape));
+  SmallVector<OpFoldResult> outputSizes =
+      getAsOpFoldResult(builder.getIndexArrayAttr(outputShape));
+  inputSizes[2] = outputSizes[0] = sizes[0];
+  inputSizes[5] = outputSizes[3] = sizes[1];
+
+  SmallVector<Value> tiledOperands;
+  tiledOperands.emplace_back(
+      getSlice(builder, loc, input(), inputOffsets, inputSizes, inputStrides));
+  tiledOperands.emplace_back(getSlice(builder, loc, output(), outputOffsets,
+                                      outputSizes, outputStrides));
+
+  SmallVector<Type, 4> resultTypes;
+  if (hasTensorSemantics()) {
+    resultTypes.push_back(tiledOperands[1].getType());
+  }
+
+  Operation *tiledOp =
+      mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
+
+  return {tiledOp};
+}
+
+LogicalResult WinogradOutputTransformOp::getResultTilePosition(
+    OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
+    ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
+    SmallVector<OpFoldResult> &resultSizes) {
+  if (resultNumber == 0) {
+    auto resultShape = output().getType().cast<ShapedType>().getShape();
+    resultSizes = getAsOpFoldResult(builder.getIndexArrayAttr(resultShape));
+    resultOffsets = SmallVector<OpFoldResult>(getOutputOperandRank(),
+                                              builder.getIndexAttr(0));
+    resultOffsets[0] = offsets[0];
+    resultOffsets[3] = offsets[1];
+    resultSizes[0] = sizes[0];
+    resultSizes[3] = sizes[1];
+    return success();
+  }
+  return failure();
+}
+
+LogicalResult WinogradOutputTransformOp::fold(ArrayRef<Attribute>,
+                                              SmallVectorImpl<OpFoldResult> &) {
+  return memref::foldMemRefCast(*this);
+}
+
+LogicalResult WinogradOutputTransformOp::reifyResultShapes(
+    OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+  return cast<LinalgExtOp>(getOperation())
+      .reifyResultShapes(b, reifiedReturnShapes);
+}
+
 #define DEFINE_OP_GET_EFFECTS(OP_NAME)                                         \
   void OP_NAME::getEffects(                                                    \
       SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>      \
@@ -2550,6 +2724,7 @@
 DEFINE_OP_GET_EFFECTS(PackOp)
 DEFINE_OP_GET_EFFECTS(UnPackOp)
 DEFINE_OP_GET_EFFECTS(WinogradInputTransformOp)
+DEFINE_OP_GET_EFFECTS(WinogradOutputTransformOp)
 
 //===----------------------------------------------------------------------===//
 // iree_linalg_ext.set_encoding
diff --git a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/CMakeLists.txt b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/CMakeLists.txt
index 69fbd32..68fdab8 100644
--- a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/CMakeLists.txt
+++ b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/CMakeLists.txt
@@ -1,10 +1,12 @@
 add_mlir_library(IREELinalgExtPasses
+  ConvertConv2DToWinograd.cpp
   ConvertToLoops.cpp
   FoldIntoPackAndUnpackOps.cpp
   MaterializeEncoding.cpp
   PadContractionToBlockSize.cpp
   Passes.cpp
   SplitReduction.cpp
+  TileAndDecomposeWinogradPass.cpp
   Tiling.cpp
 
   DEPENDS
diff --git a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/ConvertConv2DToWinograd.cpp b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/ConvertConv2DToWinograd.cpp
new file mode 100644
index 0000000..e0f289e
--- /dev/null
+++ b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/ConvertConv2DToWinograd.cpp
@@ -0,0 +1,396 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h"
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "iree-dialects/Dialect/LinalgExt/Passes/PassDetail.h"
+#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h"
+#include "iree-dialects/Dialect/LinalgExt/Utils/WinogradConstants.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SetVector.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace LinalgExt {
+
+static inline int index(int y, int x, int dimy, int dimx) {
+  return (x + dimx * y);
+}
+
+static inline int index(int z, int y, int x, int w, int dimz, int dimy,
+                        int dimx, int dimw) {
+  return (w + dimw * (x + dimx * (y + dimy * z)));
+}
+
+static bool hasAllOneValues(DenseIntElementsAttr attr) {
+  return llvm::all_of(attr, [](APInt element) { return element.isOne(); });
+}
+
+// TODO: Make this a user-settable parameter once we have support
+// for more tile sizes
+static constexpr int64_t outputTileSize = 6;
+
+/// This function computes the Winograd filter transform when
+/// the filter is known to be a constant. Specifically, this
+/// function computes matmul(G, matmul(F, transpose(G))) where
+/// F is a tile of the convolution filter of size m x m
+/// (single input channel, single output channel) and G has
+/// shape m x (m + r - 1) where r is the output tile size and
+/// (m + r - 1) is the input tile size.
+/// The time complexity of this function is O(ic * oc)
+/// where ic is the number of input channels and oc is the
+/// number of output channels since input tile size and kernel size
+/// are constants. So for large ic and oc, this function is
+/// time intensive.
+/// TODO: Codegen this as a kernel and run once at initialization
+static DenseElementsAttr foldFilterTransform(
+    ArrayRef<int64_t> shape, int64_t inputTileSize, int64_t kernelSize,
+    Type outputType, const float *G, bool isSplat, float splatValue,
+    DenseElementsAttr::iterator_range<APFloat> &input, Type elementType) {
+  const int &kh = shape[0];
+  const int &kw = shape[1];
+  const int &ic = shape[2];
+  const int &oc = shape[3];
+  const int64_t numElements = inputTileSize * inputTileSize * ic * oc;
+  SmallVector<APFloat> output(numElements, APFloat(0.0f));
+  for (int d0 = 0; d0 < inputTileSize; d0++) {
+    for (int d1 = 0; d1 < inputTileSize; d1++) {
+      for (int d2 = 0; d2 < ic; d2++) {
+        for (int d3 = 0; d3 < oc; d3++) {
+          APFloat accum(0.0f);
+          for (int d4 = 0; d4 < kernelSize; d4++) {
+            for (int d5 = 0; d5 < kernelSize; d5++) {
+              APFloat ival(splatValue);
+              if (!isSplat) {
+                ival = input[index(d4, d5, d2, d3, kh, kw, ic, oc)];
+              }
+              int idx0 = index(d0, d4, inputTileSize, kernelSize);
+              int idx1 = index(d1, d5, inputTileSize, kernelSize);
+              accum = accum + APFloat(G[idx0]) * ival * APFloat(G[idx1]);
+            }
+          }
+          int odx = index(d0, d1, d2, d3, inputTileSize, inputTileSize, ic, oc);
+          output[odx] = accum;
+        }
+      }
+    }
+  }
+  return DenseElementsAttr::get(outputType, output);
+}
+
+namespace {
+
+class FoldWinogradFilterTransform final
+    : public OpRewritePattern<linalg::Conv2DNhwcHwcfOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
+                                PatternRewriter &rewriter) const override {
+    // Check that kernel size = 3x3
+    Value kernel = convOp.getInputs()[1];
+    auto kernelType = kernel.getType().cast<ShapedType>();
+    if (!kernelType)
+      return failure();
+    ArrayRef<int64_t> kernelShape = kernelType.getShape();
+    const int64_t kh = kernelShape[0];
+    const int64_t kw = kernelShape[1];
+    if ((kh != 3) || (kw != 3))
+      return failure();
+    const int64_t kernelSize = kh;
+    const int64_t inputTileSize = outputTileSize + kernelSize - 1;
+
+    DenseIntOrFPElementsAttr kernelAttr;
+    if (!matchPattern(kernel, m_Constant(&kernelAttr))) {
+      return failure();
+    }
+
+    Operation *constOp = kernel.getDefiningOp();
+    ShapedType type = constOp->getResult(0).getType().cast<ShapedType>();
+    Type elementType = type.getElementType();
+    assert(elementType.isa<FloatType>());
+    ArrayRef<int64_t> shape = type.getShape();
+    DenseElementsAttr::iterator_range<APFloat> nonSplatValues =
+        kernelAttr.getValues<APFloat>();
+    bool isSplat = kernelAttr.isSplat();
+    float splatValue{0.0};
+    if (isSplat) {
+      splatValue = kernelAttr.getSplatValue<APFloat>().convertToFloat();
+    }
+    SmallVector<int64_t> resultShape{inputTileSize * inputTileSize, shape[2],
+                                     shape[3]};
+    auto resultType = RankedTensorType::get(resultShape, elementType);
+    auto foldedKernelAttr =
+        foldFilterTransform(shape, inputTileSize, kernelSize, resultType,
+                            IREE::LinalgExt::Winograd::G_6x6_3x3, isSplat,
+                            splatValue, nonSplatValues, elementType);
+    rewriter.replaceOpWithNewOp<arith::ConstantOp>(constOp, foldedKernelAttr);
+    return success();
+  }
+};
+
+} // namespace
+
+static Value
+createCollapse(Value tensor, Location loc, PatternRewriter &rewriter,
+               SmallVectorImpl<int64_t> &outputShape,
+               SmallVectorImpl<ReassociationIndices> &reassociations) {
+  auto tensorType = tensor.getType().cast<ShapedType>();
+  auto elementTy = tensorType.getElementType();
+  auto resultType = RankedTensorType::get(outputShape, elementTy);
+  return rewriter.create<tensor::CollapseShapeOp>(loc, resultType, tensor,
+                                                  reassociations);
+}
+
+static Value
+createExpand(Value tensor, Location loc, PatternRewriter &rewriter,
+             SmallVectorImpl<int64_t> &outputShape,
+             SmallVectorImpl<ReassociationIndices> &reassociations) {
+  auto tensorType = tensor.getType().cast<ShapedType>();
+  auto elementTy = tensorType.getElementType();
+  auto resultType = RankedTensorType::get(outputShape, elementTy);
+  return rewriter.create<tensor::ExpandShapeOp>(loc, resultType, tensor,
+                                                reassociations);
+}
+
+namespace {
+
+/// Convert conv2d to a sequence of ops that implement the
+/// Winograd transformation. The Winograd transformation
+/// is parameterized by the output tile size(r). The larger
+/// the tile size, the greater the computational savings,
+/// but this comes at the cost of accuracy.
+///
+/// For now, we restrict this transform to convolutions
+/// where the filter size = 3x3, though extensions to larger
+/// filter sizes are possible. We refer to the
+/// filter size as (m). The input tile size (i) is defined as
+/// m + r - 1. For a given output tile size, the Winograd
+/// transformation defines 3 constant matrices:
+///
+/// B: i x i [used in input transform]
+/// G: m x i [used in the filter transform]
+/// A: i x r [used in output transform]
+///
+/// The choice of these matrices is not unique and affects
+/// the accuracy of the approach.
+///
+/// Given a convolution of the form
+///
+/// y = conv2d(x, f)
+///
+/// where x: (N, H, W, C)
+///       f: (H, W, C, F)
+///
+/// this pattern converts the convolution to the following
+/// sequence:
+///
+/// f_winograd = winograd.filter_transform(f) [folded]
+/// x_winograd = winograd.input_transform(x)
+/// x_winograd_c = collapse(x_winograd)
+/// y_winograd = batch_matmul(x_winograd_c, f_winograd)
+/// y_winograd_e = expand(y_winograd)
+/// y_padded = winograd.output_transform(y_winograd_e)
+/// y = extract_slice(y_padded)
+///
+/// where the dimensions of the tensors above are:
+///
+/// f_winograd:   (i * i, C, F)
+/// x_winograd:   (i, i, N, H', W', C)
+/// x_winograd_c: (i * i, N * H' * W', C)
+/// y_winograd:   (i * i, N * H' * W', F)
+/// y_winograd_e: (i, i, N, H', W', F)
+/// y_padded:     (N, r * H', r * W', F)
+///
+/// H': ceil((H - m + 1) / r)
+/// W': ceil((W - m + 1) / r)
+///
+/// The winograd input transform extracts a tile of the input
+/// of size i x i and computes matmul(transpose(B), matmul(tile(x), B)).
+/// The winograd filter transform extracts a tile of the filter
+/// of size m x m and computes matmul(G, matmul(tile(f), transpose(G)).
+/// These two are then combined using elementwise multiplication
+/// (which becomes a batch matmul when combining over multiple channels).
+/// The winograd output filter extracts a tile of the result of size
+/// i x i and computes matmul(transpose(A), matmul(tile(y_winograd_e), A)).
+///
+/// For more information and additional references,
+/// see here:
+///
+/// https://github.com/nod-ai/MLIRWinogradTalk/blob/main/MLIRSummit2022.Nodai.Menon.pdf
+///
+class ConvertConv2DNhwcHwcf final
+    : public OpRewritePattern<linalg::Conv2DNhwcHwcfOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
+                                PatternRewriter &rewriter) const override {
+    // Check that strides = 1
+    if (!hasAllOneValues(convOp.getStrides()))
+      return failure();
+
+    // Check that dilations = 1
+    if (!hasAllOneValues(convOp.getDilations()))
+      return failure();
+
+    // Check that kernel has been constant folded (by validating rank = 3)
+    Value kernel = convOp.getInputs()[1];
+    auto kernelType = kernel.getType().cast<ShapedType>();
+    if (!kernelType)
+      return failure();
+    Type elementType = kernelType.getElementType();
+    ArrayRef<int64_t> kernelShape = kernelType.getShape();
+    if (kernelShape.size() != 3)
+      return failure();
+
+    const int64_t kernelSize = 3;
+    const int64_t inputTileSize = outputTileSize + kernelSize - 1;
+
+    // Create winograd input transform op
+    Location loc = convOp.getLoc();
+    Value zero = rewriter.create<arith::ConstantOp>(
+        loc, rewriter.getZeroAttr(elementType));
+    Value input = convOp.getInputs()[0];
+    auto inputType = input.getType().cast<ShapedType>();
+    if (!inputType)
+      return failure();
+    ArrayRef<int64_t> inputShape = inputType.getShape();
+    if (llvm::any_of(inputShape, ShapedType::isDynamic))
+      return failure();
+    assert(inputShape.size() == 4);
+
+    SmallVector<int64_t, 2> imageDimensions = {1, 2};
+    const size_t numImageDims = imageDimensions.size();
+    SmallVector<int64_t> resultShape(6, inputTileSize);
+    llvm::SmallSetVector<int64_t, 2> imageDimensionsSet(imageDimensions.begin(),
+                                                        imageDimensions.end());
+    int outputIndex;
+    for (int i = 0; i < inputShape.size(); i++) {
+      outputIndex = i + numImageDims;
+      if (!imageDimensionsSet.contains(i)) {
+        resultShape[outputIndex] = inputShape[i];
+      } else {
+        resultShape[outputIndex] =
+            std::ceil((float)(inputShape[i] - kernelSize + 1) / outputTileSize);
+      }
+    }
+    Value emptyTensor =
+        rewriter.create<tensor::EmptyOp>(loc, resultShape, elementType);
+    auto winogradInputOp =
+        rewriter.create<IREE::LinalgExt::WinogradInputTransformOp>(
+            loc, emptyTensor.getType(), ValueRange{input},
+            ValueRange{emptyTensor}, outputTileSize, kernelSize,
+            imageDimensions);
+    Value winogradInput = winogradInputOp.getResult()[0];
+
+    // Add collapse shape
+    SmallVector<int64_t> collapsedShape = {
+        resultShape[0] * resultShape[1],
+        resultShape[2] * resultShape[3] * resultShape[4], resultShape[5]};
+    SmallVector<ReassociationIndices> reassociations = {{0, 1}, {2, 3, 4}, {5}};
+    Value collapsedWinogradInput = createCollapse(
+        winogradInput, loc, rewriter, collapsedShape, reassociations);
+
+    // Add BatchMatmulOp
+    SmallVector<int64_t> bmmShape(collapsedShape.begin(), collapsedShape.end());
+    Value output = convOp.getOutputs()[0];
+    auto outputType = output.getType().cast<RankedTensorType>();
+    ArrayRef<int64_t> outputShape = outputType.getShape();
+    bmmShape[2] = outputShape[3];
+    auto bmmOutputType = RankedTensorType::get(bmmShape, elementType);
+    emptyTensor = rewriter.create<tensor::EmptyOp>(loc, bmmShape, elementType);
+    auto fillOp = rewriter.create<linalg::FillOp>(loc, ValueRange{zero},
+                                                  ValueRange{emptyTensor});
+    auto bmmOp = rewriter.create<linalg::BatchMatmulOp>(
+        loc, bmmOutputType, ValueRange({collapsedWinogradInput, kernel}),
+        ValueRange({fillOp.result()}));
+    Value bmmResult = bmmOp.getResult(0);
+
+    // Add expand shape
+    SmallVector<int64_t> expandedShape = {resultShape[0], resultShape[1],
+                                          resultShape[2], resultShape[3],
+                                          resultShape[4], outputShape[3]};
+    reassociations = {{0, 1}, {2, 3, 4}, {5}};
+    Value expandedBmmResult =
+        createExpand(bmmResult, loc, rewriter, expandedShape, reassociations);
+
+    // Convert back into original domain
+    SmallVector<int64_t> paddedResultShape(outputShape.size(), 0);
+    for (int i = 0; i < outputShape.size(); i++) {
+      if (!imageDimensionsSet.contains(i)) {
+        paddedResultShape[i] = outputShape[i];
+      } else {
+        paddedResultShape[i] = resultShape[i + numImageDims] * outputTileSize;
+      }
+    }
+    emptyTensor =
+        rewriter.create<tensor::EmptyOp>(loc, paddedResultShape, elementType);
+    auto winogradOutputOp =
+        rewriter.create<IREE::LinalgExt::WinogradOutputTransformOp>(
+            loc, emptyTensor.getType(), ValueRange{expandedBmmResult},
+            ValueRange{emptyTensor}, outputTileSize, kernelSize,
+            imageDimensions);
+    Value paddedOutput = winogradOutputOp.getResult()[0];
+
+    // Extract slice
+    SmallVector<OpFoldResult> offsets(outputShape.size(),
+                                      rewriter.getIndexAttr(0));
+    SmallVector<OpFoldResult> strides(outputShape.size(),
+                                      rewriter.getIndexAttr(1));
+    SmallVector<OpFoldResult> sizes;
+    for (int i = 0; i < outputShape.size(); i++)
+      sizes.push_back(rewriter.getIndexAttr(outputShape[i]));
+    auto winogradOutput = rewriter.create<tensor::ExtractSliceOp>(
+        loc, outputType, paddedOutput, offsets, sizes, strides);
+
+    Value result = convOp.getResult(0);
+    result.replaceAllUsesWith(winogradOutput);
+    return success();
+  }
+};
+
+struct ConvertConv2DToWinogradPass
+    : ConvertConv2DToWinogradBase<ConvertConv2DToWinogradPass> {
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry
+        .insert<linalg::LinalgDialect, IREE::LinalgExt::IREELinalgExtDialect>();
+  }
+  void runOnOperation() override {
+    MLIRContext *context = &getContext();
+    RewritePatternSet patterns(&getContext());
+    patterns.insert<FoldWinogradFilterTransform, ConvertConv2DNhwcHwcf>(
+        context);
+    if (failed(applyPatternsAndFoldGreedily(getOperation(),
+                                            std::move(patterns)))) {
+      return signalPassFailure();
+    }
+  }
+};
+
+} // namespace
+
+std::unique_ptr<Pass> createConvertConv2DToWinogradPass() {
+  return std::make_unique<ConvertConv2DToWinogradPass>();
+}
+
+} // namespace LinalgExt
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/FoldIntoPackAndUnpackOps.cpp b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/FoldIntoPackAndUnpackOps.cpp
index 0713ce6..7477bc0 100644
--- a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/FoldIntoPackAndUnpackOps.cpp
+++ b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/FoldIntoPackAndUnpackOps.cpp
@@ -53,9 +53,8 @@
     Value output = rewriter.create<tensor::EmptyOp>(
         sliceOp.getLoc(), sliceOp.getMixedSizes(), elementType);
     rewriter.replaceOpWithNewOp<UnPackOp>(
-        sliceOp, output.getType(), unpackOp.getInput(), output,
-        unpackOp.getOuterDimsPerm(), unpackOp.getInnerDimsPos(),
-        unpackOp.getInnerTiles(), unpackOp.getStaticInnerTiles());
+        sliceOp, unpackOp.getInput(), output, unpackOp.getInnerDimsPos(),
+        unpackOp.getMixedTiles(), unpackOp.getOuterDimsPerm());
     return success();
   }
 };
diff --git a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp
index 66978aa..fc22757 100644
--- a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp
+++ b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp
@@ -22,35 +22,6 @@
 using namespace mlir::iree_compiler::IREE::LinalgExt;
 
 //===---------------------------------------------------------------------===//
-// Methods to convert the encoding to parameters of the Pack operation
-//===---------------------------------------------------------------------===//
-
-/// Given the `encoding` return the `MaterializeEncodingInfo` to use for
-/// materializing the pack op.
-// TODO(ravishankarm): THis is currently hard-coded here for convenience. When
-// used in IREE, this will be computed based on the architecture information in
-// `hal.executable.variant`.
-static FailureOr<MaterializeEncodingInfo>
-getPackOpInfoFromEncoding(TensorEncoding encoding) {
-  switch (encoding) {
-  case TensorEncoding::GEMM_LHS:
-    return MaterializeEncodingInfo{{0, 1}, {8, 4}, {}};
-    break;
-  case TensorEncoding::GEMM_RHS:
-    return MaterializeEncodingInfo{{0, 1}, {4, 8}, {}};
-    break;
-  case TensorEncoding::GEMM_RESULT:
-    return MaterializeEncodingInfo{{0, 1}, {8, 8}, {}};
-    break;
-  case TensorEncoding::GEMM_RHS_TRANSPOSE:
-    return MaterializeEncodingInfo{{1, 0}, {8, 4}, {1, 0}};
-    break;
-  default:
-    return failure();
-  }
-}
-
-//===---------------------------------------------------------------------===//
 // Utility methods
 //===---------------------------------------------------------------------===//
 
@@ -72,7 +43,7 @@
   if (!encoding)
     return tensorType;
   FailureOr<MaterializeEncodingInfo> materializeEncodingInfo =
-      materializeEncodingFn(encoding.value());
+      materializeEncodingFn(tensorType);
   if (failed(materializeEncodingInfo)) {
     return tensorType;
   }
@@ -94,6 +65,48 @@
 }
 
 //===---------------------------------------------------------------------===//
+// Methods to convert the encoding to parameters of the Pack operation
+//===---------------------------------------------------------------------===//
+
+/// Given the `encoding` return the `MaterializeEncodingInfo` to use for
+/// materializing the pack op.
+// TODO(ravishankarm): This is currently hard-coded here for convenience. When
+// used in IREE, this will be computed based on the architecture information in
+// `hal.executable.variant`.
+// A real implementation would return tile sizes that depend on at least the
+// `tensorType`'s element type (e.g. different tile sizes for i8 vs f32, because
+// the SIMD instructions may have different shapes).
+// Moreover, in a real implementation, the tile sizes would typically also
+// depend on target information. This is demonstrated in
+// iree/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingPass.cpp
+static FailureOr<MaterializeEncodingInfo>
+chooseEncodingInfo(RankedTensorType tensorType) {
+  Optional<TensorEncoding> encoding = getEncoding(tensorType);
+  if (!encoding)
+    return failure();
+  switch (*encoding) {
+  case TensorEncoding::MATMUL_F32F32F32_LHS:
+  case TensorEncoding::MATMUL_I8I8I32_LHS:
+    return MaterializeEncodingInfo{{0, 1}, {8, 4}, {}};
+    break;
+  case TensorEncoding::MATMUL_F32F32F32_RHS:
+  case TensorEncoding::MATMUL_I8I8I32_RHS:
+    return MaterializeEncodingInfo{{0, 1}, {4, 8}, {}};
+    break;
+  case TensorEncoding::MATMUL_F32F32F32_RHS_TRANSPOSE:
+  case TensorEncoding::MATMUL_I8I8I32_RHS_TRANSPOSE:
+    return MaterializeEncodingInfo{{1, 0}, {8, 4}, {1, 0}};
+    break;
+  case TensorEncoding::MATMUL_F32F32F32_RESULT:
+  case TensorEncoding::MATMUL_I8I8I32_RESULT:
+    return MaterializeEncodingInfo{{0, 1}, {8, 8}, {}};
+    break;
+  default:
+    return failure();
+  }
+}
+
+//===---------------------------------------------------------------------===//
 // Methods to convert `set_encoding` and `unset_encoding` operations
 // to `pack` and `unpack` operations respectively.
 //===---------------------------------------------------------------------===//
@@ -121,8 +134,9 @@
 lowerSetEncodingOpToPackOp(RewriterBase &rewriter, SetEncodingOp encodingOp,
                            Value source,
                            MaterializeEncodingFn materializeEncodingFn) {
+  RankedTensorType resultType = encodingOp.getResultType();
   FailureOr<MaterializeEncodingInfo> materializeEncodingInfo =
-      materializeEncodingFn(encodingOp.getResultTensorEncoding());
+      materializeEncodingFn(resultType);
   if (failed(materializeEncodingInfo)) {
     return rewriter.notifyMatchFailure(encodingOp, "unhandled result encoding");
   }
@@ -137,7 +151,7 @@
                              materializeEncodingInfo->innerDimsPos,
                              materializeEncodingInfo->outerDimsPerm);
   auto initTensor = rewriter.create<tensor::EmptyOp>(
-      loc, resultDims, encodingOp.getSourceType().getElementType());
+      loc, resultDims, resultType.getElementType());
   Optional<Value> paddingValue = getPaddingValue(source);
   return rewriter.create<PackOp>(
       loc, source, initTensor, materializeEncodingInfo->innerDimsPos,
@@ -151,8 +165,9 @@
 lowerUnsetEncodingToUnpackOp(RewriterBase &rewriter, UnsetEncodingOp encodingOp,
                              Value packedValue,
                              MaterializeEncodingFn materializeEncodingFn) {
+  RankedTensorType sourceType = encodingOp.getSourceType();
   FailureOr<MaterializeEncodingInfo> materializeEncodingInfo =
-      materializeEncodingFn(encodingOp.getSourceTensorEncoding());
+      materializeEncodingFn(sourceType);
   if (failed(materializeEncodingInfo)) {
     return rewriter.notifyMatchFailure(encodingOp, "unhandled source encoding");
   }
@@ -161,7 +176,7 @@
   SmallVector<OpFoldResult> resultDims =
       getDims(rewriter, loc, encodingOp.getSource());
   auto initTensor = rewriter.create<tensor::EmptyOp>(
-      loc, resultDims, encodingOp.getResultType().getElementType());
+      loc, resultDims, sourceType.getElementType());
 
   SmallVector<OpFoldResult> innerTileSizesOfr =
       getAsOpFoldResult(rewriter, materializeEncodingInfo->innerTileSizes);
@@ -171,9 +186,9 @@
 }
 
 /// Utility method to convert from `linalg.matmul` with
-/// - lhs encoding of GEMM_LHS
-/// - rhs encoding of GEMM_RHS_TRANSPOSE
-/// - result encoding of GEMM_RESULT
+/// - lhs encoding of MATMUL_*_LHS
+/// - rhs encoding of MATMUL_*_RHS_TRANSPOSE
+/// - result encoding of MATMUL_*_RESULT
 /// to linalg.mmt4d op.
 static FailureOr<Operation *>
 lowerOpWithEncoding(RewriterBase &rewriter, linalg::MatmulOp matmulOp,
@@ -190,11 +205,15 @@
       getEncoding(inputs[1]->get().getType().cast<RankedTensorType>());
   Optional<TensorEncoding> resultEncoding =
       getEncoding(outputs[0]->get().getType().cast<RankedTensorType>());
-  if (!lhsEncoding || lhsEncoding.value() != TensorEncoding::GEMM_LHS ||
+  if (!lhsEncoding ||
+      (lhsEncoding.value() != TensorEncoding::MATMUL_F32F32F32_LHS &&
+       lhsEncoding.value() != TensorEncoding::MATMUL_I8I8I32_LHS) ||
       !rhsEncoding ||
-      rhsEncoding.value() != TensorEncoding::GEMM_RHS_TRANSPOSE ||
+      (rhsEncoding.value() != TensorEncoding::MATMUL_F32F32F32_RHS_TRANSPOSE &&
+       rhsEncoding.value() != TensorEncoding::MATMUL_I8I8I32_RHS_TRANSPOSE) ||
       !resultEncoding ||
-      resultEncoding.value() != TensorEncoding::GEMM_RESULT) {
+      (resultEncoding.value() != TensorEncoding::MATMUL_F32F32F32_RESULT &&
+       resultEncoding.value() != TensorEncoding::MATMUL_I8I8I32_RESULT)) {
     return failure();
   }
   Operation *mmt4DOp = rewriter.create<linalg::Mmt4DOp>(
@@ -225,13 +244,8 @@
                     ValueRange convertedOperands,
                     MaterializeEncodingFn materializeEncodingFn) {
   auto resultType = emptyOp.getResult().getType().cast<RankedTensorType>();
-  Optional<TensorEncoding> encoding = getEncoding(resultType);
-  if (!encoding) {
-    return rewriter.notifyMatchFailure(emptyOp,
-                                       "result type does not have encoding");
-  }
   FailureOr<MaterializeEncodingInfo> materializeEncodingInfo =
-      materializeEncodingFn(encoding.value());
+      materializeEncodingFn(resultType);
   if (failed(materializeEncodingInfo)) {
     return rewriter.notifyMatchFailure(
         emptyOp, "failed to find materialization info for result type");
@@ -360,12 +374,12 @@
   MLIRContext *context = &getContext();
 
   {
+    Operation *op = getOperation();
     RewritePatternSet patterns(context);
-    MaterializeEncodingTypeConverter typeConverter(getPackOpInfoFromEncoding);
+    MaterializeEncodingTypeConverter typeConverter(chooseEncodingInfo);
     MaterializeEncodingConversionTarget target(*context);
     populateMaterializeEncodingPatterns(patterns, target, typeConverter);
-    if (failed(applyPartialConversion(getOperation(), target,
-                                      std::move(patterns))))
+    if (failed(applyPartialConversion(op, target, std::move(patterns))))
       return signalPassFailure();
   }
 
@@ -391,6 +405,7 @@
   addConversion([](IntegerType intType) { return intType; });
   addConversion([](IndexType indexType) { return indexType; });
   addConversion([](FloatType floatType) { return floatType; });
+  addConversion([](MemRefType memrefType) { return memrefType; });
   addConversion(
       [materializeEncodingFn](RankedTensorType t) -> RankedTensorType {
         return getMaterializedType(t, materializeEncodingFn);
diff --git a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeWinogradPass.cpp b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeWinogradPass.cpp
new file mode 100644
index 0000000..7214789
--- /dev/null
+++ b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeWinogradPass.cpp
@@ -0,0 +1,381 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h"
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "iree-dialects/Dialect/LinalgExt/Passes/PassDetail.h"
+#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h"
+#include "iree-dialects/Dialect/LinalgExt/Utils/Utils.h"
+#include "iree-dialects/Dialect/LinalgExt/Utils/WinogradConstants.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/SCF/Transforms/Transforms.h"
+#include "mlir/Dialect/Tensor/Utils/Utils.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/Support/Debug.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace LinalgExt {
+
+namespace {
+
+static void computeLoopParams(SmallVectorImpl<Value> &lbs,
+                              SmallVectorImpl<Value> &ubs,
+                              SmallVectorImpl<Value> &steps, Value tensor,
+                              int numImageDims, Location loc,
+                              OpBuilder &builder) {
+  Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+  Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
+  SmallVector<OpFoldResult> dimValues =
+      tensor::createDimValues(builder, loc, tensor);
+  for (int i = numImageDims; i < dimValues.size(); i++) {
+    lbs.push_back(zero);
+    ubs.push_back(getValueOrCreateConstantIndexOp(builder, loc, dimValues[i]));
+    steps.push_back(one);
+  }
+}
+
+class ReifyWinogradInputTransform final
+    : public OpRewritePattern<WinogradInputTransformOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(WinogradInputTransformOp inputOp,
+                                PatternRewriter &rewriter) const override {
+    Location loc = inputOp.getLoc();
+    auto funcOp = inputOp->getParentOfType<func::FuncOp>();
+    if (!funcOp) {
+      return rewriter.notifyMatchFailure(
+          inputOp, "Could not find parent of type funcOp");
+    }
+
+    const float *BT{nullptr};
+    const float *B{nullptr};
+    const int64_t inputTileSize = inputOp.getInputTileSize();
+    const int64_t outputTileSize = inputOp.getOutputTileSize();
+    switch (outputTileSize) {
+    case 6:
+      B = IREE::LinalgExt::Winograd::B_6x6_3x3;
+      BT = IREE::LinalgExt::Winograd::BT_6x6_3x3;
+      break;
+    default:
+      return failure();
+    }
+    /// The two values below are the transpose(B) [BTV]
+    /// and B [BV] constant matrices that convert the input
+    /// tile to the Winograd domain.
+    Value BTV = IREE::LinalgExt::createValueFrom2DConstant(
+        BT, inputTileSize, inputTileSize, loc, rewriter);
+    Value BV = IREE::LinalgExt::createValueFrom2DConstant(
+        B, inputTileSize, inputTileSize, loc, rewriter);
+
+    Value input = inputOp.input();
+    Value output = inputOp.output();
+    auto outputType = output.getType().cast<ShapedType>();
+    auto inputType = input.getType().cast<ShapedType>();
+    ArrayRef<int64_t> inputShape = inputType.getShape();
+    Type elementType = outputType.getElementType();
+    SmallVector<int64_t> imageDims = inputOp.imageDimensions();
+    const size_t numImageDims = imageDims.size();
+    llvm::SmallSetVector<int64_t, 2> imageDimsSet(imageDims.begin(),
+                                                  imageDims.end());
+    SmallVector<int64_t> inputTileSquare(imageDims.size(), inputTileSize);
+
+    rewriter.setInsertionPointToStart(&funcOp.getBody().front());
+    Value zeroF32 = rewriter.create<arith::ConstantOp>(
+        loc, rewriter.getZeroAttr(elementType));
+    Value scratch =
+        rewriter.create<tensor::EmptyOp>(loc, inputTileSquare, elementType);
+
+    rewriter.setInsertionPoint(inputOp);
+    SmallVector<Value> lbs, ubs, steps;
+    computeLoopParams(lbs, ubs, steps, output, numImageDims, loc, rewriter);
+    // Construct loops
+    scf::LoopNest loopNest = scf::buildLoopNest(
+        rewriter, loc, lbs, ubs, steps, ValueRange({output}),
+        [&](OpBuilder &nestedBuilder, Location loc, ValueRange outputIvs,
+            ValueRange iterArgs) -> scf::ValueVector { return {iterArgs[0]}; });
+
+    // Extract input slice
+    auto one = rewriter.getIndexAttr(1);
+    auto zero = rewriter.getIndexAttr(0);
+    auto inputTileSizeAttr = rewriter.getIndexAttr(inputTileSize);
+    SmallVector<OpFoldResult> strides(inputOp.getInputOperandRank(), one);
+    SmallVector<OpFoldResult> sizes(inputOp.getInputOperandRank(), one);
+    SmallVector<OpFoldResult> offsets(inputOp.getInputOperandRank(), zero);
+    SmallVector<Value> ivs;
+    for (scf::ForOp loop : loopNest.loops) {
+      ivs.push_back(loop.getInductionVar());
+    }
+    for (int i = 0; i < inputShape.size(); i++) {
+      if (!imageDimsSet.contains(i)) {
+        offsets[i] = ivs[i];
+      } else {
+        rewriter.setInsertionPointToStart(loopNest.loops[i].getBody());
+        AffineExpr dim0;
+        auto it = rewriter.getAffineConstantExpr(inputTileSize);
+        auto ot = rewriter.getAffineConstantExpr(outputTileSize);
+        auto delta = rewriter.getAffineConstantExpr(inputShape[i]);
+        bindDims(rewriter.getContext(), dim0);
+        AffineMap scaleMap =
+            AffineMap::get(1, 0, {dim0 * ot}, rewriter.getContext());
+        offsets[i] = rewriter.createOrFold<AffineApplyOp>(loc, scaleMap,
+                                                          ValueRange{ivs[i]});
+        AffineMap minMap =
+            AffineMap::get(1, 0, {-dim0 + delta, it}, rewriter.getContext());
+        sizes[i] = rewriter.createOrFold<AffineMinOp>(
+            loc, minMap,
+            ValueRange{
+                getValueOrCreateConstantIndexOp(rewriter, loc, offsets[i])});
+      }
+    }
+    rewriter.setInsertionPointToStart(loopNest.loops.back().getBody());
+    auto tensorType = RankedTensorType::get(
+        SmallVector<int64_t>(numImageDims, ShapedType::kDynamic), elementType);
+    Value dynamicSlice = rewriter.create<tensor::ExtractSliceOp>(
+        loc, tensorType, input, offsets, sizes, strides);
+
+    // Copy input slice into zeroed padded scratch space
+    strides = SmallVector<OpFoldResult>(numImageDims, one);
+    offsets = SmallVector<OpFoldResult>(numImageDims, zero);
+    sizes = SmallVector<OpFoldResult>{sizes[1], sizes[2]};
+    linalg::FillOp fillOp = rewriter.create<linalg::FillOp>(
+        loc, ValueRange{zeroF32}, ValueRange{scratch});
+    Value inputSlice = rewriter.create<tensor::InsertSliceOp>(
+        loc, dynamicSlice, fillOp.result(), offsets, sizes, strides);
+
+    // Extract output slice
+    strides = SmallVector<OpFoldResult>(inputOp.getOutputOperandRank(), one);
+    offsets = SmallVector<OpFoldResult>(numImageDims, zero);
+    offsets.append(ivs.begin(), ivs.end());
+    sizes = SmallVector<OpFoldResult>(inputOp.getOutputOperandRank(), one);
+    sizes[0] = sizes[1] = inputTileSizeAttr;
+    tensorType = RankedTensorType::get(inputTileSquare, elementType);
+    Value iterArg = loopNest.loops.back().getRegionIterArg(0);
+    Value outputSlice = rewriter.create<tensor::ExtractSliceOp>(
+        loc, tensorType, iterArg, offsets, sizes, strides);
+
+    // Create computation
+    Value result, AMatrix, BMatrix;
+    linalg::MatmulOp matmulOp;
+    for (int i = 0; i < 2; i++) {
+      fillOp = rewriter.create<linalg::FillOp>(loc, ValueRange{zeroF32},
+                                               ValueRange{outputSlice});
+      if (i == 0) {
+        AMatrix = inputSlice;
+        BMatrix = BV;
+      } else {
+        AMatrix = BTV;
+        BMatrix = result;
+      }
+      matmulOp = rewriter.create<linalg::MatmulOp>(
+          loc, tensorType, ValueRange{AMatrix, BMatrix}, fillOp.result());
+      result = matmulOp.getResult(0);
+    }
+
+    // Insert results into output slice
+    Value updatedOutput = rewriter.create<tensor::InsertSliceOp>(
+        loc, result, iterArg, offsets, sizes, strides);
+
+    // Replace returned value
+    if (scf::YieldOp yieldOp = dyn_cast<scf::YieldOp>(
+            loopNest.loops.back().getBody()->getTerminator())) {
+      rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, updatedOutput);
+    }
+    inputOp.getResults()[0].replaceAllUsesWith(loopNest.getResults()[0]);
+    return success();
+  }
+};
+
+} // namespace
+
+namespace {
+
+class ReifyWinogradOutputTransform final
+    : public OpRewritePattern<WinogradOutputTransformOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(WinogradOutputTransformOp outputOp,
+                                PatternRewriter &rewriter) const override {
+    Location loc = outputOp.getLoc();
+    auto funcOp = outputOp->getParentOfType<func::FuncOp>();
+    if (!funcOp) {
+      return rewriter.notifyMatchFailure(
+          outputOp, "Could not find parent of type funcOp");
+    }
+
+    const float *AT{nullptr};
+    const float *A{nullptr};
+    const int64_t inputTileSize = outputOp.getInputTileSize();
+    const int64_t outputTileSize = outputOp.getOutputTileSize();
+    switch (outputTileSize) {
+    case 6:
+      A = IREE::LinalgExt::Winograd::A_6x6_3x3;
+      AT = IREE::LinalgExt::Winograd::AT_6x6_3x3;
+      break;
+    default:
+      return failure();
+    }
+    /// The two values below are the transpose(A) [ATV]
+    /// and A [AV] constant matrices that convert the output
+    /// tile from the Winograd domain to the original domain.
+    Value ATV = IREE::LinalgExt::createValueFrom2DConstant(
+        AT, outputTileSize, inputTileSize, loc, rewriter);
+    Value AV = IREE::LinalgExt::createValueFrom2DConstant(
+        A, inputTileSize, outputTileSize, loc, rewriter);
+
+    Value input = outputOp.input();
+    Value output = outputOp.output();
+    auto outputType = output.getType().cast<ShapedType>();
+    ArrayRef<int64_t> outputShape = outputType.getShape();
+    Type elementType = outputType.getElementType();
+    SmallVector<int64_t> imageDims = outputOp.imageDimensions();
+    const size_t numImageDims = imageDims.size();
+    llvm::SmallSetVector<int64_t, 2> imageDimsSet(imageDims.begin(),
+                                                  imageDims.end());
+    SmallVector<int64_t> inputTileSquare(imageDims.size(), inputTileSize);
+
+    rewriter.setInsertionPointToStart(&funcOp.getBody().front());
+    Value zeroF32 = rewriter.create<arith::ConstantOp>(
+        loc, rewriter.getZeroAttr(elementType));
+    SmallVector<int64_t> scratchShape = {inputTileSize, outputTileSize};
+    Value scratch =
+        rewriter.create<tensor::EmptyOp>(loc, scratchShape, elementType);
+
+    rewriter.setInsertionPoint(outputOp);
+    SmallVector<Value> lbs, ubs, steps;
+    computeLoopParams(lbs, ubs, steps, input, numImageDims, loc, rewriter);
+    // Construct loops
+    scf::LoopNest loopNest = scf::buildLoopNest(
+        rewriter, loc, lbs, ubs, steps, ValueRange({output}),
+        [&](OpBuilder &nestedBuilder, Location loc, ValueRange outputIvs,
+            ValueRange iterArgs) -> scf::ValueVector { return {iterArgs[0]}; });
+
+    // Extract input slice
+    rewriter.setInsertionPointToStart(loopNest.loops.back().getBody());
+    auto one = rewriter.getIndexAttr(1);
+    auto zero = rewriter.getIndexAttr(0);
+    auto inputTileSizeAttr = rewriter.getIndexAttr(inputTileSize);
+    auto outputTileSizeAttr = rewriter.getIndexAttr(outputTileSize);
+    SmallVector<OpFoldResult> strides(outputOp.getInputOperandRank(), one);
+    SmallVector<OpFoldResult> sizes(outputOp.getInputOperandRank(), one);
+    SmallVector<OpFoldResult> offsets(numImageDims, zero);
+    sizes[0] = sizes[1] = inputTileSizeAttr;
+    SmallVector<Value> ivs;
+    for (scf::ForOp loop : loopNest.loops) {
+      ivs.push_back(loop.getInductionVar());
+    }
+    offsets.append(ivs.begin(), ivs.end());
+    auto tensorType = RankedTensorType::get(inputTileSquare, elementType);
+    tensor::ExtractSliceOp extractSliceOp =
+        rewriter.create<tensor::ExtractSliceOp>(loc, tensorType, input, offsets,
+                                                sizes, strides);
+    Value inputSlice = extractSliceOp.getResult();
+
+    // Extract output slice
+    strides = SmallVector<OpFoldResult>(outputOp.getOutputOperandRank(), one);
+    offsets = SmallVector<OpFoldResult>(outputOp.getOutputOperandRank(), zero);
+    sizes = SmallVector<OpFoldResult>(outputOp.getOutputOperandRank(), one);
+    for (int i = 0; i < outputShape.size(); i++) {
+      if (!imageDimsSet.contains(i)) {
+        offsets[i] = ivs[i];
+      } else {
+        rewriter.setInsertionPointToStart(loopNest.loops[i].getBody());
+        AffineExpr dim0;
+        auto ot = rewriter.getAffineConstantExpr(outputTileSize);
+        bindDims(rewriter.getContext(), dim0);
+        AffineMap scaleMap =
+            AffineMap::get(1, 0, {dim0 * ot}, rewriter.getContext());
+        offsets[i] = rewriter.createOrFold<AffineApplyOp>(loc, scaleMap,
+                                                          ValueRange{ivs[i]});
+        sizes[i] = outputTileSizeAttr;
+      }
+    }
+    rewriter.setInsertionPointAfter(extractSliceOp);
+    tensorType = RankedTensorType::get(
+        SmallVector<int64_t>(numImageDims, outputTileSize), elementType);
+    Value iterArg = loopNest.loops.back().getRegionIterArg(0);
+    Value outputSlice = rewriter.create<tensor::ExtractSliceOp>(
+        loc, tensorType, iterArg, offsets, sizes, strides);
+
+    // Create computation
+    Value result, AMatrix, BMatrix;
+    linalg::MatmulOp matmulOp;
+    linalg::FillOp fillOp;
+    Value tmp;
+    for (int i = 0; i < 2; i++) {
+      tmp = i == 0 ? scratch : outputSlice;
+      fillOp = rewriter.create<linalg::FillOp>(loc, ValueRange{zeroF32},
+                                               ValueRange{tmp});
+      if (i == 0) {
+        AMatrix = inputSlice;
+        BMatrix = AV;
+      } else {
+        AMatrix = ATV;
+        BMatrix = result;
+      }
+      matmulOp = rewriter.create<linalg::MatmulOp>(
+          loc, tmp.getType(), ValueRange{AMatrix, BMatrix}, fillOp.result());
+      result = matmulOp.getResult(0);
+    }
+
+    // Insert results into output slice
+    Value updatedOutput = rewriter.create<tensor::InsertSliceOp>(
+        loc, result, iterArg, offsets, sizes, strides);
+
+    // Replace returned value
+    if (scf::YieldOp yieldOp = dyn_cast<scf::YieldOp>(
+            loopNest.loops.back().getBody()->getTerminator())) {
+      rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, updatedOutput);
+    }
+    outputOp.getResults()[0].replaceAllUsesWith(loopNest.getResults()[0]);
+    return success();
+  }
+};
+
+} // namespace
+
+namespace {
+struct TileAndDecomposeWinogradTransformPass
+    : public TileAndDecomposeWinogradTransformBase<
+          TileAndDecomposeWinogradTransformPass> {
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<AffineDialect, IREE::LinalgExt::IREELinalgExtDialect,
+                    linalg::LinalgDialect, scf::SCFDialect,
+                    tensor::TensorDialect>();
+  }
+
+  void runOnOperation() override;
+};
+} // namespace
+
+void TileAndDecomposeWinogradTransformPass::runOnOperation() {
+  MLIRContext *context = &getContext();
+  RewritePatternSet patterns(&getContext());
+  patterns.insert<ReifyWinogradInputTransform, ReifyWinogradOutputTransform>(
+      context);
+  if (failed(
+          applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) {
+    return signalPassFailure();
+  }
+}
+
+std::unique_ptr<OperationPass<func::FuncOp>>
+createTileAndDecomposeWinogradTransformPass() {
+  return std::make_unique<TileAndDecomposeWinogradTransformPass>();
+}
+
+} // namespace LinalgExt
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/Tiling.cpp b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/Tiling.cpp
index f4067df..7c4cd79 100644
--- a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/Tiling.cpp
+++ b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/Tiling.cpp
@@ -402,6 +402,11 @@
           StringAttr::get(context, "tiling_repeated_indices_scatter_input"),
           StringAttr::get(context, "tiling_repeated_indices_scatter_output")));
 
+  patterns.add<TilingInterfaceTilingPattern>(
+      context, linalg::LinalgTilingOptions().setTileSizes({1, 32}),
+      IREE::LinalgExt::LinalgTransformationFilter(
+          StringAttr::get(context, "tiling_winograd_input_nhwc")));
+
   if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
     return signalPassFailure();
   }
diff --git a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp
index 58e892f..725d2ad 100644
--- a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp
+++ b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp
@@ -107,96 +107,6 @@
   return tileLoopNest;
 }
 
-namespace {
-///
-/// Linalg tile and fuse tensor ops pattern.
-///
-/// Apply tiling and fusion as a pattern.
-/// See `tileConsumerAndFuseProducers` for more details.
-struct LinalgTileAndFuseTensorOpsBasePattern : public RewritePattern {
-  // Entry point to match any LinalgOp.
-  LinalgTileAndFuseTensorOpsBasePattern(
-      MLIRContext *context, linalg::LinalgTilingAndFusionOptions options,
-      PatternBenefit benefit = 1)
-      : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
-        options(std::move(options)) {}
-  // Entry point to match a specific LinalgOp.
-  LinalgTileAndFuseTensorOpsBasePattern(
-      StringRef opName, MLIRContext *context,
-      linalg::LinalgTilingAndFusionOptions options, PatternBenefit benefit = 1)
-      : RewritePattern(opName, benefit, context), options(std::move(options)) {}
-
-  /// `matchAndRewrite` implementation that returns the significant transformed
-  /// pieces of IR.
-  FailureOr<linalg::TileLoopNest>
-  returningMatchAndRewrite(Operation *op, PatternRewriter &rewriter) const;
-
-  LogicalResult matchAndRewrite(Operation *op,
-                                PatternRewriter &rewriter) const override {
-    return returningMatchAndRewrite(op, rewriter);
-  }
-
-private:
-  /// Tile sizes and interchange used to tile the root operation.
-  linalg::LinalgTilingAndFusionOptions options;
-};
-} // namespace
-
-FailureOr<mlir::linalg::TileLoopNest>
-LinalgTileAndFuseTensorOpsBasePattern::returningMatchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
-  linalg::LinalgOp rootOp = dyn_cast<linalg::LinalgOp>(op);
-  if (!rootOp)
-    return failure();
-
-  // Check `tileSizes` contains a tile size for every `rootOp` loop dimension.
-  if (options.tileSizes.size() < rootOp.getNumLoops())
-    return rewriter.notifyMatchFailure(op, "expect #tile sizes >= #loops");
-
-  // Check `tileInterchange` contains no entries or as many as `tileSizes`.
-  if (!options.tileInterchange.empty() &&
-      options.tileInterchange.size() != options.tileSizes.size())
-    return rewriter.notifyMatchFailure(
-        op, "expect the number of tile sizes and interchange dims to match");
-
-  // Copy the `tileSizes` and `tileInterchange` prefixes needed for `rootOp`.
-  SmallVector<int64_t> rootTileSizes(options.tileSizes.begin(),
-                                     options.tileSizes.begin() +
-                                         rootOp.getNumLoops());
-  SmallVector<int64_t> rootInterchange =
-      options.tileInterchange.empty()
-          ? llvm::to_vector<6>(llvm::seq<int64_t>(0, rootOp.getNumLoops()))
-          : SmallVector<int64_t>(options.tileInterchange.begin(),
-                                 options.tileInterchange.begin() +
-                                     rootOp.getNumLoops());
-
-  // Check `rootTileSizes` contains non-zero tile sizes.
-  if (llvm::count(rootTileSizes, 0) == static_cast<long>(rootTileSizes.size()))
-    return rewriter.notifyMatchFailure(
-        op, "expect at least one non-zero tile size");
-
-  // Check `rootInterchange` is a permutation of the `rootOp` loop dimensions.
-  // It has to be a permutation since the tiling cannot tile the same loop
-  // dimension multiple times.
-  if (!linalg::isPermutation(rootInterchange))
-    return rewriter.notifyMatchFailure(
-        op, "expect the tile interchange permutes the root loops");
-
-  // Tile `rootOp` and fuse its producers.
-  FailureOr<linalg::TileLoopNest> tileLoopNest =
-      IREE::LinalgExt::tileConsumerAndFuseProducers(
-          rewriter, rootOp, rootTileSizes, rootInterchange,
-          options.tileDistribution);
-  if (failed(tileLoopNest))
-    return rewriter.notifyMatchFailure(
-        op, "tileConsumerAndFuseProducers failed unexpectedly");
-
-  // Replace all uses of the tiled loop operation.
-  rootOp->replaceAllUsesWith(tileLoopNest->getRootOpReplacementResults());
-
-  return tileLoopNest;
-}
-
 /// Peel loops after tiling.
 static void peelTiledLinalgOp(RewriterBase &rewriter,
                               linalg::TiledLinalgOp &res,
@@ -256,6 +166,83 @@
   return res;
 }
 
+/// Linalg SCF tiling pattern.
+LinalgSCFTilingPattern::LinalgSCFTilingPattern(
+    MLIRContext *context, scf::SCFTilingOptions options,
+    LinalgExt::LinalgTransformationFilter f, PatternBenefit benefit)
+    : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
+      filter(std::move(f)), options(std::move(options)) {}
+
+LinalgSCFTilingPattern::LinalgSCFTilingPattern(
+    StringRef opName, MLIRContext *context, scf::SCFTilingOptions options,
+    LinalgExt::LinalgTransformationFilter f, PatternBenefit benefit)
+    : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
+      filter(f.addOpNameFilter(opName)), options(std::move(options)) {}
+
+LogicalResult LinalgSCFTilingPattern::returningMatchAndRewrite(
+    TilingInterface op, PatternRewriter &rewriter) const {
+  if (failed(filter.checkAndNotify(rewriter, op)))
+    return failure();
+
+  FailureOr<scf::SCFTilingResult> tiledResults =
+      scf::tileUsingSCFForOp(rewriter, op, options);
+  if (failed(tiledResults))
+    return failure();
+
+  rewriter.replaceOp(op, tiledResults->replacements);
+
+  for (auto tiledOp : tiledResults->tiledOps) {
+    filter.replaceLinalgTransformationFilter(rewriter, tiledOp);
+  }
+
+  return success();
+}
+
+/// Linalg tile and fuse tensor ops pattern.
+LinalgSCFTileAndFusePattern::LinalgSCFTileAndFusePattern(
+    MLIRContext *context, scf::SCFTileAndFuseOptions options,
+    LinalgExt::LinalgTransformationFilter f, PatternBenefit benefit)
+    : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
+      filter(std::move(f)), options(std::move(options)) {}
+
+LinalgSCFTileAndFusePattern::LinalgSCFTileAndFusePattern(
+    StringRef opName, MLIRContext *context, scf::SCFTileAndFuseOptions options,
+    LinalgExt::LinalgTransformationFilter f, PatternBenefit benefit)
+    : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
+      filter(f.addOpNameFilter(opName)), options(std::move(options)) {}
+
+LogicalResult
+LinalgSCFTileAndFusePattern::matchAndRewrite(TilingInterface op,
+                                             PatternRewriter &rewriter) const {
+  if (failed(filter.checkAndNotify(rewriter, op)))
+    return failure();
+
+  FailureOr<scf::SCFTileAndFuseResult> tiledResults =
+      tileConsumerAndFuseProducerGreedilyUsingSCFForOp(rewriter, op, options);
+  if (failed(tiledResults))
+    return rewriter.notifyMatchFailure(
+        op,
+        "tileConsumerAndFuseProducerGreedilyUsingSCFForOp failed unexpectedly");
+
+  // Replace all uses of the tiled loop operation.
+  SmallVector<Value> replacements(op->getNumResults());
+  for (auto result : llvm::enumerate(op->getResults())) {
+    auto it = tiledResults->replacements.find(result.value());
+    if (it == tiledResults->replacements.end()) {
+      replacements[result.index()] = result.value();
+    } else {
+      replacements[result.index()] = it->getSecond();
+    }
+  }
+  rewriter.replaceOp(op, replacements);
+
+  // Apply the filter if specified.
+  for (linalg::LinalgOp linalgOp : tiledResults->tiledAndFusedOps)
+    filter.replaceLinalgTransformationFilter(rewriter, linalgOp);
+
+  return success();
+}
+
 LinalgVectorizationPattern::LinalgVectorizationPattern(
     MLIRContext *context, LinalgExt::LinalgTransformationFilter f,
     PatternBenefit benefit)
@@ -346,62 +333,6 @@
   return success();
 }
 
-namespace {
-///
-/// Linalg tile and fuse tensor ops pattern.
-///
-/// Apply tiling and fusion as a pattern.
-/// `filter` controls LinalgTransformMarker matching and update when specified.
-/// See `tileConsumerAndFuseProducers` for more details.
-struct LinalgTileAndFuseTensorOpsPattern : public RewritePattern {
-  // Entry point to match any LinalgOp.
-  LinalgTileAndFuseTensorOpsPattern(
-      MLIRContext *context, linalg::LinalgTilingAndFusionOptions options,
-      LinalgExt::LinalgTransformationFilter f =
-          LinalgExt::LinalgTransformationFilter(),
-      PatternBenefit benefit = 1)
-      : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
-        filter(std::move(f)), options(std::move(options)) {}
-  // Entry point to match a specific LinalgOp.
-  LinalgTileAndFuseTensorOpsPattern(
-      StringRef opName, MLIRContext *context,
-      linalg::LinalgTilingAndFusionOptions options,
-      LinalgExt::LinalgTransformationFilter f =
-          LinalgExt::LinalgTransformationFilter(),
-      PatternBenefit benefit = 1)
-      : RewritePattern(opName, benefit, context), filter(std::move(f)),
-        options(std::move(options)) {}
-
-  /// `matchAndRewrite` implementation that returns the significant transformed
-  /// pieces of IR.
-  FailureOr<linalg::TileLoopNest>
-  returningMatchAndRewrite(Operation *op, PatternRewriter &rewriter) const {
-    if (failed(filter.checkAndNotify(rewriter, op)))
-      return failure();
-    LinalgTileAndFuseTensorOpsBasePattern p(op->getContext(), options);
-    auto maybeTileLoopNest = p.returningMatchAndRewrite(op, rewriter);
-    if (failed(maybeTileLoopNest))
-      return failure();
-    // Apply the filter if specified.
-    for (linalg::LinalgOp linalgOp :
-         maybeTileLoopNest->getAllTiledAndFusedOps())
-      filter.replaceLinalgTransformationFilter(rewriter, linalgOp);
-    return maybeTileLoopNest;
-  }
-
-  LogicalResult matchAndRewrite(Operation *op,
-                                PatternRewriter &rewriter) const override {
-    return returningMatchAndRewrite(op, rewriter);
-  }
-
-private:
-  /// LinalgTransformMarker handles special attribute manipulations.
-  LinalgExt::LinalgTransformationFilter filter;
-  /// Tile sizes and interchange used to tile the root operation.
-  linalg::LinalgTilingAndFusionOptions options;
-};
-} // namespace
-  //
 /// Configurable pass to apply pattern-based tiling and fusion.
 struct LinalgStrategyTileAndFusePass
     : public LinalgStrategyTileAndFusePassBase<LinalgStrategyTileAndFusePass> {
@@ -409,9 +340,9 @@
   LinalgStrategyTileAndFusePass() = default;
 
   LinalgStrategyTileAndFusePass(StringRef opName,
-                                linalg::LinalgTilingAndFusionOptions opt,
+                                scf::SCFTileAndFuseOptions options,
                                 LinalgExt::LinalgTransformationFilter filt)
-      : options(std::move(opt)), filter(std::move(filt)) {
+      : options(std::move(options)), filter(std::move(filt)) {
     this->anchorOpName.setValue(opName.str());
   }
 
@@ -422,10 +353,10 @@
 
     RewritePatternSet tilingAndFusionPattern(funcOp.getContext());
     if (!anchorOpName.empty()) {
-      tilingAndFusionPattern.add<LinalgTileAndFuseTensorOpsPattern>(
+      tilingAndFusionPattern.add<LinalgSCFTileAndFusePattern>(
           anchorOpName, funcOp.getContext(), options, filter);
     } else {
-      tilingAndFusionPattern.add<LinalgTileAndFuseTensorOpsPattern>(
+      tilingAndFusionPattern.add<LinalgSCFTileAndFusePattern>(
           funcOp.getContext(), options, filter);
     }
     // Search the root operation using bottom up traversal.
@@ -435,7 +366,7 @@
         funcOp, std::move(tilingAndFusionPattern), config);
   }
 
-  linalg::LinalgTilingAndFusionOptions options;
+  scf::SCFTileAndFuseOptions options;
   LinalgExt::LinalgTransformationFilter filter;
 };
 
@@ -445,9 +376,9 @@
 
   LinalgStrategyTilePass() = default;
 
-  LinalgStrategyTilePass(StringRef opName, linalg::LinalgTilingOptions opt,
+  LinalgStrategyTilePass(StringRef opName, scf::SCFTilingOptions options,
                          LinalgExt::LinalgTransformationFilter filt)
-      : options(std::move(opt)), filter(std::move(filt)) {
+      : options(std::move(options)), filter(std::move(filt)) {
     this->anchorOpName.setValue(opName.str());
   }
 
@@ -459,16 +390,15 @@
     MLIRContext *ctx = funcOp.getContext();
     RewritePatternSet tilingPattern(ctx);
     if (!anchorOpName.empty())
-      tilingPattern.add<LinalgTilingPattern>(anchorOpName, ctx, options,
-                                             filter);
+      tilingPattern.add<LinalgSCFTilingPattern>(anchorOpName, ctx, options,
+                                                filter);
     else
-      tilingPattern.add<LinalgTilingPattern>(ctx, options, filter);
-    if (anchorOpName == tensor::PadOp::getOperationName())
-      populatePadTensorTilingPatterns(tilingPattern, options);
+      tilingPattern.add<LinalgSCFTilingPattern>(ctx, options, filter);
+
     (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
   }
 
-  linalg::LinalgTilingOptions options;
+  scf::SCFTilingOptions options;
   LinalgExt::LinalgTransformationFilter filter;
 };
 
@@ -757,7 +687,7 @@
 /// Create a LinalgStrategyTileAndFusePass.
 std::unique_ptr<OperationPass<func::FuncOp>>
 createLinalgStrategyTileAndFusePass(
-    StringRef opName, const linalg::LinalgTilingAndFusionOptions &options,
+    StringRef opName, const scf::SCFTileAndFuseOptions &options,
     const LinalgExt::LinalgTransformationFilter &filter) {
   return std::make_unique<LinalgStrategyTileAndFusePass>(opName, options,
                                                          filter);
@@ -765,9 +695,9 @@
 
 /// Create a LinalgStrategyTilePass.
 std::unique_ptr<OperationPass<func::FuncOp>> createLinalgStrategyTilePass(
-    StringRef opName, const linalg::LinalgTilingOptions &opt,
+    StringRef opName, const scf::SCFTilingOptions &options,
     const LinalgExt::LinalgTransformationFilter &filter) {
-  return std::make_unique<LinalgStrategyTilePass>(opName, opt, filter);
+  return std::make_unique<LinalgStrategyTilePass>(opName, options, filter);
 }
 
 /// Create a LinalgStrategyPadPass.
@@ -859,7 +789,7 @@
     // The size is less than or equal to tileSize because outer dims are all 1s.
     Optional<int64_t> tileSize =
         getConstantIntValue(tileAndPosMapping.lookup(dim));
-    assert(tileSize.hasValue() && "dynamic inner tile size is not supported");
+    assert(tileSize.has_value() && "dynamic inner tile size is not supported");
     paddedShape.push_back(tileSize.value());
   }
   auto resultType =
@@ -1022,7 +952,7 @@
     // Apply tiling to make outer dims be all 1s.
     {
       SimpleRewriter rewriter(ctx);
-      auto packTilingOptions =
+      auto packOptions = scf::SCFTileAndFuseOptions().setTilingOptions(
           scf::SCFTilingOptions().setTileSizeComputationFunction(
               [](OpBuilder &builder, Operation *op) {
                 Location loc = op->getLoc();
@@ -1030,15 +960,16 @@
                 SmallVector<Value> tileSizes(
                     inputRank, builder.create<arith::ConstantIndexOp>(loc, 1));
                 return tileSizes;
-              });
+              }));
       auto funcOp = getOperation();
       funcOp->walk([&](LinalgExt::PackOp op) {
-        FailureOr<scf::SCFTilingResult> tilingResult = scf::tileUsingSCFForOp(
-            rewriter, cast<TilingInterface>(op.getOperation()),
-            packTilingOptions);
-        if (failed(tilingResult))
+        FailureOr<scf::SCFTileAndFuseResult> tileAndFuseResult =
+            scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(rewriter, op,
+                                                                  packOptions);
+        if (failed(tileAndFuseResult))
           return signalPassFailure();
-        rewriter.replaceOp(op, tilingResult->replacements);
+        rewriter.replaceOp(op,
+                           tileAndFuseResult->replacements[op.getResult(0)]);
       });
 
       auto unpackTilingOptions =
diff --git a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Utils/Utils.cpp b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Utils/Utils.cpp
index 937e3ee..5eef5b4 100644
--- a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Utils/Utils.cpp
+++ b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Utils/Utils.cpp
@@ -71,6 +71,15 @@
   return interchangeVector;
 }
 
+Value createValueFrom2DConstant(const float *val, int64_t rows, int64_t cols,
+                                Location loc, PatternRewriter &rewriter) {
+  ArrayRef<float> vector(val, rows * cols);
+  SmallVector<int64_t> shape{rows, cols};
+  return rewriter.create<arith::ConstantOp>(
+      loc, DenseFPElementsAttr::get(
+               RankedTensorType::get(shape, rewriter.getF32Type()), vector));
+}
+
 } // namespace LinalgExt
 } // namespace IREE
 } // namespace iree_compiler
diff --git a/integrations/tensorflow/iree-dialects/python/IREEDialectsModule.cpp b/integrations/tensorflow/iree-dialects/python/IREEDialectsModule.cpp
index 3c19ffb..1a85609 100644
--- a/integrations/tensorflow/iree-dialects/python/IREEDialectsModule.cpp
+++ b/integrations/tensorflow/iree-dialects/python/IREEDialectsModule.cpp
@@ -5,7 +5,6 @@
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 #include "iree-dialects-c/Dialects.h"
-#include "iree-dialects-c/Utils.h"
 #include "mlir-c/BuiltinAttributes.h"
 #include "mlir-c/BuiltinTypes.h"
 #include "mlir-c/Diagnostics.h"
@@ -22,28 +21,6 @@
   auto typeClass = irModule.attr("Type");
 
   //===--------------------------------------------------------------------===//
-  // Utils
-  //===--------------------------------------------------------------------===//
-
-  m.def(
-      "lookup_nearest_symbol_from",
-      [](MlirOperation fromOp, MlirAttribute symbol) {
-        if (!mlirAttributeIsASymbolRef(symbol)) {
-          throw std::invalid_argument("expected a SymbolRefAttr");
-        }
-        return ireeLookupNearestSymbolFrom(fromOp, symbol);
-      },
-      py::arg("fromOp"), py::arg("symbol"));
-
-  // TODO: Upstream this into the main Python bindings.
-  m.def(
-      "emit_error",
-      [](MlirLocation loc, std::string message) {
-        mlirEmitError(loc, message.c_str());
-      },
-      py::arg("loc"), py::arg("message"));
-
-  //===--------------------------------------------------------------------===//
   // IREEDialect
   //===--------------------------------------------------------------------===//
   auto iree_m = m.def_submodule("iree_input");
diff --git a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/conv2d_to_winograd.mlir b/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/conv2d_to_winograd.mlir
new file mode 100644
index 0000000..be10a1d
--- /dev/null
+++ b/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/conv2d_to_winograd.mlir
@@ -0,0 +1,77 @@
+// RUN: iree-dialects-opt --split-input-file -iree-linalg-ext-convert-conv2d-to-winograd -mlir-elide-elementsattrs-if-larger=4 %s | FileCheck %s
+
+func.func @conv_16433136(%arg0: tensor<1x16x16x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
+  %c0 = arith.constant dense<0.1> : tensor<3x3x4x16xf32>
+  %0 = linalg.conv_2d_nhwc_hwcf
+    {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+     ins(%arg0, %c0: tensor<1x16x16x4xf32>, tensor<3x3x4x16xf32>)
+    outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
+  return %0 : tensor<1x14x14x16xf32>
+}
+// CHECK:      func.func @conv_16433136(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x16x16x4xf32>, %[[ARG1:[a-zA-Z0-9_]+]]:
+// CHECK-SAME:   tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
+// CHECK:        %[[CST:.+]] = arith.constant dense_resource<__elided__> : tensor<64x4x16xf32>
+// CHECK:        %[[CST_0:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK:        %[[D0:.+]] = tensor.empty() : tensor<8x8x1x3x3x4xf32>
+// CHECK:        %[[D1:.+]] = iree_linalg_ext.winograd.input_transform output_tile_size(6) kernel_size(3)
+// CHECK-SAME:     image_dimensions([1, 2]) ins(%[[ARG0]] : tensor<1x16x16x4xf32>) outs(%[[D0]] :
+// CHECK-SAME:     tensor<8x8x1x3x3x4xf32>) -> tensor<8x8x1x3x3x4xf32>
+// CHECK:        %[[COLLAPSED:.+]] = tensor.collapse_shape %[[D1]]
+// CHECK-SAME{LITERAL}:  [[0, 1], [2, 3, 4], [5]]
+// CHECK-SAME:           tensor<8x8x1x3x3x4xf32> into tensor<64x9x4xf32>
+// CHECK:        %[[D2:.+]] = tensor.empty() : tensor<64x9x16xf32>
+// CHECK:        %[[D3:.+]] = linalg.fill ins(%[[CST_0]] : f32) outs(%[[D2]] : tensor<64x9x16xf32>) ->
+// CHECK-SAME:     tensor<64x9x16xf32>
+// CHECK:        %[[D4:.+]] = linalg.batch_matmul ins(%[[COLLAPSED]], %[[CST]] : tensor<64x9x4xf32>,
+// CHECK-SAME:     tensor<64x4x16xf32>) outs(%[[D3]] : tensor<64x9x16xf32>) -> tensor<64x9x16xf32>
+// CHECK:        %[[EXPANDED:.+]] = tensor.expand_shape %[[D4]]
+// CHECK-SAME{LITERAL}: [[0, 1], [2, 3, 4], [5]]
+// CHECK-SAME:          tensor<64x9x16xf32> into tensor<8x8x1x3x3x16xf32>
+// CHECK:        %[[D5:.+]] = tensor.empty() : tensor<1x18x18x16xf32>
+// CHECK:        %[[D6:.+]] = iree_linalg_ext.winograd.output_transform output_tile_size(6) kernel_size(3)
+// CHECK-SAME:     image_dimensions([1, 2]) ins(%[[EXPANDED]] : tensor<8x8x1x3x3x16xf32>) outs(%[[D5]] :
+// CHECK-SAME:     tensor<1x18x18x16xf32>) -> tensor<1x18x18x16xf32>
+// CHECK:        %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[D6]][0, 0, 0, 0] [1, 14, 14, 16] [1, 1, 1, 1] :
+// CHECK-SAME:     tensor<1x18x18x16xf32> to tensor<1x14x14x16xf32>
+// CHECK:        return %[[EXTRACTED_SLICE]] : tensor<1x14x14x16xf32>
+// CHECK:      }
+
+// -----
+
+func.func @conv2d_non_splat_weights(%inputs : tensor<1x4x4x1xf32>, %arg2: tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32> {
+  %c0 = arith.constant dense<[[ [[1.0]],  [[3.0]],  [[5.0]]  ],
+                              [ [[7.0]],  [[9.0]],  [[11.0]] ],
+                              [ [[13.0]], [[15.0]], [[17.0]] ]]> : tensor<3x3x1x1xf32>
+  %0 = linalg.conv_2d_nhwc_hwcf
+    {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+     ins(%inputs, %c0: tensor<1x4x4x1xf32>, tensor<3x3x1x1xf32>)
+    outs(%arg2: tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32>
+  return %0 : tensor<1x2x2x1xf32>
+}
+// CHECK:      func.func @conv2d_non_splat_weights(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x4x4x1xf32>,
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32> {
+// CHECK:        %[[CST:.+]] = arith.constant dense_resource<__elided__> : tensor<64x1x1xf32>
+// CHECK:        %[[CST_0:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK:        %[[D0:.+]] = tensor.empty() : tensor<8x8x1x1x1x1xf32>
+// CHECK:        %[[D1:.+]] = iree_linalg_ext.winograd.input_transform output_tile_size(6) kernel_size(3)
+// CHECK-SAME:     image_dimensions([1, 2]) ins(%[[ARG0]] : tensor<1x4x4x1xf32>) outs(%[[D0]] : tensor<8x8x1x1x1x1xf32>)
+// CHECK-SAME:     -> tensor<8x8x1x1x1x1xf32>
+// CHECK:        %[[COLLAPSED:.+]] = tensor.collapse_shape %[[D1]]
+// CHECK-SAME{LITERAL}:   [[0, 1], [2, 3, 4], [5]]
+// CHECK-SAME:            tensor<8x8x1x1x1x1xf32> into tensor<64x1x1xf32>
+// CHECK:        %[[D2:.+]] = tensor.empty() : tensor<64x1x1xf32>
+// CHECK:        %[[D3:.+]] = linalg.fill ins(%[[CST_0]] : f32) outs(%[[D2]] : tensor<64x1x1xf32>) ->
+// CHECK-SAME:     tensor<64x1x1xf32>
+// CHECK:        %[[D4:.+]] = linalg.batch_matmul ins(%[[COLLAPSED]], %[[CST]] : tensor<64x1x1xf32>, tensor<64x1x1xf32>)
+// CHECK-SAME:     outs(%[[D3]] : tensor<64x1x1xf32>) -> tensor<64x1x1xf32>
+// CHECK:        %[[EXPANDED:.+]] = tensor.expand_shape %[[D4]]
+// CHECK-SAME{LITERAL}:   [[0, 1], [2, 3, 4], [5]]
+// CHECK-SAME:            tensor<64x1x1xf32> into tensor<8x8x1x1x1x1xf32>
+// CHECK:        %[[D5:.+]] = tensor.empty() : tensor<1x6x6x1xf32>
+// CHECK:        %[[D6:.+]] = iree_linalg_ext.winograd.output_transform output_tile_size(6) kernel_size(3)
+// CHECK-SAME:     image_dimensions([1, 2]) ins(%[[EXPANDED]] : tensor<8x8x1x1x1x1xf32>) outs(%[[D5]] :
+// CHECK-SAME:     tensor<1x6x6x1xf32>) -> tensor<1x6x6x1xf32>
+// CHECK:        %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[D6]][0, 0, 0, 0] [1, 2, 2, 1] [1, 1, 1, 1] :
+// CHECK-SAME:     tensor<1x6x6x1xf32> to tensor<1x2x2x1xf32>
+// CHECK:        return %[[EXTRACTED_SLICE]] : tensor<1x2x2x1xf32>
+// CHECK:      }
diff --git a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/invalid.mlir b/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/invalid.mlir
index 5b9b499..038f293 100644
--- a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/invalid.mlir
+++ b/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/invalid.mlir
@@ -590,9 +590,9 @@
 
 // -----
 
-func.func @illegal_set_encoding_op_with_source_encoding(%arg0 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>) -> tensor<?x?xf32> {
+func.func @illegal_set_encoding_op_with_source_encoding(%arg0 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>) -> tensor<?x?xf32> {
   // expected-error @+1 {{source of set_encoding op cannot have a tensor encoding}}
-  %0 = iree_linalg_ext.set_encoding %arg0: tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>> -> tensor<?x?xf32>
+  %0 = iree_linalg_ext.set_encoding %arg0: tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>> -> tensor<?x?xf32>
   return %0 : tensor<?x?xf32>
 }
 
@@ -606,18 +606,18 @@
 
 // -----
 
-func.func @illegal_set_encoding_op_with_rank_change(%arg0 : tensor<?x?xf32>) -> tensor<?xf32, #iree_linalg_ext.encoding<GEMM_LHS>> {
+func.func @illegal_set_encoding_op_with_rank_change(%arg0 : tensor<?x?xf32>) -> tensor<?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>> {
   // expected-error @+1 {{cannot change the rank of the tensor}}
-  %0 = iree_linalg_ext.set_encoding %arg0: tensor<?x?xf32> -> tensor<?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
-  return %0 : tensor<?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
+  %0 = iree_linalg_ext.set_encoding %arg0: tensor<?x?xf32> -> tensor<?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
+  return %0 : tensor<?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
 }
 
 // -----
 
-func.func @illegal_set_encoding_op_with_shape_change(%arg0 : tensor<10x20xf32>) -> tensor<20x30xf32, #iree_linalg_ext.encoding<GEMM_LHS>> {
+func.func @illegal_set_encoding_op_with_shape_change(%arg0 : tensor<10x20xf32>) -> tensor<20x30xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>> {
   // expected-error @+1 {{expected to preserve the logical shape of the tensor}}
-  %0 = iree_linalg_ext.set_encoding %arg0: tensor<10x20xf32> -> tensor<20x30xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
-  return %0 : tensor<20x30xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
+  %0 = iree_linalg_ext.set_encoding %arg0: tensor<10x20xf32> -> tensor<20x30xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
+  return %0 : tensor<20x30xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
 }
 
 // -----
@@ -630,10 +630,10 @@
 
 // -----
 
-func.func @illegal_unset_encoding_op_with_result_encoding(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>> {
+func.func @illegal_unset_encoding_op_with_result_encoding(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>> {
   // expected-error @+1 {{result of unset_encoding op cannot have a tensor encoding}}
-  %0 = iree_linalg_ext.unset_encoding %arg0: tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
-  return %0 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
+  %0 = iree_linalg_ext.unset_encoding %arg0: tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
+  return %0 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
 }
 
 // -----
@@ -646,16 +646,49 @@
 
 // -----
 
-func.func @illegal_unset_encoding_op_with_rank_change(%arg0 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>) -> tensor<?xf32> {
+func.func @illegal_unset_encoding_op_with_rank_change(%arg0 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>) -> tensor<?xf32> {
   // expected-error @+1 {{cannot change the rank of the tensor}}
-  %0 = iree_linalg_ext.unset_encoding %arg0: tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>> -> tensor<?xf32>
+  %0 = iree_linalg_ext.unset_encoding %arg0: tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>> -> tensor<?xf32>
   return %0 : tensor<?xf32>
 }
 
 // -----
 
-func.func @illegal_unset_encoding_op_with_shape_change(%arg0 : tensor<20x30xf32, #iree_linalg_ext.encoding<GEMM_LHS>>) -> tensor<10x20xf32> {
+func.func @illegal_unset_encoding_op_with_shape_change(%arg0 : tensor<20x30xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>) -> tensor<10x20xf32> {
   // expected-error @+1 {{expected to preserve the logical shape of the tensor}}
-  %0 = iree_linalg_ext.unset_encoding %arg0: tensor<20x30xf32, #iree_linalg_ext.encoding<GEMM_LHS>> -> tensor<10x20xf32>
+  %0 = iree_linalg_ext.unset_encoding %arg0: tensor<20x30xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>> -> tensor<10x20xf32>
   return %0 : tensor<10x20xf32>
 }
+
+// -----
+
+func.func @illegal_winograd_input_shape(%arg0: tensor<1x10x10x32xf32>) -> tensor<8x8x1x6x6x32xf32> {
+  %0 = tensor.empty() : tensor<8x8x1x6x6x32xf32>
+  // expected-error @+1 {{incompatible output shape}}
+  %1 = iree_linalg_ext.winograd.input_transform output_tile_size(6) kernel_size(3) image_dimensions([1, 2])
+    ins(%arg0 : tensor<1x10x10x32xf32>) outs(%0 : tensor<8x8x1x6x6x32xf32>) -> tensor<8x8x1x6x6x32xf32>
+  return %1 : tensor<8x8x1x6x6x32xf32>
+}
+
+// -----
+
+func.func @illegal_winograd_input_rank(%arg0: tensor<1x10x10x32xf32>) -> tensor<8x8x1x6xf32> {
+  %0 = tensor.empty() : tensor<8x8x1x6xf32>
+  // expected-error @+1 {{expected output rank to be equal to input rank + 2}}
+  %1 = iree_linalg_ext.winograd.input_transform output_tile_size(6) kernel_size(3) image_dimensions([1, 2])
+    ins(%arg0 : tensor<1x10x10x32xf32>) outs(%0 : tensor<8x8x1x6xf32>) -> tensor<8x8x1x6xf32>
+  return %1 : tensor<8x8x1x6xf32>
+}
+
+// -----
+
+func.func @illegal_winograd_output_shape(%arg0: tensor<8x8x1x2x2x32xf32>) -> tensor<1x8x8x32xf32> {
+  %0 = tensor.empty() : tensor<1x8x8x32xf32>
+  // expected-error @+1 {{incompatible output shape}}
+  %1 = iree_linalg_ext.winograd.output_transform output_tile_size(6)
+        kernel_size(3) image_dimensions([1, 2])
+        ins(%arg0 : tensor<8x8x1x2x2x32xf32>) outs(%0 : tensor<1x8x8x32xf32>) -> tensor<1x8x8x32xf32>
+  return %1 : tensor<1x8x8x32xf32>
+}
+
+// -----
diff --git a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir b/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir
index fe91dc5..35e2053 100644
--- a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir
+++ b/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir
@@ -1,8 +1,8 @@
 // RUN: iree-dialects-opt --iree-linalg-ext-materialize-encoding -cse -split-input-file %s | FileCheck %s
 
 func.func @pack_unpack_gemm_lhs(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
-  %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
-  %1 = iree_linalg_ext.unset_encoding %0 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>> -> tensor<?x?xf32>
+  %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
+  %1 = iree_linalg_ext.unset_encoding %0 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>> -> tensor<?x?xf32>
   return %1 : tensor<?x?xf32>
 }
 //  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
@@ -24,8 +24,8 @@
 // -----
 
 func.func @pack_unpack_gemm_rhs(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
-  %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS>>
-  %1 = iree_linalg_ext.unset_encoding %0 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS>> -> tensor<?x?xf32>
+  %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS>>
+  %1 = iree_linalg_ext.unset_encoding %0 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS>> -> tensor<?x?xf32>
   return %1 : tensor<?x?xf32>
 }
 // CHECK-LABEL: func @pack_unpack_gemm_rhs(
@@ -35,8 +35,8 @@
 // -----
 
 func.func @pack_unpack_gemm_rhs_transpose(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
-  %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS_TRANSPOSE>>
-  %1 = iree_linalg_ext.unset_encoding %0 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS_TRANSPOSE>> -> tensor<?x?xf32>
+  %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS_TRANSPOSE>>
+  %1 = iree_linalg_ext.unset_encoding %0 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS_TRANSPOSE>> -> tensor<?x?xf32>
   return %1 : tensor<?x?xf32>
 }
 // CHECK-LABEL: func @pack_unpack_gemm_rhs_transpose(
@@ -46,8 +46,8 @@
 // -----
 
 func.func @pack_unpack_gemm_result(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
-  %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>
-  %1 = iree_linalg_ext.unset_encoding %0 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>> -> tensor<?x?xf32>
+  %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>
+  %1 = iree_linalg_ext.unset_encoding %0 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>> -> tensor<?x?xf32>
   return %1 : tensor<?x?xf32>
 }
 // CHECK-LABEL: func @pack_unpack_gemm_result(
@@ -62,20 +62,20 @@
     ^bb0(%b0: index, %b1 : index):
       tensor.yield %pad_value : f32
     } : tensor<100x250xf32> to tensor<104x252xf32>
-  %lhs = iree_linalg_ext.set_encoding %pad_lhs : tensor<104x252xf32> -> tensor<104x252xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
+  %lhs = iree_linalg_ext.set_encoding %pad_lhs : tensor<104x252xf32> -> tensor<104x252xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
   %pad_rhs = tensor.pad %arg1 low[0, 0] high[2, 4] {
     ^bb0(%b0: index, %b1 : index):
       tensor.yield %pad_value : f32
     } : tensor<250x500xf32> to tensor<252x504xf32>
-  %rhs = iree_linalg_ext.set_encoding %pad_rhs : tensor<252x504xf32> -> tensor<252x504xf32, #iree_linalg_ext.encoding<GEMM_RHS_TRANSPOSE>>
+  %rhs = iree_linalg_ext.set_encoding %pad_rhs : tensor<252x504xf32> -> tensor<252x504xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS_TRANSPOSE>>
   %pad_output = tensor.pad %arg2 low[0, 0] high[4, 4] {
     ^bb0(%b0: index, %b1 : index):
       tensor.yield %pad_value : f32
     } : tensor<100x500xf32> to tensor<104x504xf32>
-  %output = iree_linalg_ext.set_encoding %pad_output : tensor<104x504xf32> -> tensor<104x504xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>
-  %gemm_packed = linalg.matmul ins(%lhs, %rhs : tensor<104x252xf32, #iree_linalg_ext.encoding<GEMM_LHS>>, tensor<252x504xf32, #iree_linalg_ext.encoding<GEMM_RHS_TRANSPOSE>>)
-      outs(%output : tensor<104x504xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>) -> tensor<104x504xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>
-  %gemm = iree_linalg_ext.unset_encoding %gemm_packed : tensor<104x504xf32, #iree_linalg_ext.encoding<GEMM_RESULT>> -> tensor<104x504xf32>
+  %output = iree_linalg_ext.set_encoding %pad_output : tensor<104x504xf32> -> tensor<104x504xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>
+  %gemm_packed = linalg.matmul ins(%lhs, %rhs : tensor<104x252xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>, tensor<252x504xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS_TRANSPOSE>>)
+      outs(%output : tensor<104x504xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>) -> tensor<104x504xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>
+  %gemm = iree_linalg_ext.unset_encoding %gemm_packed : tensor<104x504xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>> -> tensor<104x504xf32>
   %result = tensor.extract_slice %gemm[0, 0] [100, 500] [1, 1] : tensor<104x504xf32> to tensor<100x500xf32>
   return %result : tensor<100x500xf32>
 }
@@ -102,12 +102,12 @@
 // -----
 
 func.func @pack_gemm_dynamic(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
-  %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
-  %1 = iree_linalg_ext.set_encoding %arg1 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS_TRANSPOSE>>
-  %2 = iree_linalg_ext.set_encoding %arg2 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>
-  %3 = linalg.matmul ins(%0, %1 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>, tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS_TRANSPOSE>>)
-      outs(%2 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>) -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>
-  %4 = iree_linalg_ext.unset_encoding %3 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>> -> tensor<?x?xf32>
+  %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
+  %1 = iree_linalg_ext.set_encoding %arg1 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS_TRANSPOSE>>
+  %2 = iree_linalg_ext.set_encoding %arg2 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>
+  %3 = linalg.matmul ins(%0, %1 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>, tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS_TRANSPOSE>>)
+      outs(%2 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>) -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>
+  %4 = iree_linalg_ext.unset_encoding %3 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>> -> tensor<?x?xf32>
   return %4 : tensor<?x?xf32>
 }
 //  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
@@ -133,14 +133,14 @@
   %cst = arith.constant 0.0 : f32
   %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
   %d1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
-  %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
-  %1 = iree_linalg_ext.set_encoding %arg1 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS_TRANSPOSE>>
-  %2 = tensor.empty(%d0, %d1) : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>
-  %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>)
-      -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>
-  %4 = linalg.matmul ins(%0, %1 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>, tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS_TRANSPOSE>>)
-      outs(%3 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>) -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>
-  %5 = iree_linalg_ext.unset_encoding %4 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>> -> tensor<?x?xf32>
+  %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
+  %1 = iree_linalg_ext.set_encoding %arg1 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS_TRANSPOSE>>
+  %2 = tensor.empty(%d0, %d1) : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>
+  %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>)
+      -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>
+  %4 = linalg.matmul ins(%0, %1 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>, tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS_TRANSPOSE>>)
+      outs(%3 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>) -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>
+  %5 = iree_linalg_ext.unset_encoding %4 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>> -> tensor<?x?xf32>
   return %5 : tensor<?x?xf32>
 }
 //  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
diff --git a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/resolve-shaped-type-result-dims.mlir b/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/resolve-shaped-type-result-dims.mlir
index f5619f7..41bd0f8 100644
--- a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/resolve-shaped-type-result-dims.mlir
+++ b/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/resolve-shaped-type-result-dims.mlir
@@ -3,9 +3,9 @@
 func.func @pack_static(%arg0 : tensor<100x250xf32>) -> (index, index) {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
-  %0 = iree_linalg_ext.set_encoding %arg0 : tensor<100x250xf32> -> tensor<100x250xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
-  %1 = tensor.dim %0, %c0 : tensor<100x250xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
-  %2 = tensor.dim %0, %c1 : tensor<100x250xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
+  %0 = iree_linalg_ext.set_encoding %arg0 : tensor<100x250xf32> -> tensor<100x250xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
+  %1 = tensor.dim %0, %c0 : tensor<100x250xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
+  %2 = tensor.dim %0, %c1 : tensor<100x250xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
   return %1, %2 : index, index
 }
 // CHECK-LABEL: func @pack_static(
@@ -18,9 +18,9 @@
 func.func @pack_dynamic(%arg0 : tensor<?x?xf32>) -> (index, index) {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
-  %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
-  %1 = tensor.dim %0, %c0 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
-  %2 = tensor.dim %0, %c1 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
+  %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
+  %1 = tensor.dim %0, %c0 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
+  %2 = tensor.dim %0, %c1 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
   return %1, %2 : index, index
 }
 //       CHECK: func @pack_dynamic(%[[ARG0:.+]]: tensor<?x?xf32>)
diff --git a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/roundtrip.mlir b/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/roundtrip.mlir
index aef7ffa..4cb706c 100644
--- a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/roundtrip.mlir
+++ b/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/roundtrip.mlir
@@ -869,36 +869,36 @@
 // -----
 
 // CHECK: @set_encoding_ops(%[[ARG0:.+]]: tensor<?x?xf32>)
-func.func @set_encoding_ops(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>> {
-  // CHECK: iree_linalg_ext.set_encoding %[[ARG0]] : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
-  %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
-  return %0 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
+func.func @set_encoding_ops(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>> {
+  // CHECK: iree_linalg_ext.set_encoding %[[ARG0]] : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
+  %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
+  return %0 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
 }
 
 // -----
 
 // CHECK: @set_encoding_ops_mixed_dynamic_static(%[[ARG0:.+]]: tensor<?x10xf32>)
-func.func @set_encoding_ops_mixed_dynamic_static(%arg0: tensor<?x10xf32>) -> tensor<20x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>> {
-  // CHECK: iree_linalg_ext.set_encoding %[[ARG0]] : tensor<?x10xf32> -> tensor<20x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
-  %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x10xf32> -> tensor<20x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
-  return %0 : tensor<20x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
+func.func @set_encoding_ops_mixed_dynamic_static(%arg0: tensor<?x10xf32>) -> tensor<20x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>> {
+  // CHECK: iree_linalg_ext.set_encoding %[[ARG0]] : tensor<?x10xf32> -> tensor<20x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
+  %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x10xf32> -> tensor<20x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
+  return %0 : tensor<20x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
 }
 
 // -----
 
-// CHECK: @unset_encoding_ops(%[[ARG0:.+]]: tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS>>)
-func.func @unset_encoding_ops(%arg0: tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS>>) -> tensor<?x?xf32> {
-  // CHECK: iree_linalg_ext.unset_encoding %[[ARG0]] : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS>> -> tensor<?x?xf32>
-  %0 = iree_linalg_ext.unset_encoding %arg0 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS>> -> tensor<?x?xf32>
+// CHECK: @unset_encoding_ops(%[[ARG0:.+]]: tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS>>)
+func.func @unset_encoding_ops(%arg0: tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS>>) -> tensor<?x?xf32> {
+  // CHECK: iree_linalg_ext.unset_encoding %[[ARG0]] : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS>> -> tensor<?x?xf32>
+  %0 = iree_linalg_ext.unset_encoding %arg0 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS>> -> tensor<?x?xf32>
   return %0 : tensor<?x?xf32>
 }
 
 // -----
 
-// CHECK: @unset_encoding_ops_mixed_dynamic_static(%[[ARG0:.+]]: tensor<10x?xf32, #iree_linalg_ext.encoding<GEMM_RHS>>)
-func.func @unset_encoding_ops_mixed_dynamic_static(%arg0: tensor<10x?xf32, #iree_linalg_ext.encoding<GEMM_RHS>>) -> tensor<?x20xf32> {
-  // CHECK: iree_linalg_ext.unset_encoding %[[ARG0]] : tensor<10x?xf32, #iree_linalg_ext.encoding<GEMM_RHS>>
-  %0 = iree_linalg_ext.unset_encoding %arg0 : tensor<10x?xf32, #iree_linalg_ext.encoding<GEMM_RHS>> -> tensor<?x20xf32>
+// CHECK: @unset_encoding_ops_mixed_dynamic_static(%[[ARG0:.+]]: tensor<10x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS>>)
+func.func @unset_encoding_ops_mixed_dynamic_static(%arg0: tensor<10x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS>>) -> tensor<?x20xf32> {
+  // CHECK: iree_linalg_ext.unset_encoding %[[ARG0]] : tensor<10x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS>>
+  %0 = iree_linalg_ext.unset_encoding %arg0 : tensor<10x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS>> -> tensor<?x20xf32>
   return %0 : tensor<?x20xf32>
 }
 
@@ -906,14 +906,14 @@
 
 func.func @encoding_tensors_with_ops(%arg0 : tensor<?x?xf32>,
     %arg1 : tensor<?x?xf32>, %arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
-  %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
-  %1 = iree_linalg_ext.set_encoding %arg1 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS>>
-  %2 = iree_linalg_ext.set_encoding %arg2 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>
+  %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
+  %1 = iree_linalg_ext.set_encoding %arg1 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS>>
+  %2 = iree_linalg_ext.set_encoding %arg2 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>
   %3 = linalg.matmul
-      ins(%0, %1 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>, tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS>>)
-      outs(%2 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>)
-      -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>
-  %4 = iree_linalg_ext.unset_encoding %3 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>> -> tensor<?x?xf32>
+      ins(%0, %1 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>, tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS>>)
+      outs(%2 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>)
+      -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>
+  %4 = iree_linalg_ext.unset_encoding %3 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>> -> tensor<?x?xf32>
   return %4 : tensor<?x?xf32>
 }
 // CHECK-LABEL: func.func @encoding_tensors_with_ops
@@ -921,13 +921,80 @@
 //  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
 //  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<?x?xf32>
 //       CHECK:   %[[LHS:.+]] = iree_linalg_ext.set_encoding %[[ARG0]]
-//  CHECK-SAME:       tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
+//  CHECK-SAME:       tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
 //       CHECK:   %[[RHS:.+]] = iree_linalg_ext.set_encoding %[[ARG1]]
-//  CHECK-SAME:       tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS>>
+//  CHECK-SAME:       tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS>>
 //       CHECK:   %[[OUT:.+]] = iree_linalg_ext.set_encoding %[[ARG2]]
-//  CHECK-SAME:       tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>
+//  CHECK-SAME:       tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>
 //       CHECK:   %[[GEMM:.+]] = linalg.matmul
 //  CHECK-SAME:       ins(%[[LHS]], %[[RHS]] :
 //  CHECK-SAME:       outs(%[[OUT]] :
 //       CHECK:   %[[RESULT:.+]] = iree_linalg_ext.unset_encoding %[[GEMM]]
 //       CHECK:   return %[[RESULT]]
+
+// -----
+
+func.func @winograd_input_transform(%arg0: tensor<1x10x10x1280xf32>) -> tensor<8x8x1x2x2x1280xf32> {
+  %0 = tensor.empty() : tensor<8x8x1x2x2x1280xf32>
+  %1 = iree_linalg_ext.winograd.input_transform output_tile_size(6) kernel_size(3) image_dimensions([1, 2])
+    ins(%arg0 : tensor<1x10x10x1280xf32>) outs(%0 : tensor<8x8x1x2x2x1280xf32>) -> tensor<8x8x1x2x2x1280xf32>
+  return %1 : tensor<8x8x1x2x2x1280xf32>
+}
+// CHECK:      func.func @winograd_input_transform(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x10x10x1280xf32>) ->
+// CHECK-SAME:   tensor<8x8x1x2x2x1280xf32> {
+// CHECK:        %[[D0:.+]] = tensor.empty() : tensor<8x8x1x2x2x1280xf32>
+// CHECK:        %[[D1:.+]] = iree_linalg_ext.winograd.input_transform output_tile_size(6) kernel_size(3)
+// CHECK-SAME:     image_dimensions([1, 2]) ins(%[[ARG0]] : tensor<1x10x10x1280xf32>) outs(%[[D0]] :
+// CHECK-SAME:     tensor<8x8x1x2x2x1280xf32>) -> tensor<8x8x1x2x2x1280xf32>
+// CHECK:        return %[[D1]] : tensor<8x8x1x2x2x1280xf32>
+// CHECK:      }
+
+// -----
+
+func.func @winograd_input_transform_dynamic(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<8x8x?x?x?x?xf32>) -> tensor<8x8x?x?x?x?xf32> {
+  %1 = iree_linalg_ext.winograd.input_transform
+    output_tile_size(6) kernel_size(3) image_dimensions([1, 2])
+    ins(%arg0 : tensor<?x?x?x?xf32>) outs(%arg1 : tensor<8x8x?x?x?x?xf32>) -> tensor<8x8x?x?x?x?xf32>
+  return %1 : tensor<8x8x?x?x?x?xf32>
+}
+// CHECK:      func.func @winograd_input_transform_dynamic(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?x?xf32>,
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<8x8x?x?x?x?xf32>) -> tensor<8x8x?x?x?x?xf32> {
+// CHECK:        %[[D0:.+]] = iree_linalg_ext.winograd.input_transform output_tile_size(6) kernel_size(3)
+// CHECK-SAME:     image_dimensions([1, 2]) ins(%[[ARG0]] : tensor<?x?x?x?xf32>) outs(%[[ARG1]] :
+// CHECK-SAME:     tensor<8x8x?x?x?x?xf32>) -> tensor<8x8x?x?x?x?xf32>
+// CHECK:        return %[[D0]] : tensor<8x8x?x?x?x?xf32>
+// CHECK:      }
+
+// -----
+
+func.func @winograd_output_transform(%arg0: tensor<8x8x1x2x2x1280xf32>) -> tensor<1x12x12x1280xf32> {
+  %0 = tensor.empty() : tensor<1x12x12x1280xf32>
+  %1 = iree_linalg_ext.winograd.output_transform output_tile_size(6) kernel_size(3) image_dimensions([1, 2])
+    ins(%arg0 : tensor<8x8x1x2x2x1280xf32>) outs(%0 : tensor<1x12x12x1280xf32>) -> tensor<1x12x12x1280xf32>
+  return %1 : tensor<1x12x12x1280xf32>
+}
+// CHECK:      func.func @winograd_output_transform(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<8x8x1x2x2x1280xf32>) ->
+// CHECK-SAME:   tensor<1x12x12x1280xf32> {
+// CHECK:        %[[D0:.+]] = tensor.empty() : tensor<1x12x12x1280xf32>
+// CHECK:        %[[D1:.+]] = iree_linalg_ext.winograd.output_transform output_tile_size(6) kernel_size(3)
+// CHECK-SAME:     image_dimensions([1, 2]) ins(%[[ARG0]] : tensor<8x8x1x2x2x1280xf32>) outs(%[[D0]] :
+// CHECK-SAME:     tensor<1x12x12x1280xf32>) -> tensor<1x12x12x1280xf32>
+// CHECK:        return %[[D1]] : tensor<1x12x12x1280xf32>
+// CHECK:      }
+
+// -----
+
+func.func @winograd_output_transform(%arg0: tensor<8x8x?x?x?x?xf32>, %arg1: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %1 = iree_linalg_ext.winograd.output_transform output_tile_size(6) kernel_size(3) image_dimensions([1, 2])
+    ins(%arg0 : tensor<8x8x?x?x?x?xf32>) outs(%arg1 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %1 : tensor<?x?x?x?xf32>
+}
+// CHECK:      func.func @winograd_output_transform(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<8x8x?x?x?x?xf32>,
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+// CHECK:        %[[D0:.+]] = iree_linalg_ext.winograd.output_transform output_tile_size(6) kernel_size(3)
+// CHECK-SAME:     image_dimensions([1, 2]) ins(%[[ARG0]] : tensor<8x8x?x?x?x?xf32>) outs(%[[ARG1]] :
+// CHECK-SAME:     tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+// CHECK:        return %[[D0]] : tensor<?x?x?x?xf32>
+// CHECK:      }
+
+// -----
diff --git a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/tile_and_decompose_winograd.mlir b/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/tile_and_decompose_winograd.mlir
new file mode 100644
index 0000000..aabe3d9
--- /dev/null
+++ b/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/tile_and_decompose_winograd.mlir
@@ -0,0 +1,195 @@
+// RUN: iree-dialects-opt --iree-linalg-ext-tile-and-decompose-winograd --split-input-file %s | FileCheck %s
+
+#map = affine_map<(d0)[s0, s1] -> (1, -d0 + s1)>
+#map1 = affine_map<(d0)[s0, s1] -> (32, -d0 + s1)>
+module {
+  func.func @winograd_input_transform(%arg0: tensor<1x10x10x1280xf32>) -> tensor<8x8x1x2x2x1280xf32> {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c1280 = arith.constant 1280 : index
+    %c32 = arith.constant 32 : index
+    %0 = tensor.empty() : tensor<8x8x1x2x2x1280xf32>
+    %1 = scf.for %arg1 = %c0 to %c1 step %c1 iter_args(%arg2 = %0) -> (tensor<8x8x1x2x2x1280xf32>) {
+      %2 = affine.min #map(%arg1)[%c1, %c1]
+      %3 = scf.for %arg3 = %c0 to %c1280 step %c32 iter_args(%arg4 = %arg2) -> (tensor<8x8x1x2x2x1280xf32>) {
+        %4 = affine.min #map1(%arg3)[%c32, %c1280]
+        %extracted_slice = tensor.extract_slice %arg0[%arg1, 0, 0, %arg3] [%2, 10, 10, %4] [1, 1, 1, 1] : tensor<1x10x10x1280xf32> to tensor<?x10x10x?xf32>
+        %extracted_slice_0 = tensor.extract_slice %0[0, 0, %arg1, 0, 0, %arg3] [8, 8, %2, 2, 2, %4] [1, 1, 1, 1, 1, 1] : tensor<8x8x1x2x2x1280xf32> to tensor<8x8x?x2x2x?xf32>
+        %5 = iree_linalg_ext.winograd.input_transform output_tile_size(6) kernel_size(3) image_dimensions([1, 2]) ins(%extracted_slice : tensor<?x10x10x?xf32>) outs(%extracted_slice_0 : tensor<8x8x?x2x2x?xf32>) -> tensor<8x8x?x2x2x?xf32>
+        %inserted_slice = tensor.insert_slice %5 into %arg4[0, 0, %arg1, 0, 0, %arg3] [8, 8, %2, 2, 2, %4] [1, 1, 1, 1, 1, 1] : tensor<8x8x?x2x2x?xf32> into tensor<8x8x1x2x2x1280xf32>
+        scf.yield %inserted_slice : tensor<8x8x1x2x2x1280xf32>
+      }
+      scf.yield %3 : tensor<8x8x1x2x2x1280xf32>
+    }
+    return %1 : tensor<8x8x1x2x2x1280xf32>
+  }
+}
+// CHECK-DAG:  #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (1, -d0 + s1)>
+// CHECK-DAG:  #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (32, -d0 + s1)>
+// CHECK-DAG:  #[[MAP2:.+]] = affine_map<(d0) -> (d0 * 6)>
+// CHECK-DAG:  #[[MAP3:.+]] = affine_map<(d0) -> (-d0 + 10, 8)>
+// CHECK:      func.func @winograd_input_transform(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x10x10x1280xf32>) ->
+// CHECK-SAME:   tensor<8x8x1x2x2x1280xf32> {
+// CHECK:        %[[C32:.+]] = arith.constant 32 : index
+// CHECK:        %[[C1280:.+]] = arith.constant 1280 : index
+// CHECK:        %[[C1:.+]] = arith.constant 1 : index
+// CHECK:        %[[C0:.+]] = arith.constant 0 : index
+// CHECK:        %[[CST:.+]] = arith.constant dense<
+// CHECK:        %[[CST_0:.+]] = arith.constant dense<
+// CHECK:        %[[CST_1:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK:        %[[C2:.+]] = arith.constant 2 : index
+// CHECK:        %[[D0:.+]] = tensor.empty() : tensor<8x8xf32>
+// CHECK:        %[[D1:.+]] = tensor.empty() : tensor<8x8x1x2x2x1280xf32>
+// CHECK:        %[[D2:.+]] = scf.for %[[ARG1:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C1]] step %[[C1]]
+// CHECK-SAME:     iter_args(%[[ARG2:[a-zA-Z0-9_]+]] = %[[D1]]) -> (tensor<8x8x1x2x2x1280xf32>) {
+// CHECK-DAG:        %[[D3:.+]] = affine.min #[[MAP]](%[[ARG1]])[%[[C1]], %[[C1]]]
+// CHECK:          %[[D4:.+]] = scf.for %[[ARG3:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C1280]] step %[[C32]]
+// CHECK-SAME:       iter_args(%[[ARG4:[a-zA-Z0-9_]+]] = %[[ARG2]]) -> (tensor<8x8x1x2x2x1280xf32>) {
+// CHECK-DAG:          %[[D5:.+]] = affine.min #[[MAP1]](%[[ARG3]])[%[[C32]], %[[C1280]]]
+// CHECK:            %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[ARG1]], 0, 0, %[[ARG3]]] [%[[D3]], 10,
+// CHECK-SAME:         10, %[[D5]]] [1, 1, 1, 1] : tensor<1x10x10x1280xf32> to tensor<?x10x10x?xf32>
+// CHECK:            %[[EXTRACTED_SLICE_2:.+]] = tensor.extract_slice %[[D1]][0, 0, %[[ARG1]], 0, 0, %[[ARG3]]] [8, 8,
+// CHECK-SAME:         %[[D3]], 2, 2, %[[D5]]] [1, 1, 1, 1, 1, 1] : tensor<8x8x1x2x2x1280xf32> to
+// CHECK-SAME:         tensor<8x8x?x2x2x?xf32>
+// CHECK:            %[[D6:.+]] = scf.for %[[ARG5:[a-zA-Z0-9_]+]] = %[[C0]] to %[[D3]] step %[[C1]]
+// CHECK-SAME:         iter_args(%[[ARG6:[a-zA-Z0-9_]+]] = %[[EXTRACTED_SLICE_2]]) -> (tensor<8x8x?x2x2x?xf32>) {
+// CHECK:              %[[D7:.+]] = scf.for %[[ARG7:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C2]] step %[[C1]]
+// CHECK-SAME:           iter_args(%[[ARG8:[a-zA-Z0-9_]+]] = %[[ARG6]]) -> (tensor<8x8x?x2x2x?xf32>) {
+// CHECK-DAG:              %[[D8:.+]] = affine.apply #[[MAP2]](%[[ARG7]])
+// CHECK-DAG:              %[[D9:.+]] = affine.min #[[MAP3]](%[[D8]])
+// CHECK:                %[[D10:.+]] = scf.for %[[ARG9:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C2]] step %[[C1]]
+// CHECK-SAME:             iter_args(%[[ARG10:[a-zA-Z0-9_]+]] = %[[ARG8]]) -> (tensor<8x8x?x2x2x?xf32>) {
+// CHECK-DAG:                %[[D11:.+]] = affine.apply #[[MAP2]](%[[ARG9]])
+// CHECK-DAG:                %[[D12:.+]] = affine.min #[[MAP3]](%[[D11]])
+// CHECK:                  %[[D13:.+]] = scf.for %[[ARG11:[a-zA-Z0-9_]+]] = %[[C0]] to %[[D5]] step %[[C1]]
+// CHECK-SAME:               iter_args(%[[ARG12:[a-zA-Z0-9_]+]] = %[[ARG10]]) -> (tensor<8x8x?x2x2x?xf32>) {
+// CHECK:                    %[[EXTRACTED_SLICE_3:.+]] = tensor.extract_slice %[[EXTRACTED_SLICE]][%[[ARG5]], %[[D8]],
+// CHECK-SAME:                 %[[D11]], %[[ARG11]]] [1, %[[D9]], %[[D12]], 1] [1, 1, 1, 1] : tensor<?x10x10x?xf32> to
+// CHECK-SAME:                 tensor<?x?xf32>
+// CHECK:                    %[[D14:.+]] = linalg.fill ins(%[[CST_1]] : f32) outs(%[[D0]] : tensor<8x8xf32>) ->
+// CHECK-SAME:                 tensor<8x8xf32>
+// CHECK:                    %[[INSERTED_SLICE_4:.+]] = tensor.insert_slice %[[EXTRACTED_SLICE_3]] into %[[D14]][0, 0]
+// CHECK-SAME:                 [%[[D9]], %[[D12]]] [1, 1] : tensor<?x?xf32> into tensor<8x8xf32>
+// CHECK:                    %[[EXTRACTED_SLICE_5:.+]] = tensor.extract_slice %[[ARG12]][0, 0, %[[ARG5]], %[[ARG7]],
+// CHECK-SAME:                 %[[ARG9]], %[[ARG11]]] [8, 8, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<8x8x?x2x2x?xf32> to
+// CHECK-SAME:                 tensor<8x8xf32>
+// CHECK:                    %[[D15:.+]] = linalg.fill ins(%[[CST_1]] : f32) outs(%[[EXTRACTED_SLICE_5]] :
+// CHECK-SAME:                 tensor<8x8xf32>) -> tensor<8x8xf32>
+// CHECK:                    %[[D16:.+]] = linalg.matmul ins(%[[INSERTED_SLICE_4]], %[[CST_0]] : tensor<8x8xf32>,
+// CHECK-SAME:                 tensor<8x8xf32>) outs(%[[D15]] : tensor<8x8xf32>) -> tensor<8x8xf32>
+// CHECK:                    %[[D17:.+]] = linalg.fill ins(%[[CST_1]] : f32) outs(%[[EXTRACTED_SLICE_5]] :
+// CHECK-SAME:                 tensor<8x8xf32>) -> tensor<8x8xf32>
+// CHECK:                    %[[D18:.+]] = linalg.matmul ins(%[[CST]], %[[D16]] : tensor<8x8xf32>, tensor<8x8xf32>)
+// CHECK-SAME:                 outs(%[[D17]] : tensor<8x8xf32>) -> tensor<8x8xf32>
+// CHECK:                    %[[INSERTED_SLICE_6:.+]] = tensor.insert_slice %[[D18]] into %[[ARG12]][0, 0, %[[ARG5]],
+// CHECK-SAME:                 %[[ARG7]], %[[ARG9]], %[[ARG11]]] [8, 8, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<8x8xf32>
+// CHECK-SAME:                 into tensor<8x8x?x2x2x?xf32>
+// CHECK:                    scf.yield %[[INSERTED_SLICE_6]] : tensor<8x8x?x2x2x?xf32>
+// CHECK:                  }
+// CHECK:                  scf.yield %[[D13]] : tensor<8x8x?x2x2x?xf32>
+// CHECK:                }
+// CHECK:                scf.yield %[[D10]] : tensor<8x8x?x2x2x?xf32>
+// CHECK:              }
+// CHECK:              scf.yield %[[D7]] : tensor<8x8x?x2x2x?xf32>
+// CHECK:            }
+// CHECK:            %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D6]] into %[[ARG4]][0, 0, %[[ARG1]], 0, 0,
+// CHECK-SAME:         %[[ARG3]]] [8, 8, %[[D3]], 2, 2, %[[D5]]] [1, 1, 1, 1, 1, 1] : tensor<8x8x?x2x2x?xf32> into
+// CHECK-SAME:         tensor<8x8x1x2x2x1280xf32>
+// CHECK:            scf.yield %[[INSERTED_SLICE]] : tensor<8x8x1x2x2x1280xf32>
+// CHECK:          }
+// CHECK:          scf.yield %[[D4]] : tensor<8x8x1x2x2x1280xf32>
+// CHECK:        }
+// CHECK:        return %[[D2]] : tensor<8x8x1x2x2x1280xf32>
+// CHECK:      }
+
+// -----
+
+#map = affine_map<(d0)[s0, s1] -> (1, -d0 + s1)>
+#map1 = affine_map<(d0)[s0, s1] -> (32, -d0 + s1)>
+module {
+  func.func @winograd_output_transform(%arg0: tensor<8x8x1x2x2x32xf32>) -> tensor<1x12x12x32xf32> {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c32 = arith.constant 32 : index
+    %0 = tensor.empty() : tensor<1x12x12x32xf32>
+    %1 = scf.for %arg1 = %c0 to %c1 step %c1 iter_args(%arg2 = %0) -> (tensor<1x12x12x32xf32>) {
+      %2 = affine.min #map(%arg1)[%c1, %c1]
+      %3 = scf.for %arg3 = %c0 to %c32 step %c32 iter_args(%arg4 = %arg2) -> (tensor<1x12x12x32xf32>) {
+        %4 = affine.min #map1(%arg3)[%c32, %c32]
+        %extracted_slice = tensor.extract_slice %arg0[0, 0, %arg1, 0, 0, %arg3] [8, 8, %2, 2, 2, %4] [1, 1, 1, 1, 1, 1] : tensor<8x8x1x2x2x32xf32> to tensor<8x8x?x2x2x?xf32>
+        %extracted_slice_0 = tensor.extract_slice %0[%arg1, 0, 0, %arg3] [%2, 12, 12, %4] [1, 1, 1, 1] : tensor<1x12x12x32xf32> to tensor<?x12x12x?xf32>
+        %5 = iree_linalg_ext.winograd.output_transform output_tile_size(6) kernel_size(3) image_dimensions([1, 2]) ins(%extracted_slice : tensor<8x8x?x2x2x?xf32>) outs(%extracted_slice_0 : tensor<?x12x12x?xf32>) -> tensor<?x12x12x?xf32>
+        %inserted_slice = tensor.insert_slice %5 into %arg4[%arg1, 0, 0, %arg3] [%2, 12, 12, %4] [1, 1, 1, 1] : tensor<?x12x12x?xf32> into tensor<1x12x12x32xf32>
+        scf.yield %inserted_slice : tensor<1x12x12x32xf32>
+      }
+      scf.yield %3 : tensor<1x12x12x32xf32>
+    }
+    return %1 : tensor<1x12x12x32xf32>
+  }
+}
+// CHECK-DAG:  #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (1, -d0 + s1)>
+// CHECK-DAG:  #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (32, -d0 + s1)>
+// CHECK-DAG:  #[[MAP2:.+]] = affine_map<(d0) -> (d0 * 6)>
+// CHECK:      func.func @winograd_output_transform(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<8x8x1x2x2x32xf32>) ->
+// CHECK-SAME:   tensor<1x12x12x32xf32> {
+// CHECK:        %[[C32:.+]] = arith.constant 32 : index
+// CHECK:        %[[C1:.+]] = arith.constant 1 : index
+// CHECK:        %[[C0:.+]] = arith.constant 0 : index
+// CHECK:        %[[CST:.+]] = arith.constant dense<
+// CHECK:        %[[CST_0:.+]] = arith.constant dense<
+// CHECK:        %[[CST_1:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK:        %[[C2:.+]] = arith.constant 2 : index
+// CHECK:        %[[D0:.+]] = tensor.empty() : tensor<8x6xf32>
+// CHECK:        %[[D1:.+]] = tensor.empty() : tensor<1x12x12x32xf32>
+// CHECK:        %[[D2:.+]] = scf.for %[[ARG1:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C1]] step %[[C1]]
+// CHECK-SAME:     iter_args(%[[ARG2:[a-zA-Z0-9_]+]] = %[[D1]]) -> (tensor<1x12x12x32xf32>) {
+// CHECK-DAG:        %[[D3:.+]] = affine.min #[[MAP]](%[[ARG1]])[%[[C1]], %[[C1]]]
+// CHECK:          %[[D4:.+]] = scf.for %[[ARG3:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C32]] step %[[C32]]
+// CHECK-SAME:       iter_args(%[[ARG4:[a-zA-Z0-9_]+]] = %[[ARG2]]) -> (tensor<1x12x12x32xf32>) {
+// CHECK-DAG:          %[[D5:.+]] = affine.min #[[MAP1]](%[[ARG3]])[%[[C32]], %[[C32]]]
+// CHECK:            %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[ARG0]][0, 0, %[[ARG1]], 0, 0, %[[ARG3]]] [8, 8,
+// CHECK-SAME:         %[[D3]], 2, 2, %[[D5]]] [1, 1, 1, 1, 1, 1] : tensor<8x8x1x2x2x32xf32> to tensor<8x8x?x2x2x?xf32>
+// CHECK:            %[[EXTRACTED_SLICE_2:.+]] = tensor.extract_slice %[[D1]][%[[ARG1]], 0, 0, %[[ARG3]]] [%[[D3]], 12,
+// CHECK-SAME:         12, %[[D5]]] [1, 1, 1, 1] : tensor<1x12x12x32xf32> to tensor<?x12x12x?xf32>
+// CHECK:            %[[D6:.+]] = scf.for %[[ARG5:[a-zA-Z0-9_]+]] = %[[C0]] to %[[D3]] step %[[C1]]
+// CHECK-SAME:         iter_args(%[[ARG6:[a-zA-Z0-9_]+]] = %[[EXTRACTED_SLICE_2]]) -> (tensor<?x12x12x?xf32>) {
+// CHECK:              %[[D7:.+]] = scf.for %[[ARG7:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C2]] step %[[C1]]
+// CHECK-SAME:           iter_args(%[[ARG8:[a-zA-Z0-9_]+]] = %[[ARG6]]) -> (tensor<?x12x12x?xf32>) {
+// CHECK-DAG:              %[[D8:.+]] = affine.apply #[[MAP2]](%[[ARG7]])
+// CHECK:                %[[D9:.+]] = scf.for %[[ARG9:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C2]] step %[[C1]]
+// CHECK-SAME:             iter_args(%[[ARG10:[a-zA-Z0-9_]+]] = %[[ARG8]]) -> (tensor<?x12x12x?xf32>) {
+// CHECK-DAG:                %[[D10:.+]] = affine.apply #[[MAP2]](%[[ARG9]])
+// CHECK:                  %[[D11:.+]] = scf.for %[[ARG11:[a-zA-Z0-9_]+]] = %[[C0]] to %[[D5]] step %[[C1]]
+// CHECK-SAME:               iter_args(%[[ARG12:[a-zA-Z0-9_]+]] = %[[ARG10]]) -> (tensor<?x12x12x?xf32>) {
+// CHECK:                    %[[EXTRACTED_SLICE_3:.+]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, %[[ARG5]],
+// CHECK-SAME:                 %[[ARG7]], %[[ARG9]], %[[ARG11]]] [8, 8, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] :
+// CHECK-SAME:                 tensor<8x8x?x2x2x?xf32> to tensor<8x8xf32>
+// CHECK:                    %[[EXTRACTED_SLICE_4:.+]] = tensor.extract_slice %[[ARG12]][%[[ARG5]], %[[D8]], %[[D10]],
+// CHECK-SAME:                 %[[ARG11]]] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<?x12x12x?xf32> to tensor<6x6xf32>
+// CHECK:                    %[[D12:.+]] = linalg.fill ins(%[[CST_1]] : f32) outs(%[[D0]] : tensor<8x6xf32>) ->
+// CHECK-SAME:                 tensor<8x6xf32>
+// CHECK:                    %[[D13:.+]] = linalg.matmul ins(%[[EXTRACTED_SLICE_3]], %[[CST_0]] : tensor<8x8xf32>,
+// CHECK-SAME:                 tensor<8x6xf32>) outs(%[[D12]] : tensor<8x6xf32>) -> tensor<8x6xf32>
+// CHECK:                    %[[D14:.+]] = linalg.fill ins(%[[CST_1]] : f32) outs(%[[EXTRACTED_SLICE_4]] :
+// CHECK-SAME:                 tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK:                    %[[D15:.+]] = linalg.matmul ins(%[[CST]], %[[D13]] : tensor<6x8xf32>, tensor<8x6xf32>)
+// CHECK-SAME:                 outs(%[[D14]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK:                    %[[INSERTED_SLICE_5:.+]] = tensor.insert_slice %[[D15]] into %[[ARG12]][%[[ARG5]], %[[D8]],
+// CHECK-SAME:                 %[[D10]], %[[ARG11]]] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<6x6xf32> into
+// CHECK-SAME:                 tensor<?x12x12x?xf32>
+// CHECK:                    scf.yield %[[INSERTED_SLICE_5]] : tensor<?x12x12x?xf32>
+// CHECK:                  }
+// CHECK:                  scf.yield %[[D11]] : tensor<?x12x12x?xf32>
+// CHECK:                }
+// CHECK:                scf.yield %[[D9]] : tensor<?x12x12x?xf32>
+// CHECK:              }
+// CHECK:              scf.yield %[[D7]] : tensor<?x12x12x?xf32>
+// CHECK:            }
+// CHECK:            %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D6]] into %[[ARG4]][%[[ARG1]], 0, 0, %[[ARG3]]]
+// CHECK-SAME:         [%[[D3]], 12, 12, %[[D5]]] [1, 1, 1, 1] : tensor<?x12x12x?xf32> into tensor<1x12x12x32xf32>
+// CHECK:            scf.yield %[[INSERTED_SLICE]] : tensor<1x12x12x32xf32>
+// CHECK:          }
+// CHECK:          scf.yield %[[D4]] : tensor<1x12x12x32xf32>
+// CHECK:        }
+// CHECK:        return %[[D2]] : tensor<1x12x12x32xf32>
+// CHECK:      }
diff --git a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/tiling.mlir b/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/tiling.mlir
index 463134f..44cab77 100644
--- a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/tiling.mlir
+++ b/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/tiling.mlir
@@ -1205,3 +1205,224 @@
 // CHECK-SAME:          into %{{.+}}[%[[K]], %[[C]]] [%[[C2]], %[[C4]]]
 // CHECK:             scf.yield %[[RES]]
 
+// -----
+
+func.func @perfect_CKkc_to_KC(%arg0: tensor<32x4x2x4xf32>, %arg1: tensor<8x128xf32>) -> tensor<8x128xf32> {
+  %0 = iree_linalg_ext.unpack {__internal_linalg_transform__ = "tiling_pack_input"} %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [2, 4] into %arg1 : (tensor<32x4x2x4xf32> tensor<8x128xf32>) -> tensor<8x128xf32>
+  return %0 : tensor<8x128xf32>
+}
+// CHECK-DAG:   #[[MAP0:.+]] = affine_map<(d0) -> (d0 floordiv 2)>
+// CHECK-DAG:   #[[MAP1:.+]] = affine_map<(d0) -> (d0 floordiv 4)>
+// CHECK-LABEL: func.func @perfect_CKkc_to_KC
+// CHECK-SAME:    %[[IN:[A-Za-z0-9]+]]:
+// CHECK-SAME:    %[[OUT:[A-Za-z0-9]+]]:
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG:     %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG:     %[[C8:.*]] = arith.constant 8 : index
+// CHECK-DAG:     %[[C128:.*]] = arith.constant 128 : index
+// CHECK:         %{{.+}} = scf.for %[[K:.+]] = %[[C0]] to %[[C8]] step %[[C2]]
+// CHECK:           %{{.+}} = scf.for %[[C:.+]] = %[[C0]] to %[[C128]] step %[[C4]]
+// CHECK-DAG:         %[[IN_K:.+]] = affine.apply #[[MAP0]](%[[K]])
+// CHECK-DAG:         %[[IN_C:.+]] = affine.apply #[[MAP1]](%[[C]])
+// CHECK:             %[[IN_SLICE:.+]] = tensor.extract_slice %[[IN]]
+// CHECK:               [%[[IN_C]], %[[IN_K]], 0, 0] [1, 1, 2, 4]
+// CHECK:             %[[ITER_SLICE:.+]] = tensor.extract_slice %{{.+}}[%[[K]], %[[C]]] [%[[C2]], %[[C4]]]
+// CHECK:             %[[UNPACK:.+]] = iree_linalg_ext.unpack
+// CHECK-SAME:          {__internal_linalg_transform__ = "tiling_pack_output"
+// CHECK-SAME:          %[[IN_SLICE]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [2, 4]
+// CHECK-SAME:          into %[[ITER_SLICE]]
+// CHECK:             %[[RES:.+]] = tensor.insert_slice %[[UNPACK]]
+// CHECK-SAME:          into %{{.+}}[%[[K]], %[[C]]] [%[[C2]], %[[C4]]]
+// CHECK:             scf.yield %[[RES]]
+
+// -----
+
+func.func @dynamic_perfect_CKkc_to_KC(%arg0: tensor<?x?x2x2xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = iree_linalg_ext.unpack {__internal_linalg_transform__ = "tiling_pack_input"} %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [2, 2] into %arg1 : (tensor<?x?x2x2xf32> tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+// CHECK-DAG:   #[[MAP0:.+]] = affine_map<(d0)[s0] -> (2, -d0 + s0)>
+// CHECK-DAG:   #[[MAP1:.+]] = affine_map<(d0)[s0] -> (4, -d0 + s0)>
+// CHECK-DAG:   #[[MAP2:.+]] = affine_map<(d0) -> (d0 floordiv 2)>
+// CHECK-DAG:   #[[MAP3:.+]] = affine_map<(d0) -> (d0 ceildiv 2)>
+// CHECK-LABEL: func.func @dynamic_perfect_CKkc_to_KC
+// CHECK-SAME:    %[[IN:[A-Za-z0-9]+]]:
+// CHECK-SAME:    %[[OUT:[A-Za-z0-9]+]]:
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG:     %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C0]]
+// CHECK-DAG:     %[[DIM_1:.+]] = tensor.dim %[[ARG1]], %[[C1]]
+// CHECK:         %{{.+}} = scf.for %[[K:.+]] = %[[C0]] to %[[DIM_0]] step %[[C2]]
+// CHECK-DAG:       %[[OUT_K_SZ:.+]] = affine.min #[[MAP0]](%[[K]])[%[[DIM_0]]]
+// CHECK:           %{{.+}} = scf.for %[[C:.+]] = %[[C0]] to %[[DIM_1]] step %[[C4]]
+// CHECK-DAG:         %[[OUT_C_SZ:.+]] = affine.min #[[MAP1]](%[[C]])[%[[DIM_1]]]
+// CHECK-DAG:         %[[IN_K:.+]] = affine.apply #[[MAP2]](%[[K]])
+// CHECK-DAG:         %[[IN_C:.+]] = affine.apply #[[MAP2]](%[[C]])
+// CHECK-DAG:         %[[IN_C_SZ:.+]] = affine.apply #[[MAP3]](%[[OUT_C_SZ]])
+// CHECK:             %[[IN_SLICE:.+]] = tensor.extract_slice %[[IN]]
+// CHECK:               [%[[IN_C]], %[[IN_K]], 0, 0] [%[[IN_C_SZ]], 1, 2, 2]
+// CHECK:             %[[ITER_SLICE:.+]] = tensor.extract_slice %{{.+}}[%[[K]], %[[C]]] [%[[OUT_K_SZ]], %[[OUT_C_SZ]]]
+// CHECK:             %[[UNPACK:.+]] = iree_linalg_ext.unpack
+// CHECK-SAME:          {__internal_linalg_transform__ = "tiling_pack_output"
+// CHECK-SAME:          %[[IN_SLICE]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [2, 2]
+// CHECK-SAME:          into %[[ITER_SLICE]]
+// CHECK:             %[[RES:.+]] = tensor.insert_slice %[[UNPACK]]
+// CHECK-SAME:          into %{{.+}}[%[[K]], %[[C]]] [%[[OUT_K_SZ]], %[[OUT_C_SZ]]]
+// CHECK:             scf.yield %[[RES]]
+
+// -----
+
+func.func @winograd_input_transform(%arg0: tensor<1x10x10x1280xf32>) -> tensor<8x8x1x2x2x1280xf32> {
+  %0 = tensor.empty() : tensor<8x8x1x2x2x1280xf32>
+  %1 = iree_linalg_ext.winograd.input_transform {__internal_linalg_transform__ = "tiling_winograd_input_nhwc"}
+    output_tile_size(6) kernel_size(3) image_dimensions([1, 2])
+    ins(%arg0 : tensor<1x10x10x1280xf32>) outs(%0 : tensor<8x8x1x2x2x1280xf32>) -> tensor<8x8x1x2x2x1280xf32>
+  return %1 : tensor<8x8x1x2x2x1280xf32>
+}
+// CHECK-DAG:  #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (1, -d0 + s1)>
+// CHECK-DAG:  #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (32, -d0 + s1)>
+// CHECK:      func.func @winograd_input_transform(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x10x10x1280xf32>) ->
+// CHECK-SAME:   tensor<8x8x1x2x2x1280xf32> {
+// CHECK:        %[[C0:.+]] = arith.constant 0 : index
+// CHECK:        %[[C1:.+]] = arith.constant 1 : index
+// CHECK:        %[[C1280:.+]] = arith.constant 1280 : index
+// CHECK:        %[[C32:.+]] = arith.constant 32 : index
+// CHECK:        %[[D0:.+]] = tensor.empty() : tensor<8x8x1x2x2x1280xf32>
+// CHECK:        %[[D1:.+]] = scf.for %[[ARG1:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C1]] step %[[C1]]
+// CHECK-SAME:     iter_args(%[[ARG2:[a-zA-Z0-9_]+]] = %[[D0]]) -> (tensor<8x8x1x2x2x1280xf32>) {
+// CHECK-DAG:        %[[D2:.+]] = affine.min #[[MAP]](%[[ARG1]])[%[[C1]], %[[C1]]]
+// CHECK:          %[[D3:.+]] = scf.for %[[ARG3:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C1280]] step %[[C32]]
+// CHECK-SAME:       iter_args(%[[ARG4:[a-zA-Z0-9_]+]] = %[[ARG2]]) -> (tensor<8x8x1x2x2x1280xf32>) {
+// CHECK-DAG:          %[[D4:.+]] = affine.min #[[MAP1]](%[[ARG3]])[%[[C32]], %[[C1280]]]
+// CHECK:            %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[ARG1]], 0, 0, %[[ARG3]]] [%[[D2]], 10,
+// CHECK-SAME:         10, %[[D4]]] [1, 1, 1, 1] : tensor<1x10x10x1280xf32> to tensor<?x10x10x?xf32>
+// CHECK:            %[[EXTRACTED_SLICE_0:.+]] = tensor.extract_slice %[[D0]][0, 0, %[[ARG1]], 0, 0, %[[ARG3]]] [8, 8,
+// CHECK-SAME:         %[[D2]], 2, 2, %[[D4]]] [1, 1, 1, 1, 1, 1] : tensor<8x8x1x2x2x1280xf32> to
+// CHECK-SAME:         tensor<8x8x?x2x2x?xf32>
+// CHECK:            %[[D5:.+]] = iree_linalg_ext.winograd.input_transform output_tile_size(6) kernel_size(3)
+// CHECK-SAME:         image_dimensions([1, 2]) ins(%[[EXTRACTED_SLICE]] : tensor<?x10x10x?xf32>)
+// CHECK-SAME:         outs(%[[EXTRACTED_SLICE]]_0 : tensor<8x8x?x2x2x?xf32>) -> tensor<8x8x?x2x2x?xf32>
+// CHECK:            %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D5]] into %[[ARG4]][0, 0, %[[ARG1]], 0, 0,
+// CHECK-SAME:         %[[ARG3]]] [8, 8, %[[D2]], 2, 2, %[[D4]]] [1, 1, 1, 1, 1, 1] : tensor<8x8x?x2x2x?xf32> into
+// CHECK-SAME:         tensor<8x8x1x2x2x1280xf32>
+// CHECK:            scf.yield %[[INSERTED_SLICE]] : tensor<8x8x1x2x2x1280xf32>
+// CHECK:          }
+// CHECK:          scf.yield %[[D3]] : tensor<8x8x1x2x2x1280xf32>
+// CHECK:        }
+// CHECK:        return %[[D1]] : tensor<8x8x1x2x2x1280xf32>
+// CHECK:      }
+
+// -----
+
+func.func @winograd_input_transform_memref(%arg0: memref<1x10x10x1280xf32>, %arg1: memref<8x8x1x2x2x1280xf32>) {
+  iree_linalg_ext.winograd.input_transform {__internal_linalg_transform__ = "tiling_winograd_input_nhwc"}
+    output_tile_size(6) kernel_size(3) image_dimensions([1, 2])
+    ins(%arg0 : memref<1x10x10x1280xf32>) outs(%arg1 : memref<8x8x1x2x2x1280xf32>)
+  return
+}
+// CHECK-DAG:  #[[MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-DAG:  #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
+// CHECK-DAG:  #[[MAP2:.+]] = affine_map<(d0)[s0, s1] -> (1, -d0 + s1)>
+// CHECK-DAG:  #[[MAP3:.+]] = affine_map<(d0)[s0, s1] -> (32, -d0 + s1)>
+// CHECK:      func.func @winograd_input_transform_memref(%[[ARG0:[a-zA-Z0-9_]+]]: memref<1x10x10x1280xf32>,
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: memref<8x8x1x2x2x1280xf32>) {
+// CHECK:        %[[C0:.+]] = arith.constant 0 : index
+// CHECK:        %[[C1:.+]] = arith.constant 1 : index
+// CHECK:        %[[C1280:.+]] = arith.constant 1280 : index
+// CHECK:        %[[C32:.+]] = arith.constant 32 : index
+// CHECK:        scf.for %[[ARG2:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C1]] step %[[C1]] {
+// CHECK-DAG:        %[[D0:.+]] = affine.min #[[MAP2]](%[[ARG2]])[%[[C1]], %[[C1]]]
+// CHECK:          scf.for %[[ARG3:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C1280]] step %[[C32]] {
+// CHECK-DAG:          %[[D1:.+]] = affine.min #[[MAP3]](%[[ARG3]])[%[[C32]], %[[C1280]]]
+// CHECK:            %[[SUBVIEW:.+]] = memref.subview %[[ARG0]][%[[ARG2]], 0, 0, %[[ARG3]]] [%[[D0]], 10, 10, %[[D1]]]
+// CHECK-SAME:         [1, 1, 1, 1] : memref<1x10x10x1280xf32> to memref<?x10x10x?xf32, strided<[128000, 12800, 1280,
+// CHECK-SAME:         1], offset: ?>>
+// CHECK:            %[[SUBVIEW_0:.+]] = memref.subview %[[ARG1]][0, 0, %[[ARG2]], 0, 0, %[[ARG3]]] [8, 8, %[[D0]], 2,
+// CHECK-SAME:         2, %[[D1]]] [1, 1, 1, 1, 1, 1] : memref<8x8x1x2x2x1280xf32> to memref<8x8x?x2x2x?xf32,
+// CHECK-SAME:         strided<[40960, 5120, 5120, 2560, 1280, 1], offset: ?>>
+// CHECK:            iree_linalg_ext.winograd.input_transform output_tile_size(6) kernel_size(3) image_dimensions([1,
+// CHECK-SAME:         2]) ins(%[[SUBVIEW]] : memref<?x10x10x?xf32, strided<[128000, 12800, 1280, 1], offset: ?>>)
+// CHECK-SAME:         outs(%[[SUBVIEW]]_0 : memref<8x8x?x2x2x?xf32, strided<[40960, 5120, 5120, 2560, 1280, 1], offset:
+// CHECK-SAME:         ?>>)
+// CHECK:          }
+// CHECK:        }
+// CHECK:        return
+// CHECK:      }
+
+// -----
+
+func.func @winograd_output_transform(%arg0: tensor<8x8x1x2x2x32xf32>) -> tensor<1x12x12x32xf32> {
+  %0 = tensor.empty() : tensor<1x12x12x32xf32>
+  %1 = iree_linalg_ext.winograd.output_transform {__internal_linalg_transform__ = "tiling_winograd_input_nhwc"}
+        output_tile_size(6) kernel_size(3) image_dimensions([1, 2])
+        ins(%arg0 : tensor<8x8x1x2x2x32xf32>) outs(%0 : tensor<1x12x12x32xf32>) -> tensor<1x12x12x32xf32>
+  return %1 : tensor<1x12x12x32xf32>
+}
+// CHECK-DAG:  #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (1, -d0 + s1)>
+// CHECK-DAG:  #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (32, -d0 + s1)>
+// CHECK:      func.func @winograd_output_transform(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<8x8x1x2x2x32xf32>) ->
+// CHECK-SAME:   tensor<1x12x12x32xf32> {
+// CHECK:        %[[C0:.+]] = arith.constant 0 : index
+// CHECK:        %[[C1:.+]] = arith.constant 1 : index
+// CHECK:        %[[C32:.+]] = arith.constant 32 : index
+// CHECK:        %[[D0:.+]] = tensor.empty() : tensor<1x12x12x32xf32>
+// CHECK:        %[[D1:.+]] = scf.for %[[ARG1:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C1]] step %[[C1]]
+// CHECK-SAME:     iter_args(%[[ARG2:[a-zA-Z0-9_]+]] = %[[D0]]) -> (tensor<1x12x12x32xf32>) {
+// CHECK-DAG:        %[[D2:.+]] = affine.min #[[MAP]](%[[ARG1]])[%[[C1]], %[[C1]]]
+// CHECK:          %[[D3:.+]] = scf.for %[[ARG3:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C32]] step %[[C32]]
+// CHECK-SAME:       iter_args(%[[ARG4:[a-zA-Z0-9_]+]] = %[[ARG2]]) -> (tensor<1x12x12x32xf32>) {
+// CHECK-DAG:          %[[D4:.+]] = affine.min #[[MAP1]](%[[ARG3]])[%[[C32]], %[[C32]]]
+// CHECK:            %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[ARG0]][0, 0, %[[ARG1]], 0, 0, %[[ARG3]]] [8, 8,
+// CHECK-SAME:         %[[D2]], 2, 2, %[[D4]]] [1, 1, 1, 1, 1, 1] : tensor<8x8x1x2x2x32xf32> to tensor<8x8x?x2x2x?xf32>
+// CHECK:            %[[EXTRACTED_SLICE_0:.+]] = tensor.extract_slice %[[D0]][%[[ARG1]], 0, 0, %[[ARG3]]] [%[[D2]], 12,
+// CHECK-SAME:         12, %[[D4]]] [1, 1, 1, 1] : tensor<1x12x12x32xf32> to tensor<?x12x12x?xf32>
+// CHECK:            %[[D5:.+]] = iree_linalg_ext.winograd.output_transform output_tile_size(6) kernel_size(3)
+// CHECK-SAME:         image_dimensions([1, 2]) ins(%[[EXTRACTED_SLICE]] : tensor<8x8x?x2x2x?xf32>)
+// CHECK-SAME:         outs(%[[EXTRACTED_SLICE]]_0 : tensor<?x12x12x?xf32>) -> tensor<?x12x12x?xf32>
+// CHECK:            %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D5]] into %[[ARG4]][%[[ARG1]], 0, 0, %[[ARG3]]]
+// CHECK-SAME:         [%[[D2]], 12, 12, %[[D4]]] [1, 1, 1, 1] : tensor<?x12x12x?xf32> into tensor<1x12x12x32xf32>
+// CHECK:            scf.yield %[[INSERTED_SLICE]] : tensor<1x12x12x32xf32>
+// CHECK:          }
+// CHECK:          scf.yield %[[D3]] : tensor<1x12x12x32xf32>
+// CHECK:        }
+// CHECK:        return %[[D1]] : tensor<1x12x12x32xf32>
+// CHECK:      }
+
+// -----
+
+func.func @winograd_output_transform_memref(%arg0: memref<8x8x1x2x2x32xf32>, %arg1: memref<1x12x12x32xf32>) {
+  iree_linalg_ext.winograd.output_transform {__internal_linalg_transform__ = "tiling_winograd_input_nhwc"}
+   output_tile_size(6) kernel_size(3) image_dimensions([1, 2])
+   ins(%arg0 : memref<8x8x1x2x2x32xf32>) outs(%arg1 : memref<1x12x12x32xf32>)
+  return
+}
+// CHECK-DAG:  #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
+// CHECK-DAG:  #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-DAG:  #[[MAP2:.+]] = affine_map<(d0)[s0, s1] -> (1, -d0 + s1)>
+// CHECK-DAG:  #[[MAP3:.+]] = affine_map<(d0)[s0, s1] -> (32, -d0 + s1)>
+// CHECK:      func.func @winograd_output_transform_memref(%[[ARG0:[a-zA-Z0-9_]+]]: memref<8x8x1x2x2x32xf32>,
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: memref<1x12x12x32xf32>) {
+// CHECK:        %[[C0:.+]] = arith.constant 0 : index
+// CHECK:        %[[C1:.+]] = arith.constant 1 : index
+// CHECK:        %[[C32:.+]] = arith.constant 32 : index
+// CHECK:        scf.for %[[ARG2:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C1]] step %[[C1]] {
+// CHECK-DAG:        %[[D0:.+]] = affine.min #[[MAP2]](%[[ARG2]])[%[[C1]], %[[C1]]]
+// CHECK:          scf.for %[[ARG3:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C32]] step %[[C32]] {
+// CHECK-DAG:          %[[D1:.+]] = affine.min #[[MAP3]](%[[ARG3]])[%[[C32]], %[[C32]]]
+// CHECK:            %[[SUBVIEW:.+]] = memref.subview %[[ARG0]][0, 0, %[[ARG2]], 0, 0, %[[ARG3]]] [8, 8, %[[D0]], 2, 2,
+// CHECK-SAME:         %[[D1]]] [1, 1, 1, 1, 1, 1] : memref<8x8x1x2x2x32xf32> to memref<8x8x?x2x2x?xf32, strided<[1024,
+// CHECK-SAME:         128, 128, 64, 32, 1], offset: ?>>
+// CHECK:            %[[SUBVIEW_0:.+]] = memref.subview %[[ARG1]][%[[ARG2]], 0, 0, %[[ARG3]]] [%[[D0]], 12, 12, %[[D1]]]
+// CHECK-SAME:         [1, 1, 1, 1] : memref<1x12x12x32xf32> to memref<?x12x12x?xf32, strided<[4608, 384, 32, 1],
+// CHECK-SAME:         offset: ?>>
+// CHECK:            iree_linalg_ext.winograd.output_transform output_tile_size(6) kernel_size(3) image_dimensions([1,
+// CHECK-SAME:         2]) ins(%[[SUBVIEW]] : memref<8x8x?x2x2x?xf32, strided<[1024, 128, 128, 64, 32, 1], offset: ?>>)
+// CHECK-SAME:         outs(%[[SUBVIEW]]_0 : memref<?x12x12x?xf32, strided<[4608, 384, 32, 1], offset: ?>>)
+// CHECK:          }
+// CHECK:        }
+// CHECK:        return
+// CHECK:      }
+
+// -----
diff --git a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/vectorization.mlir b/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/vectorization.mlir
index 8490a4f..17671aa 100644
--- a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/vectorization.mlir
+++ b/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/vectorization.mlir
@@ -183,13 +183,12 @@
 // CHECK-SAME:    %[[OUT:[A-Za-z0-9]+]]:
 // CHECK-DAG:     %[[C0:.+]] = arith.constant 0 : index
 // CHECK-DAG:     %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32
-// CHECK-DAG:     %[[EMPTY:.+]] = tensor.empty() : tensor<1x1x32x8xf32>
 // CHECK:         %[[READ:.+]] = vector.transfer_read %[[IN]]
 // CHECK-SAME:      [%[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[ZERO]]
 // CHECK-SAME:      {in_bounds = [true, true]} : tensor<1x1x1x1x8x32xf32>, vector<8x32xf32>
 // CHECK:         %[[TRANSP:.+]] = vector.transpose %[[READ]], [1, 0]
 // CHECK:         %[[WRITE:.+]] = vector.transfer_write %[[TRANSP]]
-// CHECK-SAME:      %[[EMPTY]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]]
+// CHECK-SAME:      %[[OUT]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]]
 // CHECK-SAME:      {in_bounds = [true, true]} : vector<32x8xf32>, tensor<1x1x32x8xf32>
 // CHECK:         return %[[WRITE]]
 
@@ -241,9 +240,7 @@
   return %0 : tensor<1x1x128x64xf32>
 }
 // CHECK-DAG:   #[[MAP0:.+]] = affine_map<(d0) -> (d0 floordiv 32)>
-// CHECK-DAG:   #[[MAP1:.+]] = affine_map<(d0) -> (d0 mod 32)>
-// CHECK-DAG:   #[[MAP2:.+]] = affine_map<(d0) -> (d0 floordiv 8)>
-// CHECK-DAG:   #[[MAP3:.+]] = affine_map<(d0) -> (d0 mod 8)>
+// CHECK-DAG:   #[[MAP1:.+]] = affine_map<(d0) -> (d0 floordiv 8)>
 // CHECK-LABEL: func.func @KCRSsr_to_KCRS
 // CHECK-SAME:    %[[IN:[A-Za-z0-9]+]]:
 // CHECK-SAME:    %[[OUT:[A-Za-z0-9]+]]:
@@ -258,20 +255,17 @@
 // CHECK:           %[[RES1:.+]] = scf.for %[[S:.+]] = %[[C0]] to %[[C64]] step %[[C8]]
 // CHECK-SAME:        iter_args(%[[ITER1:.+]] = %[[ITER0]])
 // CHECK-DAG:         %[[IN_R:.+]] = affine.apply #[[MAP0]](%[[R]])
-// CHECK-DAG:         %[[OUT_R:.+]] = affine.apply #[[MAP1]](%[[R]])
-// CHECK-DAG:         %[[IN_S:.+]] = affine.apply #[[MAP2]](%[[S]])
-// CHECK-DAG:         %[[OUT_S:.+]] = affine.apply #[[MAP3]](%[[S]])
-// CHECK-DAG:         %[[EMPTY:.+]] = tensor.empty() : tensor<1x1x32x8xf32>
+// CHECK-DAG:         %[[IN_S:.+]] = affine.apply #[[MAP1]](%[[S]])
+// CHECK-DAG:         %[[ITER1_SLICE:.+]] = tensor.extract_slice %[[ITER1]]
+// CHECK-SAME:          [0, 0, %[[R]], %[[S]]] [1, 1, 32, 8] [1, 1, 1, 1]
 // CHECK:             %[[READ:.+]] = vector.transfer_read %[[IN]]
 // CHECK-SAME:          [%[[C0]], %[[C0]], %[[IN_R]], %[[IN_S]], %[[C0]], %[[C0]]], %[[ZERO]]
 // CHECK-SAME:          {in_bounds = [true, true]} : tensor<1x1x4x8x8x32xf32>, vector<8x32xf32>
 // CHECK:             %[[TRANSP:.+]] = vector.transpose %[[READ]], [1, 0]
 // CHECK:             %[[WRITE:.+]] = vector.transfer_write %[[TRANSP]]
-// CHECK-SAME:          %[[EMPTY]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]]
+// CHECK-SAME:          %[[ITER1_SLICE]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]]
 // CHECK-SAME:          {in_bounds = [true, true]} : vector<32x8xf32>, tensor<1x1x32x8xf32>
-// CHECK:             %[[WRITE_SLICE:.+]] = tensor.extract_slice %[[WRITE]]
-// CHECK-SAME:          [0, 0, %[[OUT_R]], %[[OUT_S]]] [1, 1, 32, 8] [1, 1, 1, 1]
-// CHECK:             %[[INSERT:.+]] = tensor.insert_slice %[[WRITE_SLICE]]
+// CHECK:             %[[INSERT:.+]] = tensor.insert_slice %[[WRITE]]
 // CHECK-SAME:          into %[[ITER1]][0, 0, %[[R]], %[[S]]] [1, 1, 32, 8] [1, 1, 1, 1]
 // CHECK:             scf.yield %[[INSERT]]
 // CHECK:           }
@@ -288,9 +282,7 @@
 //CHECK-DAG:    #[[MAP0:.+]] = affine_map<(d0) -> (-d0 + 13, 8)>
 //CHECK-DAG:    #[[MAP1:.+]] = affine_map<(d0) -> (-d0 + 15, 2)>
 //CHECK-DAG:    #[[MAP2:.+]] = affine_map<(d0) -> (d0 floordiv 8)>
-//CHECK-DAG:    #[[MAP3:.+]] = affine_map<(d0) -> (d0 mod 8)>
-//CHECK-DAG:    #[[MAP4:.+]] = affine_map<(d0) -> (d0 floordiv 2)>
-//CHECK-DAG:    #[[MAP5:.+]] = affine_map<(d0) -> (d0 mod 2)>
+//CHECK-DAG:    #[[MAP3:.+]] = affine_map<(d0) -> (d0 floordiv 2)>
 // CHECK-LABEL: func.func @unpack_and_extract_slice
 // CHECK-SAME:    %[[IN:[A-Za-z0-9]+]]:
 // CHECK-SAME:    %[[OUT:[A-Za-z0-9]+]]:
@@ -307,20 +299,21 @@
 // CHECK-SAME:        iter_args(%[[ITER1:.+]] = %[[ITER0]])
 // CHECK-DAG:         %[[OUT_J_SZ:.+]] = affine.min #[[MAP1]](%[[J]])
 // CHECK-DAG:         %[[IN_I:.+]] = affine.apply #[[MAP2]](%[[I]])
-// CHECK-DAG:         %[[OUT_I:.+]] = affine.apply #[[MAP3]](%[[I]])
-// CHECK-DAG:         %[[IN_J:.+]] = affine.apply #[[MAP4]](%[[J]])
-// CHECK-DAG:         %[[OUT_J:.+]] = affine.apply #[[MAP5]](%[[J]])
-// CHECK-DAG:         %[[EMPTY:.+]] = tensor.empty() : tensor<8x2xf32>
+// CHECK-DAG:         %[[IN_J:.+]] = affine.apply #[[MAP3]](%[[J]])
+// CHECK-DAG:         %[[ITER1_SLICE1:.+]] = tensor.extract_slice %[[ITER1]]
+// CHECK-SAME:          [%[[I]], %[[J]]] [%[[OUT_I_SZ]], %[[OUT_J_SZ]]] [1, 1]
 // CHECK-DAG:         %[[READ:.+]] = vector.transfer_read %[[IN]]
 // CHECK-SAME:          [%[[IN_I]], %[[IN_J]], %[[C0]], %[[C0]]], %[[ZERO]]
 // CHECK-SAME:          {in_bounds = [true, true]} : tensor<2x8x8x2xf32>, vector<8x2xf32>
+// CHECK-DAG:         %[[ITER1_SLICE2:.+]] = tensor.extract_slice %[[ITER1_SLICE1]]
+// CHECK-SAME:          [0, 0] [%[[OUT_I_SZ]], %[[OUT_J_SZ]]] [1, 1]
 // CHECK:             %[[WRITE:.+]] = vector.transfer_write %[[READ]]
-// CHECK-SAME:          %[[EMPTY]][%[[C0]], %[[C0]]] {in_bounds = [true, true]}
-// CHECK:             %[[WRITE_SLICE:.+]] = tensor.extract_slice %[[WRITE]]
-// CHECK-SAME:          [%[[OUT_I]], %[[OUT_J]]] [%[[OUT_I_SZ]], %[[OUT_J_SZ]]] [1, 1]
-// CHECK:             %[[INSERT:.+]] = tensor.insert_slice %[[WRITE_SLICE]]
+// CHECK-SAME:          %[[ITER1_SLICE2]][%[[C0]], %[[C0]]]
+// CHECK:             %[[INSERT1:.+]] = tensor.insert_slice %[[WRITE]]
+// CHECK-SAME:          into %[[ITER1_SLICE1]][0, 0] [%[[OUT_I_SZ]], %[[OUT_J_SZ]]] [1, 1]
+// CHECK:             %[[INSERT2:.+]] = tensor.insert_slice %[[INSERT1]]
 // CHECK-SAME:          into %[[ITER1]][%[[I]], %[[J]]] [%[[OUT_I_SZ]], %[[OUT_J_SZ]]] [1, 1]
-// CHECK:             scf.yield %[[INSERT]]
+// CHECK:             scf.yield %[[INSERT2]]
 // CHECK:           }
 // CHECK:           scf.yield %[[RES1]]
 // CHECK:         }
@@ -333,9 +326,7 @@
   return %0 : tensor<128x256xf32>
 }
 //CHECK-DAG:    #[[MAP0:.+]] = affine_map<(d0) -> (d0 floordiv 32)>
-//CHECK-DAG:    #[[MAP1:.+]] = affine_map<(d0) -> (d0 mod 32)>
-//CHECK-DAG:    #[[MAP2:.+]] = affine_map<(d0) -> (d0 floordiv 8)>
-//CHECK-DAG:    #[[MAP3:.+]] = affine_map<(d0) -> (d0 mod 8)>
+//CHECK-DAG:    #[[MAP1:.+]] = affine_map<(d0) -> (d0 floordiv 8)>
 // CHECK-LABEL: func.func @CKck_to_KC
 // CHECK-SAME:    %[[IN:[A-Za-z0-9]+]]:
 // CHECK-SAME:    %[[OUT:[A-Za-z0-9]+]]:
@@ -350,20 +341,13 @@
 // CHECK:           %[[RES1:.+]] = scf.for %[[C:.+]] = %[[C0]] to %[[C256]] step %[[C8]]
 // CHECK-SAME:        iter_args(%[[ITER1:.+]] = %[[ITER0]])
 // CHECK-DAG:         %[[IN_K:.+]] = affine.apply #[[MAP0]](%[[K]])
-// CHECK-DAG:         %[[OUT_K:.+]] = affine.apply #[[MAP1]](%[[K]])
-// CHECK-DAG:         %[[IN_C:.+]] = affine.apply #[[MAP2]](%[[C]])
-// CHECK-DAG:         %[[OUT_C:.+]] = affine.apply #[[MAP3]](%[[C]])
-// CHECK-DAG:         %[[EMPTY:.+]] = tensor.empty() : tensor<32x8xf32>
+// CHECK-DAG:         %[[IN_C:.+]] = affine.apply #[[MAP1]](%[[C]])
 // CHECK-DAG:         %[[READ:.+]] = vector.transfer_read %[[IN]]
 // CHECK-SAME:          [%[[IN_C]], %[[IN_K]], %[[C0]], %[[C0]]], %[[ZERO]]
 // CHECK-SAME:          {in_bounds = [true, true]} : tensor<32x4x32x8xf32>, vector<32x8xf32>
 // CHECK:             %[[WRITE:.+]] = vector.transfer_write %[[READ]]
-// CHECK-SAME:          %[[EMPTY]][%[[C0]], %[[C0]]] {in_bounds = [true, true]}
-// CHECK:             %[[WRITE_SLICE:.+]] = tensor.extract_slice %[[WRITE]]
-// CHECK-SAME:          [%[[OUT_K]], %[[OUT_C]]] [32, 8] [1, 1]
-// CHECK:             %[[INSERT:.+]] = tensor.insert_slice %[[WRITE_SLICE]]
-// CHECK-SAME:          into %[[ITER1]][%[[K]], %[[C]]] [32, 8] [1, 1]
-// CHECK:             scf.yield %[[INSERT]]
+// CHECK-SAME:          %[[ITER1]][%[[K]], %[[C]]]
+// CHECK:             scf.yield %[[WRITE]]
 // CHECK:           }
 // CHECK:           scf.yield %[[RES1]]
 // CHECK:         }