NFC: refresh external dialects under integrations/tensorflow (#11426)
diff --git a/integrations/tensorflow/iree-dialects/include/iree-dialects-c/Utils.h b/integrations/tensorflow/iree-dialects/include/iree-dialects-c/Utils.h
deleted file mode 100644
index 6132ee9..0000000
--- a/integrations/tensorflow/iree-dialects/include/iree-dialects-c/Utils.h
+++ /dev/null
@@ -1,26 +0,0 @@
-// Copyright 2021 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#ifndef IREE_DIALECTS_C_UTILS_H
-#define IREE_DIALECTS_C_UTILS_H
-
-#include "mlir-c/IR.h"
-
-#ifdef __cplusplus
-extern "C" {
-#endif
-
-// TODO: Upstream C/Python APIs for symbol table.
-// Looks up the referrent operation with the given flat symbol, starting from
-// a specific op.
-MLIR_CAPI_EXPORTED MlirOperation
-ireeLookupNearestSymbolFrom(MlirOperation fromOp, MlirAttribute symbolRefAttr);
-
-#ifdef __cplusplus
-}
-#endif
-
-#endif // IREE_DIALECTS_C_UTILS_H
diff --git a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtBase.td b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtBase.td
index afb6c3a..c0530fe 100644
--- a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtBase.td
+++ b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtBase.td
@@ -45,19 +45,28 @@
: AttrDef<IREELinalgExt_Dialect, name, traits>;
// List of pre-defined data layout encoding attributes.
-def GEMM_LHS
- : I32EnumAttrCase<"GEMM_LHS", 0>;
-def GEMM_RESULT
- : I32EnumAttrCase<"GEMM_RESULT", 1>;
-def GEMM_RHS
- : I32EnumAttrCase<"GEMM_RHS", 2>;
-def GEMM_RHS_TRANSPOSE
- : I32EnumAttrCase<"GEMM_RHS_TRANSPOSE", 3>;
+def MATMUL_F32F32F32_LHS
+ : I32EnumAttrCase<"MATMUL_F32F32F32_LHS", 0>;
+def MATMUL_F32F32F32_RHS
+ : I32EnumAttrCase<"MATMUL_F32F32F32_RHS", 1>;
+def MATMUL_F32F32F32_RHS_TRANSPOSE
+ : I32EnumAttrCase<"MATMUL_F32F32F32_RHS_TRANSPOSE", 2>;
+def MATMUL_F32F32F32_RESULT
+ : I32EnumAttrCase<"MATMUL_F32F32F32_RESULT", 3>;
+def MATMUL_I8I8I32_LHS
+ : I32EnumAttrCase<"MATMUL_I8I8I32_LHS", 4>;
+def MATMUL_I8I8I32_RHS
+ : I32EnumAttrCase<"MATMUL_I8I8I32_RHS", 5>;
+def MATMUL_I8I8I32_RHS_TRANSPOSE
+ : I32EnumAttrCase<"MATMUL_I8I8I32_RHS_TRANSPOSE", 6>;
+def MATMUL_I8I8I32_RESULT
+ : I32EnumAttrCase<"MATMUL_I8I8I32_RESULT", 7>;
def TensorEncodingEnum
: I32EnumAttr<"TensorEncoding",
"identifier for encoding used for the tensor",[
- GEMM_LHS, GEMM_RESULT, GEMM_RHS, GEMM_RHS_TRANSPOSE
+ MATMUL_F32F32F32_LHS, MATMUL_F32F32F32_RHS, MATMUL_F32F32F32_RHS_TRANSPOSE, MATMUL_F32F32F32_RESULT,
+ MATMUL_I8I8I32_LHS, MATMUL_I8I8I32_RHS, MATMUL_I8I8I32_RHS_TRANSPOSE, MATMUL_I8I8I32_RESULT,
]> {
let cppNamespace = "::mlir::iree_compiler::IREE::LinalgExt";
let genSpecializedAttr = 0;
diff --git a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.td b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.td
index 502d207..25733c7 100644
--- a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.td
+++ b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.td
@@ -588,8 +588,7 @@
(`outer_dims_perm` `=` $outer_dims_perm^)?
`inner_dims_pos` `=` $inner_dims_pos
`inner_tiles` `=`
- custom<DynamicIndexList>($inner_tiles, $static_inner_tiles,
- "ShapedType::kDynamic")
+ custom<DynamicIndexList>($inner_tiles, $static_inner_tiles)
`into` $outputs `:` `(` type($inputs) type($outputs) `)`
(`->` type($results)^)?
}];
@@ -693,7 +692,8 @@
"getLoopIteratorTypes",
"generateScalarImplementation",
"getResultTilePosition",
- "getTiledImplementation"]>,
+ "getTiledImplementation",
+ "generateResultTileValue"]>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
]>{
let summary = "unpack operation";
@@ -740,8 +740,7 @@
(`outer_dims_perm` `=` $outer_dims_perm^)?
`inner_dims_pos` `=` $inner_dims_pos
`inner_tiles` `=`
- custom<DynamicIndexList>($inner_tiles, $static_inner_tiles,
- "ShapedType::kDynamic")
+ custom<DynamicIndexList>($inner_tiles, $static_inner_tiles)
`into` $outputs `:` `(` type($inputs) type($outputs) `)`
(`->` type($results)^)?
}];
@@ -887,6 +886,188 @@
}
//===----------------------------------------------------------------------===//
+// Winograd ops
+//===----------------------------------------------------------------------===//
+
+def IREELinalgExt_WinogradInputTransformOp : IREELinalgExt_Op<"winograd.input_transform",
+ [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+ DeclareOpInterfaceMethods<TilingInterface,
+ ["getIterationDomain",
+ "getLoopIteratorTypes",
+ "getResultTilePosition",
+ "getTiledImplementation"]>]> {
+ let summary = "Winograd Input Transform operator";
+ let description = [{
+ This operator is the first step in converting a convolution to
+ its Winograd equivalent. Given a tile of an input image (I),
+ this operator computes matmul(tranpose(B), matmul(I, B)).
+ The input tile is assumed to be square with each side of size m + r - 1,
+ where the convolutional kernel is m x m and the output tile size is r x r.
+ B is a constant 2-d square matrix of the same shape as the input tile I.
+ The input to the operator is an image of shape (N, H, W, C) and the
+ output is an operator of shape (m + r - 1, m + r - 1, N, H', W', C)
+ where H' = ceil((H - m + 1)/r) and W' = ceil((W - m + 1)/r). The result
+ of this operator is first collapsed and then fed to a batch matmul op.
+ }];
+
+ let arguments = (ins Variadic<AnyShaped>:$inputs,
+ Variadic<AnyShaped>:$outputs,
+ I64Attr:$output_tile_size,
+ I64Attr:$kernel_size,
+ DenseI64ArrayAttr:$image_dimensions
+ );
+
+ let builders = [
+ OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs,
+ CArg<"int64_t", "8">:$output_tile_size, CArg<"int64_t", "3">:$kernel_size,
+ CArg<"ArrayRef<int64_t>", "{1, 2}">:$image_dimensions)>
+ ];
+
+ let results = (outs Variadic<AnyRankedTensor>:$result);
+ let hasFolder = 1;
+ let assemblyFormat = [{
+ attr-dict
+ `output_tile_size` `(` $output_tile_size `)`
+ `kernel_size` `(` $kernel_size `)`
+ `image_dimensions` `(` $image_dimensions `)`
+ `ins` `(` $inputs `:` type($inputs) `)`
+ `outs` `(` $outputs `:` type($outputs) `)`
+ (`->` type($result)^)?
+ }];
+
+ let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{
+ Value input() {
+ return getInputOperand(0)->get();
+ }
+ Value output() {
+ return getOutputOperand(0)->get();
+ }
+ ShapedType getInputOperandType() {
+ return input().getType().cast<ShapedType>();
+ }
+ ShapedType getOutputOperandType() {
+ return output().getType().cast<ShapedType>();
+ }
+ int64_t getInputOperandRank() {
+ return getInputOperandType().getRank();
+ }
+ int64_t getOutputOperandRank() {
+ return getOutputOperandType().getRank();
+ }
+ int64_t getInputTileSize() {
+ return getOutputTileSize() + getKernelSize() - 1;
+ }
+ SmallVector<int64_t> imageDimensions() {
+ return llvm::to_vector(getImageDimensions());
+ }
+ int64_t getIterationDomainRank() {
+ SmallVector<int64_t> imageDims = imageDimensions();
+ return getInputOperandRank() - imageDims.size();
+ }
+ // Method to implement for specifying output range for
+ // DestinationStyleOpInterface
+ std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
+ std::pair<unsigned, unsigned> outputsIndexAndLength =
+ getODSOperandIndexAndLength(1);
+ return std::make_pair<int64_t, int64_t>(
+ outputsIndexAndLength.first,
+ outputsIndexAndLength.first + outputsIndexAndLength.second);
+ }
+ }];
+}
+
+def IREELinalgExt_WinogradOutputTransformOp : IREELinalgExt_Op<"winograd.output_transform",
+ [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+ DeclareOpInterfaceMethods<TilingInterface,
+ ["getIterationDomain",
+ "getLoopIteratorTypes",
+ "getResultTilePosition",
+ "getTiledImplementation"]>]> {
+ let summary = "Winograd Output Transform operator";
+ let description = [{
+ This operator is the last transform in converting a convolution to
+ its Winograd equivalent. After convolution in the Winograd domain
+ (which turns into an elementwise product for a single channel and
+ batch matrix multiplication for many channels), this operator converts
+ the output back into the original domain. Given a tile of the
+ output (O) in the Winograd domain, this operator computes
+ matmul(transpose(A), matmul(O, A)). The output tile is square with
+ each side of size m + r - 1, where the convolutional kernel is m x m
+ and the output tile size is r x r. A is a constant 2-d matrix of
+ shape (m + r - 1) x r. The input to the operator is a tensor of
+ shape (m + r - 1, m + r - 1, N, H', W', C) and the output is a
+ tensor of shape (N, H, W, C) where H = r H' and W = r W'. This operator
+ is followed by a tensor.extract_slice which extracts only the non-padded
+ part of the output.
+ }];
+
+ let arguments = (ins Variadic<AnyShaped>:$inputs,
+ Variadic<AnyShaped>:$outputs,
+ I64Attr:$output_tile_size,
+ I64Attr:$kernel_size,
+ DenseI64ArrayAttr:$image_dimensions
+ );
+
+ let builders = [
+ OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs,
+ CArg<"int64_t", "8">:$output_tile_size, CArg<"int64_t", "3">:$kernel_size,
+ CArg<"ArrayRef<int64_t>", "{1, 2}">:$image_dimensions)>
+ ];
+
+ let results = (outs Variadic<AnyRankedTensor>:$result);
+ let hasFolder = 1;
+ let assemblyFormat = [{
+ attr-dict
+ `output_tile_size` `(` $output_tile_size `)`
+ `kernel_size` `(` $kernel_size `)`
+ `image_dimensions` `(` $image_dimensions `)`
+ `ins` `(` $inputs `:` type($inputs) `)`
+ `outs` `(` $outputs `:` type($outputs) `)`
+ (`->` type($result)^)?
+ }];
+
+ let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{
+ Value input() {
+ return getInputOperand(0)->get();
+ }
+ Value output() {
+ return getOutputOperand(0)->get();
+ }
+ ShapedType getInputOperandType() {
+ return input().getType().cast<ShapedType>();
+ }
+ ShapedType getOutputOperandType() {
+ return output().getType().cast<ShapedType>();
+ }
+ SmallVector<int64_t> imageDimensions() {
+ return llvm::to_vector(getImageDimensions());
+ }
+ int64_t getInputOperandRank() {
+ return getInputOperandType().getRank();
+ }
+ int64_t getOutputOperandRank() {
+ return getOutputOperandType().getRank();
+ }
+ int64_t getIterationDomainRank() {
+ SmallVector<int64_t> imageDims = imageDimensions();
+ return getOutputOperandRank() - imageDims.size();
+ }
+ int64_t getInputTileSize() {
+ return getOutputTileSize() + getKernelSize() - 1;
+ }
+ // Method to implement for specifying output range for
+ // DestinationStyleOpInterface
+ std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
+ std::pair<unsigned, unsigned> outputsIndexAndLength =
+ getODSOperandIndexAndLength(1);
+ return std::make_pair<int64_t, int64_t>(
+ outputsIndexAndLength.first,
+ outputsIndexAndLength.first + outputsIndexAndLength.second);
+ }
+ }];
+}
+
+//===----------------------------------------------------------------------===//
// Pure ops
//===----------------------------------------------------------------------===//
diff --git a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h
index fd3664a..b715a4c 100644
--- a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h
+++ b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h
@@ -10,6 +10,7 @@
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/Pass/Pass.h"
namespace mlir {
@@ -93,7 +94,7 @@
SmallVector<int64_t> outerDimsPerm;
};
using MaterializeEncodingFn =
- std::function<FailureOr<MaterializeEncodingInfo>(TensorEncoding)>;
+ std::function<FailureOr<MaterializeEncodingInfo>(RankedTensorType)>;
/// TypeConverter to use for materializing the encoding.
struct MaterializeEncodingTypeConverter : public TypeConverter {
@@ -161,6 +162,16 @@
std::unique_ptr<OperationPass<func::FuncOp>> createLinalgExtVectorizationPass();
+/// Tile and decompose the winograd transform ops into a sequence
+/// of linalg ops.
+std::unique_ptr<OperationPass<func::FuncOp>>
+createTileAndDecomposeWinogradTransformPass();
+
+// Creates a pass to convert linalg convolution ops into a sequence of
+// linalg_ext.winograd.* ops and linalg.batch_matmul ops using the winograd
+// tranformation.
+std::unique_ptr<Pass> createConvertConv2DToWinogradPass();
+
// Marker used as attribute the depth of the split reduction transformations.
const StringLiteral kSplitReductionDepthMarker = "__split_reduction_depth__";
@@ -195,14 +206,14 @@
/// Create a LinalgStrategyTileAndFusePass.
std::unique_ptr<OperationPass<func::FuncOp>>
createLinalgStrategyTileAndFusePass(
- StringRef opName = "", const linalg::LinalgTilingAndFusionOptions &opt = {},
+ StringRef opName = "", const scf::SCFTileAndFuseOptions &options = {},
const LinalgExt::LinalgTransformationFilter &filter =
LinalgExt::LinalgTransformationFilter());
/// Create a LinalgStrategyTilePass.
std::unique_ptr<OperationPass<func::FuncOp>> createLinalgStrategyTilePass(
StringRef opName = "",
- const linalg::LinalgTilingOptions &opt = linalg::LinalgTilingOptions(),
+ const scf::SCFTilingOptions &options = scf::SCFTilingOptions(),
const LinalgExt::LinalgTransformationFilter &filter =
LinalgExt::LinalgTransformationFilter());
diff --git a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.td b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.td
index 51d9eef..badb658 100644
--- a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.td
+++ b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.td
@@ -82,6 +82,21 @@
"createLinalgExtVectorizationPass()";
}
+def TileAndDecomposeWinogradTransform :
+ Pass<"iree-linalg-ext-tile-and-decompose-winograd", "func::FuncOp"> {
+ let summary =
+ "Tiles and decomposes winograd transform ops into linalg ops";
+ let constructor = "mlir::iree_compiler::IREE::LinalgExt::"
+ "createTileAndDecomposeWinogradTransformPass()";
+}
+
+def ConvertConv2DToWinograd :
+ Pass<"iree-linalg-ext-convert-conv2d-to-winograd", ""> {
+ let summary = "Convert linalg convolution ops to winograd based implementation";
+ let constructor = "mlir::iree_compiler::IREE::LinalgExt::createConvertConv2DToWinogradPass()";
+}
+
+
//===---------------------------------------------------------------------====//
// Codegen Strategy passes moved into IREE
// TODO: Deprecate all this.
diff --git a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/CodegenStrategy.h b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/CodegenStrategy.h
index 7664332..d803588 100644
--- a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/CodegenStrategy.h
+++ b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/CodegenStrategy.h
@@ -8,6 +8,7 @@
#define IREE_DIALECTS_DIALECT_LINALGEXT_TRANSFORMS_CODEGENSTRATEGY_H_
#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h"
+#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/Pass/PassManager.h"
#include <utility>
@@ -38,7 +39,7 @@
/// Represent one application of LinalgStrategyTileAndFusePass.
struct TileAndFuse : public Transformation {
- TileAndFuse(StringRef name, linalg::LinalgTilingAndFusionOptions options,
+ TileAndFuse(StringRef name, scf::SCFTileAndFuseOptions options,
LinalgExt::LinalgTransformationFilter::FilterFunction f = nullptr)
: Transformation(std::move(f)), opName(name),
options(std::move(options)) {}
@@ -51,12 +52,12 @@
private:
std::string opName;
- linalg::LinalgTilingAndFusionOptions options;
+ scf::SCFTileAndFuseOptions options;
};
/// Represent one application of LinalgStrategyTilePass.
struct Tile : public Transformation {
- Tile(StringRef name, linalg::LinalgTilingOptions options,
+ Tile(StringRef name, scf::SCFTilingOptions options,
LinalgExt::LinalgTransformationFilter::FilterFunction f = nullptr)
: Transformation(std::move(f)), opName(name),
options(std::move(options)) {}
@@ -69,7 +70,7 @@
private:
std::string opName;
- linalg::LinalgTilingOptions options;
+ scf::SCFTilingOptions options;
};
/// Represent one application of LinalgStrategyPadPass.
@@ -171,8 +172,7 @@
/// Append a pattern to tile the Op `opName` and fuse its producers with
/// tiling and fusion `options`.
CodegenStrategy &
- tileAndFuse(StringRef opName,
- const linalg::LinalgTilingAndFusionOptions &options,
+ tileAndFuse(StringRef opName, const scf::SCFTileAndFuseOptions &options,
const LinalgExt::LinalgTransformationFilter::FilterFunction &f =
nullptr) {
transformationSequence.emplace_back(
@@ -182,14 +182,14 @@
/// Conditionally append a pattern to tile the Op `opName` and fuse its
/// producers with tiling and fusion `options`.
CodegenStrategy &tileAndFuseIf(
- bool b, StringRef opName, linalg::LinalgTilingAndFusionOptions options,
+ bool b, StringRef opName, scf::SCFTileAndFuseOptions options,
LinalgExt::LinalgTransformationFilter::FilterFunction f = nullptr) {
return b ? tileAndFuse(opName, std::move(options), std::move(f)) : *this;
}
/// Append a pattern to add a level of tiling for Op `opName` with tiling
/// `options`.
CodegenStrategy &
- tile(StringRef opName, const linalg::LinalgTilingOptions &options,
+ tile(StringRef opName, const scf::SCFTilingOptions &options,
const LinalgExt::LinalgTransformationFilter::FilterFunction &f =
nullptr) {
transformationSequence.emplace_back(
@@ -199,7 +199,7 @@
/// Conditionally append a pattern to add a level of tiling for
/// `LinalgOpType` with tiling `options`.
CodegenStrategy &
- tileIf(bool b, StringRef opName, linalg::LinalgTilingOptions options,
+ tileIf(bool b, StringRef opName, scf::SCFTilingOptions options,
LinalgExt::LinalgTransformationFilter::FilterFunction f = nullptr) {
return b ? tile(opName, std::move(options), std::move(f)) : *this;
}
diff --git a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h
index 5243cfe..34db971 100644
--- a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h
+++ b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h
@@ -10,6 +10,7 @@
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/IR/PatternMatch.h"
namespace mlir {
@@ -161,6 +162,43 @@
linalg::LinalgTilingOptions options;
};
+///
+/// Linalg SCF tiling pattern.
+///
+/// Apply the `tiling` transformation as a pattern.
+/// `filter` controls LinalgTransformMarker matching and update when specified.
+/// See `tiling` for more details.
+struct LinalgSCFTilingPattern
+ : public OpInterfaceRewritePattern<TilingInterface> {
+ /// Construct a generic pattern applied to all LinalgOp that verify `filter`.
+ LinalgSCFTilingPattern(
+ MLIRContext *context, scf::SCFTilingOptions options,
+ LinalgTransformationFilter f = LinalgTransformationFilter(),
+ PatternBenefit benefit = 1);
+
+ /// Construct a pattern specifically applied to `opName`.
+ LinalgSCFTilingPattern(
+ StringRef opName, MLIRContext *context, scf::SCFTilingOptions options,
+ LinalgTransformationFilter f = LinalgTransformationFilter(),
+ PatternBenefit benefit = 1);
+
+ /// `matchAndRewrite` implementation that returns the significant transformed
+ /// pieces of IR.
+ LogicalResult returningMatchAndRewrite(TilingInterface op,
+ PatternRewriter &rewriter) const;
+
+ LogicalResult matchAndRewrite(TilingInterface op,
+ PatternRewriter &rewriter) const override {
+ return returningMatchAndRewrite(op, rewriter);
+ }
+
+private:
+ /// LinalgTransformMarker handles special attribute manipulations.
+ LinalgTransformationFilter filter;
+ /// Options to control tiling;
+ scf::SCFTilingOptions options;
+};
+
template <typename... OpTypes>
class TilingPatterns;
@@ -185,6 +223,36 @@
};
///
+/// Linalg SCF tile and fuse patterns.
+///
+/// `filter` controls LinalgTransformMarker matching and update when specified.
+struct LinalgSCFTileAndFusePattern
+ : public OpInterfaceRewritePattern<TilingInterface> {
+ /// Construct a generic pattern applied to all LinalgOp that verify `filter`.
+ LinalgSCFTileAndFusePattern(
+ MLIRContext *context,
+ scf::SCFTileAndFuseOptions options = scf::SCFTileAndFuseOptions(),
+ LinalgTransformationFilter f = LinalgTransformationFilter(),
+ PatternBenefit benefit = 1);
+
+ /// Construct a pattern specifically applied to `opName`.
+ LinalgSCFTileAndFusePattern(
+ StringRef opName, MLIRContext *context,
+ scf::SCFTileAndFuseOptions options = scf::SCFTileAndFuseOptions(),
+ LinalgTransformationFilter f = LinalgTransformationFilter(),
+ PatternBenefit benefit = 1);
+
+ LogicalResult matchAndRewrite(TilingInterface op,
+ PatternRewriter &rewriter) const override;
+
+private:
+ /// LinalgTransformMarker handles special attribute manipulations.
+ LinalgTransformationFilter filter;
+
+ scf::SCFTileAndFuseOptions options;
+};
+
+///
/// Linalg vectorization patterns.
///
/// `filter` controls LinalgTransformMarker matching and update when specified.
diff --git a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/Utils.h b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/Utils.h
index 3ad536d..722e1a9 100644
--- a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/Utils.h
+++ b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/Utils.h
@@ -11,6 +11,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
namespace mlir {
namespace iree_compiler {
@@ -53,6 +54,11 @@
SmallVector<int64_t> computeInterchangeFromDimPos(ArrayRef<int64_t> dimsPos,
int64_t rank);
+/// Converts a 2D float array to a constant value. The 2D array is stored as
+/// a 1D row-major array in `val` and has shape `rows` x `cols`.
+Value createValueFrom2DConstant(const float *val, int64_t rows, int64_t cols,
+ Location loc, PatternRewriter &rewriter);
+
} // namespace LinalgExt
} // namespace IREE
} // namespace iree_compiler
diff --git a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/WinogradConstants.h b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/WinogradConstants.h
new file mode 100644
index 0000000..77dbb09
--- /dev/null
+++ b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/WinogradConstants.h
@@ -0,0 +1,90 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_DIALECTS_DIALECT_LINALGEXT_UTILS_WINOGRAD_CONSTANTS_H_
+#define IREE_DIALECTS_DIALECT_LINALGEXT_UTILS_WINOGRAD_CONSTANTS_H_
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace LinalgExt {
+namespace Winograd {
+
+// This file contains the Winograd constant matrices for different
+// output tile sizes
+
+//===----------------------------------------------------------------------===//
+// Output tile size = 6, Kernel size = 3
+//===----------------------------------------------------------------------===//
+// These constants were obtained from this paper:
+//
+// Liu, J. et al (2021) Optimizing Winograd-Based Convolution with Tensor Cores.
+// https://dl.acm.org/doi/abs/10.1145/3472456.3472473
+//
+
+// clang-format off
+
+const float BT_6x6_3x3[] = {
+ 1, 0, -21./4., 0, 21./4., 0, -1, 0,
+ 0, 1, 1, -17./4., -17./4., 1, 1, 0,
+ 0, -1, 1, 17./4., -17./4., -1, 1, 0,
+ 0, 1./2, 1./4., -5./2., -5./4., 2, 1, 0,
+ 0, -1./2, 1./4., 5./2., -5./4., -2, 1, 0,
+ 0, 2, 4, -5./2., -5, 1./2., 1, 0,
+ 0, -2, 4, 5./2., -5, -1./2., 1, 0,
+ 0, -1, 0, 21./4., 0, -21./4., 0, 1
+};
+
+const float B_6x6_3x3[] = {
+ 1, 0, 0, 0, 0, 0, 0, 0,
+ 0, 1, -1, 1./2, -1./2, 2, -2, -1,
+ -21./4., 1, 1, 1./4., 1./4., 4, 4, 0,
+ 0, -17./4., 17./4., -5./2., 5./2., -5./2., 5./2., 21./4.,
+ 21./4., -17./4., -17./4., -5./4., -5./4., -5, -5, 0,
+ 0, 1, -1, 2, -2, 1./2., -1./2., -21./4.,
+ -1, 1, 1, 1, 1, 1, 1, 0,
+ 0, 0, 0, 0, 0, 0, 0, 1
+};
+
+const float G_6x6_3x3[] = {
+ 1, 0, 0,
+ -2./9., -2./9., -2./9.,
+ -2./9., 2./9., -2./9.,
+ 1./90, 1./45, 2./45,
+ 1./90, -1./45, 2./45,
+ 32./45, 16./45, 8./45,
+ 32./45, -16./45, 8./45,
+ 0, 0, 1
+};
+
+const float AT_6x6_3x3[] = {
+ 1, 1, 1, 1, 1, 1, 1, 0,
+ 0, 1, -1, 2, -2, 1./2, -1./2, 0,
+ 0, 1, 1, 4, 4, 1./4, 1./4, 0,
+ 0, 1, -1, 8, -8, 1./8, -1./8, 0,
+ 0, 1, 1, 16, 16, 1./16, 1./16, 0,
+ 0, 1, -1, 32, -32, 1./32, -1./32, 1
+};
+
+const float A_6x6_3x3[] = {
+ 1, 0, 0, 0, 0, 0,
+ 1, 1, 1, 1, 1, 1,
+ 1, -1, 1, -1, 1, -1,
+ 1, 2, 4, 8, 16, 32,
+ 1, -2, 4, -8, 16, -32,
+ 1, 1./2, 1./4, 1./8, 1./16, 1./32,
+ 1, -1./2, 1./4, -1./8, 1./16, -1./32,
+ 0, 0, 0, 0, 0, 1
+};
+
+// clang-format on
+
+} // namespace Winograd
+} // namespace LinalgExt
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
+#endif // IREE_DIALECTS_DIALECT_LINALGEXT_UTILS_WINOGRAD_CONSTANTS_H_
diff --git a/integrations/tensorflow/iree-dialects/lib/CAPI/CMakeLists.txt b/integrations/tensorflow/iree-dialects/lib/CAPI/CMakeLists.txt
index cba823b..17561a6 100644
--- a/integrations/tensorflow/iree-dialects/lib/CAPI/CMakeLists.txt
+++ b/integrations/tensorflow/iree-dialects/lib/CAPI/CMakeLists.txt
@@ -1,6 +1,5 @@
add_mlir_public_c_api_library(IREEDialectsCAPI
Dialects.cpp
- Utils.cpp
LINK_LIBS PUBLIC
MLIRIR
MLIRTransformDialect
diff --git a/integrations/tensorflow/iree-dialects/lib/CAPI/Utils.cpp b/integrations/tensorflow/iree-dialects/lib/CAPI/Utils.cpp
deleted file mode 100644
index d704f2b..0000000
--- a/integrations/tensorflow/iree-dialects/lib/CAPI/Utils.cpp
+++ /dev/null
@@ -1,20 +0,0 @@
-// Copyright 2021 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree-dialects-c/Utils.h"
-
-#include "mlir/CAPI/IR.h"
-#include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/IR/SymbolTable.h"
-
-using namespace mlir;
-
-MlirOperation ireeLookupNearestSymbolFrom(MlirOperation fromOp,
- MlirAttribute symbolRefAttr) {
- auto symbolRefAttrCpp = unwrap(symbolRefAttr).cast<SymbolRefAttr>();
- return wrap(
- SymbolTable::lookupNearestSymbolFrom(unwrap(fromOp), symbolRefAttrCpp));
-}
diff --git a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
index 6e0a36c..fc0db72 100644
--- a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
+++ b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -1706,8 +1706,11 @@
SmallVector<Value> dynamicTileSizes;
dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes,
ShapedType::kDynamic);
- build(builder, state, output.getType(), source, output, outerDimsPerm,
- innerDimsPos, dynamicTileSizes, staticTileSizes,
+ build(builder, state, output.getType(), source, output,
+ outerDimsPerm.empty() ? nullptr
+ : builder.getDenseI64ArrayAttr(outerDimsPerm),
+ builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
+ builder.getDenseI64ArrayAttr(staticTileSizes),
(paddingValue ? paddingValue.value() : nullptr));
}
@@ -2096,8 +2099,11 @@
SmallVector<Value> dynamicTileSizes;
dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes,
ShapedType::kDynamic);
- build(builder, state, output.getType(), source, output, outerDimsPerm,
- innerDimsPos, dynamicTileSizes, staticTileSizes);
+ build(builder, state, output.getType(), source, output,
+ outerDimsPerm.empty() ? nullptr
+ : builder.getDenseI64ArrayAttr(outerDimsPerm),
+ builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
+ builder.getDenseI64ArrayAttr(staticTileSizes));
}
SmallVector<OpFoldResult> UnPackOp::getMixedTiles() {
@@ -2363,6 +2369,15 @@
return iteratorTypes;
}
+FailureOr<Value>
+UnPackOp::generateResultTileValue(OpBuilder &b, unsigned resultNumber,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes) {
+ return getTiledImplementation(b, offsets, sizes)
+ .back()
+ ->getResult(resultNumber);
+}
+
//===----------------------------------------------------------------------===//
// WinogradInputTransformOp
//===----------------------------------------------------------------------===//
@@ -2531,6 +2546,165 @@
.reifyResultShapes(b, reifiedReturnShapes);
}
+//===----------------------------------------------------------------------===//
+// WinogradOutputTransformOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult WinogradOutputTransformOp::verify() {
+ Operation *op = getOperation();
+ if (getNumInputs() != 1) {
+ return op->emitOpError("expected one input operand");
+ }
+ if (getNumOutputs() != 1) {
+ return op->emitOpError("expected one output operand");
+ }
+ auto inputType = input().getType().cast<ShapedType>();
+ auto outputType = output().getType().cast<ShapedType>();
+ ArrayRef<int64_t> inputShape = inputType.getShape();
+ if (inputShape.size() != 6) {
+ return op->emitOpError("expected input operand to have rank 6");
+ }
+ ArrayRef<int64_t> outputShape = outputType.getShape();
+ if (outputType.getElementType() != inputType.getElementType()) {
+ return op->emitOpError(
+ "expected input/output element types to be identical");
+ }
+ if (getOutputOperandRank() != getInputOperandRank() - 2) {
+ return op->emitOpError(
+ "expected output rank to be equal to input rank - 2");
+ }
+ const SmallVector<int64_t> imageDims = imageDimensions();
+ const size_t numImageDims = imageDims.size();
+ llvm::SmallSetVector<int64_t, 2> imageDimsSet(imageDims.begin(),
+ imageDims.end());
+ if (imageDims.size() != 2) {
+ return op->emitOpError("expected only 2 image dimensions");
+ }
+ for (auto dim : imageDims) {
+ if ((dim < 0) || (dim > 3)) {
+ return op->emitOpError(
+ "expect image dimensions to be in the range: [0, 3]");
+ }
+ }
+ const int64_t outputTileSize = getOutputTileSize();
+ SmallVector<int64_t> expectedOutputShape(getOutputOperandRank(), 1);
+ int outputIndex;
+ for (int i = numImageDims; i < inputShape.size(); i++) {
+ outputIndex = i - numImageDims;
+ if (!imageDimsSet.contains(outputIndex)) {
+ expectedOutputShape[outputIndex] = inputShape[i];
+ } else {
+ expectedOutputShape[outputIndex] = outputTileSize * inputShape[i];
+ }
+ }
+ if (!areShapesCompatible(expectedOutputShape, outputShape)) {
+ return op->emitOpError("incompatible output shape");
+ }
+ return success();
+}
+
+SmallVector<Range>
+WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
+ Location loc = getLoc();
+ Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+ Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
+ Value source = output();
+ SmallVector<int64_t> imageDims = imageDimensions();
+ llvm::SmallSetVector<int64_t, 2> imageDimsSet(imageDims.begin(),
+ imageDims.end());
+ SmallVector<Range> loopBounds(imageDims.size());
+ int count = 0;
+ for (auto dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
+ if (!imageDimsSet.contains(dim)) {
+ loopBounds[count].offset = zero;
+ loopBounds[count].size = getDimValue(builder, loc, source, dim);
+ loopBounds[count].stride = one;
+ count++;
+ }
+ }
+ return loopBounds;
+}
+
+SmallVector<utils::IteratorType>
+WinogradOutputTransformOp::getLoopIteratorTypes() {
+ SmallVector<utils::IteratorType> iteratorTypes(getIterationDomainRank(),
+ utils::IteratorType::parallel);
+ return iteratorTypes;
+}
+
+SmallVector<Operation *> WinogradOutputTransformOp::getTiledImplementation(
+ OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes) {
+
+ Location loc = getLoc();
+ auto one = builder.getIndexAttr(1);
+ auto zero = builder.getIndexAttr(0);
+
+ assert(offsets.size() == 2);
+ SmallVector<OpFoldResult> inputOffsets(getInputOperandRank(), zero);
+ SmallVector<OpFoldResult> outputOffsets(getOutputOperandRank(), zero);
+ inputOffsets[2] = outputOffsets[0] = offsets[0];
+ inputOffsets[5] = outputOffsets[3] = offsets[1];
+
+ SmallVector<OpFoldResult> inputStrides(getInputOperandRank(), one);
+ SmallVector<OpFoldResult> outputStrides(getOutputOperandRank(), one);
+
+ assert(sizes.size() == 2);
+ auto inputShape = input().getType().cast<ShapedType>().getShape();
+ auto outputShape = output().getType().cast<ShapedType>().getShape();
+ SmallVector<OpFoldResult> inputSizes =
+ getAsOpFoldResult(builder.getIndexArrayAttr(inputShape));
+ SmallVector<OpFoldResult> outputSizes =
+ getAsOpFoldResult(builder.getIndexArrayAttr(outputShape));
+ inputSizes[2] = outputSizes[0] = sizes[0];
+ inputSizes[5] = outputSizes[3] = sizes[1];
+
+ SmallVector<Value> tiledOperands;
+ tiledOperands.emplace_back(
+ getSlice(builder, loc, input(), inputOffsets, inputSizes, inputStrides));
+ tiledOperands.emplace_back(getSlice(builder, loc, output(), outputOffsets,
+ outputSizes, outputStrides));
+
+ SmallVector<Type, 4> resultTypes;
+ if (hasTensorSemantics()) {
+ resultTypes.push_back(tiledOperands[1].getType());
+ }
+
+ Operation *tiledOp =
+ mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
+
+ return {tiledOp};
+}
+
+LogicalResult WinogradOutputTransformOp::getResultTilePosition(
+ OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
+ SmallVector<OpFoldResult> &resultSizes) {
+ if (resultNumber == 0) {
+ auto resultShape = output().getType().cast<ShapedType>().getShape();
+ resultSizes = getAsOpFoldResult(builder.getIndexArrayAttr(resultShape));
+ resultOffsets = SmallVector<OpFoldResult>(getOutputOperandRank(),
+ builder.getIndexAttr(0));
+ resultOffsets[0] = offsets[0];
+ resultOffsets[3] = offsets[1];
+ resultSizes[0] = sizes[0];
+ resultSizes[3] = sizes[1];
+ return success();
+ }
+ return failure();
+}
+
+LogicalResult WinogradOutputTransformOp::fold(ArrayRef<Attribute>,
+ SmallVectorImpl<OpFoldResult> &) {
+ return memref::foldMemRefCast(*this);
+}
+
+LogicalResult WinogradOutputTransformOp::reifyResultShapes(
+ OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+ return cast<LinalgExtOp>(getOperation())
+ .reifyResultShapes(b, reifiedReturnShapes);
+}
+
#define DEFINE_OP_GET_EFFECTS(OP_NAME) \
void OP_NAME::getEffects( \
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> \
@@ -2550,6 +2724,7 @@
DEFINE_OP_GET_EFFECTS(PackOp)
DEFINE_OP_GET_EFFECTS(UnPackOp)
DEFINE_OP_GET_EFFECTS(WinogradInputTransformOp)
+DEFINE_OP_GET_EFFECTS(WinogradOutputTransformOp)
//===----------------------------------------------------------------------===//
// iree_linalg_ext.set_encoding
diff --git a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/CMakeLists.txt b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/CMakeLists.txt
index 69fbd32..68fdab8 100644
--- a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/CMakeLists.txt
+++ b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/CMakeLists.txt
@@ -1,10 +1,12 @@
add_mlir_library(IREELinalgExtPasses
+ ConvertConv2DToWinograd.cpp
ConvertToLoops.cpp
FoldIntoPackAndUnpackOps.cpp
MaterializeEncoding.cpp
PadContractionToBlockSize.cpp
Passes.cpp
SplitReduction.cpp
+ TileAndDecomposeWinogradPass.cpp
Tiling.cpp
DEPENDS
diff --git a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/ConvertConv2DToWinograd.cpp b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/ConvertConv2DToWinograd.cpp
new file mode 100644
index 0000000..e0f289e
--- /dev/null
+++ b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/ConvertConv2DToWinograd.cpp
@@ -0,0 +1,396 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h"
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "iree-dialects/Dialect/LinalgExt/Passes/PassDetail.h"
+#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h"
+#include "iree-dialects/Dialect/LinalgExt/Utils/WinogradConstants.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SetVector.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace LinalgExt {
+
+static inline int index(int y, int x, int dimy, int dimx) {
+ return (x + dimx * y);
+}
+
+static inline int index(int z, int y, int x, int w, int dimz, int dimy,
+ int dimx, int dimw) {
+ return (w + dimw * (x + dimx * (y + dimy * z)));
+}
+
+static bool hasAllOneValues(DenseIntElementsAttr attr) {
+ return llvm::all_of(attr, [](APInt element) { return element.isOne(); });
+}
+
+// TODO: Make this a user-settable parameter once we have support
+// for more tile sizes
+static constexpr int64_t outputTileSize = 6;
+
+/// This function computes the Winograd filter transform when
+/// the filter is known to be a constant. Specifically, this
+/// function computes matmul(G, matmul(F, transpose(G))) where
+/// F is a tile of the convolution filter of size m x m
+/// (single input channel, single output channel) and G has
+/// shape m x (m + r - 1) where r is the output tile size and
+/// (m + r - 1) is the input tile size.
+/// The time complexity of this function is O(ic * oc)
+/// where ic is the number of input channels and oc is the
+/// number of output channels since input tile size and kernel size
+/// are constants. So for large ic and oc, this function is
+/// time intensive.
+/// TODO: Codegen this as a kernel and run once at initialization
+static DenseElementsAttr foldFilterTransform(
+ ArrayRef<int64_t> shape, int64_t inputTileSize, int64_t kernelSize,
+ Type outputType, const float *G, bool isSplat, float splatValue,
+ DenseElementsAttr::iterator_range<APFloat> &input, Type elementType) {
+ const int &kh = shape[0];
+ const int &kw = shape[1];
+ const int &ic = shape[2];
+ const int &oc = shape[3];
+ const int64_t numElements = inputTileSize * inputTileSize * ic * oc;
+ SmallVector<APFloat> output(numElements, APFloat(0.0f));
+ for (int d0 = 0; d0 < inputTileSize; d0++) {
+ for (int d1 = 0; d1 < inputTileSize; d1++) {
+ for (int d2 = 0; d2 < ic; d2++) {
+ for (int d3 = 0; d3 < oc; d3++) {
+ APFloat accum(0.0f);
+ for (int d4 = 0; d4 < kernelSize; d4++) {
+ for (int d5 = 0; d5 < kernelSize; d5++) {
+ APFloat ival(splatValue);
+ if (!isSplat) {
+ ival = input[index(d4, d5, d2, d3, kh, kw, ic, oc)];
+ }
+ int idx0 = index(d0, d4, inputTileSize, kernelSize);
+ int idx1 = index(d1, d5, inputTileSize, kernelSize);
+ accum = accum + APFloat(G[idx0]) * ival * APFloat(G[idx1]);
+ }
+ }
+ int odx = index(d0, d1, d2, d3, inputTileSize, inputTileSize, ic, oc);
+ output[odx] = accum;
+ }
+ }
+ }
+ }
+ return DenseElementsAttr::get(outputType, output);
+}
+
+namespace {
+
+class FoldWinogradFilterTransform final
+ : public OpRewritePattern<linalg::Conv2DNhwcHwcfOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
+ PatternRewriter &rewriter) const override {
+ // Check that kernel size = 3x3
+ Value kernel = convOp.getInputs()[1];
+ auto kernelType = kernel.getType().cast<ShapedType>();
+ if (!kernelType)
+ return failure();
+ ArrayRef<int64_t> kernelShape = kernelType.getShape();
+ const int64_t kh = kernelShape[0];
+ const int64_t kw = kernelShape[1];
+ if ((kh != 3) || (kw != 3))
+ return failure();
+ const int64_t kernelSize = kh;
+ const int64_t inputTileSize = outputTileSize + kernelSize - 1;
+
+ DenseIntOrFPElementsAttr kernelAttr;
+ if (!matchPattern(kernel, m_Constant(&kernelAttr))) {
+ return failure();
+ }
+
+ Operation *constOp = kernel.getDefiningOp();
+ ShapedType type = constOp->getResult(0).getType().cast<ShapedType>();
+ Type elementType = type.getElementType();
+ assert(elementType.isa<FloatType>());
+ ArrayRef<int64_t> shape = type.getShape();
+ DenseElementsAttr::iterator_range<APFloat> nonSplatValues =
+ kernelAttr.getValues<APFloat>();
+ bool isSplat = kernelAttr.isSplat();
+ float splatValue{0.0};
+ if (isSplat) {
+ splatValue = kernelAttr.getSplatValue<APFloat>().convertToFloat();
+ }
+ SmallVector<int64_t> resultShape{inputTileSize * inputTileSize, shape[2],
+ shape[3]};
+ auto resultType = RankedTensorType::get(resultShape, elementType);
+ auto foldedKernelAttr =
+ foldFilterTransform(shape, inputTileSize, kernelSize, resultType,
+ IREE::LinalgExt::Winograd::G_6x6_3x3, isSplat,
+ splatValue, nonSplatValues, elementType);
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(constOp, foldedKernelAttr);
+ return success();
+ }
+};
+
+} // namespace
+
+static Value
+createCollapse(Value tensor, Location loc, PatternRewriter &rewriter,
+ SmallVectorImpl<int64_t> &outputShape,
+ SmallVectorImpl<ReassociationIndices> &reassociations) {
+ auto tensorType = tensor.getType().cast<ShapedType>();
+ auto elementTy = tensorType.getElementType();
+ auto resultType = RankedTensorType::get(outputShape, elementTy);
+ return rewriter.create<tensor::CollapseShapeOp>(loc, resultType, tensor,
+ reassociations);
+}
+
+static Value
+createExpand(Value tensor, Location loc, PatternRewriter &rewriter,
+ SmallVectorImpl<int64_t> &outputShape,
+ SmallVectorImpl<ReassociationIndices> &reassociations) {
+ auto tensorType = tensor.getType().cast<ShapedType>();
+ auto elementTy = tensorType.getElementType();
+ auto resultType = RankedTensorType::get(outputShape, elementTy);
+ return rewriter.create<tensor::ExpandShapeOp>(loc, resultType, tensor,
+ reassociations);
+}
+
+namespace {
+
+/// Convert conv2d to a sequence of ops that implement the
+/// Winograd transformation. The Winograd transformation
+/// is parameterized by the output tile size(r). The larger
+/// the tile size, the greater the computational savings,
+/// but this comes at the cost of accuracy.
+///
+/// For now, we restrict this transform to convolutions
+/// where the filter size = 3x3, though extensions to larger
+/// filter sizes are possible. We refer to the
+/// filter size as (m). The input tile size (i) is defined as
+/// m + r - 1. For a given output tile size, the Winograd
+/// transformation defines 3 constant matrices:
+///
+/// B: i x i [used in input transform]
+/// G: m x i [used in the filter transform]
+/// A: i x r [used in output transform]
+///
+/// The choice of these matrices is not unique and affects
+/// the accuracy of the approach.
+///
+/// Given a convolution of the form
+///
+/// y = conv2d(x, f)
+///
+/// where x: (N, H, W, C)
+/// f: (H, W, C, F)
+///
+/// this pattern converts the convolution to the following
+/// sequence:
+///
+/// f_winograd = winograd.filter_transform(f) [folded]
+/// x_winograd = winograd.input_transform(x)
+/// x_winograd_c = collapse(x_winograd)
+/// y_winograd = batch_matmul(x_winograd_c, f_winograd)
+/// y_winograd_e = expand(y_winograd)
+/// y_padded = winograd.output_transform(y_winograd_e)
+/// y = extract_slice(y_padded)
+///
+/// where the dimensions of the tensors above are:
+///
+/// f_winograd: (i * i, C, F)
+/// x_winograd: (i, i, N, H', W', C)
+/// x_winograd_c: (i * i, N * H' * W', C)
+/// y_winograd: (i * i, N * H' * W', F)
+/// y_winograd_e: (i, i, N, H', W', F)
+/// y_padded: (N, r * H', r * W', F)
+///
+/// H': ceil((H - m + 1) / r)
+/// W': ceil((W - m + 1) / r)
+///
+/// The winograd input transform extracts a tile of the input
+/// of size i x i and computes matmul(transpose(B), matmul(tile(x), B)).
+/// The winograd filter transform extracts a tile of the filter
+/// of size m x m and computes matmul(G, matmul(tile(f), transpose(G)).
+/// These two are then combined using elementwise multiplication
+/// (which becomes a batch matmul when combining over multiple channels).
+/// The winograd output filter extracts a tile of the result of size
+/// i x i and computes matmul(transpose(A), matmul(tile(y_winograd_e), A)).
+///
+/// For more information and additional references,
+/// see here:
+///
+/// https://github.com/nod-ai/MLIRWinogradTalk/blob/main/MLIRSummit2022.Nodai.Menon.pdf
+///
+class ConvertConv2DNhwcHwcf final
+ : public OpRewritePattern<linalg::Conv2DNhwcHwcfOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
+ PatternRewriter &rewriter) const override {
+ // Check that strides = 1
+ if (!hasAllOneValues(convOp.getStrides()))
+ return failure();
+
+ // Check that dilations = 1
+ if (!hasAllOneValues(convOp.getDilations()))
+ return failure();
+
+ // Check that kernel has been constant folded (by validating rank = 3)
+ Value kernel = convOp.getInputs()[1];
+ auto kernelType = kernel.getType().cast<ShapedType>();
+ if (!kernelType)
+ return failure();
+ Type elementType = kernelType.getElementType();
+ ArrayRef<int64_t> kernelShape = kernelType.getShape();
+ if (kernelShape.size() != 3)
+ return failure();
+
+ const int64_t kernelSize = 3;
+ const int64_t inputTileSize = outputTileSize + kernelSize - 1;
+
+ // Create winograd input transform op
+ Location loc = convOp.getLoc();
+ Value zero = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getZeroAttr(elementType));
+ Value input = convOp.getInputs()[0];
+ auto inputType = input.getType().cast<ShapedType>();
+ if (!inputType)
+ return failure();
+ ArrayRef<int64_t> inputShape = inputType.getShape();
+ if (llvm::any_of(inputShape, ShapedType::isDynamic))
+ return failure();
+ assert(inputShape.size() == 4);
+
+ SmallVector<int64_t, 2> imageDimensions = {1, 2};
+ const size_t numImageDims = imageDimensions.size();
+ SmallVector<int64_t> resultShape(6, inputTileSize);
+ llvm::SmallSetVector<int64_t, 2> imageDimensionsSet(imageDimensions.begin(),
+ imageDimensions.end());
+ int outputIndex;
+ for (int i = 0; i < inputShape.size(); i++) {
+ outputIndex = i + numImageDims;
+ if (!imageDimensionsSet.contains(i)) {
+ resultShape[outputIndex] = inputShape[i];
+ } else {
+ resultShape[outputIndex] =
+ std::ceil((float)(inputShape[i] - kernelSize + 1) / outputTileSize);
+ }
+ }
+ Value emptyTensor =
+ rewriter.create<tensor::EmptyOp>(loc, resultShape, elementType);
+ auto winogradInputOp =
+ rewriter.create<IREE::LinalgExt::WinogradInputTransformOp>(
+ loc, emptyTensor.getType(), ValueRange{input},
+ ValueRange{emptyTensor}, outputTileSize, kernelSize,
+ imageDimensions);
+ Value winogradInput = winogradInputOp.getResult()[0];
+
+ // Add collapse shape
+ SmallVector<int64_t> collapsedShape = {
+ resultShape[0] * resultShape[1],
+ resultShape[2] * resultShape[3] * resultShape[4], resultShape[5]};
+ SmallVector<ReassociationIndices> reassociations = {{0, 1}, {2, 3, 4}, {5}};
+ Value collapsedWinogradInput = createCollapse(
+ winogradInput, loc, rewriter, collapsedShape, reassociations);
+
+ // Add BatchMatmulOp
+ SmallVector<int64_t> bmmShape(collapsedShape.begin(), collapsedShape.end());
+ Value output = convOp.getOutputs()[0];
+ auto outputType = output.getType().cast<RankedTensorType>();
+ ArrayRef<int64_t> outputShape = outputType.getShape();
+ bmmShape[2] = outputShape[3];
+ auto bmmOutputType = RankedTensorType::get(bmmShape, elementType);
+ emptyTensor = rewriter.create<tensor::EmptyOp>(loc, bmmShape, elementType);
+ auto fillOp = rewriter.create<linalg::FillOp>(loc, ValueRange{zero},
+ ValueRange{emptyTensor});
+ auto bmmOp = rewriter.create<linalg::BatchMatmulOp>(
+ loc, bmmOutputType, ValueRange({collapsedWinogradInput, kernel}),
+ ValueRange({fillOp.result()}));
+ Value bmmResult = bmmOp.getResult(0);
+
+ // Add expand shape
+ SmallVector<int64_t> expandedShape = {resultShape[0], resultShape[1],
+ resultShape[2], resultShape[3],
+ resultShape[4], outputShape[3]};
+ reassociations = {{0, 1}, {2, 3, 4}, {5}};
+ Value expandedBmmResult =
+ createExpand(bmmResult, loc, rewriter, expandedShape, reassociations);
+
+ // Convert back into original domain
+ SmallVector<int64_t> paddedResultShape(outputShape.size(), 0);
+ for (int i = 0; i < outputShape.size(); i++) {
+ if (!imageDimensionsSet.contains(i)) {
+ paddedResultShape[i] = outputShape[i];
+ } else {
+ paddedResultShape[i] = resultShape[i + numImageDims] * outputTileSize;
+ }
+ }
+ emptyTensor =
+ rewriter.create<tensor::EmptyOp>(loc, paddedResultShape, elementType);
+ auto winogradOutputOp =
+ rewriter.create<IREE::LinalgExt::WinogradOutputTransformOp>(
+ loc, emptyTensor.getType(), ValueRange{expandedBmmResult},
+ ValueRange{emptyTensor}, outputTileSize, kernelSize,
+ imageDimensions);
+ Value paddedOutput = winogradOutputOp.getResult()[0];
+
+ // Extract slice
+ SmallVector<OpFoldResult> offsets(outputShape.size(),
+ rewriter.getIndexAttr(0));
+ SmallVector<OpFoldResult> strides(outputShape.size(),
+ rewriter.getIndexAttr(1));
+ SmallVector<OpFoldResult> sizes;
+ for (int i = 0; i < outputShape.size(); i++)
+ sizes.push_back(rewriter.getIndexAttr(outputShape[i]));
+ auto winogradOutput = rewriter.create<tensor::ExtractSliceOp>(
+ loc, outputType, paddedOutput, offsets, sizes, strides);
+
+ Value result = convOp.getResult(0);
+ result.replaceAllUsesWith(winogradOutput);
+ return success();
+ }
+};
+
+struct ConvertConv2DToWinogradPass
+ : ConvertConv2DToWinogradBase<ConvertConv2DToWinogradPass> {
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry
+ .insert<linalg::LinalgDialect, IREE::LinalgExt::IREELinalgExtDialect>();
+ }
+ void runOnOperation() override {
+ MLIRContext *context = &getContext();
+ RewritePatternSet patterns(&getContext());
+ patterns.insert<FoldWinogradFilterTransform, ConvertConv2DNhwcHwcf>(
+ context);
+ if (failed(applyPatternsAndFoldGreedily(getOperation(),
+ std::move(patterns)))) {
+ return signalPassFailure();
+ }
+ }
+};
+
+} // namespace
+
+std::unique_ptr<Pass> createConvertConv2DToWinogradPass() {
+ return std::make_unique<ConvertConv2DToWinogradPass>();
+}
+
+} // namespace LinalgExt
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/FoldIntoPackAndUnpackOps.cpp b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/FoldIntoPackAndUnpackOps.cpp
index 0713ce6..7477bc0 100644
--- a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/FoldIntoPackAndUnpackOps.cpp
+++ b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/FoldIntoPackAndUnpackOps.cpp
@@ -53,9 +53,8 @@
Value output = rewriter.create<tensor::EmptyOp>(
sliceOp.getLoc(), sliceOp.getMixedSizes(), elementType);
rewriter.replaceOpWithNewOp<UnPackOp>(
- sliceOp, output.getType(), unpackOp.getInput(), output,
- unpackOp.getOuterDimsPerm(), unpackOp.getInnerDimsPos(),
- unpackOp.getInnerTiles(), unpackOp.getStaticInnerTiles());
+ sliceOp, unpackOp.getInput(), output, unpackOp.getInnerDimsPos(),
+ unpackOp.getMixedTiles(), unpackOp.getOuterDimsPerm());
return success();
}
};
diff --git a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp
index 66978aa..fc22757 100644
--- a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp
+++ b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp
@@ -22,35 +22,6 @@
using namespace mlir::iree_compiler::IREE::LinalgExt;
//===---------------------------------------------------------------------===//
-// Methods to convert the encoding to parameters of the Pack operation
-//===---------------------------------------------------------------------===//
-
-/// Given the `encoding` return the `MaterializeEncodingInfo` to use for
-/// materializing the pack op.
-// TODO(ravishankarm): THis is currently hard-coded here for convenience. When
-// used in IREE, this will be computed based on the architecture information in
-// `hal.executable.variant`.
-static FailureOr<MaterializeEncodingInfo>
-getPackOpInfoFromEncoding(TensorEncoding encoding) {
- switch (encoding) {
- case TensorEncoding::GEMM_LHS:
- return MaterializeEncodingInfo{{0, 1}, {8, 4}, {}};
- break;
- case TensorEncoding::GEMM_RHS:
- return MaterializeEncodingInfo{{0, 1}, {4, 8}, {}};
- break;
- case TensorEncoding::GEMM_RESULT:
- return MaterializeEncodingInfo{{0, 1}, {8, 8}, {}};
- break;
- case TensorEncoding::GEMM_RHS_TRANSPOSE:
- return MaterializeEncodingInfo{{1, 0}, {8, 4}, {1, 0}};
- break;
- default:
- return failure();
- }
-}
-
-//===---------------------------------------------------------------------===//
// Utility methods
//===---------------------------------------------------------------------===//
@@ -72,7 +43,7 @@
if (!encoding)
return tensorType;
FailureOr<MaterializeEncodingInfo> materializeEncodingInfo =
- materializeEncodingFn(encoding.value());
+ materializeEncodingFn(tensorType);
if (failed(materializeEncodingInfo)) {
return tensorType;
}
@@ -94,6 +65,48 @@
}
//===---------------------------------------------------------------------===//
+// Methods to convert the encoding to parameters of the Pack operation
+//===---------------------------------------------------------------------===//
+
+/// Given the `encoding` return the `MaterializeEncodingInfo` to use for
+/// materializing the pack op.
+// TODO(ravishankarm): This is currently hard-coded here for convenience. When
+// used in IREE, this will be computed based on the architecture information in
+// `hal.executable.variant`.
+// A real implementation would return tile sizes that depend on at least the
+// `tensorType`'s element type (e.g. different tile sizes for i8 vs f32, because
+// the SIMD instructions may have different shapes).
+// Moreover, in a real implementation, the tile sizes would typically also
+// depend on target information. This is demonstrated in
+// iree/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingPass.cpp
+static FailureOr<MaterializeEncodingInfo>
+chooseEncodingInfo(RankedTensorType tensorType) {
+ Optional<TensorEncoding> encoding = getEncoding(tensorType);
+ if (!encoding)
+ return failure();
+ switch (*encoding) {
+ case TensorEncoding::MATMUL_F32F32F32_LHS:
+ case TensorEncoding::MATMUL_I8I8I32_LHS:
+ return MaterializeEncodingInfo{{0, 1}, {8, 4}, {}};
+ break;
+ case TensorEncoding::MATMUL_F32F32F32_RHS:
+ case TensorEncoding::MATMUL_I8I8I32_RHS:
+ return MaterializeEncodingInfo{{0, 1}, {4, 8}, {}};
+ break;
+ case TensorEncoding::MATMUL_F32F32F32_RHS_TRANSPOSE:
+ case TensorEncoding::MATMUL_I8I8I32_RHS_TRANSPOSE:
+ return MaterializeEncodingInfo{{1, 0}, {8, 4}, {1, 0}};
+ break;
+ case TensorEncoding::MATMUL_F32F32F32_RESULT:
+ case TensorEncoding::MATMUL_I8I8I32_RESULT:
+ return MaterializeEncodingInfo{{0, 1}, {8, 8}, {}};
+ break;
+ default:
+ return failure();
+ }
+}
+
+//===---------------------------------------------------------------------===//
// Methods to convert `set_encoding` and `unset_encoding` operations
// to `pack` and `unpack` operations respectively.
//===---------------------------------------------------------------------===//
@@ -121,8 +134,9 @@
lowerSetEncodingOpToPackOp(RewriterBase &rewriter, SetEncodingOp encodingOp,
Value source,
MaterializeEncodingFn materializeEncodingFn) {
+ RankedTensorType resultType = encodingOp.getResultType();
FailureOr<MaterializeEncodingInfo> materializeEncodingInfo =
- materializeEncodingFn(encodingOp.getResultTensorEncoding());
+ materializeEncodingFn(resultType);
if (failed(materializeEncodingInfo)) {
return rewriter.notifyMatchFailure(encodingOp, "unhandled result encoding");
}
@@ -137,7 +151,7 @@
materializeEncodingInfo->innerDimsPos,
materializeEncodingInfo->outerDimsPerm);
auto initTensor = rewriter.create<tensor::EmptyOp>(
- loc, resultDims, encodingOp.getSourceType().getElementType());
+ loc, resultDims, resultType.getElementType());
Optional<Value> paddingValue = getPaddingValue(source);
return rewriter.create<PackOp>(
loc, source, initTensor, materializeEncodingInfo->innerDimsPos,
@@ -151,8 +165,9 @@
lowerUnsetEncodingToUnpackOp(RewriterBase &rewriter, UnsetEncodingOp encodingOp,
Value packedValue,
MaterializeEncodingFn materializeEncodingFn) {
+ RankedTensorType sourceType = encodingOp.getSourceType();
FailureOr<MaterializeEncodingInfo> materializeEncodingInfo =
- materializeEncodingFn(encodingOp.getSourceTensorEncoding());
+ materializeEncodingFn(sourceType);
if (failed(materializeEncodingInfo)) {
return rewriter.notifyMatchFailure(encodingOp, "unhandled source encoding");
}
@@ -161,7 +176,7 @@
SmallVector<OpFoldResult> resultDims =
getDims(rewriter, loc, encodingOp.getSource());
auto initTensor = rewriter.create<tensor::EmptyOp>(
- loc, resultDims, encodingOp.getResultType().getElementType());
+ loc, resultDims, sourceType.getElementType());
SmallVector<OpFoldResult> innerTileSizesOfr =
getAsOpFoldResult(rewriter, materializeEncodingInfo->innerTileSizes);
@@ -171,9 +186,9 @@
}
/// Utility method to convert from `linalg.matmul` with
-/// - lhs encoding of GEMM_LHS
-/// - rhs encoding of GEMM_RHS_TRANSPOSE
-/// - result encoding of GEMM_RESULT
+/// - lhs encoding of MATMUL_*_LHS
+/// - rhs encoding of MATMUL_*_RHS_TRANSPOSE
+/// - result encoding of MATMUL_*_RESULT
/// to linalg.mmt4d op.
static FailureOr<Operation *>
lowerOpWithEncoding(RewriterBase &rewriter, linalg::MatmulOp matmulOp,
@@ -190,11 +205,15 @@
getEncoding(inputs[1]->get().getType().cast<RankedTensorType>());
Optional<TensorEncoding> resultEncoding =
getEncoding(outputs[0]->get().getType().cast<RankedTensorType>());
- if (!lhsEncoding || lhsEncoding.value() != TensorEncoding::GEMM_LHS ||
+ if (!lhsEncoding ||
+ (lhsEncoding.value() != TensorEncoding::MATMUL_F32F32F32_LHS &&
+ lhsEncoding.value() != TensorEncoding::MATMUL_I8I8I32_LHS) ||
!rhsEncoding ||
- rhsEncoding.value() != TensorEncoding::GEMM_RHS_TRANSPOSE ||
+ (rhsEncoding.value() != TensorEncoding::MATMUL_F32F32F32_RHS_TRANSPOSE &&
+ rhsEncoding.value() != TensorEncoding::MATMUL_I8I8I32_RHS_TRANSPOSE) ||
!resultEncoding ||
- resultEncoding.value() != TensorEncoding::GEMM_RESULT) {
+ (resultEncoding.value() != TensorEncoding::MATMUL_F32F32F32_RESULT &&
+ resultEncoding.value() != TensorEncoding::MATMUL_I8I8I32_RESULT)) {
return failure();
}
Operation *mmt4DOp = rewriter.create<linalg::Mmt4DOp>(
@@ -225,13 +244,8 @@
ValueRange convertedOperands,
MaterializeEncodingFn materializeEncodingFn) {
auto resultType = emptyOp.getResult().getType().cast<RankedTensorType>();
- Optional<TensorEncoding> encoding = getEncoding(resultType);
- if (!encoding) {
- return rewriter.notifyMatchFailure(emptyOp,
- "result type does not have encoding");
- }
FailureOr<MaterializeEncodingInfo> materializeEncodingInfo =
- materializeEncodingFn(encoding.value());
+ materializeEncodingFn(resultType);
if (failed(materializeEncodingInfo)) {
return rewriter.notifyMatchFailure(
emptyOp, "failed to find materialization info for result type");
@@ -360,12 +374,12 @@
MLIRContext *context = &getContext();
{
+ Operation *op = getOperation();
RewritePatternSet patterns(context);
- MaterializeEncodingTypeConverter typeConverter(getPackOpInfoFromEncoding);
+ MaterializeEncodingTypeConverter typeConverter(chooseEncodingInfo);
MaterializeEncodingConversionTarget target(*context);
populateMaterializeEncodingPatterns(patterns, target, typeConverter);
- if (failed(applyPartialConversion(getOperation(), target,
- std::move(patterns))))
+ if (failed(applyPartialConversion(op, target, std::move(patterns))))
return signalPassFailure();
}
@@ -391,6 +405,7 @@
addConversion([](IntegerType intType) { return intType; });
addConversion([](IndexType indexType) { return indexType; });
addConversion([](FloatType floatType) { return floatType; });
+ addConversion([](MemRefType memrefType) { return memrefType; });
addConversion(
[materializeEncodingFn](RankedTensorType t) -> RankedTensorType {
return getMaterializedType(t, materializeEncodingFn);
diff --git a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeWinogradPass.cpp b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeWinogradPass.cpp
new file mode 100644
index 0000000..7214789
--- /dev/null
+++ b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeWinogradPass.cpp
@@ -0,0 +1,381 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h"
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "iree-dialects/Dialect/LinalgExt/Passes/PassDetail.h"
+#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h"
+#include "iree-dialects/Dialect/LinalgExt/Utils/Utils.h"
+#include "iree-dialects/Dialect/LinalgExt/Utils/WinogradConstants.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/SCF/Transforms/Transforms.h"
+#include "mlir/Dialect/Tensor/Utils/Utils.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/Support/Debug.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace LinalgExt {
+
+namespace {
+
+static void computeLoopParams(SmallVectorImpl<Value> &lbs,
+ SmallVectorImpl<Value> &ubs,
+ SmallVectorImpl<Value> &steps, Value tensor,
+ int numImageDims, Location loc,
+ OpBuilder &builder) {
+ Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+ Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
+ SmallVector<OpFoldResult> dimValues =
+ tensor::createDimValues(builder, loc, tensor);
+ for (int i = numImageDims; i < dimValues.size(); i++) {
+ lbs.push_back(zero);
+ ubs.push_back(getValueOrCreateConstantIndexOp(builder, loc, dimValues[i]));
+ steps.push_back(one);
+ }
+}
+
+class ReifyWinogradInputTransform final
+ : public OpRewritePattern<WinogradInputTransformOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(WinogradInputTransformOp inputOp,
+ PatternRewriter &rewriter) const override {
+ Location loc = inputOp.getLoc();
+ auto funcOp = inputOp->getParentOfType<func::FuncOp>();
+ if (!funcOp) {
+ return rewriter.notifyMatchFailure(
+ inputOp, "Could not find parent of type funcOp");
+ }
+
+ const float *BT{nullptr};
+ const float *B{nullptr};
+ const int64_t inputTileSize = inputOp.getInputTileSize();
+ const int64_t outputTileSize = inputOp.getOutputTileSize();
+ switch (outputTileSize) {
+ case 6:
+ B = IREE::LinalgExt::Winograd::B_6x6_3x3;
+ BT = IREE::LinalgExt::Winograd::BT_6x6_3x3;
+ break;
+ default:
+ return failure();
+ }
+ /// The two values below are the transpose(B) [BTV]
+ /// and B [BV] constant matrices that convert the input
+ /// tile to the Winograd domain.
+ Value BTV = IREE::LinalgExt::createValueFrom2DConstant(
+ BT, inputTileSize, inputTileSize, loc, rewriter);
+ Value BV = IREE::LinalgExt::createValueFrom2DConstant(
+ B, inputTileSize, inputTileSize, loc, rewriter);
+
+ Value input = inputOp.input();
+ Value output = inputOp.output();
+ auto outputType = output.getType().cast<ShapedType>();
+ auto inputType = input.getType().cast<ShapedType>();
+ ArrayRef<int64_t> inputShape = inputType.getShape();
+ Type elementType = outputType.getElementType();
+ SmallVector<int64_t> imageDims = inputOp.imageDimensions();
+ const size_t numImageDims = imageDims.size();
+ llvm::SmallSetVector<int64_t, 2> imageDimsSet(imageDims.begin(),
+ imageDims.end());
+ SmallVector<int64_t> inputTileSquare(imageDims.size(), inputTileSize);
+
+ rewriter.setInsertionPointToStart(&funcOp.getBody().front());
+ Value zeroF32 = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getZeroAttr(elementType));
+ Value scratch =
+ rewriter.create<tensor::EmptyOp>(loc, inputTileSquare, elementType);
+
+ rewriter.setInsertionPoint(inputOp);
+ SmallVector<Value> lbs, ubs, steps;
+ computeLoopParams(lbs, ubs, steps, output, numImageDims, loc, rewriter);
+ // Construct loops
+ scf::LoopNest loopNest = scf::buildLoopNest(
+ rewriter, loc, lbs, ubs, steps, ValueRange({output}),
+ [&](OpBuilder &nestedBuilder, Location loc, ValueRange outputIvs,
+ ValueRange iterArgs) -> scf::ValueVector { return {iterArgs[0]}; });
+
+ // Extract input slice
+ auto one = rewriter.getIndexAttr(1);
+ auto zero = rewriter.getIndexAttr(0);
+ auto inputTileSizeAttr = rewriter.getIndexAttr(inputTileSize);
+ SmallVector<OpFoldResult> strides(inputOp.getInputOperandRank(), one);
+ SmallVector<OpFoldResult> sizes(inputOp.getInputOperandRank(), one);
+ SmallVector<OpFoldResult> offsets(inputOp.getInputOperandRank(), zero);
+ SmallVector<Value> ivs;
+ for (scf::ForOp loop : loopNest.loops) {
+ ivs.push_back(loop.getInductionVar());
+ }
+ for (int i = 0; i < inputShape.size(); i++) {
+ if (!imageDimsSet.contains(i)) {
+ offsets[i] = ivs[i];
+ } else {
+ rewriter.setInsertionPointToStart(loopNest.loops[i].getBody());
+ AffineExpr dim0;
+ auto it = rewriter.getAffineConstantExpr(inputTileSize);
+ auto ot = rewriter.getAffineConstantExpr(outputTileSize);
+ auto delta = rewriter.getAffineConstantExpr(inputShape[i]);
+ bindDims(rewriter.getContext(), dim0);
+ AffineMap scaleMap =
+ AffineMap::get(1, 0, {dim0 * ot}, rewriter.getContext());
+ offsets[i] = rewriter.createOrFold<AffineApplyOp>(loc, scaleMap,
+ ValueRange{ivs[i]});
+ AffineMap minMap =
+ AffineMap::get(1, 0, {-dim0 + delta, it}, rewriter.getContext());
+ sizes[i] = rewriter.createOrFold<AffineMinOp>(
+ loc, minMap,
+ ValueRange{
+ getValueOrCreateConstantIndexOp(rewriter, loc, offsets[i])});
+ }
+ }
+ rewriter.setInsertionPointToStart(loopNest.loops.back().getBody());
+ auto tensorType = RankedTensorType::get(
+ SmallVector<int64_t>(numImageDims, ShapedType::kDynamic), elementType);
+ Value dynamicSlice = rewriter.create<tensor::ExtractSliceOp>(
+ loc, tensorType, input, offsets, sizes, strides);
+
+ // Copy input slice into zeroed padded scratch space
+ strides = SmallVector<OpFoldResult>(numImageDims, one);
+ offsets = SmallVector<OpFoldResult>(numImageDims, zero);
+ sizes = SmallVector<OpFoldResult>{sizes[1], sizes[2]};
+ linalg::FillOp fillOp = rewriter.create<linalg::FillOp>(
+ loc, ValueRange{zeroF32}, ValueRange{scratch});
+ Value inputSlice = rewriter.create<tensor::InsertSliceOp>(
+ loc, dynamicSlice, fillOp.result(), offsets, sizes, strides);
+
+ // Extract output slice
+ strides = SmallVector<OpFoldResult>(inputOp.getOutputOperandRank(), one);
+ offsets = SmallVector<OpFoldResult>(numImageDims, zero);
+ offsets.append(ivs.begin(), ivs.end());
+ sizes = SmallVector<OpFoldResult>(inputOp.getOutputOperandRank(), one);
+ sizes[0] = sizes[1] = inputTileSizeAttr;
+ tensorType = RankedTensorType::get(inputTileSquare, elementType);
+ Value iterArg = loopNest.loops.back().getRegionIterArg(0);
+ Value outputSlice = rewriter.create<tensor::ExtractSliceOp>(
+ loc, tensorType, iterArg, offsets, sizes, strides);
+
+ // Create computation
+ Value result, AMatrix, BMatrix;
+ linalg::MatmulOp matmulOp;
+ for (int i = 0; i < 2; i++) {
+ fillOp = rewriter.create<linalg::FillOp>(loc, ValueRange{zeroF32},
+ ValueRange{outputSlice});
+ if (i == 0) {
+ AMatrix = inputSlice;
+ BMatrix = BV;
+ } else {
+ AMatrix = BTV;
+ BMatrix = result;
+ }
+ matmulOp = rewriter.create<linalg::MatmulOp>(
+ loc, tensorType, ValueRange{AMatrix, BMatrix}, fillOp.result());
+ result = matmulOp.getResult(0);
+ }
+
+ // Insert results into output slice
+ Value updatedOutput = rewriter.create<tensor::InsertSliceOp>(
+ loc, result, iterArg, offsets, sizes, strides);
+
+ // Replace returned value
+ if (scf::YieldOp yieldOp = dyn_cast<scf::YieldOp>(
+ loopNest.loops.back().getBody()->getTerminator())) {
+ rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, updatedOutput);
+ }
+ inputOp.getResults()[0].replaceAllUsesWith(loopNest.getResults()[0]);
+ return success();
+ }
+};
+
+} // namespace
+
+namespace {
+
+class ReifyWinogradOutputTransform final
+ : public OpRewritePattern<WinogradOutputTransformOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(WinogradOutputTransformOp outputOp,
+ PatternRewriter &rewriter) const override {
+ Location loc = outputOp.getLoc();
+ auto funcOp = outputOp->getParentOfType<func::FuncOp>();
+ if (!funcOp) {
+ return rewriter.notifyMatchFailure(
+ outputOp, "Could not find parent of type funcOp");
+ }
+
+ const float *AT{nullptr};
+ const float *A{nullptr};
+ const int64_t inputTileSize = outputOp.getInputTileSize();
+ const int64_t outputTileSize = outputOp.getOutputTileSize();
+ switch (outputTileSize) {
+ case 6:
+ A = IREE::LinalgExt::Winograd::A_6x6_3x3;
+ AT = IREE::LinalgExt::Winograd::AT_6x6_3x3;
+ break;
+ default:
+ return failure();
+ }
+ /// The two values below are the transpose(A) [ATV]
+ /// and A [AV] constant matrices that convert the output
+ /// tile from the Winograd domain to the original domain.
+ Value ATV = IREE::LinalgExt::createValueFrom2DConstant(
+ AT, outputTileSize, inputTileSize, loc, rewriter);
+ Value AV = IREE::LinalgExt::createValueFrom2DConstant(
+ A, inputTileSize, outputTileSize, loc, rewriter);
+
+ Value input = outputOp.input();
+ Value output = outputOp.output();
+ auto outputType = output.getType().cast<ShapedType>();
+ ArrayRef<int64_t> outputShape = outputType.getShape();
+ Type elementType = outputType.getElementType();
+ SmallVector<int64_t> imageDims = outputOp.imageDimensions();
+ const size_t numImageDims = imageDims.size();
+ llvm::SmallSetVector<int64_t, 2> imageDimsSet(imageDims.begin(),
+ imageDims.end());
+ SmallVector<int64_t> inputTileSquare(imageDims.size(), inputTileSize);
+
+ rewriter.setInsertionPointToStart(&funcOp.getBody().front());
+ Value zeroF32 = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getZeroAttr(elementType));
+ SmallVector<int64_t> scratchShape = {inputTileSize, outputTileSize};
+ Value scratch =
+ rewriter.create<tensor::EmptyOp>(loc, scratchShape, elementType);
+
+ rewriter.setInsertionPoint(outputOp);
+ SmallVector<Value> lbs, ubs, steps;
+ computeLoopParams(lbs, ubs, steps, input, numImageDims, loc, rewriter);
+ // Construct loops
+ scf::LoopNest loopNest = scf::buildLoopNest(
+ rewriter, loc, lbs, ubs, steps, ValueRange({output}),
+ [&](OpBuilder &nestedBuilder, Location loc, ValueRange outputIvs,
+ ValueRange iterArgs) -> scf::ValueVector { return {iterArgs[0]}; });
+
+ // Extract input slice
+ rewriter.setInsertionPointToStart(loopNest.loops.back().getBody());
+ auto one = rewriter.getIndexAttr(1);
+ auto zero = rewriter.getIndexAttr(0);
+ auto inputTileSizeAttr = rewriter.getIndexAttr(inputTileSize);
+ auto outputTileSizeAttr = rewriter.getIndexAttr(outputTileSize);
+ SmallVector<OpFoldResult> strides(outputOp.getInputOperandRank(), one);
+ SmallVector<OpFoldResult> sizes(outputOp.getInputOperandRank(), one);
+ SmallVector<OpFoldResult> offsets(numImageDims, zero);
+ sizes[0] = sizes[1] = inputTileSizeAttr;
+ SmallVector<Value> ivs;
+ for (scf::ForOp loop : loopNest.loops) {
+ ivs.push_back(loop.getInductionVar());
+ }
+ offsets.append(ivs.begin(), ivs.end());
+ auto tensorType = RankedTensorType::get(inputTileSquare, elementType);
+ tensor::ExtractSliceOp extractSliceOp =
+ rewriter.create<tensor::ExtractSliceOp>(loc, tensorType, input, offsets,
+ sizes, strides);
+ Value inputSlice = extractSliceOp.getResult();
+
+ // Extract output slice
+ strides = SmallVector<OpFoldResult>(outputOp.getOutputOperandRank(), one);
+ offsets = SmallVector<OpFoldResult>(outputOp.getOutputOperandRank(), zero);
+ sizes = SmallVector<OpFoldResult>(outputOp.getOutputOperandRank(), one);
+ for (int i = 0; i < outputShape.size(); i++) {
+ if (!imageDimsSet.contains(i)) {
+ offsets[i] = ivs[i];
+ } else {
+ rewriter.setInsertionPointToStart(loopNest.loops[i].getBody());
+ AffineExpr dim0;
+ auto ot = rewriter.getAffineConstantExpr(outputTileSize);
+ bindDims(rewriter.getContext(), dim0);
+ AffineMap scaleMap =
+ AffineMap::get(1, 0, {dim0 * ot}, rewriter.getContext());
+ offsets[i] = rewriter.createOrFold<AffineApplyOp>(loc, scaleMap,
+ ValueRange{ivs[i]});
+ sizes[i] = outputTileSizeAttr;
+ }
+ }
+ rewriter.setInsertionPointAfter(extractSliceOp);
+ tensorType = RankedTensorType::get(
+ SmallVector<int64_t>(numImageDims, outputTileSize), elementType);
+ Value iterArg = loopNest.loops.back().getRegionIterArg(0);
+ Value outputSlice = rewriter.create<tensor::ExtractSliceOp>(
+ loc, tensorType, iterArg, offsets, sizes, strides);
+
+ // Create computation
+ Value result, AMatrix, BMatrix;
+ linalg::MatmulOp matmulOp;
+ linalg::FillOp fillOp;
+ Value tmp;
+ for (int i = 0; i < 2; i++) {
+ tmp = i == 0 ? scratch : outputSlice;
+ fillOp = rewriter.create<linalg::FillOp>(loc, ValueRange{zeroF32},
+ ValueRange{tmp});
+ if (i == 0) {
+ AMatrix = inputSlice;
+ BMatrix = AV;
+ } else {
+ AMatrix = ATV;
+ BMatrix = result;
+ }
+ matmulOp = rewriter.create<linalg::MatmulOp>(
+ loc, tmp.getType(), ValueRange{AMatrix, BMatrix}, fillOp.result());
+ result = matmulOp.getResult(0);
+ }
+
+ // Insert results into output slice
+ Value updatedOutput = rewriter.create<tensor::InsertSliceOp>(
+ loc, result, iterArg, offsets, sizes, strides);
+
+ // Replace returned value
+ if (scf::YieldOp yieldOp = dyn_cast<scf::YieldOp>(
+ loopNest.loops.back().getBody()->getTerminator())) {
+ rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, updatedOutput);
+ }
+ outputOp.getResults()[0].replaceAllUsesWith(loopNest.getResults()[0]);
+ return success();
+ }
+};
+
+} // namespace
+
+namespace {
+struct TileAndDecomposeWinogradTransformPass
+ : public TileAndDecomposeWinogradTransformBase<
+ TileAndDecomposeWinogradTransformPass> {
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<AffineDialect, IREE::LinalgExt::IREELinalgExtDialect,
+ linalg::LinalgDialect, scf::SCFDialect,
+ tensor::TensorDialect>();
+ }
+
+ void runOnOperation() override;
+};
+} // namespace
+
+void TileAndDecomposeWinogradTransformPass::runOnOperation() {
+ MLIRContext *context = &getContext();
+ RewritePatternSet patterns(&getContext());
+ patterns.insert<ReifyWinogradInputTransform, ReifyWinogradOutputTransform>(
+ context);
+ if (failed(
+ applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) {
+ return signalPassFailure();
+ }
+}
+
+std::unique_ptr<OperationPass<func::FuncOp>>
+createTileAndDecomposeWinogradTransformPass() {
+ return std::make_unique<TileAndDecomposeWinogradTransformPass>();
+}
+
+} // namespace LinalgExt
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/Tiling.cpp b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/Tiling.cpp
index f4067df..7c4cd79 100644
--- a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/Tiling.cpp
+++ b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/Tiling.cpp
@@ -402,6 +402,11 @@
StringAttr::get(context, "tiling_repeated_indices_scatter_input"),
StringAttr::get(context, "tiling_repeated_indices_scatter_output")));
+ patterns.add<TilingInterfaceTilingPattern>(
+ context, linalg::LinalgTilingOptions().setTileSizes({1, 32}),
+ IREE::LinalgExt::LinalgTransformationFilter(
+ StringAttr::get(context, "tiling_winograd_input_nhwc")));
+
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
diff --git a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp
index 58e892f..725d2ad 100644
--- a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp
+++ b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp
@@ -107,96 +107,6 @@
return tileLoopNest;
}
-namespace {
-///
-/// Linalg tile and fuse tensor ops pattern.
-///
-/// Apply tiling and fusion as a pattern.
-/// See `tileConsumerAndFuseProducers` for more details.
-struct LinalgTileAndFuseTensorOpsBasePattern : public RewritePattern {
- // Entry point to match any LinalgOp.
- LinalgTileAndFuseTensorOpsBasePattern(
- MLIRContext *context, linalg::LinalgTilingAndFusionOptions options,
- PatternBenefit benefit = 1)
- : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
- options(std::move(options)) {}
- // Entry point to match a specific LinalgOp.
- LinalgTileAndFuseTensorOpsBasePattern(
- StringRef opName, MLIRContext *context,
- linalg::LinalgTilingAndFusionOptions options, PatternBenefit benefit = 1)
- : RewritePattern(opName, benefit, context), options(std::move(options)) {}
-
- /// `matchAndRewrite` implementation that returns the significant transformed
- /// pieces of IR.
- FailureOr<linalg::TileLoopNest>
- returningMatchAndRewrite(Operation *op, PatternRewriter &rewriter) const;
-
- LogicalResult matchAndRewrite(Operation *op,
- PatternRewriter &rewriter) const override {
- return returningMatchAndRewrite(op, rewriter);
- }
-
-private:
- /// Tile sizes and interchange used to tile the root operation.
- linalg::LinalgTilingAndFusionOptions options;
-};
-} // namespace
-
-FailureOr<mlir::linalg::TileLoopNest>
-LinalgTileAndFuseTensorOpsBasePattern::returningMatchAndRewrite(
- Operation *op, PatternRewriter &rewriter) const {
- linalg::LinalgOp rootOp = dyn_cast<linalg::LinalgOp>(op);
- if (!rootOp)
- return failure();
-
- // Check `tileSizes` contains a tile size for every `rootOp` loop dimension.
- if (options.tileSizes.size() < rootOp.getNumLoops())
- return rewriter.notifyMatchFailure(op, "expect #tile sizes >= #loops");
-
- // Check `tileInterchange` contains no entries or as many as `tileSizes`.
- if (!options.tileInterchange.empty() &&
- options.tileInterchange.size() != options.tileSizes.size())
- return rewriter.notifyMatchFailure(
- op, "expect the number of tile sizes and interchange dims to match");
-
- // Copy the `tileSizes` and `tileInterchange` prefixes needed for `rootOp`.
- SmallVector<int64_t> rootTileSizes(options.tileSizes.begin(),
- options.tileSizes.begin() +
- rootOp.getNumLoops());
- SmallVector<int64_t> rootInterchange =
- options.tileInterchange.empty()
- ? llvm::to_vector<6>(llvm::seq<int64_t>(0, rootOp.getNumLoops()))
- : SmallVector<int64_t>(options.tileInterchange.begin(),
- options.tileInterchange.begin() +
- rootOp.getNumLoops());
-
- // Check `rootTileSizes` contains non-zero tile sizes.
- if (llvm::count(rootTileSizes, 0) == static_cast<long>(rootTileSizes.size()))
- return rewriter.notifyMatchFailure(
- op, "expect at least one non-zero tile size");
-
- // Check `rootInterchange` is a permutation of the `rootOp` loop dimensions.
- // It has to be a permutation since the tiling cannot tile the same loop
- // dimension multiple times.
- if (!linalg::isPermutation(rootInterchange))
- return rewriter.notifyMatchFailure(
- op, "expect the tile interchange permutes the root loops");
-
- // Tile `rootOp` and fuse its producers.
- FailureOr<linalg::TileLoopNest> tileLoopNest =
- IREE::LinalgExt::tileConsumerAndFuseProducers(
- rewriter, rootOp, rootTileSizes, rootInterchange,
- options.tileDistribution);
- if (failed(tileLoopNest))
- return rewriter.notifyMatchFailure(
- op, "tileConsumerAndFuseProducers failed unexpectedly");
-
- // Replace all uses of the tiled loop operation.
- rootOp->replaceAllUsesWith(tileLoopNest->getRootOpReplacementResults());
-
- return tileLoopNest;
-}
-
/// Peel loops after tiling.
static void peelTiledLinalgOp(RewriterBase &rewriter,
linalg::TiledLinalgOp &res,
@@ -256,6 +166,83 @@
return res;
}
+/// Linalg SCF tiling pattern.
+LinalgSCFTilingPattern::LinalgSCFTilingPattern(
+ MLIRContext *context, scf::SCFTilingOptions options,
+ LinalgExt::LinalgTransformationFilter f, PatternBenefit benefit)
+ : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
+ filter(std::move(f)), options(std::move(options)) {}
+
+LinalgSCFTilingPattern::LinalgSCFTilingPattern(
+ StringRef opName, MLIRContext *context, scf::SCFTilingOptions options,
+ LinalgExt::LinalgTransformationFilter f, PatternBenefit benefit)
+ : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
+ filter(f.addOpNameFilter(opName)), options(std::move(options)) {}
+
+LogicalResult LinalgSCFTilingPattern::returningMatchAndRewrite(
+ TilingInterface op, PatternRewriter &rewriter) const {
+ if (failed(filter.checkAndNotify(rewriter, op)))
+ return failure();
+
+ FailureOr<scf::SCFTilingResult> tiledResults =
+ scf::tileUsingSCFForOp(rewriter, op, options);
+ if (failed(tiledResults))
+ return failure();
+
+ rewriter.replaceOp(op, tiledResults->replacements);
+
+ for (auto tiledOp : tiledResults->tiledOps) {
+ filter.replaceLinalgTransformationFilter(rewriter, tiledOp);
+ }
+
+ return success();
+}
+
+/// Linalg tile and fuse tensor ops pattern.
+LinalgSCFTileAndFusePattern::LinalgSCFTileAndFusePattern(
+ MLIRContext *context, scf::SCFTileAndFuseOptions options,
+ LinalgExt::LinalgTransformationFilter f, PatternBenefit benefit)
+ : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
+ filter(std::move(f)), options(std::move(options)) {}
+
+LinalgSCFTileAndFusePattern::LinalgSCFTileAndFusePattern(
+ StringRef opName, MLIRContext *context, scf::SCFTileAndFuseOptions options,
+ LinalgExt::LinalgTransformationFilter f, PatternBenefit benefit)
+ : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
+ filter(f.addOpNameFilter(opName)), options(std::move(options)) {}
+
+LogicalResult
+LinalgSCFTileAndFusePattern::matchAndRewrite(TilingInterface op,
+ PatternRewriter &rewriter) const {
+ if (failed(filter.checkAndNotify(rewriter, op)))
+ return failure();
+
+ FailureOr<scf::SCFTileAndFuseResult> tiledResults =
+ tileConsumerAndFuseProducerGreedilyUsingSCFForOp(rewriter, op, options);
+ if (failed(tiledResults))
+ return rewriter.notifyMatchFailure(
+ op,
+ "tileConsumerAndFuseProducerGreedilyUsingSCFForOp failed unexpectedly");
+
+ // Replace all uses of the tiled loop operation.
+ SmallVector<Value> replacements(op->getNumResults());
+ for (auto result : llvm::enumerate(op->getResults())) {
+ auto it = tiledResults->replacements.find(result.value());
+ if (it == tiledResults->replacements.end()) {
+ replacements[result.index()] = result.value();
+ } else {
+ replacements[result.index()] = it->getSecond();
+ }
+ }
+ rewriter.replaceOp(op, replacements);
+
+ // Apply the filter if specified.
+ for (linalg::LinalgOp linalgOp : tiledResults->tiledAndFusedOps)
+ filter.replaceLinalgTransformationFilter(rewriter, linalgOp);
+
+ return success();
+}
+
LinalgVectorizationPattern::LinalgVectorizationPattern(
MLIRContext *context, LinalgExt::LinalgTransformationFilter f,
PatternBenefit benefit)
@@ -346,62 +333,6 @@
return success();
}
-namespace {
-///
-/// Linalg tile and fuse tensor ops pattern.
-///
-/// Apply tiling and fusion as a pattern.
-/// `filter` controls LinalgTransformMarker matching and update when specified.
-/// See `tileConsumerAndFuseProducers` for more details.
-struct LinalgTileAndFuseTensorOpsPattern : public RewritePattern {
- // Entry point to match any LinalgOp.
- LinalgTileAndFuseTensorOpsPattern(
- MLIRContext *context, linalg::LinalgTilingAndFusionOptions options,
- LinalgExt::LinalgTransformationFilter f =
- LinalgExt::LinalgTransformationFilter(),
- PatternBenefit benefit = 1)
- : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
- filter(std::move(f)), options(std::move(options)) {}
- // Entry point to match a specific LinalgOp.
- LinalgTileAndFuseTensorOpsPattern(
- StringRef opName, MLIRContext *context,
- linalg::LinalgTilingAndFusionOptions options,
- LinalgExt::LinalgTransformationFilter f =
- LinalgExt::LinalgTransformationFilter(),
- PatternBenefit benefit = 1)
- : RewritePattern(opName, benefit, context), filter(std::move(f)),
- options(std::move(options)) {}
-
- /// `matchAndRewrite` implementation that returns the significant transformed
- /// pieces of IR.
- FailureOr<linalg::TileLoopNest>
- returningMatchAndRewrite(Operation *op, PatternRewriter &rewriter) const {
- if (failed(filter.checkAndNotify(rewriter, op)))
- return failure();
- LinalgTileAndFuseTensorOpsBasePattern p(op->getContext(), options);
- auto maybeTileLoopNest = p.returningMatchAndRewrite(op, rewriter);
- if (failed(maybeTileLoopNest))
- return failure();
- // Apply the filter if specified.
- for (linalg::LinalgOp linalgOp :
- maybeTileLoopNest->getAllTiledAndFusedOps())
- filter.replaceLinalgTransformationFilter(rewriter, linalgOp);
- return maybeTileLoopNest;
- }
-
- LogicalResult matchAndRewrite(Operation *op,
- PatternRewriter &rewriter) const override {
- return returningMatchAndRewrite(op, rewriter);
- }
-
-private:
- /// LinalgTransformMarker handles special attribute manipulations.
- LinalgExt::LinalgTransformationFilter filter;
- /// Tile sizes and interchange used to tile the root operation.
- linalg::LinalgTilingAndFusionOptions options;
-};
-} // namespace
- //
/// Configurable pass to apply pattern-based tiling and fusion.
struct LinalgStrategyTileAndFusePass
: public LinalgStrategyTileAndFusePassBase<LinalgStrategyTileAndFusePass> {
@@ -409,9 +340,9 @@
LinalgStrategyTileAndFusePass() = default;
LinalgStrategyTileAndFusePass(StringRef opName,
- linalg::LinalgTilingAndFusionOptions opt,
+ scf::SCFTileAndFuseOptions options,
LinalgExt::LinalgTransformationFilter filt)
- : options(std::move(opt)), filter(std::move(filt)) {
+ : options(std::move(options)), filter(std::move(filt)) {
this->anchorOpName.setValue(opName.str());
}
@@ -422,10 +353,10 @@
RewritePatternSet tilingAndFusionPattern(funcOp.getContext());
if (!anchorOpName.empty()) {
- tilingAndFusionPattern.add<LinalgTileAndFuseTensorOpsPattern>(
+ tilingAndFusionPattern.add<LinalgSCFTileAndFusePattern>(
anchorOpName, funcOp.getContext(), options, filter);
} else {
- tilingAndFusionPattern.add<LinalgTileAndFuseTensorOpsPattern>(
+ tilingAndFusionPattern.add<LinalgSCFTileAndFusePattern>(
funcOp.getContext(), options, filter);
}
// Search the root operation using bottom up traversal.
@@ -435,7 +366,7 @@
funcOp, std::move(tilingAndFusionPattern), config);
}
- linalg::LinalgTilingAndFusionOptions options;
+ scf::SCFTileAndFuseOptions options;
LinalgExt::LinalgTransformationFilter filter;
};
@@ -445,9 +376,9 @@
LinalgStrategyTilePass() = default;
- LinalgStrategyTilePass(StringRef opName, linalg::LinalgTilingOptions opt,
+ LinalgStrategyTilePass(StringRef opName, scf::SCFTilingOptions options,
LinalgExt::LinalgTransformationFilter filt)
- : options(std::move(opt)), filter(std::move(filt)) {
+ : options(std::move(options)), filter(std::move(filt)) {
this->anchorOpName.setValue(opName.str());
}
@@ -459,16 +390,15 @@
MLIRContext *ctx = funcOp.getContext();
RewritePatternSet tilingPattern(ctx);
if (!anchorOpName.empty())
- tilingPattern.add<LinalgTilingPattern>(anchorOpName, ctx, options,
- filter);
+ tilingPattern.add<LinalgSCFTilingPattern>(anchorOpName, ctx, options,
+ filter);
else
- tilingPattern.add<LinalgTilingPattern>(ctx, options, filter);
- if (anchorOpName == tensor::PadOp::getOperationName())
- populatePadTensorTilingPatterns(tilingPattern, options);
+ tilingPattern.add<LinalgSCFTilingPattern>(ctx, options, filter);
+
(void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
}
- linalg::LinalgTilingOptions options;
+ scf::SCFTilingOptions options;
LinalgExt::LinalgTransformationFilter filter;
};
@@ -757,7 +687,7 @@
/// Create a LinalgStrategyTileAndFusePass.
std::unique_ptr<OperationPass<func::FuncOp>>
createLinalgStrategyTileAndFusePass(
- StringRef opName, const linalg::LinalgTilingAndFusionOptions &options,
+ StringRef opName, const scf::SCFTileAndFuseOptions &options,
const LinalgExt::LinalgTransformationFilter &filter) {
return std::make_unique<LinalgStrategyTileAndFusePass>(opName, options,
filter);
@@ -765,9 +695,9 @@
/// Create a LinalgStrategyTilePass.
std::unique_ptr<OperationPass<func::FuncOp>> createLinalgStrategyTilePass(
- StringRef opName, const linalg::LinalgTilingOptions &opt,
+ StringRef opName, const scf::SCFTilingOptions &options,
const LinalgExt::LinalgTransformationFilter &filter) {
- return std::make_unique<LinalgStrategyTilePass>(opName, opt, filter);
+ return std::make_unique<LinalgStrategyTilePass>(opName, options, filter);
}
/// Create a LinalgStrategyPadPass.
@@ -859,7 +789,7 @@
// The size is less than or equal to tileSize because outer dims are all 1s.
Optional<int64_t> tileSize =
getConstantIntValue(tileAndPosMapping.lookup(dim));
- assert(tileSize.hasValue() && "dynamic inner tile size is not supported");
+ assert(tileSize.has_value() && "dynamic inner tile size is not supported");
paddedShape.push_back(tileSize.value());
}
auto resultType =
@@ -1022,7 +952,7 @@
// Apply tiling to make outer dims be all 1s.
{
SimpleRewriter rewriter(ctx);
- auto packTilingOptions =
+ auto packOptions = scf::SCFTileAndFuseOptions().setTilingOptions(
scf::SCFTilingOptions().setTileSizeComputationFunction(
[](OpBuilder &builder, Operation *op) {
Location loc = op->getLoc();
@@ -1030,15 +960,16 @@
SmallVector<Value> tileSizes(
inputRank, builder.create<arith::ConstantIndexOp>(loc, 1));
return tileSizes;
- });
+ }));
auto funcOp = getOperation();
funcOp->walk([&](LinalgExt::PackOp op) {
- FailureOr<scf::SCFTilingResult> tilingResult = scf::tileUsingSCFForOp(
- rewriter, cast<TilingInterface>(op.getOperation()),
- packTilingOptions);
- if (failed(tilingResult))
+ FailureOr<scf::SCFTileAndFuseResult> tileAndFuseResult =
+ scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(rewriter, op,
+ packOptions);
+ if (failed(tileAndFuseResult))
return signalPassFailure();
- rewriter.replaceOp(op, tilingResult->replacements);
+ rewriter.replaceOp(op,
+ tileAndFuseResult->replacements[op.getResult(0)]);
});
auto unpackTilingOptions =
diff --git a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Utils/Utils.cpp b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Utils/Utils.cpp
index 937e3ee..5eef5b4 100644
--- a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Utils/Utils.cpp
+++ b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Utils/Utils.cpp
@@ -71,6 +71,15 @@
return interchangeVector;
}
+Value createValueFrom2DConstant(const float *val, int64_t rows, int64_t cols,
+ Location loc, PatternRewriter &rewriter) {
+ ArrayRef<float> vector(val, rows * cols);
+ SmallVector<int64_t> shape{rows, cols};
+ return rewriter.create<arith::ConstantOp>(
+ loc, DenseFPElementsAttr::get(
+ RankedTensorType::get(shape, rewriter.getF32Type()), vector));
+}
+
} // namespace LinalgExt
} // namespace IREE
} // namespace iree_compiler
diff --git a/integrations/tensorflow/iree-dialects/python/IREEDialectsModule.cpp b/integrations/tensorflow/iree-dialects/python/IREEDialectsModule.cpp
index 3c19ffb..1a85609 100644
--- a/integrations/tensorflow/iree-dialects/python/IREEDialectsModule.cpp
+++ b/integrations/tensorflow/iree-dialects/python/IREEDialectsModule.cpp
@@ -5,7 +5,6 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree-dialects-c/Dialects.h"
-#include "iree-dialects-c/Utils.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Diagnostics.h"
@@ -22,28 +21,6 @@
auto typeClass = irModule.attr("Type");
//===--------------------------------------------------------------------===//
- // Utils
- //===--------------------------------------------------------------------===//
-
- m.def(
- "lookup_nearest_symbol_from",
- [](MlirOperation fromOp, MlirAttribute symbol) {
- if (!mlirAttributeIsASymbolRef(symbol)) {
- throw std::invalid_argument("expected a SymbolRefAttr");
- }
- return ireeLookupNearestSymbolFrom(fromOp, symbol);
- },
- py::arg("fromOp"), py::arg("symbol"));
-
- // TODO: Upstream this into the main Python bindings.
- m.def(
- "emit_error",
- [](MlirLocation loc, std::string message) {
- mlirEmitError(loc, message.c_str());
- },
- py::arg("loc"), py::arg("message"));
-
- //===--------------------------------------------------------------------===//
// IREEDialect
//===--------------------------------------------------------------------===//
auto iree_m = m.def_submodule("iree_input");
diff --git a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/conv2d_to_winograd.mlir b/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/conv2d_to_winograd.mlir
new file mode 100644
index 0000000..be10a1d
--- /dev/null
+++ b/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/conv2d_to_winograd.mlir
@@ -0,0 +1,77 @@
+// RUN: iree-dialects-opt --split-input-file -iree-linalg-ext-convert-conv2d-to-winograd -mlir-elide-elementsattrs-if-larger=4 %s | FileCheck %s
+
+func.func @conv_16433136(%arg0: tensor<1x16x16x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
+ %c0 = arith.constant dense<0.1> : tensor<3x3x4x16xf32>
+ %0 = linalg.conv_2d_nhwc_hwcf
+ {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+ ins(%arg0, %c0: tensor<1x16x16x4xf32>, tensor<3x3x4x16xf32>)
+ outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
+ return %0 : tensor<1x14x14x16xf32>
+}
+// CHECK: func.func @conv_16433136(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x16x16x4xf32>, %[[ARG1:[a-zA-Z0-9_]+]]:
+// CHECK-SAME: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
+// CHECK: %[[CST:.+]] = arith.constant dense_resource<__elided__> : tensor<64x4x16xf32>
+// CHECK: %[[CST_0:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[D0:.+]] = tensor.empty() : tensor<8x8x1x3x3x4xf32>
+// CHECK: %[[D1:.+]] = iree_linalg_ext.winograd.input_transform output_tile_size(6) kernel_size(3)
+// CHECK-SAME: image_dimensions([1, 2]) ins(%[[ARG0]] : tensor<1x16x16x4xf32>) outs(%[[D0]] :
+// CHECK-SAME: tensor<8x8x1x3x3x4xf32>) -> tensor<8x8x1x3x3x4xf32>
+// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[D1]]
+// CHECK-SAME{LITERAL}: [[0, 1], [2, 3, 4], [5]]
+// CHECK-SAME: tensor<8x8x1x3x3x4xf32> into tensor<64x9x4xf32>
+// CHECK: %[[D2:.+]] = tensor.empty() : tensor<64x9x16xf32>
+// CHECK: %[[D3:.+]] = linalg.fill ins(%[[CST_0]] : f32) outs(%[[D2]] : tensor<64x9x16xf32>) ->
+// CHECK-SAME: tensor<64x9x16xf32>
+// CHECK: %[[D4:.+]] = linalg.batch_matmul ins(%[[COLLAPSED]], %[[CST]] : tensor<64x9x4xf32>,
+// CHECK-SAME: tensor<64x4x16xf32>) outs(%[[D3]] : tensor<64x9x16xf32>) -> tensor<64x9x16xf32>
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[D4]]
+// CHECK-SAME{LITERAL}: [[0, 1], [2, 3, 4], [5]]
+// CHECK-SAME: tensor<64x9x16xf32> into tensor<8x8x1x3x3x16xf32>
+// CHECK: %[[D5:.+]] = tensor.empty() : tensor<1x18x18x16xf32>
+// CHECK: %[[D6:.+]] = iree_linalg_ext.winograd.output_transform output_tile_size(6) kernel_size(3)
+// CHECK-SAME: image_dimensions([1, 2]) ins(%[[EXPANDED]] : tensor<8x8x1x3x3x16xf32>) outs(%[[D5]] :
+// CHECK-SAME: tensor<1x18x18x16xf32>) -> tensor<1x18x18x16xf32>
+// CHECK: %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[D6]][0, 0, 0, 0] [1, 14, 14, 16] [1, 1, 1, 1] :
+// CHECK-SAME: tensor<1x18x18x16xf32> to tensor<1x14x14x16xf32>
+// CHECK: return %[[EXTRACTED_SLICE]] : tensor<1x14x14x16xf32>
+// CHECK: }
+
+// -----
+
+func.func @conv2d_non_splat_weights(%inputs : tensor<1x4x4x1xf32>, %arg2: tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32> {
+ %c0 = arith.constant dense<[[ [[1.0]], [[3.0]], [[5.0]] ],
+ [ [[7.0]], [[9.0]], [[11.0]] ],
+ [ [[13.0]], [[15.0]], [[17.0]] ]]> : tensor<3x3x1x1xf32>
+ %0 = linalg.conv_2d_nhwc_hwcf
+ {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+ ins(%inputs, %c0: tensor<1x4x4x1xf32>, tensor<3x3x1x1xf32>)
+ outs(%arg2: tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32>
+ return %0 : tensor<1x2x2x1xf32>
+}
+// CHECK: func.func @conv2d_non_splat_weights(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x4x4x1xf32>,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32> {
+// CHECK: %[[CST:.+]] = arith.constant dense_resource<__elided__> : tensor<64x1x1xf32>
+// CHECK: %[[CST_0:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[D0:.+]] = tensor.empty() : tensor<8x8x1x1x1x1xf32>
+// CHECK: %[[D1:.+]] = iree_linalg_ext.winograd.input_transform output_tile_size(6) kernel_size(3)
+// CHECK-SAME: image_dimensions([1, 2]) ins(%[[ARG0]] : tensor<1x4x4x1xf32>) outs(%[[D0]] : tensor<8x8x1x1x1x1xf32>)
+// CHECK-SAME: -> tensor<8x8x1x1x1x1xf32>
+// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[D1]]
+// CHECK-SAME{LITERAL}: [[0, 1], [2, 3, 4], [5]]
+// CHECK-SAME: tensor<8x8x1x1x1x1xf32> into tensor<64x1x1xf32>
+// CHECK: %[[D2:.+]] = tensor.empty() : tensor<64x1x1xf32>
+// CHECK: %[[D3:.+]] = linalg.fill ins(%[[CST_0]] : f32) outs(%[[D2]] : tensor<64x1x1xf32>) ->
+// CHECK-SAME: tensor<64x1x1xf32>
+// CHECK: %[[D4:.+]] = linalg.batch_matmul ins(%[[COLLAPSED]], %[[CST]] : tensor<64x1x1xf32>, tensor<64x1x1xf32>)
+// CHECK-SAME: outs(%[[D3]] : tensor<64x1x1xf32>) -> tensor<64x1x1xf32>
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[D4]]
+// CHECK-SAME{LITERAL}: [[0, 1], [2, 3, 4], [5]]
+// CHECK-SAME: tensor<64x1x1xf32> into tensor<8x8x1x1x1x1xf32>
+// CHECK: %[[D5:.+]] = tensor.empty() : tensor<1x6x6x1xf32>
+// CHECK: %[[D6:.+]] = iree_linalg_ext.winograd.output_transform output_tile_size(6) kernel_size(3)
+// CHECK-SAME: image_dimensions([1, 2]) ins(%[[EXPANDED]] : tensor<8x8x1x1x1x1xf32>) outs(%[[D5]] :
+// CHECK-SAME: tensor<1x6x6x1xf32>) -> tensor<1x6x6x1xf32>
+// CHECK: %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[D6]][0, 0, 0, 0] [1, 2, 2, 1] [1, 1, 1, 1] :
+// CHECK-SAME: tensor<1x6x6x1xf32> to tensor<1x2x2x1xf32>
+// CHECK: return %[[EXTRACTED_SLICE]] : tensor<1x2x2x1xf32>
+// CHECK: }
diff --git a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/invalid.mlir b/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/invalid.mlir
index 5b9b499..038f293 100644
--- a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/invalid.mlir
+++ b/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/invalid.mlir
@@ -590,9 +590,9 @@
// -----
-func.func @illegal_set_encoding_op_with_source_encoding(%arg0 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>) -> tensor<?x?xf32> {
+func.func @illegal_set_encoding_op_with_source_encoding(%arg0 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>) -> tensor<?x?xf32> {
// expected-error @+1 {{source of set_encoding op cannot have a tensor encoding}}
- %0 = iree_linalg_ext.set_encoding %arg0: tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>> -> tensor<?x?xf32>
+ %0 = iree_linalg_ext.set_encoding %arg0: tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>> -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
@@ -606,18 +606,18 @@
// -----
-func.func @illegal_set_encoding_op_with_rank_change(%arg0 : tensor<?x?xf32>) -> tensor<?xf32, #iree_linalg_ext.encoding<GEMM_LHS>> {
+func.func @illegal_set_encoding_op_with_rank_change(%arg0 : tensor<?x?xf32>) -> tensor<?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>> {
// expected-error @+1 {{cannot change the rank of the tensor}}
- %0 = iree_linalg_ext.set_encoding %arg0: tensor<?x?xf32> -> tensor<?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
- return %0 : tensor<?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
+ %0 = iree_linalg_ext.set_encoding %arg0: tensor<?x?xf32> -> tensor<?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
+ return %0 : tensor<?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
}
// -----
-func.func @illegal_set_encoding_op_with_shape_change(%arg0 : tensor<10x20xf32>) -> tensor<20x30xf32, #iree_linalg_ext.encoding<GEMM_LHS>> {
+func.func @illegal_set_encoding_op_with_shape_change(%arg0 : tensor<10x20xf32>) -> tensor<20x30xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>> {
// expected-error @+1 {{expected to preserve the logical shape of the tensor}}
- %0 = iree_linalg_ext.set_encoding %arg0: tensor<10x20xf32> -> tensor<20x30xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
- return %0 : tensor<20x30xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
+ %0 = iree_linalg_ext.set_encoding %arg0: tensor<10x20xf32> -> tensor<20x30xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
+ return %0 : tensor<20x30xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
}
// -----
@@ -630,10 +630,10 @@
// -----
-func.func @illegal_unset_encoding_op_with_result_encoding(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>> {
+func.func @illegal_unset_encoding_op_with_result_encoding(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>> {
// expected-error @+1 {{result of unset_encoding op cannot have a tensor encoding}}
- %0 = iree_linalg_ext.unset_encoding %arg0: tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
- return %0 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
+ %0 = iree_linalg_ext.unset_encoding %arg0: tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
+ return %0 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
}
// -----
@@ -646,16 +646,49 @@
// -----
-func.func @illegal_unset_encoding_op_with_rank_change(%arg0 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>) -> tensor<?xf32> {
+func.func @illegal_unset_encoding_op_with_rank_change(%arg0 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>) -> tensor<?xf32> {
// expected-error @+1 {{cannot change the rank of the tensor}}
- %0 = iree_linalg_ext.unset_encoding %arg0: tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>> -> tensor<?xf32>
+ %0 = iree_linalg_ext.unset_encoding %arg0: tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>> -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// -----
-func.func @illegal_unset_encoding_op_with_shape_change(%arg0 : tensor<20x30xf32, #iree_linalg_ext.encoding<GEMM_LHS>>) -> tensor<10x20xf32> {
+func.func @illegal_unset_encoding_op_with_shape_change(%arg0 : tensor<20x30xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>) -> tensor<10x20xf32> {
// expected-error @+1 {{expected to preserve the logical shape of the tensor}}
- %0 = iree_linalg_ext.unset_encoding %arg0: tensor<20x30xf32, #iree_linalg_ext.encoding<GEMM_LHS>> -> tensor<10x20xf32>
+ %0 = iree_linalg_ext.unset_encoding %arg0: tensor<20x30xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>> -> tensor<10x20xf32>
return %0 : tensor<10x20xf32>
}
+
+// -----
+
+func.func @illegal_winograd_input_shape(%arg0: tensor<1x10x10x32xf32>) -> tensor<8x8x1x6x6x32xf32> {
+ %0 = tensor.empty() : tensor<8x8x1x6x6x32xf32>
+ // expected-error @+1 {{incompatible output shape}}
+ %1 = iree_linalg_ext.winograd.input_transform output_tile_size(6) kernel_size(3) image_dimensions([1, 2])
+ ins(%arg0 : tensor<1x10x10x32xf32>) outs(%0 : tensor<8x8x1x6x6x32xf32>) -> tensor<8x8x1x6x6x32xf32>
+ return %1 : tensor<8x8x1x6x6x32xf32>
+}
+
+// -----
+
+func.func @illegal_winograd_input_rank(%arg0: tensor<1x10x10x32xf32>) -> tensor<8x8x1x6xf32> {
+ %0 = tensor.empty() : tensor<8x8x1x6xf32>
+ // expected-error @+1 {{expected output rank to be equal to input rank + 2}}
+ %1 = iree_linalg_ext.winograd.input_transform output_tile_size(6) kernel_size(3) image_dimensions([1, 2])
+ ins(%arg0 : tensor<1x10x10x32xf32>) outs(%0 : tensor<8x8x1x6xf32>) -> tensor<8x8x1x6xf32>
+ return %1 : tensor<8x8x1x6xf32>
+}
+
+// -----
+
+func.func @illegal_winograd_output_shape(%arg0: tensor<8x8x1x2x2x32xf32>) -> tensor<1x8x8x32xf32> {
+ %0 = tensor.empty() : tensor<1x8x8x32xf32>
+ // expected-error @+1 {{incompatible output shape}}
+ %1 = iree_linalg_ext.winograd.output_transform output_tile_size(6)
+ kernel_size(3) image_dimensions([1, 2])
+ ins(%arg0 : tensor<8x8x1x2x2x32xf32>) outs(%0 : tensor<1x8x8x32xf32>) -> tensor<1x8x8x32xf32>
+ return %1 : tensor<1x8x8x32xf32>
+}
+
+// -----
diff --git a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir b/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir
index fe91dc5..35e2053 100644
--- a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir
+++ b/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir
@@ -1,8 +1,8 @@
// RUN: iree-dialects-opt --iree-linalg-ext-materialize-encoding -cse -split-input-file %s | FileCheck %s
func.func @pack_unpack_gemm_lhs(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
- %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
- %1 = iree_linalg_ext.unset_encoding %0 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>> -> tensor<?x?xf32>
+ %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
+ %1 = iree_linalg_ext.unset_encoding %0 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>> -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
@@ -24,8 +24,8 @@
// -----
func.func @pack_unpack_gemm_rhs(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
- %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS>>
- %1 = iree_linalg_ext.unset_encoding %0 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS>> -> tensor<?x?xf32>
+ %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS>>
+ %1 = iree_linalg_ext.unset_encoding %0 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS>> -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
// CHECK-LABEL: func @pack_unpack_gemm_rhs(
@@ -35,8 +35,8 @@
// -----
func.func @pack_unpack_gemm_rhs_transpose(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
- %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS_TRANSPOSE>>
- %1 = iree_linalg_ext.unset_encoding %0 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS_TRANSPOSE>> -> tensor<?x?xf32>
+ %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS_TRANSPOSE>>
+ %1 = iree_linalg_ext.unset_encoding %0 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS_TRANSPOSE>> -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
// CHECK-LABEL: func @pack_unpack_gemm_rhs_transpose(
@@ -46,8 +46,8 @@
// -----
func.func @pack_unpack_gemm_result(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
- %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>
- %1 = iree_linalg_ext.unset_encoding %0 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>> -> tensor<?x?xf32>
+ %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>
+ %1 = iree_linalg_ext.unset_encoding %0 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>> -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
// CHECK-LABEL: func @pack_unpack_gemm_result(
@@ -62,20 +62,20 @@
^bb0(%b0: index, %b1 : index):
tensor.yield %pad_value : f32
} : tensor<100x250xf32> to tensor<104x252xf32>
- %lhs = iree_linalg_ext.set_encoding %pad_lhs : tensor<104x252xf32> -> tensor<104x252xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
+ %lhs = iree_linalg_ext.set_encoding %pad_lhs : tensor<104x252xf32> -> tensor<104x252xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
%pad_rhs = tensor.pad %arg1 low[0, 0] high[2, 4] {
^bb0(%b0: index, %b1 : index):
tensor.yield %pad_value : f32
} : tensor<250x500xf32> to tensor<252x504xf32>
- %rhs = iree_linalg_ext.set_encoding %pad_rhs : tensor<252x504xf32> -> tensor<252x504xf32, #iree_linalg_ext.encoding<GEMM_RHS_TRANSPOSE>>
+ %rhs = iree_linalg_ext.set_encoding %pad_rhs : tensor<252x504xf32> -> tensor<252x504xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS_TRANSPOSE>>
%pad_output = tensor.pad %arg2 low[0, 0] high[4, 4] {
^bb0(%b0: index, %b1 : index):
tensor.yield %pad_value : f32
} : tensor<100x500xf32> to tensor<104x504xf32>
- %output = iree_linalg_ext.set_encoding %pad_output : tensor<104x504xf32> -> tensor<104x504xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>
- %gemm_packed = linalg.matmul ins(%lhs, %rhs : tensor<104x252xf32, #iree_linalg_ext.encoding<GEMM_LHS>>, tensor<252x504xf32, #iree_linalg_ext.encoding<GEMM_RHS_TRANSPOSE>>)
- outs(%output : tensor<104x504xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>) -> tensor<104x504xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>
- %gemm = iree_linalg_ext.unset_encoding %gemm_packed : tensor<104x504xf32, #iree_linalg_ext.encoding<GEMM_RESULT>> -> tensor<104x504xf32>
+ %output = iree_linalg_ext.set_encoding %pad_output : tensor<104x504xf32> -> tensor<104x504xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>
+ %gemm_packed = linalg.matmul ins(%lhs, %rhs : tensor<104x252xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>, tensor<252x504xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS_TRANSPOSE>>)
+ outs(%output : tensor<104x504xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>) -> tensor<104x504xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>
+ %gemm = iree_linalg_ext.unset_encoding %gemm_packed : tensor<104x504xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>> -> tensor<104x504xf32>
%result = tensor.extract_slice %gemm[0, 0] [100, 500] [1, 1] : tensor<104x504xf32> to tensor<100x500xf32>
return %result : tensor<100x500xf32>
}
@@ -102,12 +102,12 @@
// -----
func.func @pack_gemm_dynamic(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
- %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
- %1 = iree_linalg_ext.set_encoding %arg1 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS_TRANSPOSE>>
- %2 = iree_linalg_ext.set_encoding %arg2 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>
- %3 = linalg.matmul ins(%0, %1 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>, tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS_TRANSPOSE>>)
- outs(%2 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>) -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>
- %4 = iree_linalg_ext.unset_encoding %3 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>> -> tensor<?x?xf32>
+ %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
+ %1 = iree_linalg_ext.set_encoding %arg1 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS_TRANSPOSE>>
+ %2 = iree_linalg_ext.set_encoding %arg2 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>
+ %3 = linalg.matmul ins(%0, %1 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>, tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS_TRANSPOSE>>)
+ outs(%2 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>) -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>
+ %4 = iree_linalg_ext.unset_encoding %3 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>> -> tensor<?x?xf32>
return %4 : tensor<?x?xf32>
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
@@ -133,14 +133,14 @@
%cst = arith.constant 0.0 : f32
%d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
%d1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
- %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
- %1 = iree_linalg_ext.set_encoding %arg1 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS_TRANSPOSE>>
- %2 = tensor.empty(%d0, %d1) : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>
- %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>)
- -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>
- %4 = linalg.matmul ins(%0, %1 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>, tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS_TRANSPOSE>>)
- outs(%3 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>) -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>
- %5 = iree_linalg_ext.unset_encoding %4 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>> -> tensor<?x?xf32>
+ %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
+ %1 = iree_linalg_ext.set_encoding %arg1 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS_TRANSPOSE>>
+ %2 = tensor.empty(%d0, %d1) : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>
+ %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>)
+ -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>
+ %4 = linalg.matmul ins(%0, %1 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>, tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS_TRANSPOSE>>)
+ outs(%3 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>) -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>
+ %5 = iree_linalg_ext.unset_encoding %4 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>> -> tensor<?x?xf32>
return %5 : tensor<?x?xf32>
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
diff --git a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/resolve-shaped-type-result-dims.mlir b/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/resolve-shaped-type-result-dims.mlir
index f5619f7..41bd0f8 100644
--- a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/resolve-shaped-type-result-dims.mlir
+++ b/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/resolve-shaped-type-result-dims.mlir
@@ -3,9 +3,9 @@
func.func @pack_static(%arg0 : tensor<100x250xf32>) -> (index, index) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
- %0 = iree_linalg_ext.set_encoding %arg0 : tensor<100x250xf32> -> tensor<100x250xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
- %1 = tensor.dim %0, %c0 : tensor<100x250xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
- %2 = tensor.dim %0, %c1 : tensor<100x250xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
+ %0 = iree_linalg_ext.set_encoding %arg0 : tensor<100x250xf32> -> tensor<100x250xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
+ %1 = tensor.dim %0, %c0 : tensor<100x250xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
+ %2 = tensor.dim %0, %c1 : tensor<100x250xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
return %1, %2 : index, index
}
// CHECK-LABEL: func @pack_static(
@@ -18,9 +18,9 @@
func.func @pack_dynamic(%arg0 : tensor<?x?xf32>) -> (index, index) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
- %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
- %1 = tensor.dim %0, %c0 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
- %2 = tensor.dim %0, %c1 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
+ %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
+ %1 = tensor.dim %0, %c0 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
+ %2 = tensor.dim %0, %c1 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
return %1, %2 : index, index
}
// CHECK: func @pack_dynamic(%[[ARG0:.+]]: tensor<?x?xf32>)
diff --git a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/roundtrip.mlir b/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/roundtrip.mlir
index aef7ffa..4cb706c 100644
--- a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/roundtrip.mlir
+++ b/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/roundtrip.mlir
@@ -869,36 +869,36 @@
// -----
// CHECK: @set_encoding_ops(%[[ARG0:.+]]: tensor<?x?xf32>)
-func.func @set_encoding_ops(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>> {
- // CHECK: iree_linalg_ext.set_encoding %[[ARG0]] : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
- %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
- return %0 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
+func.func @set_encoding_ops(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>> {
+ // CHECK: iree_linalg_ext.set_encoding %[[ARG0]] : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
+ %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
+ return %0 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
}
// -----
// CHECK: @set_encoding_ops_mixed_dynamic_static(%[[ARG0:.+]]: tensor<?x10xf32>)
-func.func @set_encoding_ops_mixed_dynamic_static(%arg0: tensor<?x10xf32>) -> tensor<20x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>> {
- // CHECK: iree_linalg_ext.set_encoding %[[ARG0]] : tensor<?x10xf32> -> tensor<20x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
- %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x10xf32> -> tensor<20x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
- return %0 : tensor<20x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
+func.func @set_encoding_ops_mixed_dynamic_static(%arg0: tensor<?x10xf32>) -> tensor<20x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>> {
+ // CHECK: iree_linalg_ext.set_encoding %[[ARG0]] : tensor<?x10xf32> -> tensor<20x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
+ %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x10xf32> -> tensor<20x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
+ return %0 : tensor<20x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
}
// -----
-// CHECK: @unset_encoding_ops(%[[ARG0:.+]]: tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS>>)
-func.func @unset_encoding_ops(%arg0: tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS>>) -> tensor<?x?xf32> {
- // CHECK: iree_linalg_ext.unset_encoding %[[ARG0]] : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS>> -> tensor<?x?xf32>
- %0 = iree_linalg_ext.unset_encoding %arg0 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS>> -> tensor<?x?xf32>
+// CHECK: @unset_encoding_ops(%[[ARG0:.+]]: tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS>>)
+func.func @unset_encoding_ops(%arg0: tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS>>) -> tensor<?x?xf32> {
+ // CHECK: iree_linalg_ext.unset_encoding %[[ARG0]] : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS>> -> tensor<?x?xf32>
+ %0 = iree_linalg_ext.unset_encoding %arg0 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS>> -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
// -----
-// CHECK: @unset_encoding_ops_mixed_dynamic_static(%[[ARG0:.+]]: tensor<10x?xf32, #iree_linalg_ext.encoding<GEMM_RHS>>)
-func.func @unset_encoding_ops_mixed_dynamic_static(%arg0: tensor<10x?xf32, #iree_linalg_ext.encoding<GEMM_RHS>>) -> tensor<?x20xf32> {
- // CHECK: iree_linalg_ext.unset_encoding %[[ARG0]] : tensor<10x?xf32, #iree_linalg_ext.encoding<GEMM_RHS>>
- %0 = iree_linalg_ext.unset_encoding %arg0 : tensor<10x?xf32, #iree_linalg_ext.encoding<GEMM_RHS>> -> tensor<?x20xf32>
+// CHECK: @unset_encoding_ops_mixed_dynamic_static(%[[ARG0:.+]]: tensor<10x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS>>)
+func.func @unset_encoding_ops_mixed_dynamic_static(%arg0: tensor<10x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS>>) -> tensor<?x20xf32> {
+ // CHECK: iree_linalg_ext.unset_encoding %[[ARG0]] : tensor<10x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS>>
+ %0 = iree_linalg_ext.unset_encoding %arg0 : tensor<10x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS>> -> tensor<?x20xf32>
return %0 : tensor<?x20xf32>
}
@@ -906,14 +906,14 @@
func.func @encoding_tensors_with_ops(%arg0 : tensor<?x?xf32>,
%arg1 : tensor<?x?xf32>, %arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
- %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
- %1 = iree_linalg_ext.set_encoding %arg1 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS>>
- %2 = iree_linalg_ext.set_encoding %arg2 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>
+ %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
+ %1 = iree_linalg_ext.set_encoding %arg1 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS>>
+ %2 = iree_linalg_ext.set_encoding %arg2 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>
%3 = linalg.matmul
- ins(%0, %1 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>, tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS>>)
- outs(%2 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>)
- -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>
- %4 = iree_linalg_ext.unset_encoding %3 : tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>> -> tensor<?x?xf32>
+ ins(%0, %1 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>, tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS>>)
+ outs(%2 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>)
+ -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>
+ %4 = iree_linalg_ext.unset_encoding %3 : tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>> -> tensor<?x?xf32>
return %4 : tensor<?x?xf32>
}
// CHECK-LABEL: func.func @encoding_tensors_with_ops
@@ -921,13 +921,80 @@
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<?x?xf32>
// CHECK: %[[LHS:.+]] = iree_linalg_ext.set_encoding %[[ARG0]]
-// CHECK-SAME: tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_LHS>>
+// CHECK-SAME: tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
// CHECK: %[[RHS:.+]] = iree_linalg_ext.set_encoding %[[ARG1]]
-// CHECK-SAME: tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RHS>>
+// CHECK-SAME: tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RHS>>
// CHECK: %[[OUT:.+]] = iree_linalg_ext.set_encoding %[[ARG2]]
-// CHECK-SAME: tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<GEMM_RESULT>>
+// CHECK-SAME: tensor<?x?xf32> -> tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>>
// CHECK: %[[GEMM:.+]] = linalg.matmul
// CHECK-SAME: ins(%[[LHS]], %[[RHS]] :
// CHECK-SAME: outs(%[[OUT]] :
// CHECK: %[[RESULT:.+]] = iree_linalg_ext.unset_encoding %[[GEMM]]
// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @winograd_input_transform(%arg0: tensor<1x10x10x1280xf32>) -> tensor<8x8x1x2x2x1280xf32> {
+ %0 = tensor.empty() : tensor<8x8x1x2x2x1280xf32>
+ %1 = iree_linalg_ext.winograd.input_transform output_tile_size(6) kernel_size(3) image_dimensions([1, 2])
+ ins(%arg0 : tensor<1x10x10x1280xf32>) outs(%0 : tensor<8x8x1x2x2x1280xf32>) -> tensor<8x8x1x2x2x1280xf32>
+ return %1 : tensor<8x8x1x2x2x1280xf32>
+}
+// CHECK: func.func @winograd_input_transform(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x10x10x1280xf32>) ->
+// CHECK-SAME: tensor<8x8x1x2x2x1280xf32> {
+// CHECK: %[[D0:.+]] = tensor.empty() : tensor<8x8x1x2x2x1280xf32>
+// CHECK: %[[D1:.+]] = iree_linalg_ext.winograd.input_transform output_tile_size(6) kernel_size(3)
+// CHECK-SAME: image_dimensions([1, 2]) ins(%[[ARG0]] : tensor<1x10x10x1280xf32>) outs(%[[D0]] :
+// CHECK-SAME: tensor<8x8x1x2x2x1280xf32>) -> tensor<8x8x1x2x2x1280xf32>
+// CHECK: return %[[D1]] : tensor<8x8x1x2x2x1280xf32>
+// CHECK: }
+
+// -----
+
+func.func @winograd_input_transform_dynamic(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<8x8x?x?x?x?xf32>) -> tensor<8x8x?x?x?x?xf32> {
+ %1 = iree_linalg_ext.winograd.input_transform
+ output_tile_size(6) kernel_size(3) image_dimensions([1, 2])
+ ins(%arg0 : tensor<?x?x?x?xf32>) outs(%arg1 : tensor<8x8x?x?x?x?xf32>) -> tensor<8x8x?x?x?x?xf32>
+ return %1 : tensor<8x8x?x?x?x?xf32>
+}
+// CHECK: func.func @winograd_input_transform_dynamic(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?x?xf32>,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<8x8x?x?x?x?xf32>) -> tensor<8x8x?x?x?x?xf32> {
+// CHECK: %[[D0:.+]] = iree_linalg_ext.winograd.input_transform output_tile_size(6) kernel_size(3)
+// CHECK-SAME: image_dimensions([1, 2]) ins(%[[ARG0]] : tensor<?x?x?x?xf32>) outs(%[[ARG1]] :
+// CHECK-SAME: tensor<8x8x?x?x?x?xf32>) -> tensor<8x8x?x?x?x?xf32>
+// CHECK: return %[[D0]] : tensor<8x8x?x?x?x?xf32>
+// CHECK: }
+
+// -----
+
+func.func @winograd_output_transform(%arg0: tensor<8x8x1x2x2x1280xf32>) -> tensor<1x12x12x1280xf32> {
+ %0 = tensor.empty() : tensor<1x12x12x1280xf32>
+ %1 = iree_linalg_ext.winograd.output_transform output_tile_size(6) kernel_size(3) image_dimensions([1, 2])
+ ins(%arg0 : tensor<8x8x1x2x2x1280xf32>) outs(%0 : tensor<1x12x12x1280xf32>) -> tensor<1x12x12x1280xf32>
+ return %1 : tensor<1x12x12x1280xf32>
+}
+// CHECK: func.func @winograd_output_transform(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<8x8x1x2x2x1280xf32>) ->
+// CHECK-SAME: tensor<1x12x12x1280xf32> {
+// CHECK: %[[D0:.+]] = tensor.empty() : tensor<1x12x12x1280xf32>
+// CHECK: %[[D1:.+]] = iree_linalg_ext.winograd.output_transform output_tile_size(6) kernel_size(3)
+// CHECK-SAME: image_dimensions([1, 2]) ins(%[[ARG0]] : tensor<8x8x1x2x2x1280xf32>) outs(%[[D0]] :
+// CHECK-SAME: tensor<1x12x12x1280xf32>) -> tensor<1x12x12x1280xf32>
+// CHECK: return %[[D1]] : tensor<1x12x12x1280xf32>
+// CHECK: }
+
+// -----
+
+func.func @winograd_output_transform(%arg0: tensor<8x8x?x?x?x?xf32>, %arg1: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+ %1 = iree_linalg_ext.winograd.output_transform output_tile_size(6) kernel_size(3) image_dimensions([1, 2])
+ ins(%arg0 : tensor<8x8x?x?x?x?xf32>) outs(%arg1 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+ return %1 : tensor<?x?x?x?xf32>
+}
+// CHECK: func.func @winograd_output_transform(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<8x8x?x?x?x?xf32>,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+// CHECK: %[[D0:.+]] = iree_linalg_ext.winograd.output_transform output_tile_size(6) kernel_size(3)
+// CHECK-SAME: image_dimensions([1, 2]) ins(%[[ARG0]] : tensor<8x8x?x?x?x?xf32>) outs(%[[ARG1]] :
+// CHECK-SAME: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+// CHECK: return %[[D0]] : tensor<?x?x?x?xf32>
+// CHECK: }
+
+// -----
diff --git a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/tile_and_decompose_winograd.mlir b/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/tile_and_decompose_winograd.mlir
new file mode 100644
index 0000000..aabe3d9
--- /dev/null
+++ b/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/tile_and_decompose_winograd.mlir
@@ -0,0 +1,195 @@
+// RUN: iree-dialects-opt --iree-linalg-ext-tile-and-decompose-winograd --split-input-file %s | FileCheck %s
+
+#map = affine_map<(d0)[s0, s1] -> (1, -d0 + s1)>
+#map1 = affine_map<(d0)[s0, s1] -> (32, -d0 + s1)>
+module {
+ func.func @winograd_input_transform(%arg0: tensor<1x10x10x1280xf32>) -> tensor<8x8x1x2x2x1280xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c1280 = arith.constant 1280 : index
+ %c32 = arith.constant 32 : index
+ %0 = tensor.empty() : tensor<8x8x1x2x2x1280xf32>
+ %1 = scf.for %arg1 = %c0 to %c1 step %c1 iter_args(%arg2 = %0) -> (tensor<8x8x1x2x2x1280xf32>) {
+ %2 = affine.min #map(%arg1)[%c1, %c1]
+ %3 = scf.for %arg3 = %c0 to %c1280 step %c32 iter_args(%arg4 = %arg2) -> (tensor<8x8x1x2x2x1280xf32>) {
+ %4 = affine.min #map1(%arg3)[%c32, %c1280]
+ %extracted_slice = tensor.extract_slice %arg0[%arg1, 0, 0, %arg3] [%2, 10, 10, %4] [1, 1, 1, 1] : tensor<1x10x10x1280xf32> to tensor<?x10x10x?xf32>
+ %extracted_slice_0 = tensor.extract_slice %0[0, 0, %arg1, 0, 0, %arg3] [8, 8, %2, 2, 2, %4] [1, 1, 1, 1, 1, 1] : tensor<8x8x1x2x2x1280xf32> to tensor<8x8x?x2x2x?xf32>
+ %5 = iree_linalg_ext.winograd.input_transform output_tile_size(6) kernel_size(3) image_dimensions([1, 2]) ins(%extracted_slice : tensor<?x10x10x?xf32>) outs(%extracted_slice_0 : tensor<8x8x?x2x2x?xf32>) -> tensor<8x8x?x2x2x?xf32>
+ %inserted_slice = tensor.insert_slice %5 into %arg4[0, 0, %arg1, 0, 0, %arg3] [8, 8, %2, 2, 2, %4] [1, 1, 1, 1, 1, 1] : tensor<8x8x?x2x2x?xf32> into tensor<8x8x1x2x2x1280xf32>
+ scf.yield %inserted_slice : tensor<8x8x1x2x2x1280xf32>
+ }
+ scf.yield %3 : tensor<8x8x1x2x2x1280xf32>
+ }
+ return %1 : tensor<8x8x1x2x2x1280xf32>
+ }
+}
+// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (1, -d0 + s1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (32, -d0 + s1)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (d0 * 6)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0) -> (-d0 + 10, 8)>
+// CHECK: func.func @winograd_input_transform(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x10x10x1280xf32>) ->
+// CHECK-SAME: tensor<8x8x1x2x2x1280xf32> {
+// CHECK: %[[C32:.+]] = arith.constant 32 : index
+// CHECK: %[[C1280:.+]] = arith.constant 1280 : index
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[CST:.+]] = arith.constant dense<
+// CHECK: %[[CST_0:.+]] = arith.constant dense<
+// CHECK: %[[CST_1:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[C2:.+]] = arith.constant 2 : index
+// CHECK: %[[D0:.+]] = tensor.empty() : tensor<8x8xf32>
+// CHECK: %[[D1:.+]] = tensor.empty() : tensor<8x8x1x2x2x1280xf32>
+// CHECK: %[[D2:.+]] = scf.for %[[ARG1:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C1]] step %[[C1]]
+// CHECK-SAME: iter_args(%[[ARG2:[a-zA-Z0-9_]+]] = %[[D1]]) -> (tensor<8x8x1x2x2x1280xf32>) {
+// CHECK-DAG: %[[D3:.+]] = affine.min #[[MAP]](%[[ARG1]])[%[[C1]], %[[C1]]]
+// CHECK: %[[D4:.+]] = scf.for %[[ARG3:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C1280]] step %[[C32]]
+// CHECK-SAME: iter_args(%[[ARG4:[a-zA-Z0-9_]+]] = %[[ARG2]]) -> (tensor<8x8x1x2x2x1280xf32>) {
+// CHECK-DAG: %[[D5:.+]] = affine.min #[[MAP1]](%[[ARG3]])[%[[C32]], %[[C1280]]]
+// CHECK: %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[ARG1]], 0, 0, %[[ARG3]]] [%[[D3]], 10,
+// CHECK-SAME: 10, %[[D5]]] [1, 1, 1, 1] : tensor<1x10x10x1280xf32> to tensor<?x10x10x?xf32>
+// CHECK: %[[EXTRACTED_SLICE_2:.+]] = tensor.extract_slice %[[D1]][0, 0, %[[ARG1]], 0, 0, %[[ARG3]]] [8, 8,
+// CHECK-SAME: %[[D3]], 2, 2, %[[D5]]] [1, 1, 1, 1, 1, 1] : tensor<8x8x1x2x2x1280xf32> to
+// CHECK-SAME: tensor<8x8x?x2x2x?xf32>
+// CHECK: %[[D6:.+]] = scf.for %[[ARG5:[a-zA-Z0-9_]+]] = %[[C0]] to %[[D3]] step %[[C1]]
+// CHECK-SAME: iter_args(%[[ARG6:[a-zA-Z0-9_]+]] = %[[EXTRACTED_SLICE_2]]) -> (tensor<8x8x?x2x2x?xf32>) {
+// CHECK: %[[D7:.+]] = scf.for %[[ARG7:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C2]] step %[[C1]]
+// CHECK-SAME: iter_args(%[[ARG8:[a-zA-Z0-9_]+]] = %[[ARG6]]) -> (tensor<8x8x?x2x2x?xf32>) {
+// CHECK-DAG: %[[D8:.+]] = affine.apply #[[MAP2]](%[[ARG7]])
+// CHECK-DAG: %[[D9:.+]] = affine.min #[[MAP3]](%[[D8]])
+// CHECK: %[[D10:.+]] = scf.for %[[ARG9:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C2]] step %[[C1]]
+// CHECK-SAME: iter_args(%[[ARG10:[a-zA-Z0-9_]+]] = %[[ARG8]]) -> (tensor<8x8x?x2x2x?xf32>) {
+// CHECK-DAG: %[[D11:.+]] = affine.apply #[[MAP2]](%[[ARG9]])
+// CHECK-DAG: %[[D12:.+]] = affine.min #[[MAP3]](%[[D11]])
+// CHECK: %[[D13:.+]] = scf.for %[[ARG11:[a-zA-Z0-9_]+]] = %[[C0]] to %[[D5]] step %[[C1]]
+// CHECK-SAME: iter_args(%[[ARG12:[a-zA-Z0-9_]+]] = %[[ARG10]]) -> (tensor<8x8x?x2x2x?xf32>) {
+// CHECK: %[[EXTRACTED_SLICE_3:.+]] = tensor.extract_slice %[[EXTRACTED_SLICE]][%[[ARG5]], %[[D8]],
+// CHECK-SAME: %[[D11]], %[[ARG11]]] [1, %[[D9]], %[[D12]], 1] [1, 1, 1, 1] : tensor<?x10x10x?xf32> to
+// CHECK-SAME: tensor<?x?xf32>
+// CHECK: %[[D14:.+]] = linalg.fill ins(%[[CST_1]] : f32) outs(%[[D0]] : tensor<8x8xf32>) ->
+// CHECK-SAME: tensor<8x8xf32>
+// CHECK: %[[INSERTED_SLICE_4:.+]] = tensor.insert_slice %[[EXTRACTED_SLICE_3]] into %[[D14]][0, 0]
+// CHECK-SAME: [%[[D9]], %[[D12]]] [1, 1] : tensor<?x?xf32> into tensor<8x8xf32>
+// CHECK: %[[EXTRACTED_SLICE_5:.+]] = tensor.extract_slice %[[ARG12]][0, 0, %[[ARG5]], %[[ARG7]],
+// CHECK-SAME: %[[ARG9]], %[[ARG11]]] [8, 8, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<8x8x?x2x2x?xf32> to
+// CHECK-SAME: tensor<8x8xf32>
+// CHECK: %[[D15:.+]] = linalg.fill ins(%[[CST_1]] : f32) outs(%[[EXTRACTED_SLICE_5]] :
+// CHECK-SAME: tensor<8x8xf32>) -> tensor<8x8xf32>
+// CHECK: %[[D16:.+]] = linalg.matmul ins(%[[INSERTED_SLICE_4]], %[[CST_0]] : tensor<8x8xf32>,
+// CHECK-SAME: tensor<8x8xf32>) outs(%[[D15]] : tensor<8x8xf32>) -> tensor<8x8xf32>
+// CHECK: %[[D17:.+]] = linalg.fill ins(%[[CST_1]] : f32) outs(%[[EXTRACTED_SLICE_5]] :
+// CHECK-SAME: tensor<8x8xf32>) -> tensor<8x8xf32>
+// CHECK: %[[D18:.+]] = linalg.matmul ins(%[[CST]], %[[D16]] : tensor<8x8xf32>, tensor<8x8xf32>)
+// CHECK-SAME: outs(%[[D17]] : tensor<8x8xf32>) -> tensor<8x8xf32>
+// CHECK: %[[INSERTED_SLICE_6:.+]] = tensor.insert_slice %[[D18]] into %[[ARG12]][0, 0, %[[ARG5]],
+// CHECK-SAME: %[[ARG7]], %[[ARG9]], %[[ARG11]]] [8, 8, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<8x8xf32>
+// CHECK-SAME: into tensor<8x8x?x2x2x?xf32>
+// CHECK: scf.yield %[[INSERTED_SLICE_6]] : tensor<8x8x?x2x2x?xf32>
+// CHECK: }
+// CHECK: scf.yield %[[D13]] : tensor<8x8x?x2x2x?xf32>
+// CHECK: }
+// CHECK: scf.yield %[[D10]] : tensor<8x8x?x2x2x?xf32>
+// CHECK: }
+// CHECK: scf.yield %[[D7]] : tensor<8x8x?x2x2x?xf32>
+// CHECK: }
+// CHECK: %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D6]] into %[[ARG4]][0, 0, %[[ARG1]], 0, 0,
+// CHECK-SAME: %[[ARG3]]] [8, 8, %[[D3]], 2, 2, %[[D5]]] [1, 1, 1, 1, 1, 1] : tensor<8x8x?x2x2x?xf32> into
+// CHECK-SAME: tensor<8x8x1x2x2x1280xf32>
+// CHECK: scf.yield %[[INSERTED_SLICE]] : tensor<8x8x1x2x2x1280xf32>
+// CHECK: }
+// CHECK: scf.yield %[[D4]] : tensor<8x8x1x2x2x1280xf32>
+// CHECK: }
+// CHECK: return %[[D2]] : tensor<8x8x1x2x2x1280xf32>
+// CHECK: }
+
+// -----
+
+#map = affine_map<(d0)[s0, s1] -> (1, -d0 + s1)>
+#map1 = affine_map<(d0)[s0, s1] -> (32, -d0 + s1)>
+module {
+ func.func @winograd_output_transform(%arg0: tensor<8x8x1x2x2x32xf32>) -> tensor<1x12x12x32xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c32 = arith.constant 32 : index
+ %0 = tensor.empty() : tensor<1x12x12x32xf32>
+ %1 = scf.for %arg1 = %c0 to %c1 step %c1 iter_args(%arg2 = %0) -> (tensor<1x12x12x32xf32>) {
+ %2 = affine.min #map(%arg1)[%c1, %c1]
+ %3 = scf.for %arg3 = %c0 to %c32 step %c32 iter_args(%arg4 = %arg2) -> (tensor<1x12x12x32xf32>) {
+ %4 = affine.min #map1(%arg3)[%c32, %c32]
+ %extracted_slice = tensor.extract_slice %arg0[0, 0, %arg1, 0, 0, %arg3] [8, 8, %2, 2, 2, %4] [1, 1, 1, 1, 1, 1] : tensor<8x8x1x2x2x32xf32> to tensor<8x8x?x2x2x?xf32>
+ %extracted_slice_0 = tensor.extract_slice %0[%arg1, 0, 0, %arg3] [%2, 12, 12, %4] [1, 1, 1, 1] : tensor<1x12x12x32xf32> to tensor<?x12x12x?xf32>
+ %5 = iree_linalg_ext.winograd.output_transform output_tile_size(6) kernel_size(3) image_dimensions([1, 2]) ins(%extracted_slice : tensor<8x8x?x2x2x?xf32>) outs(%extracted_slice_0 : tensor<?x12x12x?xf32>) -> tensor<?x12x12x?xf32>
+ %inserted_slice = tensor.insert_slice %5 into %arg4[%arg1, 0, 0, %arg3] [%2, 12, 12, %4] [1, 1, 1, 1] : tensor<?x12x12x?xf32> into tensor<1x12x12x32xf32>
+ scf.yield %inserted_slice : tensor<1x12x12x32xf32>
+ }
+ scf.yield %3 : tensor<1x12x12x32xf32>
+ }
+ return %1 : tensor<1x12x12x32xf32>
+ }
+}
+// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (1, -d0 + s1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (32, -d0 + s1)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (d0 * 6)>
+// CHECK: func.func @winograd_output_transform(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<8x8x1x2x2x32xf32>) ->
+// CHECK-SAME: tensor<1x12x12x32xf32> {
+// CHECK: %[[C32:.+]] = arith.constant 32 : index
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[CST:.+]] = arith.constant dense<
+// CHECK: %[[CST_0:.+]] = arith.constant dense<
+// CHECK: %[[CST_1:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[C2:.+]] = arith.constant 2 : index
+// CHECK: %[[D0:.+]] = tensor.empty() : tensor<8x6xf32>
+// CHECK: %[[D1:.+]] = tensor.empty() : tensor<1x12x12x32xf32>
+// CHECK: %[[D2:.+]] = scf.for %[[ARG1:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C1]] step %[[C1]]
+// CHECK-SAME: iter_args(%[[ARG2:[a-zA-Z0-9_]+]] = %[[D1]]) -> (tensor<1x12x12x32xf32>) {
+// CHECK-DAG: %[[D3:.+]] = affine.min #[[MAP]](%[[ARG1]])[%[[C1]], %[[C1]]]
+// CHECK: %[[D4:.+]] = scf.for %[[ARG3:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C32]] step %[[C32]]
+// CHECK-SAME: iter_args(%[[ARG4:[a-zA-Z0-9_]+]] = %[[ARG2]]) -> (tensor<1x12x12x32xf32>) {
+// CHECK-DAG: %[[D5:.+]] = affine.min #[[MAP1]](%[[ARG3]])[%[[C32]], %[[C32]]]
+// CHECK: %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[ARG0]][0, 0, %[[ARG1]], 0, 0, %[[ARG3]]] [8, 8,
+// CHECK-SAME: %[[D3]], 2, 2, %[[D5]]] [1, 1, 1, 1, 1, 1] : tensor<8x8x1x2x2x32xf32> to tensor<8x8x?x2x2x?xf32>
+// CHECK: %[[EXTRACTED_SLICE_2:.+]] = tensor.extract_slice %[[D1]][%[[ARG1]], 0, 0, %[[ARG3]]] [%[[D3]], 12,
+// CHECK-SAME: 12, %[[D5]]] [1, 1, 1, 1] : tensor<1x12x12x32xf32> to tensor<?x12x12x?xf32>
+// CHECK: %[[D6:.+]] = scf.for %[[ARG5:[a-zA-Z0-9_]+]] = %[[C0]] to %[[D3]] step %[[C1]]
+// CHECK-SAME: iter_args(%[[ARG6:[a-zA-Z0-9_]+]] = %[[EXTRACTED_SLICE_2]]) -> (tensor<?x12x12x?xf32>) {
+// CHECK: %[[D7:.+]] = scf.for %[[ARG7:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C2]] step %[[C1]]
+// CHECK-SAME: iter_args(%[[ARG8:[a-zA-Z0-9_]+]] = %[[ARG6]]) -> (tensor<?x12x12x?xf32>) {
+// CHECK-DAG: %[[D8:.+]] = affine.apply #[[MAP2]](%[[ARG7]])
+// CHECK: %[[D9:.+]] = scf.for %[[ARG9:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C2]] step %[[C1]]
+// CHECK-SAME: iter_args(%[[ARG10:[a-zA-Z0-9_]+]] = %[[ARG8]]) -> (tensor<?x12x12x?xf32>) {
+// CHECK-DAG: %[[D10:.+]] = affine.apply #[[MAP2]](%[[ARG9]])
+// CHECK: %[[D11:.+]] = scf.for %[[ARG11:[a-zA-Z0-9_]+]] = %[[C0]] to %[[D5]] step %[[C1]]
+// CHECK-SAME: iter_args(%[[ARG12:[a-zA-Z0-9_]+]] = %[[ARG10]]) -> (tensor<?x12x12x?xf32>) {
+// CHECK: %[[EXTRACTED_SLICE_3:.+]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, %[[ARG5]],
+// CHECK-SAME: %[[ARG7]], %[[ARG9]], %[[ARG11]]] [8, 8, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] :
+// CHECK-SAME: tensor<8x8x?x2x2x?xf32> to tensor<8x8xf32>
+// CHECK: %[[EXTRACTED_SLICE_4:.+]] = tensor.extract_slice %[[ARG12]][%[[ARG5]], %[[D8]], %[[D10]],
+// CHECK-SAME: %[[ARG11]]] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<?x12x12x?xf32> to tensor<6x6xf32>
+// CHECK: %[[D12:.+]] = linalg.fill ins(%[[CST_1]] : f32) outs(%[[D0]] : tensor<8x6xf32>) ->
+// CHECK-SAME: tensor<8x6xf32>
+// CHECK: %[[D13:.+]] = linalg.matmul ins(%[[EXTRACTED_SLICE_3]], %[[CST_0]] : tensor<8x8xf32>,
+// CHECK-SAME: tensor<8x6xf32>) outs(%[[D12]] : tensor<8x6xf32>) -> tensor<8x6xf32>
+// CHECK: %[[D14:.+]] = linalg.fill ins(%[[CST_1]] : f32) outs(%[[EXTRACTED_SLICE_4]] :
+// CHECK-SAME: tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK: %[[D15:.+]] = linalg.matmul ins(%[[CST]], %[[D13]] : tensor<6x8xf32>, tensor<8x6xf32>)
+// CHECK-SAME: outs(%[[D14]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK: %[[INSERTED_SLICE_5:.+]] = tensor.insert_slice %[[D15]] into %[[ARG12]][%[[ARG5]], %[[D8]],
+// CHECK-SAME: %[[D10]], %[[ARG11]]] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<6x6xf32> into
+// CHECK-SAME: tensor<?x12x12x?xf32>
+// CHECK: scf.yield %[[INSERTED_SLICE_5]] : tensor<?x12x12x?xf32>
+// CHECK: }
+// CHECK: scf.yield %[[D11]] : tensor<?x12x12x?xf32>
+// CHECK: }
+// CHECK: scf.yield %[[D9]] : tensor<?x12x12x?xf32>
+// CHECK: }
+// CHECK: scf.yield %[[D7]] : tensor<?x12x12x?xf32>
+// CHECK: }
+// CHECK: %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D6]] into %[[ARG4]][%[[ARG1]], 0, 0, %[[ARG3]]]
+// CHECK-SAME: [%[[D3]], 12, 12, %[[D5]]] [1, 1, 1, 1] : tensor<?x12x12x?xf32> into tensor<1x12x12x32xf32>
+// CHECK: scf.yield %[[INSERTED_SLICE]] : tensor<1x12x12x32xf32>
+// CHECK: }
+// CHECK: scf.yield %[[D4]] : tensor<1x12x12x32xf32>
+// CHECK: }
+// CHECK: return %[[D2]] : tensor<1x12x12x32xf32>
+// CHECK: }
diff --git a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/tiling.mlir b/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/tiling.mlir
index 463134f..44cab77 100644
--- a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/tiling.mlir
+++ b/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/tiling.mlir
@@ -1205,3 +1205,224 @@
// CHECK-SAME: into %{{.+}}[%[[K]], %[[C]]] [%[[C2]], %[[C4]]]
// CHECK: scf.yield %[[RES]]
+// -----
+
+func.func @perfect_CKkc_to_KC(%arg0: tensor<32x4x2x4xf32>, %arg1: tensor<8x128xf32>) -> tensor<8x128xf32> {
+ %0 = iree_linalg_ext.unpack {__internal_linalg_transform__ = "tiling_pack_input"} %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [2, 4] into %arg1 : (tensor<32x4x2x4xf32> tensor<8x128xf32>) -> tensor<8x128xf32>
+ return %0 : tensor<8x128xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> (d0 floordiv 2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 floordiv 4)>
+// CHECK-LABEL: func.func @perfect_CKkc_to_KC
+// CHECK-SAME: %[[IN:[A-Za-z0-9]+]]:
+// CHECK-SAME: %[[OUT:[A-Za-z0-9]+]]:
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+// CHECK-DAG: %[[C128:.*]] = arith.constant 128 : index
+// CHECK: %{{.+}} = scf.for %[[K:.+]] = %[[C0]] to %[[C8]] step %[[C2]]
+// CHECK: %{{.+}} = scf.for %[[C:.+]] = %[[C0]] to %[[C128]] step %[[C4]]
+// CHECK-DAG: %[[IN_K:.+]] = affine.apply #[[MAP0]](%[[K]])
+// CHECK-DAG: %[[IN_C:.+]] = affine.apply #[[MAP1]](%[[C]])
+// CHECK: %[[IN_SLICE:.+]] = tensor.extract_slice %[[IN]]
+// CHECK: [%[[IN_C]], %[[IN_K]], 0, 0] [1, 1, 2, 4]
+// CHECK: %[[ITER_SLICE:.+]] = tensor.extract_slice %{{.+}}[%[[K]], %[[C]]] [%[[C2]], %[[C4]]]
+// CHECK: %[[UNPACK:.+]] = iree_linalg_ext.unpack
+// CHECK-SAME: {__internal_linalg_transform__ = "tiling_pack_output"
+// CHECK-SAME: %[[IN_SLICE]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [2, 4]
+// CHECK-SAME: into %[[ITER_SLICE]]
+// CHECK: %[[RES:.+]] = tensor.insert_slice %[[UNPACK]]
+// CHECK-SAME: into %{{.+}}[%[[K]], %[[C]]] [%[[C2]], %[[C4]]]
+// CHECK: scf.yield %[[RES]]
+
+// -----
+
+func.func @dynamic_perfect_CKkc_to_KC(%arg0: tensor<?x?x2x2xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = iree_linalg_ext.unpack {__internal_linalg_transform__ = "tiling_pack_input"} %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [2, 2] into %arg1 : (tensor<?x?x2x2xf32> tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (2, -d0 + s0)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (4, -d0 + s0)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (d0 floordiv 2)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0) -> (d0 ceildiv 2)>
+// CHECK-LABEL: func.func @dynamic_perfect_CKkc_to_KC
+// CHECK-SAME: %[[IN:[A-Za-z0-9]+]]:
+// CHECK-SAME: %[[OUT:[A-Za-z0-9]+]]:
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG: %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C0]]
+// CHECK-DAG: %[[DIM_1:.+]] = tensor.dim %[[ARG1]], %[[C1]]
+// CHECK: %{{.+}} = scf.for %[[K:.+]] = %[[C0]] to %[[DIM_0]] step %[[C2]]
+// CHECK-DAG: %[[OUT_K_SZ:.+]] = affine.min #[[MAP0]](%[[K]])[%[[DIM_0]]]
+// CHECK: %{{.+}} = scf.for %[[C:.+]] = %[[C0]] to %[[DIM_1]] step %[[C4]]
+// CHECK-DAG: %[[OUT_C_SZ:.+]] = affine.min #[[MAP1]](%[[C]])[%[[DIM_1]]]
+// CHECK-DAG: %[[IN_K:.+]] = affine.apply #[[MAP2]](%[[K]])
+// CHECK-DAG: %[[IN_C:.+]] = affine.apply #[[MAP2]](%[[C]])
+// CHECK-DAG: %[[IN_C_SZ:.+]] = affine.apply #[[MAP3]](%[[OUT_C_SZ]])
+// CHECK: %[[IN_SLICE:.+]] = tensor.extract_slice %[[IN]]
+// CHECK: [%[[IN_C]], %[[IN_K]], 0, 0] [%[[IN_C_SZ]], 1, 2, 2]
+// CHECK: %[[ITER_SLICE:.+]] = tensor.extract_slice %{{.+}}[%[[K]], %[[C]]] [%[[OUT_K_SZ]], %[[OUT_C_SZ]]]
+// CHECK: %[[UNPACK:.+]] = iree_linalg_ext.unpack
+// CHECK-SAME: {__internal_linalg_transform__ = "tiling_pack_output"
+// CHECK-SAME: %[[IN_SLICE]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [2, 2]
+// CHECK-SAME: into %[[ITER_SLICE]]
+// CHECK: %[[RES:.+]] = tensor.insert_slice %[[UNPACK]]
+// CHECK-SAME: into %{{.+}}[%[[K]], %[[C]]] [%[[OUT_K_SZ]], %[[OUT_C_SZ]]]
+// CHECK: scf.yield %[[RES]]
+
+// -----
+
+func.func @winograd_input_transform(%arg0: tensor<1x10x10x1280xf32>) -> tensor<8x8x1x2x2x1280xf32> {
+ %0 = tensor.empty() : tensor<8x8x1x2x2x1280xf32>
+ %1 = iree_linalg_ext.winograd.input_transform {__internal_linalg_transform__ = "tiling_winograd_input_nhwc"}
+ output_tile_size(6) kernel_size(3) image_dimensions([1, 2])
+ ins(%arg0 : tensor<1x10x10x1280xf32>) outs(%0 : tensor<8x8x1x2x2x1280xf32>) -> tensor<8x8x1x2x2x1280xf32>
+ return %1 : tensor<8x8x1x2x2x1280xf32>
+}
+// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (1, -d0 + s1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (32, -d0 + s1)>
+// CHECK: func.func @winograd_input_transform(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x10x10x1280xf32>) ->
+// CHECK-SAME: tensor<8x8x1x2x2x1280xf32> {
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[C1280:.+]] = arith.constant 1280 : index
+// CHECK: %[[C32:.+]] = arith.constant 32 : index
+// CHECK: %[[D0:.+]] = tensor.empty() : tensor<8x8x1x2x2x1280xf32>
+// CHECK: %[[D1:.+]] = scf.for %[[ARG1:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C1]] step %[[C1]]
+// CHECK-SAME: iter_args(%[[ARG2:[a-zA-Z0-9_]+]] = %[[D0]]) -> (tensor<8x8x1x2x2x1280xf32>) {
+// CHECK-DAG: %[[D2:.+]] = affine.min #[[MAP]](%[[ARG1]])[%[[C1]], %[[C1]]]
+// CHECK: %[[D3:.+]] = scf.for %[[ARG3:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C1280]] step %[[C32]]
+// CHECK-SAME: iter_args(%[[ARG4:[a-zA-Z0-9_]+]] = %[[ARG2]]) -> (tensor<8x8x1x2x2x1280xf32>) {
+// CHECK-DAG: %[[D4:.+]] = affine.min #[[MAP1]](%[[ARG3]])[%[[C32]], %[[C1280]]]
+// CHECK: %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[ARG1]], 0, 0, %[[ARG3]]] [%[[D2]], 10,
+// CHECK-SAME: 10, %[[D4]]] [1, 1, 1, 1] : tensor<1x10x10x1280xf32> to tensor<?x10x10x?xf32>
+// CHECK: %[[EXTRACTED_SLICE_0:.+]] = tensor.extract_slice %[[D0]][0, 0, %[[ARG1]], 0, 0, %[[ARG3]]] [8, 8,
+// CHECK-SAME: %[[D2]], 2, 2, %[[D4]]] [1, 1, 1, 1, 1, 1] : tensor<8x8x1x2x2x1280xf32> to
+// CHECK-SAME: tensor<8x8x?x2x2x?xf32>
+// CHECK: %[[D5:.+]] = iree_linalg_ext.winograd.input_transform output_tile_size(6) kernel_size(3)
+// CHECK-SAME: image_dimensions([1, 2]) ins(%[[EXTRACTED_SLICE]] : tensor<?x10x10x?xf32>)
+// CHECK-SAME: outs(%[[EXTRACTED_SLICE]]_0 : tensor<8x8x?x2x2x?xf32>) -> tensor<8x8x?x2x2x?xf32>
+// CHECK: %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D5]] into %[[ARG4]][0, 0, %[[ARG1]], 0, 0,
+// CHECK-SAME: %[[ARG3]]] [8, 8, %[[D2]], 2, 2, %[[D4]]] [1, 1, 1, 1, 1, 1] : tensor<8x8x?x2x2x?xf32> into
+// CHECK-SAME: tensor<8x8x1x2x2x1280xf32>
+// CHECK: scf.yield %[[INSERTED_SLICE]] : tensor<8x8x1x2x2x1280xf32>
+// CHECK: }
+// CHECK: scf.yield %[[D3]] : tensor<8x8x1x2x2x1280xf32>
+// CHECK: }
+// CHECK: return %[[D1]] : tensor<8x8x1x2x2x1280xf32>
+// CHECK: }
+
+// -----
+
+func.func @winograd_input_transform_memref(%arg0: memref<1x10x10x1280xf32>, %arg1: memref<8x8x1x2x2x1280xf32>) {
+ iree_linalg_ext.winograd.input_transform {__internal_linalg_transform__ = "tiling_winograd_input_nhwc"}
+ output_tile_size(6) kernel_size(3) image_dimensions([1, 2])
+ ins(%arg0 : memref<1x10x10x1280xf32>) outs(%arg1 : memref<8x8x1x2x2x1280xf32>)
+ return
+}
+// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0, s1] -> (1, -d0 + s1)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0, s1] -> (32, -d0 + s1)>
+// CHECK: func.func @winograd_input_transform_memref(%[[ARG0:[a-zA-Z0-9_]+]]: memref<1x10x10x1280xf32>,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref<8x8x1x2x2x1280xf32>) {
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[C1280:.+]] = arith.constant 1280 : index
+// CHECK: %[[C32:.+]] = arith.constant 32 : index
+// CHECK: scf.for %[[ARG2:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C1]] step %[[C1]] {
+// CHECK-DAG: %[[D0:.+]] = affine.min #[[MAP2]](%[[ARG2]])[%[[C1]], %[[C1]]]
+// CHECK: scf.for %[[ARG3:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C1280]] step %[[C32]] {
+// CHECK-DAG: %[[D1:.+]] = affine.min #[[MAP3]](%[[ARG3]])[%[[C32]], %[[C1280]]]
+// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG0]][%[[ARG2]], 0, 0, %[[ARG3]]] [%[[D0]], 10, 10, %[[D1]]]
+// CHECK-SAME: [1, 1, 1, 1] : memref<1x10x10x1280xf32> to memref<?x10x10x?xf32, strided<[128000, 12800, 1280,
+// CHECK-SAME: 1], offset: ?>>
+// CHECK: %[[SUBVIEW_0:.+]] = memref.subview %[[ARG1]][0, 0, %[[ARG2]], 0, 0, %[[ARG3]]] [8, 8, %[[D0]], 2,
+// CHECK-SAME: 2, %[[D1]]] [1, 1, 1, 1, 1, 1] : memref<8x8x1x2x2x1280xf32> to memref<8x8x?x2x2x?xf32,
+// CHECK-SAME: strided<[40960, 5120, 5120, 2560, 1280, 1], offset: ?>>
+// CHECK: iree_linalg_ext.winograd.input_transform output_tile_size(6) kernel_size(3) image_dimensions([1,
+// CHECK-SAME: 2]) ins(%[[SUBVIEW]] : memref<?x10x10x?xf32, strided<[128000, 12800, 1280, 1], offset: ?>>)
+// CHECK-SAME: outs(%[[SUBVIEW]]_0 : memref<8x8x?x2x2x?xf32, strided<[40960, 5120, 5120, 2560, 1280, 1], offset:
+// CHECK-SAME: ?>>)
+// CHECK: }
+// CHECK: }
+// CHECK: return
+// CHECK: }
+
+// -----
+
+func.func @winograd_output_transform(%arg0: tensor<8x8x1x2x2x32xf32>) -> tensor<1x12x12x32xf32> {
+ %0 = tensor.empty() : tensor<1x12x12x32xf32>
+ %1 = iree_linalg_ext.winograd.output_transform {__internal_linalg_transform__ = "tiling_winograd_input_nhwc"}
+ output_tile_size(6) kernel_size(3) image_dimensions([1, 2])
+ ins(%arg0 : tensor<8x8x1x2x2x32xf32>) outs(%0 : tensor<1x12x12x32xf32>) -> tensor<1x12x12x32xf32>
+ return %1 : tensor<1x12x12x32xf32>
+}
+// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (1, -d0 + s1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (32, -d0 + s1)>
+// CHECK: func.func @winograd_output_transform(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<8x8x1x2x2x32xf32>) ->
+// CHECK-SAME: tensor<1x12x12x32xf32> {
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[C32:.+]] = arith.constant 32 : index
+// CHECK: %[[D0:.+]] = tensor.empty() : tensor<1x12x12x32xf32>
+// CHECK: %[[D1:.+]] = scf.for %[[ARG1:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C1]] step %[[C1]]
+// CHECK-SAME: iter_args(%[[ARG2:[a-zA-Z0-9_]+]] = %[[D0]]) -> (tensor<1x12x12x32xf32>) {
+// CHECK-DAG: %[[D2:.+]] = affine.min #[[MAP]](%[[ARG1]])[%[[C1]], %[[C1]]]
+// CHECK: %[[D3:.+]] = scf.for %[[ARG3:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C32]] step %[[C32]]
+// CHECK-SAME: iter_args(%[[ARG4:[a-zA-Z0-9_]+]] = %[[ARG2]]) -> (tensor<1x12x12x32xf32>) {
+// CHECK-DAG: %[[D4:.+]] = affine.min #[[MAP1]](%[[ARG3]])[%[[C32]], %[[C32]]]
+// CHECK: %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[ARG0]][0, 0, %[[ARG1]], 0, 0, %[[ARG3]]] [8, 8,
+// CHECK-SAME: %[[D2]], 2, 2, %[[D4]]] [1, 1, 1, 1, 1, 1] : tensor<8x8x1x2x2x32xf32> to tensor<8x8x?x2x2x?xf32>
+// CHECK: %[[EXTRACTED_SLICE_0:.+]] = tensor.extract_slice %[[D0]][%[[ARG1]], 0, 0, %[[ARG3]]] [%[[D2]], 12,
+// CHECK-SAME: 12, %[[D4]]] [1, 1, 1, 1] : tensor<1x12x12x32xf32> to tensor<?x12x12x?xf32>
+// CHECK: %[[D5:.+]] = iree_linalg_ext.winograd.output_transform output_tile_size(6) kernel_size(3)
+// CHECK-SAME: image_dimensions([1, 2]) ins(%[[EXTRACTED_SLICE]] : tensor<8x8x?x2x2x?xf32>)
+// CHECK-SAME: outs(%[[EXTRACTED_SLICE]]_0 : tensor<?x12x12x?xf32>) -> tensor<?x12x12x?xf32>
+// CHECK: %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D5]] into %[[ARG4]][%[[ARG1]], 0, 0, %[[ARG3]]]
+// CHECK-SAME: [%[[D2]], 12, 12, %[[D4]]] [1, 1, 1, 1] : tensor<?x12x12x?xf32> into tensor<1x12x12x32xf32>
+// CHECK: scf.yield %[[INSERTED_SLICE]] : tensor<1x12x12x32xf32>
+// CHECK: }
+// CHECK: scf.yield %[[D3]] : tensor<1x12x12x32xf32>
+// CHECK: }
+// CHECK: return %[[D1]] : tensor<1x12x12x32xf32>
+// CHECK: }
+
+// -----
+
+func.func @winograd_output_transform_memref(%arg0: memref<8x8x1x2x2x32xf32>, %arg1: memref<1x12x12x32xf32>) {
+ iree_linalg_ext.winograd.output_transform {__internal_linalg_transform__ = "tiling_winograd_input_nhwc"}
+ output_tile_size(6) kernel_size(3) image_dimensions([1, 2])
+ ins(%arg0 : memref<8x8x1x2x2x32xf32>) outs(%arg1 : memref<1x12x12x32xf32>)
+ return
+}
+// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0, s1] -> (1, -d0 + s1)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0, s1] -> (32, -d0 + s1)>
+// CHECK: func.func @winograd_output_transform_memref(%[[ARG0:[a-zA-Z0-9_]+]]: memref<8x8x1x2x2x32xf32>,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref<1x12x12x32xf32>) {
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[C32:.+]] = arith.constant 32 : index
+// CHECK: scf.for %[[ARG2:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C1]] step %[[C1]] {
+// CHECK-DAG: %[[D0:.+]] = affine.min #[[MAP2]](%[[ARG2]])[%[[C1]], %[[C1]]]
+// CHECK: scf.for %[[ARG3:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C32]] step %[[C32]] {
+// CHECK-DAG: %[[D1:.+]] = affine.min #[[MAP3]](%[[ARG3]])[%[[C32]], %[[C32]]]
+// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG0]][0, 0, %[[ARG2]], 0, 0, %[[ARG3]]] [8, 8, %[[D0]], 2, 2,
+// CHECK-SAME: %[[D1]]] [1, 1, 1, 1, 1, 1] : memref<8x8x1x2x2x32xf32> to memref<8x8x?x2x2x?xf32, strided<[1024,
+// CHECK-SAME: 128, 128, 64, 32, 1], offset: ?>>
+// CHECK: %[[SUBVIEW_0:.+]] = memref.subview %[[ARG1]][%[[ARG2]], 0, 0, %[[ARG3]]] [%[[D0]], 12, 12, %[[D1]]]
+// CHECK-SAME: [1, 1, 1, 1] : memref<1x12x12x32xf32> to memref<?x12x12x?xf32, strided<[4608, 384, 32, 1],
+// CHECK-SAME: offset: ?>>
+// CHECK: iree_linalg_ext.winograd.output_transform output_tile_size(6) kernel_size(3) image_dimensions([1,
+// CHECK-SAME: 2]) ins(%[[SUBVIEW]] : memref<8x8x?x2x2x?xf32, strided<[1024, 128, 128, 64, 32, 1], offset: ?>>)
+// CHECK-SAME: outs(%[[SUBVIEW]]_0 : memref<?x12x12x?xf32, strided<[4608, 384, 32, 1], offset: ?>>)
+// CHECK: }
+// CHECK: }
+// CHECK: return
+// CHECK: }
+
+// -----
diff --git a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/vectorization.mlir b/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/vectorization.mlir
index 8490a4f..17671aa 100644
--- a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/vectorization.mlir
+++ b/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/vectorization.mlir
@@ -183,13 +183,12 @@
// CHECK-SAME: %[[OUT:[A-Za-z0-9]+]]:
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32
-// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<1x1x32x8xf32>
// CHECK: %[[READ:.+]] = vector.transfer_read %[[IN]]
// CHECK-SAME: [%[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[ZERO]]
// CHECK-SAME: {in_bounds = [true, true]} : tensor<1x1x1x1x8x32xf32>, vector<8x32xf32>
// CHECK: %[[TRANSP:.+]] = vector.transpose %[[READ]], [1, 0]
// CHECK: %[[WRITE:.+]] = vector.transfer_write %[[TRANSP]]
-// CHECK-SAME: %[[EMPTY]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]]
+// CHECK-SAME: %[[OUT]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]]
// CHECK-SAME: {in_bounds = [true, true]} : vector<32x8xf32>, tensor<1x1x32x8xf32>
// CHECK: return %[[WRITE]]
@@ -241,9 +240,7 @@
return %0 : tensor<1x1x128x64xf32>
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> (d0 floordiv 32)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 mod 32)>
-// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (d0 floordiv 8)>
-// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0) -> (d0 mod 8)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 floordiv 8)>
// CHECK-LABEL: func.func @KCRSsr_to_KCRS
// CHECK-SAME: %[[IN:[A-Za-z0-9]+]]:
// CHECK-SAME: %[[OUT:[A-Za-z0-9]+]]:
@@ -258,20 +255,17 @@
// CHECK: %[[RES1:.+]] = scf.for %[[S:.+]] = %[[C0]] to %[[C64]] step %[[C8]]
// CHECK-SAME: iter_args(%[[ITER1:.+]] = %[[ITER0]])
// CHECK-DAG: %[[IN_R:.+]] = affine.apply #[[MAP0]](%[[R]])
-// CHECK-DAG: %[[OUT_R:.+]] = affine.apply #[[MAP1]](%[[R]])
-// CHECK-DAG: %[[IN_S:.+]] = affine.apply #[[MAP2]](%[[S]])
-// CHECK-DAG: %[[OUT_S:.+]] = affine.apply #[[MAP3]](%[[S]])
-// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<1x1x32x8xf32>
+// CHECK-DAG: %[[IN_S:.+]] = affine.apply #[[MAP1]](%[[S]])
+// CHECK-DAG: %[[ITER1_SLICE:.+]] = tensor.extract_slice %[[ITER1]]
+// CHECK-SAME: [0, 0, %[[R]], %[[S]]] [1, 1, 32, 8] [1, 1, 1, 1]
// CHECK: %[[READ:.+]] = vector.transfer_read %[[IN]]
// CHECK-SAME: [%[[C0]], %[[C0]], %[[IN_R]], %[[IN_S]], %[[C0]], %[[C0]]], %[[ZERO]]
// CHECK-SAME: {in_bounds = [true, true]} : tensor<1x1x4x8x8x32xf32>, vector<8x32xf32>
// CHECK: %[[TRANSP:.+]] = vector.transpose %[[READ]], [1, 0]
// CHECK: %[[WRITE:.+]] = vector.transfer_write %[[TRANSP]]
-// CHECK-SAME: %[[EMPTY]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]]
+// CHECK-SAME: %[[ITER1_SLICE]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]]
// CHECK-SAME: {in_bounds = [true, true]} : vector<32x8xf32>, tensor<1x1x32x8xf32>
-// CHECK: %[[WRITE_SLICE:.+]] = tensor.extract_slice %[[WRITE]]
-// CHECK-SAME: [0, 0, %[[OUT_R]], %[[OUT_S]]] [1, 1, 32, 8] [1, 1, 1, 1]
-// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[WRITE_SLICE]]
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[WRITE]]
// CHECK-SAME: into %[[ITER1]][0, 0, %[[R]], %[[S]]] [1, 1, 32, 8] [1, 1, 1, 1]
// CHECK: scf.yield %[[INSERT]]
// CHECK: }
@@ -288,9 +282,7 @@
//CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> (-d0 + 13, 8)>
//CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (-d0 + 15, 2)>
//CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (d0 floordiv 8)>
-//CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0) -> (d0 mod 8)>
-//CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0) -> (d0 floordiv 2)>
-//CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0) -> (d0 mod 2)>
+//CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0) -> (d0 floordiv 2)>
// CHECK-LABEL: func.func @unpack_and_extract_slice
// CHECK-SAME: %[[IN:[A-Za-z0-9]+]]:
// CHECK-SAME: %[[OUT:[A-Za-z0-9]+]]:
@@ -307,20 +299,21 @@
// CHECK-SAME: iter_args(%[[ITER1:.+]] = %[[ITER0]])
// CHECK-DAG: %[[OUT_J_SZ:.+]] = affine.min #[[MAP1]](%[[J]])
// CHECK-DAG: %[[IN_I:.+]] = affine.apply #[[MAP2]](%[[I]])
-// CHECK-DAG: %[[OUT_I:.+]] = affine.apply #[[MAP3]](%[[I]])
-// CHECK-DAG: %[[IN_J:.+]] = affine.apply #[[MAP4]](%[[J]])
-// CHECK-DAG: %[[OUT_J:.+]] = affine.apply #[[MAP5]](%[[J]])
-// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<8x2xf32>
+// CHECK-DAG: %[[IN_J:.+]] = affine.apply #[[MAP3]](%[[J]])
+// CHECK-DAG: %[[ITER1_SLICE1:.+]] = tensor.extract_slice %[[ITER1]]
+// CHECK-SAME: [%[[I]], %[[J]]] [%[[OUT_I_SZ]], %[[OUT_J_SZ]]] [1, 1]
// CHECK-DAG: %[[READ:.+]] = vector.transfer_read %[[IN]]
// CHECK-SAME: [%[[IN_I]], %[[IN_J]], %[[C0]], %[[C0]]], %[[ZERO]]
// CHECK-SAME: {in_bounds = [true, true]} : tensor<2x8x8x2xf32>, vector<8x2xf32>
+// CHECK-DAG: %[[ITER1_SLICE2:.+]] = tensor.extract_slice %[[ITER1_SLICE1]]
+// CHECK-SAME: [0, 0] [%[[OUT_I_SZ]], %[[OUT_J_SZ]]] [1, 1]
// CHECK: %[[WRITE:.+]] = vector.transfer_write %[[READ]]
-// CHECK-SAME: %[[EMPTY]][%[[C0]], %[[C0]]] {in_bounds = [true, true]}
-// CHECK: %[[WRITE_SLICE:.+]] = tensor.extract_slice %[[WRITE]]
-// CHECK-SAME: [%[[OUT_I]], %[[OUT_J]]] [%[[OUT_I_SZ]], %[[OUT_J_SZ]]] [1, 1]
-// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[WRITE_SLICE]]
+// CHECK-SAME: %[[ITER1_SLICE2]][%[[C0]], %[[C0]]]
+// CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[WRITE]]
+// CHECK-SAME: into %[[ITER1_SLICE1]][0, 0] [%[[OUT_I_SZ]], %[[OUT_J_SZ]]] [1, 1]
+// CHECK: %[[INSERT2:.+]] = tensor.insert_slice %[[INSERT1]]
// CHECK-SAME: into %[[ITER1]][%[[I]], %[[J]]] [%[[OUT_I_SZ]], %[[OUT_J_SZ]]] [1, 1]
-// CHECK: scf.yield %[[INSERT]]
+// CHECK: scf.yield %[[INSERT2]]
// CHECK: }
// CHECK: scf.yield %[[RES1]]
// CHECK: }
@@ -333,9 +326,7 @@
return %0 : tensor<128x256xf32>
}
//CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> (d0 floordiv 32)>
-//CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 mod 32)>
-//CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (d0 floordiv 8)>
-//CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0) -> (d0 mod 8)>
+//CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 floordiv 8)>
// CHECK-LABEL: func.func @CKck_to_KC
// CHECK-SAME: %[[IN:[A-Za-z0-9]+]]:
// CHECK-SAME: %[[OUT:[A-Za-z0-9]+]]:
@@ -350,20 +341,13 @@
// CHECK: %[[RES1:.+]] = scf.for %[[C:.+]] = %[[C0]] to %[[C256]] step %[[C8]]
// CHECK-SAME: iter_args(%[[ITER1:.+]] = %[[ITER0]])
// CHECK-DAG: %[[IN_K:.+]] = affine.apply #[[MAP0]](%[[K]])
-// CHECK-DAG: %[[OUT_K:.+]] = affine.apply #[[MAP1]](%[[K]])
-// CHECK-DAG: %[[IN_C:.+]] = affine.apply #[[MAP2]](%[[C]])
-// CHECK-DAG: %[[OUT_C:.+]] = affine.apply #[[MAP3]](%[[C]])
-// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<32x8xf32>
+// CHECK-DAG: %[[IN_C:.+]] = affine.apply #[[MAP1]](%[[C]])
// CHECK-DAG: %[[READ:.+]] = vector.transfer_read %[[IN]]
// CHECK-SAME: [%[[IN_C]], %[[IN_K]], %[[C0]], %[[C0]]], %[[ZERO]]
// CHECK-SAME: {in_bounds = [true, true]} : tensor<32x4x32x8xf32>, vector<32x8xf32>
// CHECK: %[[WRITE:.+]] = vector.transfer_write %[[READ]]
-// CHECK-SAME: %[[EMPTY]][%[[C0]], %[[C0]]] {in_bounds = [true, true]}
-// CHECK: %[[WRITE_SLICE:.+]] = tensor.extract_slice %[[WRITE]]
-// CHECK-SAME: [%[[OUT_K]], %[[OUT_C]]] [32, 8] [1, 1]
-// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[WRITE_SLICE]]
-// CHECK-SAME: into %[[ITER1]][%[[K]], %[[C]]] [32, 8] [1, 1]
-// CHECK: scf.yield %[[INSERT]]
+// CHECK-SAME: %[[ITER1]][%[[K]], %[[C]]]
+// CHECK: scf.yield %[[WRITE]]
// CHECK: }
// CHECK: scf.yield %[[RES1]]
// CHECK: }