Adopt TransformDialectExtension and add iree_bufferize + iree_set_num_workgroups_to_one transform ops (#8821)
* Adopt TransformDialectExtension
This revision adds support for dynamic registration of ops into the transform dialect (see https://reviews.llvm.org/D123135 for more context).
This PR also add ire_set_num_workgroups_to_one and iree_bufferize as transform dialect operations.
With this layering improvement we can now run all transformations as a transform and connect
to vector lowering all the way to LLVM (inclusive or not depending on the needs).
For registration, we piggy-back on Interfaces.cpp::registerCodegenInterfaces for now. This layering should be improved in the future.
The C++-specific parts of setnumthreads and bufferization can now be removed from the interp codegen path.
The --linalg-transform-interp-disable-bufferization is now obsolete and deleted.
* Extract the implementation of SetNumWorkgroup in a free function.
This allows reusing the implementation from the transform dialect and ditch the pass application.
Passes, nesting and non-Module root ops behave very surprisingly and may easily do nothing and not report anything.
As a consequence, we would hit bug #8823 and get flaky single-thread results ..
diff --git a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
index 1454782..6316cc6 100644
--- a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
+++ b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
@@ -13,6 +13,9 @@
"//build_tools:dl": ["${CMAKE_DL_LIBS}"],
# IREE llvm-external-projects
+ "//llvm-external-projects/iree-dialects:IREEDialectsTransforms": [
+ "IREEDialectsTransforms"
+ ],
"//llvm-external-projects/iree-dialects:IREEInputDialect": [
"IREEInputDialect"
],
diff --git a/iree/compiler/Codegen/Common/BUILD b/iree/compiler/Codegen/Common/BUILD
index 175581b..a3359f6 100644
--- a/iree/compiler/Codegen/Common/BUILD
+++ b/iree/compiler/Codegen/Common/BUILD
@@ -51,7 +51,6 @@
"PolynomialApproximationPass.cpp",
"RemoveTrivialLoops.cpp",
"RewriteLinalgDestructiveUpdatesPass.cpp",
- "SetNumWorkgroupsFromLinalgExtPass.cpp",
"SetNumWorkgroupsPass.cpp",
"TileAndDistributeToWorkgroupsPass.cpp",
"TypePropagationPass.cpp",
@@ -61,6 +60,7 @@
hdrs = [
"BufferizationAnalysis.h",
"DestructiveUpdateUtils.h",
+ "Transforms.h",
],
deps = [
"//iree/compiler/Codegen:PassHeaders",
diff --git a/iree/compiler/Codegen/Common/CMakeLists.txt b/iree/compiler/Codegen/Common/CMakeLists.txt
index dffea2a..af26185 100644
--- a/iree/compiler/Codegen/Common/CMakeLists.txt
+++ b/iree/compiler/Codegen/Common/CMakeLists.txt
@@ -25,6 +25,7 @@
HDRS
"BufferizationAnalysis.h"
"DestructiveUpdateUtils.h"
+ "Transforms.h"
SRCS
"BufferizationAnalysis.cpp"
"BufferizeCopyOnlyDispatchesPass.cpp"
@@ -43,7 +44,6 @@
"PolynomialApproximationPass.cpp"
"RemoveTrivialLoops.cpp"
"RewriteLinalgDestructiveUpdatesPass.cpp"
- "SetNumWorkgroupsFromLinalgExtPass.cpp"
"SetNumWorkgroupsPass.cpp"
"TileAndDistributeToWorkgroupsPass.cpp"
"TypePropagationPass.cpp"
diff --git a/iree/compiler/Codegen/Common/SetNumWorkgroupsFromLinalgExtPass.cpp b/iree/compiler/Codegen/Common/SetNumWorkgroupsFromLinalgExtPass.cpp
deleted file mode 100644
index 796a1c4..0000000
--- a/iree/compiler/Codegen/Common/SetNumWorkgroupsFromLinalgExtPass.cpp
+++ /dev/null
@@ -1,190 +0,0 @@
-// Copyright 2021 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h"
-#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
-#include "iree-dialects/Dialect/LinalgExt/Transforms/Utils.h"
-#include "iree/compiler/Codegen/Dialect/LoweringConfig.h"
-#include "iree/compiler/Codegen/PassDetail.h"
-#include "iree/compiler/Codegen/Passes.h"
-#include "iree/compiler/Codegen/Transforms/Transforms.h"
-#include "iree/compiler/Codegen/Utils/Utils.h"
-#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
-#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
-#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
-#include "llvm/Support/Debug.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-
-#define DEBUG_TYPE "iree-codegen-set-num-workgroups-from-linalg-ext"
-
-using namespace mlir;
-using namespace mlir::iree_compiler::IREE::LinalgExt;
-
-namespace mlir {
-namespace iree_compiler {
-
-namespace {
-class SetNumWorkgroupsFromLinalgExtPass
- : public SetNumWorkgroupsFromLinalgExtBase<
- SetNumWorkgroupsFromLinalgExtPass> {
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<AffineDialect, IREE::HAL::HALDialect, IREELinalgExtDialect,
- linalg::LinalgDialect>();
- }
- void runOnOperation() override;
-};
-
-int64_t getBuilderArgs(HALInterfaceWorkgroupIDOp op) {
- return op.dimension().getZExtValue();
-}
-
-int64_t getBuilderArgs(HALInterfaceWorkgroupCountOp op) {
- return op.dimension().getZExtValue();
-}
-
-ValueRange getBuilderArgs(HALReturnOp op) { return op.getOperands(); }
-
-/// Generic implementation of one-to-one conversion from "SourceOp" to
-/// "TargetOp".
-template <typename SourceOp, typename TargetOp>
-class OneToOneRewritePattern : public OpRewritePattern<SourceOp> {
- public:
- using OpRewritePattern<SourceOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(SourceOp op,
- PatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<TargetOp>(op, getBuilderArgs(op));
- return success();
- }
-};
-
-/// Forward LinalgExt::InParallel -> Tensor::InsertSlice -> Flow::TensorStore.
-/// This pattern is necessary for correctness, it accounts for the fact that
-/// InParallel is distributed across multiple workgroups when lowering to HAL
-/// but it then connects to a sequential tensor.insert_slice and then to
-/// flow.dispatch.tensor_store.
-///
-// TODO: All the rewrites in this file this should be done as part of InParallel
-// -> HAL rewrite. But because of dialect dependencies and layering, we have
-// some phase ordering that prevents it atm.
-class ForwardInParallelResultToFlow
- : public OpRewritePattern<IREE::Flow::DispatchTensorStoreOp> {
- public:
- using OpRewritePattern<IREE::Flow::DispatchTensorStoreOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(IREE::Flow::DispatchTensorStoreOp op,
- PatternRewriter &rewriter) const override {
- auto insertSliceOp = op.value().getDefiningOp<tensor::InsertSliceOp>();
- if (!insertSliceOp) return failure();
-
- // TODO: this should be done as part of InParallel -> HAL rewrite.
- // But because of dialect dependencies and layering, we have some phase
- // ordering that prevents it atm. It does not make sense to move the pattern
- // because of this temporary layering problem, so we just ignore the
- // condition for now.
- //
- // auto inParallelOp =
- // insertSliceOp.source().getDefiningOp<IREE::LinalgExt::InParallelOp>();
- // if (!inParallelOp) return failure();
-
- SmallVector<OpFoldResult> offsets, sizes, strides;
- // `tensor.insert_slice` (i.e. the producer) folds **into**
- // `flow.dispatch.tensor.store` (i.e. the consumer).
- if (failed(foldOffsetsSizesAndStrides(rewriter, op.getLoc(), insertSliceOp,
- op, offsets, sizes, strides)))
- return failure();
- rewriter.replaceOpWithNewOp<IREE::Flow::DispatchTensorStoreOp>(
- op, insertSliceOp.source(), op.target(), op.target_dims(), offsets,
- sizes, strides);
-
- return success();
- }
-};
-
-} // namespace
-
-void SetNumWorkgroupsFromLinalgExtPass::runOnOperation() {
- MLIRContext *context = &getContext();
- IREE::HAL::ExecutableVariantOp variantOp = getOperation();
- ModuleOp module = variantOp.getInnerModule();
-
- // Perform 1-1 rewrites first: after the ExecutableEntryPointOp is
- // modified this will be more annoying to track.
- RewritePatternSet oneToOneRewrites(context);
- oneToOneRewrites
- .insert<OneToOneRewritePattern<HALInterfaceWorkgroupIDOp,
- IREE::HAL::InterfaceWorkgroupIDOp>,
- OneToOneRewritePattern<HALInterfaceWorkgroupCountOp,
- IREE::HAL::InterfaceWorkgroupCountOp>,
- OneToOneRewritePattern<HALReturnOp, IREE::HAL::ReturnOp>>(
- context);
- if (failed(applyPatternsAndFoldGreedily(module, std::move(oneToOneRewrites))))
- return signalPassFailure();
-
- // Perform forwarding patterns to bridge the tensor / flow gap.
- // This is necessary for correctness.
- // TODO: given existing bufferization tricks, this may trigger unnecessary
- // copies that need to be further investigated.
- RewritePatternSet forwardPatterns(context);
- forwardPatterns.insert<ForwardInParallelResultToFlow>(context);
- if (failed(applyPatternsAndFoldGreedily(module, std::move(forwardPatterns))))
- return signalPassFailure();
-
- llvm::StringMap<IREE::HAL::ExecutableEntryPointOp> entryPoints =
- getAllEntryPoints(module);
- for (auto funcOp : module.getOps<FuncOp>()) {
- auto entryPointOp = entryPoints.lookup(funcOp.getName());
- if (!entryPointOp) continue;
-
- bool numWorkgroupIsSet = false;
- assert(!entryPointOp.getWorkgroupCountBody() &&
- "Expected a single entryPoint op with no workgroup_count body");
-
- funcOp->walk([&](HALExecutableEntryPointOp op) {
- assert(!numWorkgroupIsSet);
- numWorkgroupIsSet = true;
- IRRewriter rewriter(op->getContext());
- rewriter.setInsertionPoint(entryPointOp);
- auto clonedEntryPointOp =
- rewriter.create<IREE::HAL::ExecutableEntryPointOp>(
- entryPointOp.getLoc(), entryPointOp.sym_nameAttr(),
- entryPointOp.ordinalAttr(), entryPointOp.layoutAttr(),
- entryPointOp.workgroup_sizeAttr(),
- entryPointOp.workgroup_local_memoryAttr());
- Block &block = clonedEntryPointOp.workgroup_count_region().emplaceBlock();
- rewriter.mergeBlocks(&op.workgroup_count_region().front(), &block);
- // TODO: Don't add args post-hoc and instead replace them during
- // `mergeBlocks`.
- for (int64_t i = 0, e = HALExecutableEntryPointOp::getNumWorkgroupDims();
- i < e; ++i) {
- block.addArgument(rewriter.getIndexType(), op->getLoc());
- }
- op->erase();
- entryPointOp.erase();
- });
- }
-
- // Apply post-distribution canonicalization passes.
- RewritePatternSet canonicalization(context);
- AffineApplyOp::getCanonicalizationPatterns(canonicalization, context);
- AffineMinOp::getCanonicalizationPatterns(canonicalization, context);
- populateAffineMinSCFCanonicalizationPattern(canonicalization);
- IREE::Flow::populateFlowDispatchCanonicalizationPatterns(canonicalization,
- context);
- if (failed(
- applyPatternsAndFoldGreedily(module, std::move(canonicalization)))) {
- return signalPassFailure();
- }
-}
-
-std::unique_ptr<OperationPass<IREE::HAL::ExecutableVariantOp>>
-createSetNumWorkgroupsFromLinalgExtPass() {
- return std::make_unique<SetNumWorkgroupsFromLinalgExtPass>();
-}
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Codegen/Common/SetNumWorkgroupsPass.cpp b/iree/compiler/Codegen/Common/SetNumWorkgroupsPass.cpp
index 2899e66..c89fc41 100644
--- a/iree/compiler/Codegen/Common/SetNumWorkgroupsPass.cpp
+++ b/iree/compiler/Codegen/Common/SetNumWorkgroupsPass.cpp
@@ -75,9 +75,9 @@
};
} // namespace
-void SetNumWorkgroupsPass::runOnOperation() {
- MLIRContext *context = &getContext();
- IREE::HAL::ExecutableVariantOp variantOp = getOperation();
+LogicalResult setNumWorkgroupsImpl(IREE::HAL::ExecutableVariantOp variantOp,
+ ArrayRef<int64_t> workloadPerWorkgroup) {
+ MLIRContext *context = variantOp.getContext();
ModuleOp module = variantOp.getInnerModule();
llvm::StringMap<IREE::HAL::ExecutableEntryPointOp> entryPoints =
@@ -102,9 +102,8 @@
RewritePatternSet patterns(funcOp.getContext());
patterns.insert<SetWorkgroupSizePattern>(funcOp.getContext(),
currWorkloadPerWorkgroup);
- if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
- return signalPassFailure();
- }
+ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns))))
+ return failure();
}
// The workgroup count region might already be set by op-specific
@@ -140,8 +139,7 @@
}
OpBuilder builder(context);
- if (failed(defineWorkgroupCountRegion(builder, funcOp, regionBuilder)))
- return signalPassFailure();
+ return defineWorkgroupCountRegion(builder, funcOp, regionBuilder);
}
// Apply post distribution canonicalization passes.
@@ -150,10 +148,13 @@
populateAffineMinSCFCanonicalizationPattern(canonicalization);
IREE::Flow::populateFlowDispatchCanonicalizationPatterns(canonicalization,
context);
- if (failed(
- applyPatternsAndFoldGreedily(module, std::move(canonicalization)))) {
- return signalPassFailure();
- }
+ return applyPatternsAndFoldGreedily(module, std::move(canonicalization));
+}
+
+void SetNumWorkgroupsPass::runOnOperation() {
+ IREE::HAL::ExecutableVariantOp variantOp = getOperation();
+ if (failed(setNumWorkgroupsImpl(variantOp, workloadPerWorkgroup)))
+ signalPassFailure();
}
std::unique_ptr<OperationPass<IREE::HAL::ExecutableVariantOp>>
diff --git a/iree/compiler/Codegen/Common/Transforms.h b/iree/compiler/Codegen/Common/Transforms.h
new file mode 100644
index 0000000..cc4276a
--- /dev/null
+++ b/iree/compiler/Codegen/Common/Transforms.h
@@ -0,0 +1,19 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+/// Expose the implementation of the set num workgroups pass as a free function
+/// because passes are surprisingly hard to apply reliably when they need to
+/// anchor on special (i.e. non-Module) ops.
+LogicalResult setNumWorkgroupsImpl(IREE::HAL::ExecutableVariantOp variantOp,
+ ArrayRef<int64_t> workloadPerWorkgroup);
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Codegen/Interfaces/BUILD b/iree/compiler/Codegen/Interfaces/BUILD
index 3d168b6..7e26d02 100644
--- a/iree/compiler/Codegen/Interfaces/BUILD
+++ b/iree/compiler/Codegen/Interfaces/BUILD
@@ -41,6 +41,9 @@
deps = [
":BufferizationInterfaces",
":ProcessorOpInterfaces",
+ # TODO: Remove this dependency once the transform dialect extensions
+ # have a better registration mechanism.
+ "//iree/compiler/Codegen/TransformDialectExtensions",
],
)
diff --git a/iree/compiler/Codegen/Interfaces/CMakeLists.txt b/iree/compiler/Codegen/Interfaces/CMakeLists.txt
index fe86298..7235cf6 100644
--- a/iree/compiler/Codegen/Interfaces/CMakeLists.txt
+++ b/iree/compiler/Codegen/Interfaces/CMakeLists.txt
@@ -20,6 +20,7 @@
DEPS
::BufferizationInterfaces
::ProcessorOpInterfaces
+ iree::compiler::Codegen::TransformDialectExtensions
PUBLIC
)
diff --git a/iree/compiler/Codegen/Interfaces/Interfaces.cpp b/iree/compiler/Codegen/Interfaces/Interfaces.cpp
index 64a1698..e682c69 100644
--- a/iree/compiler/Codegen/Interfaces/Interfaces.cpp
+++ b/iree/compiler/Codegen/Interfaces/Interfaces.cpp
@@ -8,6 +8,9 @@
#include "iree/compiler/Codegen/Interfaces/BufferizationInterfaces.h"
#include "iree/compiler/Codegen/Interfaces/ProcessorOpInterfaces.h"
+// TODO: Remove this dependency once the transform dialect extensions
+// have a better registration mechanism.
+#include "iree/compiler/Codegen/TransformDialectExtensions/TransformDialectExtensions.h"
namespace mlir {
namespace iree_compiler {
@@ -15,6 +18,10 @@
void registerCodegenInterfaces(DialectRegistry ®istry) {
registerProcessorOpInterfaceExternalModels(registry);
registerBufferizationInterfaces(registry);
+ // TODO: Remove this dependency once the transform dialect extensions
+ // have a better registration mechanism.
+ // TODO: when warranted, move to its own file.
+ registerLinalgTransformDialectExtension(registry);
}
} // namespace iree_compiler
diff --git a/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/iree/compiler/Codegen/LLVMCPU/Passes.cpp
index 2973bac..6ab54ac 100644
--- a/iree/compiler/Codegen/LLVMCPU/Passes.cpp
+++ b/iree/compiler/Codegen/LLVMCPU/Passes.cpp
@@ -35,13 +35,6 @@
"before conversion to LLVM IR"),
llvm::cl::init(false));
-// TODO: Remove this flag once we can call bufferize from the transform dialect.
-static llvm::cl::opt<bool> clDisableLinalgTransformInterpBufferization(
- "linalg-transform-interp-disable-bufferization",
- llvm::cl::desc("Disables bufferization when running the linalg transform "
- "interp pass (testing only)."),
- llvm::cl::init(false));
-
//===---------------------------------------------------------------------===//
// Default allocation functions for CPU backend
//===---------------------------------------------------------------------===//
@@ -426,28 +419,12 @@
void addLinalgTransformInterpPasses(OpPassManager &passManager) {
// Give control to the linalg_transform dialect.
passManager.addPass(createLinalgTransformInterpreterPass());
+
// Dropping the schedule is only needed if we want to embed the transform in
// the module: we should drop the schedule once applied.
// This pass does nothing in the case where we apply a separate policy
// through a file.
passManager.addPass(createDropSchedulePass());
-
- // Sets the number of workgroups using kFakeHAL op information.
- passManager.addPass(createSetNumWorkgroupsFromLinalgExtPass());
-
- // TODO: Remove this flag and the code below once we can call bufferize from
- // the transform dialect.
- if (clDisableLinalgTransformInterpBufferization) return;
-
- OpPassManager &modulePM = passManager.nest<ModuleOp>();
- // Bufferize the dispatch.
- BufferizationOptions::AllocationFn allocationFn =
- cpuComprehensiveBufferizeAllocationFn;
- BufferizationOptions::DeallocationFn deallocationFn =
- cpuComprehensiveBufferizeDeallocationFn;
- BufferizationOptions::MemCpyFn memcpyFn = cpuComprehensiveBufferizeCopyFn;
- addIREEComprehensiveBufferizePasses(modulePM, allocationFn, deallocationFn,
- memcpyFn);
}
static void addLowerToLLVMPasses(OpPassManager &passManager) {
diff --git a/iree/compiler/Codegen/LLVMCPU/test/BUILD b/iree/compiler/Codegen/LLVMCPU/test/BUILD
index 5fe380c..910be59 100644
--- a/iree/compiler/Codegen/LLVMCPU/test/BUILD
+++ b/iree/compiler/Codegen/LLVMCPU/test/BUILD
@@ -25,7 +25,6 @@
"hal_interface_constants.mlir",
"hal_interface_workgroup_info.mlir",
"illegal_configuration.mlir",
- "linalg_ext_hal_to_hal.mlir",
"linalg_transform.mlir",
"materialize_launch_configuration.mlir",
"synchronize_symbol_visibility.mlir",
diff --git a/iree/compiler/Codegen/LLVMCPU/test/CMakeLists.txt b/iree/compiler/Codegen/LLVMCPU/test/CMakeLists.txt
index d65c024..6d7d25a 100644
--- a/iree/compiler/Codegen/LLVMCPU/test/CMakeLists.txt
+++ b/iree/compiler/Codegen/LLVMCPU/test/CMakeLists.txt
@@ -19,7 +19,6 @@
"hal_interface_constants.mlir"
"hal_interface_workgroup_info.mlir"
"illegal_configuration.mlir"
- "linalg_ext_hal_to_hal.mlir"
"linalg_transform.mlir"
"materialize_launch_configuration.mlir"
"synchronize_symbol_visibility.mlir"
diff --git a/iree/compiler/Codegen/LLVMCPU/test/linalg_ext_hal_to_hal.mlir b/iree/compiler/Codegen/LLVMCPU/test/linalg_ext_hal_to_hal.mlir
deleted file mode 100644
index d0fe6ae..0000000
--- a/iree/compiler/Codegen/LLVMCPU/test/linalg_ext_hal_to_hal.mlir
+++ /dev/null
@@ -1,45 +0,0 @@
-// RUN: iree-opt %s -pass-pipeline='hal.executable(hal.executable.variant(iree-llvmcpu-lower-executable-target))' -iree-set-num-workgroups-from-linalg-ext | FileCheck %s
-
-hal.executable @_matmul_static_dispatch_0 {
-hal.executable.variant public @embedded_elf_x86_64, target = <"llvm", "embedded-elf-x86_64", {cpu_features = "", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", native_vector_size = 16 : index, target_triple = "x86_64-unknown-unknown-eabi-elf"}> {
- // CHECK: hal.executable.entry_point public @_matmul_static_dispatch_0
- // CHECK: %[[C1:.*]] = arith.constant 1 : index
- // CHECK: %[[C3:.*]] = arith.constant 3 : index
- // CHECK: hal.return %[[C3]], %[[C1]], %[[C1]] : index, index, index
- // CHECK: }
- hal.executable.entry_point public @_matmul_static_dispatch_0 ordinal(0) layout(#hal.executable.layout<push_constants = 0, sets = [#hal.descriptor_set.layout<0, bindings = [#hal.descriptor_set.binding<0, storage_buffer>]>]>) {translation_info = #iree_codegen.translation_info<LinalgTransformInterpCodegen>}
- builtin.module {
- // CHECK: func @_matmul_static_dispatch_0
- func @_matmul_static_dispatch_0() {
- %cst = arith.constant dense<[[1.500000e+01, 1.400000e+01, 1.300000e+01], [1.200000e+01, 1.100000e+01, 1.000000e+01], [9.000000e+00, 8.000000e+00, 7.000000e+00], [6.000000e+00, 5.000000e+00, 4.000000e+00], [3.000000e+00, 2.000000e+00, 1.000000e+00]]> : tensor<5x3xf32>
- %cst_0 = arith.constant dense<[[1.500000e+01, 1.400000e+01, 1.300000e+01, 1.200000e+01, 1.100000e+01], [1.000000e+01, 9.000000e+00, 8.000000e+00, 7.000000e+00, 6.000000e+00], [5.000000e+00, 4.000000e+00, 3.000000e+00, 2.000000e+00, 1.000000e+00]]> : tensor<3x5xf32>
- %c0 = arith.constant 0 : index
- %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readwrite:5x5xf32>
- %1 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [5, 5], strides = [1, 1] : !flow.dispatch.tensor<readwrite:5x5xf32> -> tensor<5x5xf32>
-
- // The body of this op is folded into the top-level hal.executable.entry_point body.
- // CHECK-NOT: iree_linalg_ext.hal.executable.entry_point
- "iree_linalg_ext.hal.executable.entry_point"() ({
- %c1 = arith.constant 1 : index
- %c3 = arith.constant 3 : index
- // CHECK-NOT: iree_linalg_ext.hal.return
- iree_linalg_ext.hal.return %c3, %c1, %c1 : index, index, index
- }) : () -> ()
-
- // CHECK: = hal.interface.workgroup.id[0] : index
- %2 = iree_linalg_ext.hal.interface.workgroup.id[0] : index
- // Unused, just goes away.
- // CHECK-NOT: = hal.interface.workgroup.id[0] : index
- %3 = iree_linalg_ext.hal.interface.workgroup.count[0] : index
- %4 = affine.apply affine_map<()[s0] -> (s0 * 2)>()[%2]
- %5 = affine.min affine_map<()[s0] -> (s0 * -2 + 5, 2)>()[%2]
- %6 = tensor.extract_slice %1[%4, 0] [%5, 5] [1, 1] : tensor<5x5xf32> to tensor<?x5xf32>
- %7 = tensor.extract_slice %cst[%4, 0] [%5, 3] [1, 1] : tensor<5x3xf32> to tensor<?x3xf32>
- %8 = linalg.matmul {iree_linalg_transform.matched} ins(%7, %cst_0 : tensor<?x3xf32>, tensor<3x5xf32>) outs(%6 : tensor<?x5xf32>) -> tensor<?x5xf32>
- %9 = tensor.insert_slice %8 into %1[%4, 0] [%5, 5] [1, 1] : tensor<?x5xf32> into tensor<5x5xf32>
- flow.dispatch.tensor.store %9, %0, offsets = [0, 0], sizes = [5, 5], strides = [1, 1] : tensor<5x5xf32> -> !flow.dispatch.tensor<readwrite:5x5xf32>
- return
- }
- }
-}
-}
diff --git a/iree/compiler/Codegen/LLVMCPU/test/linalg_transform.mlir b/iree/compiler/Codegen/LLVMCPU/test/linalg_transform.mlir
index bcc69a5..0f6b716 100644
--- a/iree/compiler/Codegen/LLVMCPU/test/linalg_transform.mlir
+++ b/iree/compiler/Codegen/LLVMCPU/test/linalg_transform.mlir
@@ -1,14 +1,9 @@
-// RUN: iree-opt %s -pass-pipeline='hal.executable(hal.executable.variant(iree-llvmcpu-lower-executable-target))' --iree-codegen-use-linalg-transform-interp --linalg-transform-interp-disable-bufferization --linalg-transform-file-name=%p/linalg_transform_spec.mlir | FileCheck %s
+// RUN: iree-opt %s -pass-pipeline='hal.executable(hal.executable.variant(iree-llvmcpu-lower-executable-target))' --iree-codegen-use-linalg-transform-interp --linalg-transform-file-name=%p/linalg_transform_spec.mlir | FileCheck %s
#device_target_cpu = #hal.device.target<"cpu", {executable_targets = [#hal.executable.target<"llvm", "embedded-elf-x86_64", {cpu_features = "", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", native_vector_size = 16 : index, target_triple = "x86_64-unknown-unknown-eabi-elf"}>]}>
#executable_layout = #hal.executable.layout<push_constants = 0, sets = [#hal.descriptor_set.layout<0, bindings = [#hal.descriptor_set.binding<0, storage_buffer>, #hal.descriptor_set.binding<1, storage_buffer>, #hal.descriptor_set.binding<2, storage_buffer>]>]>
#executable_target_embedded_elf_x86_64_ = #hal.executable.target<"llvm", "embedded-elf-x86_64", {cpu_features = "", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", native_vector_size = 16 : index, target_triple = "x86_64-unknown-unknown-eabi-elf"}>
-// CHECK-DAG: #[[$map0:.*]] = affine_map<()[s0] -> (s0 * 2)>
-// CHECK-DAG: #[[$map1:.*]] = affine_map<()[s0] -> (s0 * -2 + 250, 2)>
-// CHECK-DAG: #[[$map2:.*]] = affine_map<()[s0] -> (s0 * 4)>
-// CHECK-DAG: #[[$map3:.*]] = affine_map<()[s0] -> (s0 * -4 + 1020, 4)>
-
hal.executable private @pad_matmul_static_dispatch_0 {
hal.executable.variant public @embedded_elf_x86_64, target = #executable_target_embedded_elf_x86_64_ {
hal.executable.entry_point public @pad_matmul_static_dispatch_0 ordinal(0) layout(#executable_layout)
@@ -21,29 +16,15 @@
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [250, 500], strides = [1, 1] : !flow.dispatch.tensor<readonly:250x500xf32> -> tensor<250x500xf32>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [500, 1020], strides = [1, 1] : !flow.dispatch.tensor<readonly:500x1020xf32> -> tensor<500x1020xf32>
- // CHECK: hal.executable.entry_point public @pad_matmul_static_dispatch_0 ordinal(0) layout(#executable_layout) {
- // CHECK: %[[C1:.*]] = arith.constant 1 : index
- // CHECK-DAG: %[[C125:.*]] = arith.constant 125 : index
- // CHECK-DAG: %[[C255:.*]] = arith.constant 255 : index
- // CHECK: hal.return %[[C125]], %[[C255]], %[[C1]] : index, index, index
%50 = linalg.init_tensor [250, 1020] : tensor<250x1020xf32>
%cst = arith.constant 0.000000e+00 : f32
%5 = linalg.fill ins(%cst : f32) outs(%50 : tensor<250x1020xf32>) -> tensor<250x1020xf32>
- // CHECK-NOT: iree_linalg_ext
- // CHECK: %[[IDX:.*]] = hal.interface.workgroup.id[0] : index
- // CHECK: %[[OFFX:.*]] = affine.apply #[[$map0]]()[%[[IDX]]]
- // CHECK: %[[SZX:.*]] = affine.min #[[$map1]]()[%[[IDX]]]
- // CHECK: tensor.extract_slice {{.*}}[%[[OFFX]]{{.*}}[%[[SZX]]
- // CHECK: %[[IDY:.*]] = hal.interface.workgroup.id[1] : index
- // CHECK: %[[OFFY:.*]] = affine.apply #[[$map2]]()[%[[IDY]]]
- // CHECK: %[[SZY:.*]] = affine.min #[[$map3]]()[%[[IDY]]]
- // CHECK: tensor.extract_slice {{.*}}[{{.*}}, %[[OFFY]]] [{{.*}}, %[[SZY]]]
- // CHECK: %[[MM:.*]] = linalg.matmul{{.*}}ins{{.*}}outs
- // CHECK: %[[RES_OFFX:.*]] = affine.apply #[[$map0]]()[%[[IDX]]]
- // CHECK: %[[RES_OFFY:.*]] = affine.apply #[[$map2]]()[%[[IDY]]]
- // CHECK: flow.dispatch.tensor.store %[[MM]], %{{.*}}, offsets = [%[[RES_OFFX]], %[[RES_OFFY]]]{{.*}} : tensor<?x?xf32> -> !flow.dispatch.tensor<readwrite:250x1020xf32>
- // CHECK: return
+ // CHECK: memref.assume_alignment %{{.*}}, 64 : memref<250x1020xf32>
+ // CHECK-NEXT: linalg.fill ins(%{{.*}} : f32) outs(%{{.*}} : memref<250x1020xf32>)
+ // CHECK-NEXT: linalg.matmul{{.*}}ins(%{{.*}} : memref<250x500xf32>, memref<500x1020xf32>) outs(%{{.*}} : memref<250x1020xf32>)
+ // CHECK-NEXT: return
+
%6 = linalg.matmul ins(%3, %4 : tensor<250x500xf32>, tensor<500x1020xf32>) outs(%5 : tensor<250x1020xf32>) -> tensor<250x1020xf32>
flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [250, 1020], strides = [1, 1] : tensor<250x1020xf32> -> !flow.dispatch.tensor<readwrite:250x1020xf32>
return
diff --git a/iree/compiler/Codegen/LLVMCPU/test/linalg_transform_spec.mlir b/iree/compiler/Codegen/LLVMCPU/test/linalg_transform_spec.mlir
index dd59910..e2c3023 100644
--- a/iree/compiler/Codegen/LLVMCPU/test/linalg_transform_spec.mlir
+++ b/iree/compiler/Codegen/LLVMCPU/test/linalg_transform_spec.mlir
@@ -9,17 +9,6 @@
}
iree_linalg_transform.sequence {
- %0 = match @pdl_matmul_target
- // %res contains the tiled op and the linalg_ext.tile op.
- %tiling_1_result:2 = tile_to_iree_linalg_ext_tile_op %0 {sizes = [2]}
- %tiling_2_result:2 = tile_to_iree_linalg_ext_tile_op %tiling_1_result#0 {sizes = [0, 4]}
- %inp_2 = rewrite_iree_linalg_ext_tile_to_in_parallel %tiling_2_result#1
- %inp_1 = rewrite_iree_linalg_ext_tile_to_in_parallel %tiling_1_result#1
- // TODO: Ideally we would bufferize here but we can't atm.
- rewrite_iree_linalg_ext_in_parallel_to_hal %inp_2
- rewrite_iree_linalg_ext_in_parallel_to_hal %inp_1
- // Bufferize happens at the IREE level on HAL operations, we cannot just
- // call the linalg_transform.bufferize operation here.
- // Instead it happens automatically at the end of the linalg-transform-interp
- // pass.
+ %0 = match @pdl_matmul_target
+ iree_bufferize
}
diff --git a/iree/compiler/Codegen/Passes.h b/iree/compiler/Codegen/Passes.h
index 5594cc3..fa25554 100644
--- a/iree/compiler/Codegen/Passes.h
+++ b/iree/compiler/Codegen/Passes.h
@@ -144,11 +144,6 @@
std::unique_ptr<OperationPass<IREE::HAL::ExecutableVariantOp>>
createSetNumWorkgroupsPass(ArrayRef<int64_t> workgroupSize = {});
-/// Propagates the number of workgroups to use for each entry point in the
-/// dispatch region.
-std::unique_ptr<OperationPass<IREE::HAL::ExecutableVariantOp>>
-createSetNumWorkgroupsFromLinalgExtPass();
-
/// Pass to optimize vector transfer_read and transfer_write.
std::unique_ptr<OperationPass<func::FuncOp>> createOptimizeVectorTransferPass();
diff --git a/iree/compiler/Codegen/Passes.td b/iree/compiler/Codegen/Passes.td
index c93fdb4..eab90a1 100644
--- a/iree/compiler/Codegen/Passes.td
+++ b/iree/compiler/Codegen/Passes.td
@@ -126,16 +126,6 @@
let constructor = "mlir::iree_compiler::createSetNumWorkgroupsPass()";
}
-// TODO: Consider removing or moving to HAL/Transforms in order to avoid
-// polluting common pass declarations with HAL specific ops.
-def SetNumWorkgroupsFromLinalgExt :
- Pass<"iree-set-num-workgroups-from-linalg-ext",
- "mlir::iree_compiler::IREE::HAL::ExecutableVariantOp"> {
- let summary =
- "Propagate the number of workgroups for entry point functions";
- let constructor = "mlir::iree_compiler::createSetNumWorkgroupsFromLinalgExtPass()";
-}
-
// TODO: Rename argument to be fully qualified.
def LinalgToVectorVectorizeConv :
Pass<"iree-codegen-vectorize-linalg-conv", "func::FuncOp"> {
diff --git a/iree/compiler/Codegen/TransformDialectExtensions/BUILD b/iree/compiler/Codegen/TransformDialectExtensions/BUILD
new file mode 100644
index 0000000..dad1ffa
--- /dev/null
+++ b/iree/compiler/Codegen/TransformDialectExtensions/BUILD
@@ -0,0 +1,48 @@
+# Copyright 2019 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "TransformDialectExtensions",
+ srcs = [
+ "TransformDialectExtensions.cpp",
+ ],
+ hdrs = [
+ "TransformDialectExtensions.h",
+ ],
+ deps = [
+ "//iree/compiler/Codegen",
+ "//iree/compiler/Codegen:PassHeaders",
+ "//iree/compiler/Codegen/Common",
+ "//iree/compiler/Codegen/Interfaces:BufferizationInterfaces",
+ "//iree/compiler/Codegen/Utils",
+ "//iree/compiler/Dialect/HAL/IR",
+ "//llvm-external-projects/iree-dialects:IREEDialectsTransforms",
+ "//llvm-external-projects/iree-dialects:IREELinalgExtDialect",
+ "//llvm-external-projects/iree-dialects:IREELinalgExtTransforms",
+ "//llvm-external-projects/iree-dialects:IREELinalgTransformDialect",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:Affine",
+ "@llvm-project//mlir:AffineUtils",
+ "@llvm-project//mlir:ArithmeticDialect",
+ "@llvm-project//mlir:BufferizationDialect",
+ "@llvm-project//mlir:BufferizationTransforms",
+ "@llvm-project//mlir:FuncDialect",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:LinalgInterfaces",
+ "@llvm-project//mlir:LinalgOps",
+ "@llvm-project//mlir:PDLDialect",
+ "@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:SCFDialect",
+ "@llvm-project//mlir:TensorDialect",
+ "@llvm-project//mlir:TransformUtils",
+ ],
+)
diff --git a/iree/compiler/Codegen/TransformDialectExtensions/CMakeLists.txt b/iree/compiler/Codegen/TransformDialectExtensions/CMakeLists.txt
new file mode 100644
index 0000000..788fa1c
--- /dev/null
+++ b/iree/compiler/Codegen/TransformDialectExtensions/CMakeLists.txt
@@ -0,0 +1,48 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# iree/compiler/Codegen/TransformDialectExtensions/BUILD #
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_cc_library(
+ NAME
+ TransformDialectExtensions
+ HDRS
+ "TransformDialectExtensions.h"
+ SRCS
+ "TransformDialectExtensions.cpp"
+ DEPS
+ IREEDialectsTransforms
+ IREELinalgExtDialect
+ IREELinalgExtTransforms
+ IREELinalgTransformDialect
+ LLVMSupport
+ MLIRAffine
+ MLIRAffineUtils
+ MLIRArithmetic
+ MLIRBufferization
+ MLIRBufferizationTransforms
+ MLIRFunc
+ MLIRIR
+ MLIRLinalg
+ MLIRPDL
+ MLIRPass
+ MLIRSCF
+ MLIRTensor
+ MLIRTransformUtils
+ iree::compiler::Codegen
+ iree::compiler::Codegen::Common
+ iree::compiler::Codegen::Interfaces::BufferizationInterfaces
+ iree::compiler::Codegen::PassHeaders
+ iree::compiler::Codegen::Utils
+ iree::compiler::Dialect::HAL::IR
+ PUBLIC
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/iree/compiler/Codegen/TransformDialectExtensions/TransformDialectExtensions.cpp b/iree/compiler/Codegen/TransformDialectExtensions/TransformDialectExtensions.cpp
new file mode 100644
index 0000000..955b9fd
--- /dev/null
+++ b/iree/compiler/Codegen/TransformDialectExtensions/TransformDialectExtensions.cpp
@@ -0,0 +1,187 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "TransformDialectExtensions.h"
+
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h"
+#include "iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.h"
+#include "iree-dialects/Dialect/LinalgTransform/TransformOpInterface.h"
+#include "iree-dialects/Transforms/Functional.h"
+#include "iree/compiler/Codegen/Common/Transforms.h"
+#include "iree/compiler/Codegen/Interfaces/BufferizationInterfaces.h"
+#include "iree/compiler/Codegen/Passes.h"
+#include "iree/compiler/Codegen/Utils/Utils.h"
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/PDL/IR/PDL.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/Pass/PassManager.h"
+
+using namespace mlir;
+using namespace mlir::iree_compiler::IREE;
+
+//===---------------------------------------------------------------------===//
+// Default allocation functions for CPU backend
+// TODO: register the bufferization behavior in a target-specific way.
+//===---------------------------------------------------------------------===//
+
+// Default allocation function to use with IREEs bufferization.
+static Value cpuAllocationFunction(OpBuilder &builder, Location loc,
+ ArrayRef<int64_t> staticShape,
+ Type elementType,
+ ArrayRef<Value> dynamicSizes) {
+ MemRefType allocType = MemRefType::get(staticShape, elementType);
+ return builder.create<memref::AllocaOp>(loc, allocType, dynamicSizes);
+}
+
+// Allocation callbacks to use with upstream comprehensive bufferization
+static FailureOr<Value> cpuComprehensiveBufferizeAllocationFn(
+ OpBuilder &builder, Location loc, MemRefType memRefType,
+ ValueRange dynamicSizes, unsigned alignment) {
+ return builder
+ .create<memref::AllocaOp>(loc, memRefType, dynamicSizes,
+ builder.getI64IntegerAttr(alignment))
+ .getResult();
+}
+
+static LogicalResult cpuComprehensiveBufferizeDeallocationFn(OpBuilder &builder,
+ Location loc,
+ Value allocation) {
+ return success();
+}
+
+static LogicalResult cpuComprehensiveBufferizeCopyFn(OpBuilder &builder,
+ Location loc, Value from,
+ Value to) {
+ // TODO: ideally we should use linalg.copy which was recently reintroduced as
+ // an OpDSL named op. However, IREE-specific patterns to cleanup spurious
+ // post-bufferization copies do not trigger properly.
+ // So we keep using `createLinalgCopyOp` which builds a GenericOp.
+ // builder.create<linalg::CopyOp>(loc, from, to);
+ mlir::iree_compiler::createLinalgCopyOp(builder, loc, from, to);
+ return success();
+}
+
+//===---------------------------------------------------------------------===//
+// IREE-specific transformations defined outside of iree_linalg_transform.
+//===---------------------------------------------------------------------===//
+
+// Note: with the recent TypeID changes, hiding these classes inside an
+// anonymous namespace would require specific `MLIR_DECLARE_EXPLICIT_TYPE_ID`
+// for each class.
+
+// namespace {
+
+// TODO: Move to tablegen. Until this stabilizes upstream, simple C++ is enough.
+class IREEBufferizeOp
+ : public Op<IREEBufferizeOp,
+ linalg::transform::TransformOpInterface::Trait> {
+ public:
+ using Op::Op;
+
+ static ArrayRef<StringRef> getAttributeNames() { return {}; }
+
+ static constexpr llvm::StringLiteral getOperationName() {
+ return llvm::StringLiteral("iree_linalg_transform.iree_bufferize");
+ }
+
+ Value target() { return nullptr; }
+
+ LogicalResult apply(linalg::transform::TransformResults &results,
+ linalg::transform::TransformState &state) {
+ PassManager pm(getContext());
+ // Bufferize the dispatch.
+ using mlir::bufferization::BufferizationOptions;
+ BufferizationOptions::AllocationFn allocationFn =
+ cpuComprehensiveBufferizeAllocationFn;
+ BufferizationOptions::DeallocationFn deallocationFn =
+ cpuComprehensiveBufferizeDeallocationFn;
+ BufferizationOptions::MemCpyFn memcpyFn = cpuComprehensiveBufferizeCopyFn;
+ mlir::iree_compiler::addIREEComprehensiveBufferizePasses(
+ pm, allocationFn, deallocationFn, memcpyFn);
+ WalkResult res = state.getTopLevel()->walk([&](ModuleOp moduleOp) {
+ if (failed(pm.run(moduleOp))) return WalkResult::interrupt();
+ return WalkResult::advance();
+ });
+ return failure(res.wasInterrupted());
+ }
+
+ // let assemblyFormat = "attr-dict";
+ static ParseResult parse(OpAsmParser &parser, OperationState &state) {
+ parser.parseOptionalAttrDict(state.attributes);
+ return success();
+ }
+
+ // let assemblyFormat = "attr-dict";
+ void print(OpAsmPrinter &printer) {
+ printer.printOptionalAttrDict((*this)->getAttrs());
+ }
+};
+
+// TODO: Move to tablegen. Until this stabilizes upstream, simple C++ is enough.
+class IREESetNumWorkgroupToOneOp
+ : public Op<IREESetNumWorkgroupToOneOp,
+ linalg::transform::TransformOpInterface::Trait> {
+ public:
+ using Op::Op;
+
+ static ArrayRef<StringRef> getAttributeNames() { return {}; }
+
+ static constexpr llvm::StringLiteral getOperationName() {
+ return llvm::StringLiteral(
+ "iree_linalg_transform.iree_set_num_workgroups_to_one");
+ }
+
+ Value target() { return nullptr; }
+
+ LogicalResult apply(linalg::transform::TransformResults &results,
+ linalg::transform::TransformState &state) {
+ auto variantOp = dyn_cast<HAL::ExecutableVariantOp>(state.getTopLevel());
+ if (!variantOp) return failure();
+ return iree_compiler::setNumWorkgroupsImpl(variantOp, {});
+ }
+
+ // let assemblyFormat = "attr-dict";
+ static ParseResult parse(OpAsmParser &parser, OperationState &state) {
+ parser.parseOptionalAttrDict(state.attributes);
+ return success();
+ }
+
+ // let assemblyFormat = "attr-dict";
+ void print(OpAsmPrinter &printer) {
+ printer.printOptionalAttrDict((*this)->getAttrs());
+ }
+};
+
+/// Test extension of the Transform dialect. Registers additional ops and
+/// declares PDL as dependent dialect since the additional ops are using PDL
+/// types for operands and results.
+class LinalgTransformDialectExtension
+ : public mlir::linalg::transform::TransformDialectExtension<
+ LinalgTransformDialectExtension> {
+ public:
+ LinalgTransformDialectExtension() {
+ declareDependentDialect<pdl::PDLDialect>();
+ registerTransformOp<IREEBufferizeOp>();
+ registerTransformOp<IREESetNumWorkgroupToOneOp>();
+ // TODO: hook up to Tablegen.
+ // registerTransformOps<
+ // #define GET_OP_LIST
+ // #include "LinalgTransformDialectExtension.cpp.inc"
+ // >();
+ }
+};
+
+// } // namespace anonymous
+
+void mlir::iree_compiler::registerLinalgTransformDialectExtension(
+ DialectRegistry ®istry) {
+ registry.addExtensions<LinalgTransformDialectExtension>();
+}
diff --git a/iree/compiler/Codegen/TransformDialectExtensions/TransformDialectExtensions.h b/iree/compiler/Codegen/TransformDialectExtensions/TransformDialectExtensions.h
new file mode 100644
index 0000000..ffec9ab
--- /dev/null
+++ b/iree/compiler/Codegen/TransformDialectExtensions/TransformDialectExtensions.h
@@ -0,0 +1,28 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_CODEGEN_COMMON_TRANSFORMDIALECTEXTENSIONS_H_
+#define IREE_COMPILER_CODEGEN_COMMON_TRANSFORMDIALECTEXTENSIONS_H_
+
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/DialectRegistry.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+/// Registers transformations that require IREE-specific information into the
+/// LinalgTransform dialect.
+void registerLinalgTransformDialectExtension(DialectRegistry ®istry);
+
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_CODEGEN_COMMON_TRANSFORMDIALECTEXTENSIONS_H_
diff --git a/iree/test/e2e/linalg_transform/linalg_transform_spec.mlir b/iree/test/e2e/linalg_transform/linalg_transform_spec.mlir
index c16da93..3f60f5e 100644
--- a/iree/test/e2e/linalg_transform/linalg_transform_spec.mlir
+++ b/iree/test/e2e/linalg_transform/linalg_transform_spec.mlir
@@ -10,16 +10,6 @@
iree_linalg_transform.sequence {
%0 = match @pdl_matmul_target
- // %res contains the tiled op and the linalg_ext.tile op.
- %tiling_1_result:2 = tile_to_iree_linalg_ext_tile_op %0 {sizes = [2]}
- %tiling_2_result:2 = tile_to_iree_linalg_ext_tile_op %tiling_1_result#0 {sizes = [0, 3]}
- %inp_2 = rewrite_iree_linalg_ext_tile_to_in_parallel %tiling_2_result#1
- %inp_1 = rewrite_iree_linalg_ext_tile_to_in_parallel %tiling_1_result#1
- // TODO: Ideally we would bufferize here but we can't atm.
- rewrite_iree_linalg_ext_in_parallel_to_hal %inp_2
- rewrite_iree_linalg_ext_in_parallel_to_hal %inp_1
- // Bufferize happens at the IREE level on HAL operations, we cannot just
- // call the linalg_transform.bufferize operation here.
- // Instead it happens automatically at the end of the linalg-transform-interp
- // pass.
+ iree_set_num_workgroups_to_one
+ iree_bufferize
}
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.h
index 7683ae0..1418abb 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.h
@@ -28,4 +28,66 @@
#define GET_OP_CLASSES
#include "iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.h.inc"
+namespace mlir {
+namespace linalg {
+namespace transform {
+
+/// Base class for extensions of the Transform dialect that supports injecting
+/// operations into the Transform dialect at load time. Concrete extensions
+/// are expected to derive this class and register operations in the
+/// constructor. They can be registered with the DialectRegistry and
+/// automatically applied to the Transform dialect when it is loaded.
+using TransformDialect = LinalgTransformDialect;
+template <typename DerivedTy, typename... ExtraDialects>
+class TransformDialectExtension
+ : public DialectExtension<DerivedTy, TransformDialect, ExtraDialects...> {
+ using Initializer = std::function<void(TransformDialect *)>;
+ using DialectLoader = std::function<void(MLIRContext *)>;
+
+public:
+ /// Extension application hook. Actually loads the dependent dialects and
+ /// registers the additional operations. Not expected to be called directly.
+ void apply(MLIRContext *context, TransformDialect *transformDialect,
+ ExtraDialects *...) const final {
+ for (const DialectLoader &loader : dialectLoaders)
+ loader(context);
+ for (const Initializer &init : opInitializers)
+ init(transformDialect);
+ }
+
+protected:
+ /// Injects the operation into the Transform dialect.
+ template <typename OpTy>
+ void registerTransformOp() {
+ opInitializers.push_back([](TransformDialect *transformDialect) {
+ RegisteredOperationName::insert<OpTy>(*transformDialect);
+ });
+ }
+
+ /// Injects the operations into the Transform dialect.
+ template <typename... OpTys>
+ void registerTransformOps() {
+ (void)std::initializer_list<int>{(registerTransformOp<OpTys>(), 0)...};
+ }
+
+ /// Declares that the Transform dialect depends on the dialect provided as
+ /// template parameter. When the Transform dialect is loaded, dependent
+ /// dialects will be loaded as well. This is intended for dialects that
+ /// contain attributes and types used in creation and canonicalization of
+ /// the injected operations.
+ template <typename DialectTy>
+ void declareDependentDialect() {
+ dialectLoaders.push_back(
+ [](MLIRContext *context) { context->loadDialect<DialectTy>(); });
+ }
+
+private:
+ SmallVector<Initializer> opInitializers;
+ SmallVector<DialectLoader> dialectLoaders;
+};
+
+} // namespace transform
+} // namespace linalg
+} // namespace mlir
+
#endif // MLIR_DIALECT_LINALG_IR_LINALGTRANSFORMOPS_H