Adopt TransformDialectExtension and add iree_bufferize + iree_set_num_workgroups_to_one transform ops (#8821)

* Adopt TransformDialectExtension

This revision adds support for dynamic registration of ops into the transform dialect (see https://reviews.llvm.org/D123135 for more context).
This PR also add ire_set_num_workgroups_to_one and iree_bufferize as transform dialect operations.

With this layering improvement we can now run all transformations as a transform and connect
to vector lowering all the way to LLVM (inclusive or not depending on the needs).

For registration, we piggy-back on Interfaces.cpp::registerCodegenInterfaces for now. This layering should be improved in the future.

The C++-specific parts of setnumthreads and bufferization can now be removed from the interp codegen path.
The --linalg-transform-interp-disable-bufferization is now obsolete and deleted.

* Extract the implementation of SetNumWorkgroup in a free function.

This allows reusing the implementation from the transform dialect and ditch the pass application.
Passes, nesting and non-Module root ops behave very surprisingly and may easily do nothing and not report anything.

As a consequence, we would hit bug #8823 and get flaky single-thread results ..
diff --git a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
index 1454782..6316cc6 100644
--- a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
+++ b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
@@ -13,6 +13,9 @@
     "//build_tools:dl": ["${CMAKE_DL_LIBS}"],
 
     # IREE llvm-external-projects
+    "//llvm-external-projects/iree-dialects:IREEDialectsTransforms": [
+        "IREEDialectsTransforms"
+    ],
     "//llvm-external-projects/iree-dialects:IREEInputDialect": [
         "IREEInputDialect"
     ],
diff --git a/iree/compiler/Codegen/Common/BUILD b/iree/compiler/Codegen/Common/BUILD
index 175581b..a3359f6 100644
--- a/iree/compiler/Codegen/Common/BUILD
+++ b/iree/compiler/Codegen/Common/BUILD
@@ -51,7 +51,6 @@
         "PolynomialApproximationPass.cpp",
         "RemoveTrivialLoops.cpp",
         "RewriteLinalgDestructiveUpdatesPass.cpp",
-        "SetNumWorkgroupsFromLinalgExtPass.cpp",
         "SetNumWorkgroupsPass.cpp",
         "TileAndDistributeToWorkgroupsPass.cpp",
         "TypePropagationPass.cpp",
@@ -61,6 +60,7 @@
     hdrs = [
         "BufferizationAnalysis.h",
         "DestructiveUpdateUtils.h",
+        "Transforms.h",
     ],
     deps = [
         "//iree/compiler/Codegen:PassHeaders",
diff --git a/iree/compiler/Codegen/Common/CMakeLists.txt b/iree/compiler/Codegen/Common/CMakeLists.txt
index dffea2a..af26185 100644
--- a/iree/compiler/Codegen/Common/CMakeLists.txt
+++ b/iree/compiler/Codegen/Common/CMakeLists.txt
@@ -25,6 +25,7 @@
   HDRS
     "BufferizationAnalysis.h"
     "DestructiveUpdateUtils.h"
+    "Transforms.h"
   SRCS
     "BufferizationAnalysis.cpp"
     "BufferizeCopyOnlyDispatchesPass.cpp"
@@ -43,7 +44,6 @@
     "PolynomialApproximationPass.cpp"
     "RemoveTrivialLoops.cpp"
     "RewriteLinalgDestructiveUpdatesPass.cpp"
-    "SetNumWorkgroupsFromLinalgExtPass.cpp"
     "SetNumWorkgroupsPass.cpp"
     "TileAndDistributeToWorkgroupsPass.cpp"
     "TypePropagationPass.cpp"
diff --git a/iree/compiler/Codegen/Common/SetNumWorkgroupsFromLinalgExtPass.cpp b/iree/compiler/Codegen/Common/SetNumWorkgroupsFromLinalgExtPass.cpp
deleted file mode 100644
index 796a1c4..0000000
--- a/iree/compiler/Codegen/Common/SetNumWorkgroupsFromLinalgExtPass.cpp
+++ /dev/null
@@ -1,190 +0,0 @@
-// Copyright 2021 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h"
-#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
-#include "iree-dialects/Dialect/LinalgExt/Transforms/Utils.h"
-#include "iree/compiler/Codegen/Dialect/LoweringConfig.h"
-#include "iree/compiler/Codegen/PassDetail.h"
-#include "iree/compiler/Codegen/Passes.h"
-#include "iree/compiler/Codegen/Transforms/Transforms.h"
-#include "iree/compiler/Codegen/Utils/Utils.h"
-#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
-#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
-#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
-#include "llvm/Support/Debug.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-
-#define DEBUG_TYPE "iree-codegen-set-num-workgroups-from-linalg-ext"
-
-using namespace mlir;
-using namespace mlir::iree_compiler::IREE::LinalgExt;
-
-namespace mlir {
-namespace iree_compiler {
-
-namespace {
-class SetNumWorkgroupsFromLinalgExtPass
-    : public SetNumWorkgroupsFromLinalgExtBase<
-          SetNumWorkgroupsFromLinalgExtPass> {
-  void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<AffineDialect, IREE::HAL::HALDialect, IREELinalgExtDialect,
-                    linalg::LinalgDialect>();
-  }
-  void runOnOperation() override;
-};
-
-int64_t getBuilderArgs(HALInterfaceWorkgroupIDOp op) {
-  return op.dimension().getZExtValue();
-}
-
-int64_t getBuilderArgs(HALInterfaceWorkgroupCountOp op) {
-  return op.dimension().getZExtValue();
-}
-
-ValueRange getBuilderArgs(HALReturnOp op) { return op.getOperands(); }
-
-/// Generic implementation of one-to-one conversion from "SourceOp" to
-/// "TargetOp".
-template <typename SourceOp, typename TargetOp>
-class OneToOneRewritePattern : public OpRewritePattern<SourceOp> {
- public:
-  using OpRewritePattern<SourceOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(SourceOp op,
-                                PatternRewriter &rewriter) const override {
-    rewriter.replaceOpWithNewOp<TargetOp>(op, getBuilderArgs(op));
-    return success();
-  }
-};
-
-/// Forward LinalgExt::InParallel -> Tensor::InsertSlice -> Flow::TensorStore.
-/// This pattern is necessary for correctness, it accounts for the fact that
-/// InParallel is distributed across multiple workgroups when lowering to HAL
-/// but it then connects to a sequential tensor.insert_slice and then to
-/// flow.dispatch.tensor_store.
-///
-// TODO: All the rewrites in this file this should be done as part of InParallel
-// -> HAL rewrite. But because of dialect dependencies and layering, we have
-// some phase ordering that prevents it atm.
-class ForwardInParallelResultToFlow
-    : public OpRewritePattern<IREE::Flow::DispatchTensorStoreOp> {
- public:
-  using OpRewritePattern<IREE::Flow::DispatchTensorStoreOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(IREE::Flow::DispatchTensorStoreOp op,
-                                PatternRewriter &rewriter) const override {
-    auto insertSliceOp = op.value().getDefiningOp<tensor::InsertSliceOp>();
-    if (!insertSliceOp) return failure();
-
-    // TODO: this should be done as part of InParallel -> HAL rewrite.
-    // But because of dialect dependencies and layering, we have some phase
-    // ordering that prevents it atm. It does not make sense to move the pattern
-    // because of this temporary layering problem, so we just ignore the
-    // condition for now.
-    //
-    // auto inParallelOp =
-    //     insertSliceOp.source().getDefiningOp<IREE::LinalgExt::InParallelOp>();
-    // if (!inParallelOp) return failure();
-
-    SmallVector<OpFoldResult> offsets, sizes, strides;
-    // `tensor.insert_slice` (i.e. the producer) folds **into**
-    // `flow.dispatch.tensor.store` (i.e. the consumer).
-    if (failed(foldOffsetsSizesAndStrides(rewriter, op.getLoc(), insertSliceOp,
-                                          op, offsets, sizes, strides)))
-      return failure();
-    rewriter.replaceOpWithNewOp<IREE::Flow::DispatchTensorStoreOp>(
-        op, insertSliceOp.source(), op.target(), op.target_dims(), offsets,
-        sizes, strides);
-
-    return success();
-  }
-};
-
-}  // namespace
-
-void SetNumWorkgroupsFromLinalgExtPass::runOnOperation() {
-  MLIRContext *context = &getContext();
-  IREE::HAL::ExecutableVariantOp variantOp = getOperation();
-  ModuleOp module = variantOp.getInnerModule();
-
-  // Perform 1-1 rewrites first: after the ExecutableEntryPointOp is
-  // modified this will be more annoying to track.
-  RewritePatternSet oneToOneRewrites(context);
-  oneToOneRewrites
-      .insert<OneToOneRewritePattern<HALInterfaceWorkgroupIDOp,
-                                     IREE::HAL::InterfaceWorkgroupIDOp>,
-              OneToOneRewritePattern<HALInterfaceWorkgroupCountOp,
-                                     IREE::HAL::InterfaceWorkgroupCountOp>,
-              OneToOneRewritePattern<HALReturnOp, IREE::HAL::ReturnOp>>(
-          context);
-  if (failed(applyPatternsAndFoldGreedily(module, std::move(oneToOneRewrites))))
-    return signalPassFailure();
-
-  // Perform forwarding patterns to bridge the tensor / flow gap.
-  // This is necessary for correctness.
-  // TODO: given existing bufferization tricks, this may trigger unnecessary
-  // copies that need to be further investigated.
-  RewritePatternSet forwardPatterns(context);
-  forwardPatterns.insert<ForwardInParallelResultToFlow>(context);
-  if (failed(applyPatternsAndFoldGreedily(module, std::move(forwardPatterns))))
-    return signalPassFailure();
-
-  llvm::StringMap<IREE::HAL::ExecutableEntryPointOp> entryPoints =
-      getAllEntryPoints(module);
-  for (auto funcOp : module.getOps<FuncOp>()) {
-    auto entryPointOp = entryPoints.lookup(funcOp.getName());
-    if (!entryPointOp) continue;
-
-    bool numWorkgroupIsSet = false;
-    assert(!entryPointOp.getWorkgroupCountBody() &&
-           "Expected a single entryPoint op with no workgroup_count body");
-
-    funcOp->walk([&](HALExecutableEntryPointOp op) {
-      assert(!numWorkgroupIsSet);
-      numWorkgroupIsSet = true;
-      IRRewriter rewriter(op->getContext());
-      rewriter.setInsertionPoint(entryPointOp);
-      auto clonedEntryPointOp =
-          rewriter.create<IREE::HAL::ExecutableEntryPointOp>(
-              entryPointOp.getLoc(), entryPointOp.sym_nameAttr(),
-              entryPointOp.ordinalAttr(), entryPointOp.layoutAttr(),
-              entryPointOp.workgroup_sizeAttr(),
-              entryPointOp.workgroup_local_memoryAttr());
-      Block &block = clonedEntryPointOp.workgroup_count_region().emplaceBlock();
-      rewriter.mergeBlocks(&op.workgroup_count_region().front(), &block);
-      // TODO: Don't add args post-hoc and instead replace them during
-      // `mergeBlocks`.
-      for (int64_t i = 0, e = HALExecutableEntryPointOp::getNumWorkgroupDims();
-           i < e; ++i) {
-        block.addArgument(rewriter.getIndexType(), op->getLoc());
-      }
-      op->erase();
-      entryPointOp.erase();
-    });
-  }
-
-  // Apply post-distribution canonicalization passes.
-  RewritePatternSet canonicalization(context);
-  AffineApplyOp::getCanonicalizationPatterns(canonicalization, context);
-  AffineMinOp::getCanonicalizationPatterns(canonicalization, context);
-  populateAffineMinSCFCanonicalizationPattern(canonicalization);
-  IREE::Flow::populateFlowDispatchCanonicalizationPatterns(canonicalization,
-                                                           context);
-  if (failed(
-          applyPatternsAndFoldGreedily(module, std::move(canonicalization)))) {
-    return signalPassFailure();
-  }
-}
-
-std::unique_ptr<OperationPass<IREE::HAL::ExecutableVariantOp>>
-createSetNumWorkgroupsFromLinalgExtPass() {
-  return std::make_unique<SetNumWorkgroupsFromLinalgExtPass>();
-}
-
-}  // namespace iree_compiler
-}  // namespace mlir
diff --git a/iree/compiler/Codegen/Common/SetNumWorkgroupsPass.cpp b/iree/compiler/Codegen/Common/SetNumWorkgroupsPass.cpp
index 2899e66..c89fc41 100644
--- a/iree/compiler/Codegen/Common/SetNumWorkgroupsPass.cpp
+++ b/iree/compiler/Codegen/Common/SetNumWorkgroupsPass.cpp
@@ -75,9 +75,9 @@
 };
 }  // namespace
 
-void SetNumWorkgroupsPass::runOnOperation() {
-  MLIRContext *context = &getContext();
-  IREE::HAL::ExecutableVariantOp variantOp = getOperation();
+LogicalResult setNumWorkgroupsImpl(IREE::HAL::ExecutableVariantOp variantOp,
+                                   ArrayRef<int64_t> workloadPerWorkgroup) {
+  MLIRContext *context = variantOp.getContext();
   ModuleOp module = variantOp.getInnerModule();
 
   llvm::StringMap<IREE::HAL::ExecutableEntryPointOp> entryPoints =
@@ -102,9 +102,8 @@
       RewritePatternSet patterns(funcOp.getContext());
       patterns.insert<SetWorkgroupSizePattern>(funcOp.getContext(),
                                                currWorkloadPerWorkgroup);
-      if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
-        return signalPassFailure();
-      }
+      if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns))))
+        return failure();
     }
 
     // The workgroup count region might already be set by op-specific
@@ -140,8 +139,7 @@
     }
 
     OpBuilder builder(context);
-    if (failed(defineWorkgroupCountRegion(builder, funcOp, regionBuilder)))
-      return signalPassFailure();
+    return defineWorkgroupCountRegion(builder, funcOp, regionBuilder);
   }
 
   // Apply post distribution canonicalization passes.
@@ -150,10 +148,13 @@
   populateAffineMinSCFCanonicalizationPattern(canonicalization);
   IREE::Flow::populateFlowDispatchCanonicalizationPatterns(canonicalization,
                                                            context);
-  if (failed(
-          applyPatternsAndFoldGreedily(module, std::move(canonicalization)))) {
-    return signalPassFailure();
-  }
+  return applyPatternsAndFoldGreedily(module, std::move(canonicalization));
+}
+
+void SetNumWorkgroupsPass::runOnOperation() {
+  IREE::HAL::ExecutableVariantOp variantOp = getOperation();
+  if (failed(setNumWorkgroupsImpl(variantOp, workloadPerWorkgroup)))
+    signalPassFailure();
 }
 
 std::unique_ptr<OperationPass<IREE::HAL::ExecutableVariantOp>>
diff --git a/iree/compiler/Codegen/Common/Transforms.h b/iree/compiler/Codegen/Common/Transforms.h
new file mode 100644
index 0000000..cc4276a
--- /dev/null
+++ b/iree/compiler/Codegen/Common/Transforms.h
@@ -0,0 +1,19 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+/// Expose the implementation of the set num workgroups pass as a free function
+/// because passes are surprisingly hard to apply reliably when they need to
+/// anchor on special (i.e. non-Module) ops.
+LogicalResult setNumWorkgroupsImpl(IREE::HAL::ExecutableVariantOp variantOp,
+                                   ArrayRef<int64_t> workloadPerWorkgroup);
+
+}  // namespace iree_compiler
+}  // namespace mlir
diff --git a/iree/compiler/Codegen/Interfaces/BUILD b/iree/compiler/Codegen/Interfaces/BUILD
index 3d168b6..7e26d02 100644
--- a/iree/compiler/Codegen/Interfaces/BUILD
+++ b/iree/compiler/Codegen/Interfaces/BUILD
@@ -41,6 +41,9 @@
     deps = [
         ":BufferizationInterfaces",
         ":ProcessorOpInterfaces",
+        # TODO: Remove this dependency once the transform dialect extensions
+        # have a better registration mechanism.
+        "//iree/compiler/Codegen/TransformDialectExtensions",
     ],
 )
 
diff --git a/iree/compiler/Codegen/Interfaces/CMakeLists.txt b/iree/compiler/Codegen/Interfaces/CMakeLists.txt
index fe86298..7235cf6 100644
--- a/iree/compiler/Codegen/Interfaces/CMakeLists.txt
+++ b/iree/compiler/Codegen/Interfaces/CMakeLists.txt
@@ -20,6 +20,7 @@
   DEPS
     ::BufferizationInterfaces
     ::ProcessorOpInterfaces
+    iree::compiler::Codegen::TransformDialectExtensions
   PUBLIC
 )
 
diff --git a/iree/compiler/Codegen/Interfaces/Interfaces.cpp b/iree/compiler/Codegen/Interfaces/Interfaces.cpp
index 64a1698..e682c69 100644
--- a/iree/compiler/Codegen/Interfaces/Interfaces.cpp
+++ b/iree/compiler/Codegen/Interfaces/Interfaces.cpp
@@ -8,6 +8,9 @@
 
 #include "iree/compiler/Codegen/Interfaces/BufferizationInterfaces.h"
 #include "iree/compiler/Codegen/Interfaces/ProcessorOpInterfaces.h"
+// TODO: Remove this dependency once the transform dialect extensions
+// have a better registration mechanism.
+#include "iree/compiler/Codegen/TransformDialectExtensions/TransformDialectExtensions.h"
 
 namespace mlir {
 namespace iree_compiler {
@@ -15,6 +18,10 @@
 void registerCodegenInterfaces(DialectRegistry &registry) {
   registerProcessorOpInterfaceExternalModels(registry);
   registerBufferizationInterfaces(registry);
+  // TODO: Remove this dependency once the transform dialect extensions
+  // have a better registration mechanism.
+  // TODO: when warranted, move to its own file.
+  registerLinalgTransformDialectExtension(registry);
 }
 
 }  // namespace iree_compiler
diff --git a/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/iree/compiler/Codegen/LLVMCPU/Passes.cpp
index 2973bac..6ab54ac 100644
--- a/iree/compiler/Codegen/LLVMCPU/Passes.cpp
+++ b/iree/compiler/Codegen/LLVMCPU/Passes.cpp
@@ -35,13 +35,6 @@
                    "before conversion to LLVM IR"),
     llvm::cl::init(false));
 
-// TODO: Remove this flag once we can call bufferize from the transform dialect.
-static llvm::cl::opt<bool> clDisableLinalgTransformInterpBufferization(
-    "linalg-transform-interp-disable-bufferization",
-    llvm::cl::desc("Disables bufferization when running the linalg transform "
-                   "interp pass (testing only)."),
-    llvm::cl::init(false));
-
 //===---------------------------------------------------------------------===//
 // Default allocation functions for CPU backend
 //===---------------------------------------------------------------------===//
@@ -426,28 +419,12 @@
 void addLinalgTransformInterpPasses(OpPassManager &passManager) {
   // Give control to the linalg_transform dialect.
   passManager.addPass(createLinalgTransformInterpreterPass());
+
   // Dropping the schedule is only needed if we want to embed the transform in
   // the module: we should drop the schedule once applied.
   // This pass does nothing in the case where we apply a separate policy
   // through a file.
   passManager.addPass(createDropSchedulePass());
-
-  // Sets the number of workgroups using kFakeHAL op information.
-  passManager.addPass(createSetNumWorkgroupsFromLinalgExtPass());
-
-  // TODO: Remove this flag and the code below once we can call bufferize from
-  // the transform dialect.
-  if (clDisableLinalgTransformInterpBufferization) return;
-
-  OpPassManager &modulePM = passManager.nest<ModuleOp>();
-  // Bufferize the dispatch.
-  BufferizationOptions::AllocationFn allocationFn =
-      cpuComprehensiveBufferizeAllocationFn;
-  BufferizationOptions::DeallocationFn deallocationFn =
-      cpuComprehensiveBufferizeDeallocationFn;
-  BufferizationOptions::MemCpyFn memcpyFn = cpuComprehensiveBufferizeCopyFn;
-  addIREEComprehensiveBufferizePasses(modulePM, allocationFn, deallocationFn,
-                                      memcpyFn);
 }
 
 static void addLowerToLLVMPasses(OpPassManager &passManager) {
diff --git a/iree/compiler/Codegen/LLVMCPU/test/BUILD b/iree/compiler/Codegen/LLVMCPU/test/BUILD
index 5fe380c..910be59 100644
--- a/iree/compiler/Codegen/LLVMCPU/test/BUILD
+++ b/iree/compiler/Codegen/LLVMCPU/test/BUILD
@@ -25,7 +25,6 @@
             "hal_interface_constants.mlir",
             "hal_interface_workgroup_info.mlir",
             "illegal_configuration.mlir",
-            "linalg_ext_hal_to_hal.mlir",
             "linalg_transform.mlir",
             "materialize_launch_configuration.mlir",
             "synchronize_symbol_visibility.mlir",
diff --git a/iree/compiler/Codegen/LLVMCPU/test/CMakeLists.txt b/iree/compiler/Codegen/LLVMCPU/test/CMakeLists.txt
index d65c024..6d7d25a 100644
--- a/iree/compiler/Codegen/LLVMCPU/test/CMakeLists.txt
+++ b/iree/compiler/Codegen/LLVMCPU/test/CMakeLists.txt
@@ -19,7 +19,6 @@
     "hal_interface_constants.mlir"
     "hal_interface_workgroup_info.mlir"
     "illegal_configuration.mlir"
-    "linalg_ext_hal_to_hal.mlir"
     "linalg_transform.mlir"
     "materialize_launch_configuration.mlir"
     "synchronize_symbol_visibility.mlir"
diff --git a/iree/compiler/Codegen/LLVMCPU/test/linalg_ext_hal_to_hal.mlir b/iree/compiler/Codegen/LLVMCPU/test/linalg_ext_hal_to_hal.mlir
deleted file mode 100644
index d0fe6ae..0000000
--- a/iree/compiler/Codegen/LLVMCPU/test/linalg_ext_hal_to_hal.mlir
+++ /dev/null
@@ -1,45 +0,0 @@
-// RUN: iree-opt %s  -pass-pipeline='hal.executable(hal.executable.variant(iree-llvmcpu-lower-executable-target))' -iree-set-num-workgroups-from-linalg-ext | FileCheck %s
-
-hal.executable @_matmul_static_dispatch_0 {
-hal.executable.variant public @embedded_elf_x86_64, target = <"llvm", "embedded-elf-x86_64", {cpu_features = "", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", native_vector_size = 16 : index, target_triple = "x86_64-unknown-unknown-eabi-elf"}> {
-  // CHECK: hal.executable.entry_point public @_matmul_static_dispatch_0 
-  // CHECK:   %[[C1:.*]] = arith.constant 1 : index
-  // CHECK:   %[[C3:.*]] = arith.constant 3 : index
-  // CHECK:   hal.return %[[C3]], %[[C1]], %[[C1]] : index, index, index
-  // CHECK: }
-  hal.executable.entry_point public @_matmul_static_dispatch_0 ordinal(0) layout(#hal.executable.layout<push_constants = 0, sets = [#hal.descriptor_set.layout<0, bindings = [#hal.descriptor_set.binding<0, storage_buffer>]>]>) {translation_info = #iree_codegen.translation_info<LinalgTransformInterpCodegen>}
-  builtin.module {
-    // CHECK: func @_matmul_static_dispatch_0
-    func @_matmul_static_dispatch_0() {
-      %cst = arith.constant dense<[[1.500000e+01, 1.400000e+01, 1.300000e+01], [1.200000e+01, 1.100000e+01, 1.000000e+01], [9.000000e+00, 8.000000e+00, 7.000000e+00], [6.000000e+00, 5.000000e+00, 4.000000e+00], [3.000000e+00, 2.000000e+00, 1.000000e+00]]> : tensor<5x3xf32>
-      %cst_0 = arith.constant dense<[[1.500000e+01, 1.400000e+01, 1.300000e+01, 1.200000e+01, 1.100000e+01], [1.000000e+01, 9.000000e+00, 8.000000e+00, 7.000000e+00, 6.000000e+00], [5.000000e+00, 4.000000e+00, 3.000000e+00, 2.000000e+00, 1.000000e+00]]> : tensor<3x5xf32>
-      %c0 = arith.constant 0 : index
-      %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readwrite:5x5xf32>
-      %1 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [5, 5], strides = [1, 1] : !flow.dispatch.tensor<readwrite:5x5xf32> -> tensor<5x5xf32>
-      
-      // The body of this op is folded into the top-level hal.executable.entry_point body.
-      // CHECK-NOT: iree_linalg_ext.hal.executable.entry_point
-      "iree_linalg_ext.hal.executable.entry_point"() ({
-        %c1 = arith.constant 1 : index
-        %c3 = arith.constant 3 : index
-        // CHECK-NOT: iree_linalg_ext.hal.return
-        iree_linalg_ext.hal.return %c3, %c1, %c1 : index, index, index
-      }) : () -> ()
-
-      // CHECK: = hal.interface.workgroup.id[0] : index
-      %2 = iree_linalg_ext.hal.interface.workgroup.id[0] : index
-      // Unused, just goes away.
-      // CHECK-NOT: = hal.interface.workgroup.id[0] : index
-      %3 = iree_linalg_ext.hal.interface.workgroup.count[0] : index
-      %4 = affine.apply affine_map<()[s0] -> (s0 * 2)>()[%2]
-      %5 = affine.min affine_map<()[s0] -> (s0 * -2 + 5, 2)>()[%2]
-      %6 = tensor.extract_slice %1[%4, 0] [%5, 5] [1, 1] : tensor<5x5xf32> to tensor<?x5xf32>
-      %7 = tensor.extract_slice %cst[%4, 0] [%5, 3] [1, 1] : tensor<5x3xf32> to tensor<?x3xf32>
-      %8 = linalg.matmul {iree_linalg_transform.matched} ins(%7, %cst_0 : tensor<?x3xf32>, tensor<3x5xf32>) outs(%6 : tensor<?x5xf32>) -> tensor<?x5xf32>
-      %9 = tensor.insert_slice %8 into %1[%4, 0] [%5, 5] [1, 1] : tensor<?x5xf32> into tensor<5x5xf32>
-      flow.dispatch.tensor.store %9, %0, offsets = [0, 0], sizes = [5, 5], strides = [1, 1] : tensor<5x5xf32> -> !flow.dispatch.tensor<readwrite:5x5xf32>
-      return
-    }
-  }
-}
-}
diff --git a/iree/compiler/Codegen/LLVMCPU/test/linalg_transform.mlir b/iree/compiler/Codegen/LLVMCPU/test/linalg_transform.mlir
index bcc69a5..0f6b716 100644
--- a/iree/compiler/Codegen/LLVMCPU/test/linalg_transform.mlir
+++ b/iree/compiler/Codegen/LLVMCPU/test/linalg_transform.mlir
@@ -1,14 +1,9 @@
-// RUN: iree-opt %s  -pass-pipeline='hal.executable(hal.executable.variant(iree-llvmcpu-lower-executable-target))' --iree-codegen-use-linalg-transform-interp --linalg-transform-interp-disable-bufferization --linalg-transform-file-name=%p/linalg_transform_spec.mlir | FileCheck %s
+// RUN: iree-opt %s  -pass-pipeline='hal.executable(hal.executable.variant(iree-llvmcpu-lower-executable-target))' --iree-codegen-use-linalg-transform-interp --linalg-transform-file-name=%p/linalg_transform_spec.mlir | FileCheck %s
 
 #device_target_cpu = #hal.device.target<"cpu", {executable_targets = [#hal.executable.target<"llvm", "embedded-elf-x86_64", {cpu_features = "", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", native_vector_size = 16 : index, target_triple = "x86_64-unknown-unknown-eabi-elf"}>]}>
 #executable_layout = #hal.executable.layout<push_constants = 0, sets = [#hal.descriptor_set.layout<0, bindings = [#hal.descriptor_set.binding<0, storage_buffer>, #hal.descriptor_set.binding<1, storage_buffer>, #hal.descriptor_set.binding<2, storage_buffer>]>]>
 #executable_target_embedded_elf_x86_64_ = #hal.executable.target<"llvm", "embedded-elf-x86_64", {cpu_features = "", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", native_vector_size = 16 : index, target_triple = "x86_64-unknown-unknown-eabi-elf"}>
 
-// CHECK-DAG: #[[$map0:.*]] = affine_map<()[s0] -> (s0 * 2)>
-// CHECK-DAG: #[[$map1:.*]] = affine_map<()[s0] -> (s0 * -2 + 250, 2)>
-// CHECK-DAG: #[[$map2:.*]] = affine_map<()[s0] -> (s0 * 4)>
-// CHECK-DAG: #[[$map3:.*]] = affine_map<()[s0] -> (s0 * -4 + 1020, 4)>
-
 hal.executable private @pad_matmul_static_dispatch_0 {
   hal.executable.variant public @embedded_elf_x86_64, target = #executable_target_embedded_elf_x86_64_ {
     hal.executable.entry_point public @pad_matmul_static_dispatch_0 ordinal(0) layout(#executable_layout)
@@ -21,29 +16,15 @@
         %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [250, 500], strides = [1, 1] : !flow.dispatch.tensor<readonly:250x500xf32> -> tensor<250x500xf32>
         %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [500, 1020], strides = [1, 1] : !flow.dispatch.tensor<readonly:500x1020xf32> -> tensor<500x1020xf32>
 
-        //     CHECK: hal.executable.entry_point public @pad_matmul_static_dispatch_0 ordinal(0) layout(#executable_layout) {
-        //     CHECK:   %[[C1:.*]] = arith.constant 1 : index
-        // CHECK-DAG: %[[C125:.*]] = arith.constant 125 : index
-        // CHECK-DAG: %[[C255:.*]] = arith.constant 255 : index
-        //     CHECK: hal.return %[[C125]], %[[C255]], %[[C1]] : index, index, index
         %50 = linalg.init_tensor [250, 1020] : tensor<250x1020xf32>
         %cst = arith.constant 0.000000e+00 : f32
         %5 = linalg.fill ins(%cst : f32) outs(%50 : tensor<250x1020xf32>) -> tensor<250x1020xf32>
 
-        //  CHECK-NOT: iree_linalg_ext
-        //      CHECK: %[[IDX:.*]] = hal.interface.workgroup.id[0] : index
-        //      CHECK: %[[OFFX:.*]] = affine.apply #[[$map0]]()[%[[IDX]]]
-        //      CHECK:  %[[SZX:.*]] = affine.min #[[$map1]]()[%[[IDX]]]
-        //      CHECK: tensor.extract_slice {{.*}}[%[[OFFX]]{{.*}}[%[[SZX]]
-        //      CHECK: %[[IDY:.*]] = hal.interface.workgroup.id[1] : index
-        //      CHECK: %[[OFFY:.*]] = affine.apply #[[$map2]]()[%[[IDY]]]
-        //      CHECK:  %[[SZY:.*]] = affine.min #[[$map3]]()[%[[IDY]]]
-        //      CHECK: tensor.extract_slice {{.*}}[{{.*}}, %[[OFFY]]] [{{.*}}, %[[SZY]]]
-        //      CHECK:  %[[MM:.*]] = linalg.matmul{{.*}}ins{{.*}}outs
-        //      CHECK: %[[RES_OFFX:.*]] = affine.apply #[[$map0]]()[%[[IDX]]]
-        //      CHECK: %[[RES_OFFY:.*]] = affine.apply #[[$map2]]()[%[[IDY]]]
-        //      CHECK: flow.dispatch.tensor.store %[[MM]], %{{.*}}, offsets = [%[[RES_OFFX]], %[[RES_OFFY]]]{{.*}} : tensor<?x?xf32> -> !flow.dispatch.tensor<readwrite:250x1020xf32>
-        //      CHECK: return
+        //      CHECK: memref.assume_alignment %{{.*}}, 64 : memref<250x1020xf32>
+        // CHECK-NEXT: linalg.fill ins(%{{.*}} : f32) outs(%{{.*}} : memref<250x1020xf32>)
+        // CHECK-NEXT: linalg.matmul{{.*}}ins(%{{.*}} : memref<250x500xf32>, memref<500x1020xf32>) outs(%{{.*}} : memref<250x1020xf32>)
+        // CHECK-NEXT: return
+
         %6 = linalg.matmul ins(%3, %4 : tensor<250x500xf32>, tensor<500x1020xf32>) outs(%5 : tensor<250x1020xf32>) -> tensor<250x1020xf32>
         flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [250, 1020], strides = [1, 1] : tensor<250x1020xf32> -> !flow.dispatch.tensor<readwrite:250x1020xf32>
         return
diff --git a/iree/compiler/Codegen/LLVMCPU/test/linalg_transform_spec.mlir b/iree/compiler/Codegen/LLVMCPU/test/linalg_transform_spec.mlir
index dd59910..e2c3023 100644
--- a/iree/compiler/Codegen/LLVMCPU/test/linalg_transform_spec.mlir
+++ b/iree/compiler/Codegen/LLVMCPU/test/linalg_transform_spec.mlir
@@ -9,17 +9,6 @@
 }
 
 iree_linalg_transform.sequence {
-  %0 = match @pdl_matmul_target
-  // %res contains the tiled op and the linalg_ext.tile op.
-  %tiling_1_result:2 = tile_to_iree_linalg_ext_tile_op %0 {sizes = [2]}
-  %tiling_2_result:2 = tile_to_iree_linalg_ext_tile_op %tiling_1_result#0 {sizes = [0, 4]}
-  %inp_2 = rewrite_iree_linalg_ext_tile_to_in_parallel %tiling_2_result#1
-  %inp_1 = rewrite_iree_linalg_ext_tile_to_in_parallel %tiling_1_result#1
-  // TODO: Ideally we would bufferize here but we can't atm.
-  rewrite_iree_linalg_ext_in_parallel_to_hal %inp_2
-  rewrite_iree_linalg_ext_in_parallel_to_hal %inp_1
-  // Bufferize happens at the IREE level on HAL operations, we cannot just 
-  // call the linalg_transform.bufferize operation here.
-  // Instead it happens automatically at the end of the linalg-transform-interp
-  // pass.
+  %0 = match @pdl_matmul_target 
+  iree_bufferize
 }
diff --git a/iree/compiler/Codegen/Passes.h b/iree/compiler/Codegen/Passes.h
index 5594cc3..fa25554 100644
--- a/iree/compiler/Codegen/Passes.h
+++ b/iree/compiler/Codegen/Passes.h
@@ -144,11 +144,6 @@
 std::unique_ptr<OperationPass<IREE::HAL::ExecutableVariantOp>>
 createSetNumWorkgroupsPass(ArrayRef<int64_t> workgroupSize = {});
 
-/// Propagates the number of workgroups to use for each entry point in the
-/// dispatch region.
-std::unique_ptr<OperationPass<IREE::HAL::ExecutableVariantOp>>
-createSetNumWorkgroupsFromLinalgExtPass();
-
 /// Pass to optimize vector transfer_read and transfer_write.
 std::unique_ptr<OperationPass<func::FuncOp>> createOptimizeVectorTransferPass();
 
diff --git a/iree/compiler/Codegen/Passes.td b/iree/compiler/Codegen/Passes.td
index c93fdb4..eab90a1 100644
--- a/iree/compiler/Codegen/Passes.td
+++ b/iree/compiler/Codegen/Passes.td
@@ -126,16 +126,6 @@
   let constructor = "mlir::iree_compiler::createSetNumWorkgroupsPass()";
 }
 
-// TODO: Consider removing or moving to HAL/Transforms in order to avoid
-// polluting common pass declarations with HAL specific ops.
-def SetNumWorkgroupsFromLinalgExt :
-    Pass<"iree-set-num-workgroups-from-linalg-ext",
-         "mlir::iree_compiler::IREE::HAL::ExecutableVariantOp"> {
-  let summary =
-      "Propagate the number of workgroups for entry point functions";
-  let constructor = "mlir::iree_compiler::createSetNumWorkgroupsFromLinalgExtPass()";
-}
-
 // TODO: Rename argument to be fully qualified.
 def LinalgToVectorVectorizeConv :
     Pass<"iree-codegen-vectorize-linalg-conv", "func::FuncOp"> {
diff --git a/iree/compiler/Codegen/TransformDialectExtensions/BUILD b/iree/compiler/Codegen/TransformDialectExtensions/BUILD
new file mode 100644
index 0000000..dad1ffa
--- /dev/null
+++ b/iree/compiler/Codegen/TransformDialectExtensions/BUILD
@@ -0,0 +1,48 @@
+# Copyright 2019 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+package(
+    default_visibility = ["//visibility:public"],
+    features = ["layering_check"],
+    licenses = ["notice"],  # Apache 2.0
+)
+
+cc_library(
+    name = "TransformDialectExtensions",
+    srcs = [
+        "TransformDialectExtensions.cpp",
+    ],
+    hdrs = [
+        "TransformDialectExtensions.h",
+    ],
+    deps = [
+        "//iree/compiler/Codegen",
+        "//iree/compiler/Codegen:PassHeaders",
+        "//iree/compiler/Codegen/Common",
+        "//iree/compiler/Codegen/Interfaces:BufferizationInterfaces",
+        "//iree/compiler/Codegen/Utils",
+        "//iree/compiler/Dialect/HAL/IR",
+        "//llvm-external-projects/iree-dialects:IREEDialectsTransforms",
+        "//llvm-external-projects/iree-dialects:IREELinalgExtDialect",
+        "//llvm-external-projects/iree-dialects:IREELinalgExtTransforms",
+        "//llvm-external-projects/iree-dialects:IREELinalgTransformDialect",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:Affine",
+        "@llvm-project//mlir:AffineUtils",
+        "@llvm-project//mlir:ArithmeticDialect",
+        "@llvm-project//mlir:BufferizationDialect",
+        "@llvm-project//mlir:BufferizationTransforms",
+        "@llvm-project//mlir:FuncDialect",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:LinalgInterfaces",
+        "@llvm-project//mlir:LinalgOps",
+        "@llvm-project//mlir:PDLDialect",
+        "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:SCFDialect",
+        "@llvm-project//mlir:TensorDialect",
+        "@llvm-project//mlir:TransformUtils",
+    ],
+)
diff --git a/iree/compiler/Codegen/TransformDialectExtensions/CMakeLists.txt b/iree/compiler/Codegen/TransformDialectExtensions/CMakeLists.txt
new file mode 100644
index 0000000..788fa1c
--- /dev/null
+++ b/iree/compiler/Codegen/TransformDialectExtensions/CMakeLists.txt
@@ -0,0 +1,48 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from           #
+# iree/compiler/Codegen/TransformDialectExtensions/BUILD                       #
+#                                                                              #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary   #
+# CMake-only content.                                                          #
+#                                                                              #
+# To disable autogeneration for this file entirely, delete this header.        #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_cc_library(
+  NAME
+    TransformDialectExtensions
+  HDRS
+    "TransformDialectExtensions.h"
+  SRCS
+    "TransformDialectExtensions.cpp"
+  DEPS
+    IREEDialectsTransforms
+    IREELinalgExtDialect
+    IREELinalgExtTransforms
+    IREELinalgTransformDialect
+    LLVMSupport
+    MLIRAffine
+    MLIRAffineUtils
+    MLIRArithmetic
+    MLIRBufferization
+    MLIRBufferizationTransforms
+    MLIRFunc
+    MLIRIR
+    MLIRLinalg
+    MLIRPDL
+    MLIRPass
+    MLIRSCF
+    MLIRTensor
+    MLIRTransformUtils
+    iree::compiler::Codegen
+    iree::compiler::Codegen::Common
+    iree::compiler::Codegen::Interfaces::BufferizationInterfaces
+    iree::compiler::Codegen::PassHeaders
+    iree::compiler::Codegen::Utils
+    iree::compiler::Dialect::HAL::IR
+  PUBLIC
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/iree/compiler/Codegen/TransformDialectExtensions/TransformDialectExtensions.cpp b/iree/compiler/Codegen/TransformDialectExtensions/TransformDialectExtensions.cpp
new file mode 100644
index 0000000..955b9fd
--- /dev/null
+++ b/iree/compiler/Codegen/TransformDialectExtensions/TransformDialectExtensions.cpp
@@ -0,0 +1,187 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "TransformDialectExtensions.h"
+
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h"
+#include "iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.h"
+#include "iree-dialects/Dialect/LinalgTransform/TransformOpInterface.h"
+#include "iree-dialects/Transforms/Functional.h"
+#include "iree/compiler/Codegen/Common/Transforms.h"
+#include "iree/compiler/Codegen/Interfaces/BufferizationInterfaces.h"
+#include "iree/compiler/Codegen/Passes.h"
+#include "iree/compiler/Codegen/Utils/Utils.h"
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/PDL/IR/PDL.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/Pass/PassManager.h"
+
+using namespace mlir;
+using namespace mlir::iree_compiler::IREE;
+
+//===---------------------------------------------------------------------===//
+// Default allocation functions for CPU backend
+// TODO: register the bufferization behavior in a target-specific way.
+//===---------------------------------------------------------------------===//
+
+// Default allocation function to use with IREEs bufferization.
+static Value cpuAllocationFunction(OpBuilder &builder, Location loc,
+                                   ArrayRef<int64_t> staticShape,
+                                   Type elementType,
+                                   ArrayRef<Value> dynamicSizes) {
+  MemRefType allocType = MemRefType::get(staticShape, elementType);
+  return builder.create<memref::AllocaOp>(loc, allocType, dynamicSizes);
+}
+
+// Allocation callbacks to use with upstream comprehensive bufferization
+static FailureOr<Value> cpuComprehensiveBufferizeAllocationFn(
+    OpBuilder &builder, Location loc, MemRefType memRefType,
+    ValueRange dynamicSizes, unsigned alignment) {
+  return builder
+      .create<memref::AllocaOp>(loc, memRefType, dynamicSizes,
+                                builder.getI64IntegerAttr(alignment))
+      .getResult();
+}
+
+static LogicalResult cpuComprehensiveBufferizeDeallocationFn(OpBuilder &builder,
+                                                             Location loc,
+                                                             Value allocation) {
+  return success();
+}
+
+static LogicalResult cpuComprehensiveBufferizeCopyFn(OpBuilder &builder,
+                                                     Location loc, Value from,
+                                                     Value to) {
+  // TODO: ideally we should use linalg.copy which was recently reintroduced as
+  // an OpDSL named op. However, IREE-specific patterns to cleanup spurious
+  // post-bufferization copies do not trigger properly.
+  // So we keep using `createLinalgCopyOp` which builds a GenericOp.
+  //   builder.create<linalg::CopyOp>(loc, from, to);
+  mlir::iree_compiler::createLinalgCopyOp(builder, loc, from, to);
+  return success();
+}
+
+//===---------------------------------------------------------------------===//
+// IREE-specific transformations defined outside of iree_linalg_transform.
+//===---------------------------------------------------------------------===//
+
+// Note: with the recent TypeID changes, hiding these classes inside an
+// anonymous namespace would require specific `MLIR_DECLARE_EXPLICIT_TYPE_ID`
+// for each class.
+
+// namespace {
+
+// TODO: Move to tablegen. Until this stabilizes upstream, simple C++ is enough.
+class IREEBufferizeOp
+    : public Op<IREEBufferizeOp,
+                linalg::transform::TransformOpInterface::Trait> {
+ public:
+  using Op::Op;
+
+  static ArrayRef<StringRef> getAttributeNames() { return {}; }
+
+  static constexpr llvm::StringLiteral getOperationName() {
+    return llvm::StringLiteral("iree_linalg_transform.iree_bufferize");
+  }
+
+  Value target() { return nullptr; }
+
+  LogicalResult apply(linalg::transform::TransformResults &results,
+                      linalg::transform::TransformState &state) {
+    PassManager pm(getContext());
+    // Bufferize the dispatch.
+    using mlir::bufferization::BufferizationOptions;
+    BufferizationOptions::AllocationFn allocationFn =
+        cpuComprehensiveBufferizeAllocationFn;
+    BufferizationOptions::DeallocationFn deallocationFn =
+        cpuComprehensiveBufferizeDeallocationFn;
+    BufferizationOptions::MemCpyFn memcpyFn = cpuComprehensiveBufferizeCopyFn;
+    mlir::iree_compiler::addIREEComprehensiveBufferizePasses(
+        pm, allocationFn, deallocationFn, memcpyFn);
+    WalkResult res = state.getTopLevel()->walk([&](ModuleOp moduleOp) {
+      if (failed(pm.run(moduleOp))) return WalkResult::interrupt();
+      return WalkResult::advance();
+    });
+    return failure(res.wasInterrupted());
+  }
+
+  // let assemblyFormat = "attr-dict";
+  static ParseResult parse(OpAsmParser &parser, OperationState &state) {
+    parser.parseOptionalAttrDict(state.attributes);
+    return success();
+  }
+
+  // let assemblyFormat = "attr-dict";
+  void print(OpAsmPrinter &printer) {
+    printer.printOptionalAttrDict((*this)->getAttrs());
+  }
+};
+
+// TODO: Move to tablegen. Until this stabilizes upstream, simple C++ is enough.
+class IREESetNumWorkgroupToOneOp
+    : public Op<IREESetNumWorkgroupToOneOp,
+                linalg::transform::TransformOpInterface::Trait> {
+ public:
+  using Op::Op;
+
+  static ArrayRef<StringRef> getAttributeNames() { return {}; }
+
+  static constexpr llvm::StringLiteral getOperationName() {
+    return llvm::StringLiteral(
+        "iree_linalg_transform.iree_set_num_workgroups_to_one");
+  }
+
+  Value target() { return nullptr; }
+
+  LogicalResult apply(linalg::transform::TransformResults &results,
+                      linalg::transform::TransformState &state) {
+    auto variantOp = dyn_cast<HAL::ExecutableVariantOp>(state.getTopLevel());
+    if (!variantOp) return failure();
+    return iree_compiler::setNumWorkgroupsImpl(variantOp, {});
+  }
+
+  // let assemblyFormat = "attr-dict";
+  static ParseResult parse(OpAsmParser &parser, OperationState &state) {
+    parser.parseOptionalAttrDict(state.attributes);
+    return success();
+  }
+
+  // let assemblyFormat = "attr-dict";
+  void print(OpAsmPrinter &printer) {
+    printer.printOptionalAttrDict((*this)->getAttrs());
+  }
+};
+
+/// Test extension of the Transform dialect. Registers additional ops and
+/// declares PDL as dependent dialect since the additional ops are using PDL
+/// types for operands and results.
+class LinalgTransformDialectExtension
+    : public mlir::linalg::transform::TransformDialectExtension<
+          LinalgTransformDialectExtension> {
+ public:
+  LinalgTransformDialectExtension() {
+    declareDependentDialect<pdl::PDLDialect>();
+    registerTransformOp<IREEBufferizeOp>();
+    registerTransformOp<IREESetNumWorkgroupToOneOp>();
+    // TODO: hook up to Tablegen.
+    //     registerTransformOps<
+    // #define GET_OP_LIST
+    // #include "LinalgTransformDialectExtension.cpp.inc"
+    //         >();
+  }
+};
+
+// } // namespace anonymous
+
+void mlir::iree_compiler::registerLinalgTransformDialectExtension(
+    DialectRegistry &registry) {
+  registry.addExtensions<LinalgTransformDialectExtension>();
+}
diff --git a/iree/compiler/Codegen/TransformDialectExtensions/TransformDialectExtensions.h b/iree/compiler/Codegen/TransformDialectExtensions/TransformDialectExtensions.h
new file mode 100644
index 0000000..ffec9ab
--- /dev/null
+++ b/iree/compiler/Codegen/TransformDialectExtensions/TransformDialectExtensions.h
@@ -0,0 +1,28 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_CODEGEN_COMMON_TRANSFORMDIALECTEXTENSIONS_H_
+#define IREE_COMPILER_CODEGEN_COMMON_TRANSFORMDIALECTEXTENSIONS_H_
+
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/DialectRegistry.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+/// Registers transformations that require IREE-specific information into the
+/// LinalgTransform dialect.
+void registerLinalgTransformDialectExtension(DialectRegistry &registry);
+
+}  // namespace iree_compiler
+}  // namespace mlir
+
+#endif  // IREE_COMPILER_CODEGEN_COMMON_TRANSFORMDIALECTEXTENSIONS_H_
diff --git a/iree/test/e2e/linalg_transform/linalg_transform_spec.mlir b/iree/test/e2e/linalg_transform/linalg_transform_spec.mlir
index c16da93..3f60f5e 100644
--- a/iree/test/e2e/linalg_transform/linalg_transform_spec.mlir
+++ b/iree/test/e2e/linalg_transform/linalg_transform_spec.mlir
@@ -10,16 +10,6 @@
 
 iree_linalg_transform.sequence {
   %0 = match @pdl_matmul_target
-  // %res contains the tiled op and the linalg_ext.tile op.
-  %tiling_1_result:2 = tile_to_iree_linalg_ext_tile_op %0 {sizes = [2]}
-  %tiling_2_result:2 = tile_to_iree_linalg_ext_tile_op %tiling_1_result#0 {sizes = [0, 3]}
-  %inp_2 = rewrite_iree_linalg_ext_tile_to_in_parallel %tiling_2_result#1
-  %inp_1 = rewrite_iree_linalg_ext_tile_to_in_parallel %tiling_1_result#1
-  // TODO: Ideally we would bufferize here but we can't atm.
-  rewrite_iree_linalg_ext_in_parallel_to_hal %inp_2
-  rewrite_iree_linalg_ext_in_parallel_to_hal %inp_1
-  // Bufferize happens at the IREE level on HAL operations, we cannot just 
-  // call the linalg_transform.bufferize operation here.
-  // Instead it happens automatically at the end of the linalg-transform-interp
-  // pass.
+  iree_set_num_workgroups_to_one
+  iree_bufferize
 }
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.h
index 7683ae0..1418abb 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.h
@@ -28,4 +28,66 @@
 #define GET_OP_CLASSES
 #include "iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.h.inc"
 
+namespace mlir {
+namespace linalg {
+namespace transform {
+
+/// Base class for extensions of the Transform dialect that supports injecting
+/// operations into the Transform dialect at load time. Concrete extensions
+/// are expected to derive this class and register operations in the
+/// constructor. They can be registered with the DialectRegistry and
+/// automatically applied to the Transform dialect when it is loaded.
+using TransformDialect = LinalgTransformDialect;
+template <typename DerivedTy, typename... ExtraDialects>
+class TransformDialectExtension
+    : public DialectExtension<DerivedTy, TransformDialect, ExtraDialects...> {
+  using Initializer = std::function<void(TransformDialect *)>;
+  using DialectLoader = std::function<void(MLIRContext *)>;
+
+public:
+  /// Extension application hook. Actually loads the dependent dialects and
+  /// registers the additional operations. Not expected to be called directly.
+  void apply(MLIRContext *context, TransformDialect *transformDialect,
+             ExtraDialects *...) const final {
+    for (const DialectLoader &loader : dialectLoaders)
+      loader(context);
+    for (const Initializer &init : opInitializers)
+      init(transformDialect);
+  }
+
+protected:
+  /// Injects the operation into the Transform dialect.
+  template <typename OpTy>
+  void registerTransformOp() {
+    opInitializers.push_back([](TransformDialect *transformDialect) {
+      RegisteredOperationName::insert<OpTy>(*transformDialect);
+    });
+  }
+
+  /// Injects the operations into the Transform dialect.
+  template <typename... OpTys>
+  void registerTransformOps() {
+    (void)std::initializer_list<int>{(registerTransformOp<OpTys>(), 0)...};
+  }
+
+  /// Declares that the Transform dialect depends on the dialect provided as
+  /// template parameter. When the Transform dialect is loaded, dependent
+  /// dialects will be loaded as well. This is intended for dialects that
+  /// contain attributes and types used in creation and canonicalization of
+  /// the injected operations.
+  template <typename DialectTy>
+  void declareDependentDialect() {
+    dialectLoaders.push_back(
+        [](MLIRContext *context) { context->loadDialect<DialectTy>(); });
+  }
+
+private:
+  SmallVector<Initializer> opInitializers;
+  SmallVector<DialectLoader> dialectLoaders;
+};
+
+} // namespace transform
+} // namespace linalg
+} // namespace mlir
+
 #endif // MLIR_DIALECT_LINALG_IR_LINALGTRANSFORMOPS_H