Revert #3585 (#3632)
This breaks LinalgToSpirv tests internally in a non-trivial way.
Reverts "Introduce tile and multi-threads distribution of linalg ops
for LLVMIR backends" (https://github.com/google/iree/pull/3585)
diff --git a/iree/compiler/Conversion/CodegenUtils/BUILD b/iree/compiler/Conversion/CodegenUtils/BUILD
index 4b58a54..1c346f6 100644
--- a/iree/compiler/Conversion/CodegenUtils/BUILD
+++ b/iree/compiler/Conversion/CodegenUtils/BUILD
@@ -25,22 +25,14 @@
srcs = [
"ForOpCanonicalization.cpp",
"FunctionUtils.cpp",
- "GetNumWorkgroups.cpp",
- "MarkerUtils.cpp",
"MatmulCodegenStrategy.cpp",
],
hdrs = [
"ForOpCanonicalization.h",
"FunctionUtils.h",
- "GetNumWorkgroups.h",
- "MarkerUtils.h",
"MatmulCodegenStrategy.h",
],
deps = [
- "//iree/compiler/Dialect/HAL/IR",
- "//iree/compiler/Dialect/HAL/IR:HALDialect",
- "//iree/compiler/Dialect/HAL/Utils",
- "//iree/compiler/Dialect/Shape/IR",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Affine",
"@llvm-project//mlir:Analysis",
diff --git a/iree/compiler/Conversion/CodegenUtils/CMakeLists.txt b/iree/compiler/Conversion/CodegenUtils/CMakeLists.txt
index bf8ddd7..35c2030 100644
--- a/iree/compiler/Conversion/CodegenUtils/CMakeLists.txt
+++ b/iree/compiler/Conversion/CodegenUtils/CMakeLists.txt
@@ -20,14 +20,10 @@
HDRS
"ForOpCanonicalization.h"
"FunctionUtils.h"
- "GetNumWorkgroups.h"
- "MarkerUtils.h"
"MatmulCodegenStrategy.h"
SRCS
"ForOpCanonicalization.cpp"
"FunctionUtils.cpp"
- "GetNumWorkgroups.cpp"
- "MarkerUtils.cpp"
"MatmulCodegenStrategy.cpp"
DEPS
LLVMSupport
@@ -43,9 +39,5 @@
MLIRTransforms
MLIRVector
MLIRVectorToSCF
- iree::compiler::Dialect::HAL::IR
- iree::compiler::Dialect::HAL::IR::HALDialect
- iree::compiler::Dialect::HAL::Utils
- iree::compiler::Dialect::Shape::IR
PUBLIC
)
diff --git a/iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.cpp b/iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.cpp
deleted file mode 100644
index b365b72..0000000
--- a/iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.cpp
+++ /dev/null
@@ -1,210 +0,0 @@
-#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
-#include "iree/compiler/Dialect/HAL/Utils/TypeUtils.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
-#include "llvm/ADT/SetVector.h"
-#include "llvm/Support/Debug.h"
-#include "mlir/Analysis/SliceAnalysis.h"
-#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
-#include "mlir/Dialect/Linalg/Utils/Utils.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/IR/Function.h"
-#include "mlir/IR/Module.h"
-#include "mlir/IR/PatternMatch.h"
-
-#define DEBUG_TYPE "workgroup-calculation"
-
-namespace mlir {
-namespace iree_compiler {
-namespace utils {
-FuncOp getNumWorkgroupsFn(FuncOp entryPointFn,
- llvm::StringRef numWorkgroupsFnAttr) {
- SymbolRefAttr attr =
- entryPointFn.getAttrOfType<SymbolRefAttr>(numWorkgroupsFnAttr);
- if (!attr) {
- entryPointFn.emitError("missing attribute '") << numWorkgroupsFnAttr << "'";
- return nullptr;
- }
- FuncOp numWorkgroupsFn = dyn_cast_or_null<FuncOp>(SymbolTable::lookupSymbolIn(
- entryPointFn.getParentOfType<ModuleOp>(), attr));
- if (!numWorkgroupsFn) {
- entryPointFn.emitError("unable to find num workgroups fn ") << attr;
- return nullptr;
- }
- return numWorkgroupsFn;
-}
-
-/// Computes the bounds of the parallel loops partitioned across workgroups.
-static Optional<SmallVector<Value, 2>> getParallelLoopRange(
- PatternRewriter &rewriter, FuncOp numWorkgroupsFn, Location loc,
- linalg::LinalgOp linalgOp) {
- if (!numWorkgroupsFn.empty()) {
- numWorkgroupsFn.emitError("num workgroups fn expected to be empty");
- return {};
- }
- LLVM_DEBUG({
- llvm::dbgs() << "Found num workgroups function : "
- << numWorkgroupsFn.getName();
- });
-
- rewriter.setInsertionPointToEnd(numWorkgroupsFn.addEntryBlock());
- llvm::SetVector<Operation *> slice;
- getBackwardSlice(linalgOp, &slice);
- BlockAndValueMapping mapper;
- for (Operation *op : slice) {
- rewriter.clone(*op, mapper);
- }
- // Clone the linalg operation just to compute the loop bounds.
- linalg::LinalgOp clonedLinalgOp =
- rewriter.clone(*linalgOp.getOperation(), mapper);
- Optional<SmallVector<Value, 4>> bounds =
- getLoopRanges(rewriter, clonedLinalgOp);
- unsigned numParallelLoops = linalgOp.iterator_types()
- .getValue()
- .take_while([](Attribute attr) -> bool {
- return attr.cast<StringAttr>().getValue() ==
- getParallelIteratorTypeName();
- })
- .size();
- SmallVector<Value, 2> returnVals(
- bounds->begin(), std::next(bounds->begin(), numParallelLoops));
- rewriter.eraseOp(clonedLinalgOp);
- return returnVals;
-}
-
-/// Utility method to build IR that computes ceil(`numerator` / `denominator`)
-static Value buildCeilDiv(PatternRewriter &rewriter, Location loc,
- Value numerator, Value denominator) {
- Value one = rewriter.create<ConstantIndexOp>(loc, 1);
- Value t = rewriter.create<AddIOp>(
- loc, numerator, rewriter.create<SubIOp>(loc, denominator, one));
- return rewriter.create<SignedDivIOp>(loc, t, denominator);
-}
-
-/// Utility method to build IR that computes ceil(`numerator` / `denominator`)
-/// when denominator is a constant.
-static Value buildCeilDivConstDenominator(PatternRewriter &rewriter,
- Location loc, Value numerator,
- int64_t denominator) {
- return buildCeilDiv(rewriter, loc, numerator,
- rewriter.create<ConstantIndexOp>(loc, denominator));
-}
-
-LogicalResult createNumWorkgroupsFromResultShape(
- PatternRewriter &rewriter, linalg::LinalgOp linalgOp, FuncOp entryPointFn,
- llvm::StringRef numWorkgroupsFnAttr, ArrayRef<int64_t> tileSizes) {
- FuncOp numWorkgroupsFn = getNumWorkgroupsFn(
- linalgOp.getParentOfType<FuncOp>(), numWorkgroupsFnAttr);
- if (!numWorkgroupsFn) return failure();
-
- Location loc = linalgOp.getLoc();
- OpBuilder::InsertionGuard gaurd(rewriter);
- Optional<SmallVector<Value, 2>> parallelLoopRange =
- getParallelLoopRange(rewriter, numWorkgroupsFn, loc, linalgOp);
- if (!parallelLoopRange) return failure();
- Value one = rewriter.create<ConstantIndexOp>(loc, 1);
- SmallVector<Value, 3> returnValues(3, one);
- for (size_t i = 0, e = std::min<size_t>(parallelLoopRange->size(), 3); i != e;
- ++i) {
- if (tileSizes[e - i - 1] != 0) {
- returnValues[i] = buildCeilDivConstDenominator(
- rewriter, loc, (*parallelLoopRange)[e - i - 1], tileSizes[e - i - 1]);
- }
- }
- rewriter.create<mlir::ReturnOp>(loc, returnValues);
- return success();
-}
-
-LogicalResult createNumWorkgroupsFromLinearizedResultShape(
- PatternRewriter &rewriter, linalg::LinalgOp linalgOp, FuncOp entryPointFn,
- llvm::StringRef numWorkgroupsFnAttr, int64_t workgroupSizeX) {
- FuncOp numWorkgroupsFn = getNumWorkgroupsFn(
- linalgOp.getParentOfType<FuncOp>(), numWorkgroupsFnAttr);
- if (!numWorkgroupsFn) return failure();
- if (!numWorkgroupsFn.empty()) {
- // TODO(ravishankarm): We can end up with multiple linalg operations
- // (typically linalg.generic operations) that have the same workload in a
- // dispatch region. In that case, the first linalg.generic creates the body
- // of number of workgroups. For now, just returning if the body is not empty
- // assuming that it is correct for all the ops in the dispatch region. This
- // needs to be enforced somehow.
- return success();
- }
-
- Location loc = linalgOp.getLoc();
- OpBuilder::InsertionGuard gaurd(rewriter);
- Optional<SmallVector<Value, 2>> parallelLoopRange =
- getParallelLoopRange(rewriter, numWorkgroupsFn, loc, linalgOp);
- if (!parallelLoopRange) return failure();
- Value one = rewriter.create<ConstantIndexOp>(loc, 1);
- SmallVector<Value, 3> returnValues(3, one);
- for (auto range : *parallelLoopRange) {
- returnValues[0] = rewriter.create<MulIOp>(loc, range, returnValues[0]);
- }
- returnValues[0] = buildCeilDivConstDenominator(rewriter, loc, returnValues[0],
- workgroupSizeX);
- rewriter.create<mlir::ReturnOp>(loc, returnValues);
- return success();
-}
-
-/// The codegeneration emits a function `numWorkgroupsFn` for each entry point
-/// function. This function has arguments the !shapex.ranked_shape for all the
-/// input and output shaped types. Using this the function returns the number of
-/// workgroups to use. To use this function on the host side, generate the
-/// !shapex.ranked_shape values that describe the shape of the inputs and
-/// outputs of the dispatch region and "inline" the function body.
-std::array<Value, 3> calculateWorkgroupCountFromNumWorkgroupsFn(
- Location loc, FuncOp numWorkgroupsFn, IREE::HAL::InterfaceOp interface,
- ArrayRef<Optional<IREE::HAL::TensorRewriteAdaptor>> operands,
- ArrayRef<Optional<IREE::HAL::TensorRewriteAdaptor>> results,
- ConversionPatternRewriter &rewriter) {
- std::array<Value, 3> returnValue = {nullptr, nullptr, nullptr};
- // TODO: This is really just inlining a function. For now assume that the
- // `numWorkgroupsFn` has a single block to make inlining easier.
- if (!numWorkgroupsFn || !llvm::hasSingleElement(numWorkgroupsFn))
- return returnValue;
- SmallVector<SmallVector<Value, 4>, 4> shapeValues;
- shapeValues.reserve(operands.size() + results.size());
- auto getShapeValuesFn =
- [&](ArrayRef<Optional<IREE::HAL::TensorRewriteAdaptor>> values)
- -> LogicalResult {
- for (auto val : values) {
- if (!val) continue;
- Optional<SmallVector<Value, 4>> shape = val->getShapeDims(rewriter);
- if (!shape) return emitError(loc, "shape computation for operand failed");
- shapeValues.push_back(shape.getValue());
- }
- return success();
- };
- if (failed(getShapeValuesFn(operands)) || failed(getShapeValuesFn(results)))
- return returnValue;
- BlockAndValueMapping mapper;
- for (Operation &op : numWorkgroupsFn.front()) {
- if (isa<mlir::ReturnOp>(op)) {
- for (unsigned i = 0, e = std::min<unsigned>(3, op.getNumOperands());
- i != e; ++i) {
- returnValue[i] = mapper.lookupOrNull(op.getOperand(i));
- }
- break;
- }
- if (auto shapeOp = dyn_cast<Shape::RankedDimOp>(op)) {
- if (BlockArgument arg = shapeOp.shape().dyn_cast<BlockArgument>()) {
- auto &dimValues = shapeValues[arg.getArgNumber()];
- mapper.map(shapeOp.result(), dimValues[shapeOp.getIndex()]);
- continue;
- }
- return returnValue;
- }
- // If all its operands are mapped, clone it.
- if (llvm::all_of(op.getOperands(), [&mapper](Value operand) {
- return mapper.contains(operand);
- })) {
- rewriter.clone(op, mapper);
- continue;
- }
- }
- return returnValue;
-}
-
-} // namespace utils
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.h b/iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.h
deleted file mode 100644
index 4ae8248..0000000
--- a/iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.h
+++ /dev/null
@@ -1,98 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef MLIR_EDGE_BENCHMARKS_STRATEGIES_WORKGROUPCALULCATION_H_
-#define MLIR_EDGE_BENCHMARKS_STRATEGIES_WORKGROUPCALULCATION_H_
-
-#include <cstdint>
-
-namespace llvm {
-class StringRef;
-template <typename T>
-class ArrayRef;
-template <typename T>
-class Optional;
-} // namespace llvm
-
-namespace mlir {
-class Location;
-class FuncOp;
-class LogicalResult;
-class PatternRewriter;
-class ConversionPatternRewriter;
-class Value;
-namespace linalg {
-class LinalgOp;
-} // namespace linalg
-
-namespace iree_compiler {
-namespace IREE {
-namespace HAL {
-class InterfaceOp;
-class TensorRewriteAdaptor;
-} // namespace HAL
-} // namespace IREE
-} // namespace iree_compiler
-namespace iree_compiler {
-namespace utils {
-/// Generates a function that computes the number of workgroups as
-/// [ceil(`parallelLoopRange`[2] / `tileSizes`[2]),
-/// ceil(`parallelLoopRange`[1] / `tileSizes`[1]),
-/// ceil(`parallelLoopRange`[0] / `tileSizes`[0])]
-/// where `parallelLoopRange` is the ranges of the parallel loops of `linalgOp`
-/// distributed across workgroups.
-LogicalResult createNumWorkgroupsFromResultShape(
- PatternRewriter &rewriter, linalg::LinalgOp linalgOp, FuncOp entryPointFn,
- llvm::StringRef numWorkgroupsFnAttr, llvm::ArrayRef<int64_t> tileSizes);
-
-/// Generates a function that computes the number of workgroups as
-/// ceil(`parallelLoopRange`[0] * `parallelLoopRange`[1] * ... *
-/// `parallelLoopRange`[n-1] / `workgroupSizeX`)
-/// where `parallelLoopRange` is the ranges of the parallel loops of `linalgOp`
-/// distributed across workgroups.
-LogicalResult createNumWorkgroupsFromLinearizedResultShape(
- PatternRewriter &rewriter, linalg::LinalgOp linalgOp, FuncOp entryPointFn,
- llvm::StringRef numWorkgroupsFnAttr, int64_t workgroupSizeX);
-
-/// For a given `entryPointFn` return the function that computes the number of
-/// workgroups to use at launch time.
-FuncOp getNumWorkgroupsFn(FuncOp entryPointFn,
- llvm::StringRef numWorkgroupsFnAttr);
-
-LogicalResult createNumWorkgroupsFromLinearizedResultShape(
- PatternRewriter &rewriter, linalg::LinalgOp linalgOp, FuncOp entryPointFn,
- llvm::StringRef numWorkgroupsFnAttr, int64_t workgroupSizeX);
-
-/// The codegeneration emits a function `numWorkgroupsFn` for each entry point
-/// function. This function has arguments the !shapex.ranked_shape for all the
-/// input and output shaped types. Using this the function returns the number of
-/// workgroups to use. To use this function on the host side, generate the
-/// !shapex.ranked_shape values that describe the shape of the inputs and
-/// outputs of the dispatch region and "inline" the function body.
-std::array<Value, 3> calculateWorkgroupCountFromNumWorkgroupsFn(
- Location loc, FuncOp numWorkgroupsFn,
- mlir::iree_compiler::IREE::HAL::InterfaceOp interface,
- llvm::ArrayRef<
- llvm::Optional<mlir::iree_compiler::IREE::HAL::TensorRewriteAdaptor>>
- operands,
- llvm::ArrayRef<
- llvm::Optional<mlir::iree_compiler::IREE::HAL::TensorRewriteAdaptor>>
- results,
- ConversionPatternRewriter &rewriter);
-
-} // namespace utils
-} // namespace iree_compiler
-} // namespace mlir
-
-#endif // MLIR_EDGE_BENCHMARKS_STRATEGIES_WORKGROUPCALULCATION_H_
diff --git a/iree/compiler/Conversion/Common/BUILD b/iree/compiler/Conversion/Common/BUILD
deleted file mode 100644
index 7c9e34c..0000000
--- a/iree/compiler/Conversion/Common/BUILD
+++ /dev/null
@@ -1,41 +0,0 @@
-# Copyright 2019 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-package(
- default_visibility = ["//visibility:public"],
- features = ["layering_check"],
- licenses = ["notice"], # Apache 2.0
-)
-
-cc_library(
- name = "Common",
- srcs = [
- "DeclareNumWorkgroupsFnPass.cpp",
- "LegalizeNumWorkgroupsFnPass.cpp",
- ],
- hdrs = [
- "Passes.h",
- ],
- deps = [
- "//iree/compiler/Conversion/CodegenUtils",
- "//iree/compiler/Dialect/HAL/IR",
- "//iree/compiler/Dialect/IREE/IR",
- "//iree/compiler/Dialect/Shape/IR",
- "@llvm-project//mlir:CFGTransforms",
- "@llvm-project//mlir:IR",
- "@llvm-project//mlir:Pass",
- "@llvm-project//mlir:StandardOps",
- "@org_tensorflow//tensorflow/compiler/mlir/hlo",
- ],
-)
diff --git a/iree/compiler/Conversion/Common/CMakeLists.txt b/iree/compiler/Conversion/Common/CMakeLists.txt
deleted file mode 100644
index 94c7b88..0000000
--- a/iree/compiler/Conversion/Common/CMakeLists.txt
+++ /dev/null
@@ -1,36 +0,0 @@
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-iree_add_all_subdirs()
-
-iree_cc_library(
- NAME
- Common
- HDRS
- "Passes.h"
- SRCS
- "DeclareNumWorkgroupsFnPass.cpp"
- "LegalizeNumWorkgroupsFnPass.cpp"
- DEPS
- MLIRIR
- MLIRPass
- MLIRSCFToStandard
- MLIRStandard
- iree::compiler::Conversion::CodegenUtils
- iree::compiler::Dialect::HAL::IR
- iree::compiler::Dialect::IREE::IR
- iree::compiler::Dialect::Shape::IR
- tensorflow::mlir_hlo
- PUBLIC
-)
diff --git a/iree/compiler/Conversion/Common/DeclareNumWorkgroupsFnPass.cpp b/iree/compiler/Conversion/Common/DeclareNumWorkgroupsFnPass.cpp
deleted file mode 100644
index 7896724..0000000
--- a/iree/compiler/Conversion/Common/DeclareNumWorkgroupsFnPass.cpp
+++ /dev/null
@@ -1,159 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-//
-//===- DeclareNumWorkgroupsFnPass.cpp - Declares num_workgroups_fn --------===//
-//
-// Define the function that computes the number of workgroups for every entry
-// point function. This pass only defines the function. Its body will be filled
-// in later.
-//
-//===----------------------------------------------------------------------===//
-#include "iree/compiler/Conversion/CodegenUtils/FunctionUtils.h"
-#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/IR/Function.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/StandardTypes.h"
-#include "mlir/Pass/Pass.h"
-
-namespace mlir {
-namespace iree_compiler {
-namespace common {
-
-static constexpr const char kNumWorkgroupsStr[] = "__num_workgroups__";
-
-namespace {
-
-/// The contract between the host and the device is captured by the _impl
-/// function that is called from the main entry point function. This pattern
-/// looks for the call operation and
-/// - Declares (doesnt define) the function that computes the number of
-/// workgroups to use for this entry point function. It is defined later in
-/// the codegen pipeline, when the computation is mapped to
-/// workgroups/workitems. The signature of this function is
-///
-/// (!shapex.ranked_shape, !shapex.ranked_shape, ....) ->
-/// (index, index, index)
-///
-/// where the arguments are the shape of the tensor inputs + outputs of the
-/// dispatch region.
-/// - Sets the attribute `operand_result_index` on the
-/// `hal.interface.load.tensor`/`hal.interface.store.tensor` ops that are
-/// later used in the generation of the function declared here.
-struct DeclareNumWorkgroupsFn : OpRewritePattern<FuncOp> {
- DeclareNumWorkgroupsFn(MLIRContext *context,
- llvm::StringRef numWorkgroupsFnAttrName,
- PatternBenefit benefit = 1)
- : OpRewritePattern(context, benefit),
- numWorkgroupsFnAttrName(numWorkgroupsFnAttrName) {}
- LogicalResult matchAndRewrite(FuncOp entryPointFn,
- PatternRewriter &rewriter) const override {
- if (!isEntryPoint(entryPointFn) ||
- entryPointFn.getAttr(numWorkgroupsFnAttrName))
- return failure();
- Region &body = entryPointFn.getBody();
- if (!llvm::hasSingleElement(body)) {
- return entryPointFn.emitError(
- "unhandled dispatch function with multiple blocks");
- }
- auto callOps = body.front().getOps<CallOp>();
- if (!llvm::hasSingleElement(callOps)) {
- return entryPointFn.emitError(
- "expected dispatch function to have a single call operation");
- }
- CallOp callOp = *callOps.begin();
-
- SmallVector<ShapedType, 4> shapedTypes;
- shapedTypes.reserve(callOp.getNumOperands() - 1 + callOp.getNumResults());
-
- // Add `operand_result_index` attribute to `hal.interface.load.tensor`
- // operations that define the operands of the call op.
- for (Value operand : callOp.operands()) {
- if (!operand.getType().isa<ShapedType>()) continue;
- if (auto definingOp =
- operand.getDefiningOp<IREE::HAL::InterfaceLoadTensorOp>()) {
- definingOp.setAttr(getOperandResultNumAttrName(),
- rewriter.getI32IntegerAttr(shapedTypes.size()));
- }
- shapedTypes.push_back(operand.getType().cast<ShapedType>());
- }
-
- // Add `operand_result_index` attribute to the `hal.interface.store.tensor`
- // that use the value returned by the call op.
- for (Value result : callOp.getResults()) {
- if (!result.getType().isa<ShapedType>()) continue;
- for (auto &use : result.getUses()) {
- if (auto storeOp =
- dyn_cast<IREE::HAL::InterfaceStoreTensorOp>(use.getOwner())) {
- storeOp.setAttr(getOperandResultNumAttrName(),
- rewriter.getI32IntegerAttr(shapedTypes.size()));
- }
- }
- shapedTypes.push_back(result.getType().cast<ShapedType>());
- }
-
- IndexType indexType = rewriter.getIndexType();
- SmallVector<Type, 4> argTypes = llvm::to_vector<4>(
- llvm::map_range(shapedTypes, [&rewriter](ShapedType t) -> Type {
- return Shape::RankedShapeType::get(t.getShape(),
- rewriter.getContext());
- }));
- FuncOp numWorkgroupsFn = rewriter.create<FuncOp>(
- entryPointFn.getLoc(), entryPointFn.getName().str() + kNumWorkgroupsStr,
- rewriter.getFunctionType(argTypes, {indexType, indexType, indexType}));
- numWorkgroupsFn.setVisibility(FuncOp::Visibility::Private);
- entryPointFn.setAttr(numWorkgroupsFnAttrName,
- rewriter.getSymbolRefAttr(numWorkgroupsFn));
- rewriter.updateRootInPlace(entryPointFn, []() {});
- return success();
- }
-
- private:
- llvm::StringRef numWorkgroupsFnAttrName;
-};
-
-/// Pass to define the function for number of workgroups for every entry point
-/// function.
-struct DeclareNumWorkgroupsFnPass
- : public PassWrapper<DeclareNumWorkgroupsFnPass, OperationPass<ModuleOp>> {
- DeclareNumWorkgroupsFnPass(llvm::StringRef numWorkgroupsFnAttrName)
- : numWorkgroupsFnAttrName(numWorkgroupsFnAttrName) {}
- DeclareNumWorkgroupsFnPass(const DeclareNumWorkgroupsFnPass &pass) {}
- void runOnOperation() override;
-
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<ShapeDialect>();
- }
-
- llvm::StringRef numWorkgroupsFnAttrName;
-};
-} // namespace
-
-void DeclareNumWorkgroupsFnPass::runOnOperation() {
- OwningRewritePatternList patterns;
- MLIRContext *context = &getContext();
- patterns.insert<DeclareNumWorkgroupsFn>(context, numWorkgroupsFnAttrName);
- applyPatternsAndFoldGreedily(getOperation(), patterns);
-}
-
-std::unique_ptr<OperationPass<ModuleOp>> createDeclareNumWorkgroupsFnPass(
- const llvm::StringRef numWorkgroupsFnAttrName) {
- return std::make_unique<DeclareNumWorkgroupsFnPass>(numWorkgroupsFnAttrName);
-}
-
-} // namespace common
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Conversion/Common/LegalizeNumWorkgroupsFnPass.cpp b/iree/compiler/Conversion/Common/LegalizeNumWorkgroupsFnPass.cpp
deleted file mode 100644
index 1f1bac0..0000000
--- a/iree/compiler/Conversion/Common/LegalizeNumWorkgroupsFnPass.cpp
+++ /dev/null
@@ -1,122 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-//===-LegalizeNumWorkgroupsFnPass.cpp - Legalize to be runnable on host ---===//
-//
-// The function generated by the codegeneration pass to compute the number of
-// workgroups uses a slice of the device-side code. Legalize it to run on the
-// host.
-//
-//===----------------------------------------------------------------------===//
-#include "iree/compiler/Conversion/CodegenUtils/FunctionUtils.h"
-#include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
-#include "mlir/IR/Module.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Pass/Pass.h"
-
-namespace mlir {
-namespace iree_compiler {
-namespace common {
-namespace {
-
-/// Pattern to legalize shapex.tie_shape operation to tie the shape of the
-/// `iree.placeholder` result to the argument of the function.
-struct LegalizeTieShapeOp : OpRewritePattern<Shape::TieShapeOp> {
- using OpRewritePattern<Shape::TieShapeOp>::OpRewritePattern;
- LogicalResult matchAndRewrite(Shape::TieShapeOp tieShapeOp,
- PatternRewriter &rewriter) const override {
- if (tieShapeOp.shape().isa<BlockArgument>()) return failure();
- auto phOp = dyn_cast_or_null<IREE::PlaceholderOp>(
- tieShapeOp.operand().getDefiningOp());
- if (!phOp) return failure();
- IntegerAttr operandNumAttr =
- phOp.getAttrOfType<IntegerAttr>(getOperandResultNumAttrName());
- if (!operandNumAttr) {
- return phOp.emitRemark("expected operand_result_index attribute");
- }
- FuncOp numWorkgroupsFn = phOp.getParentOfType<FuncOp>();
- rewriter.replaceOpWithNewOp<Shape::TieShapeOp>(
- tieShapeOp, phOp,
- numWorkgroupsFn.getArgument(
- phOp.getAttrOfType<IntegerAttr>(getOperandResultNumAttrName())
- .getInt()));
- return success();
- }
-};
-
-/// Pattern to remove dead `iree.placeholder` ops. They arent removed since
-/// they are tagged as having `MemoryEffect`.
-struct RemoveDeadPlaceholderOp : OpRewritePattern<IREE::PlaceholderOp> {
- using OpRewritePattern<IREE::PlaceholderOp>::OpRewritePattern;
- LogicalResult matchAndRewrite(IREE::PlaceholderOp phOp,
- PatternRewriter &rewriter) const override {
- if (phOp.use_empty()) {
- rewriter.eraseOp(phOp);
- return success();
- }
- return failure();
- }
-};
-
-/// Pass to legalize the function that computes the number of workgroups to
-/// use for launch to be runnable on the host-side.
-struct LegalizeNumWorkgroupsFnPass
- : public PassWrapper<LegalizeNumWorkgroupsFnPass, OperationPass<ModuleOp>> {
- LegalizeNumWorkgroupsFnPass(llvm::StringRef numWorkgroupsFnAttrName)
- : numWorkgroupsFnAttrName(numWorkgroupsFnAttrName) {}
- LegalizeNumWorkgroupsFnPass(const LegalizeNumWorkgroupsFnPass &pass) {}
- void runOnOperation() override;
- llvm::StringRef numWorkgroupsFnAttrName;
-};
-} // namespace
-
-static void populateLegalizeNumWorkgroupsFnPattern(
- MLIRContext *context, OwningRewritePatternList &patterns) {
- patterns.insert<LegalizeTieShapeOp, RemoveDeadPlaceholderOp>(context);
-}
-
-void LegalizeNumWorkgroupsFnPass::runOnOperation() {
- ModuleOp module = getOperation();
- auto fns = module.getOps<FuncOp>();
- OwningRewritePatternList patterns;
- MLIRContext *context = &getContext();
- populateLegalizeNumWorkgroupsFnPattern(context, patterns);
-
- SymbolTable symbolTable(module.getOperation());
- for (FuncOp fn : fns) {
- if (!isEntryPoint(fn)) continue;
- auto numWorkgroupsFnAttr =
- fn.getAttrOfType<SymbolRefAttr>(numWorkgroupsFnAttrName);
- if (!numWorkgroupsFnAttr) continue;
- StringRef numWorkgroupsFnName = numWorkgroupsFnAttr.getLeafReference();
- FuncOp numWorkgroupsFn = symbolTable.lookup<FuncOp>(numWorkgroupsFnName);
- if (!numWorkgroupsFn) {
- fn.emitError("unable to find function to compute number of workgroups ")
- << numWorkgroupsFnName;
- return signalPassFailure();
- }
- if (failed(applyPatternsAndFoldGreedily(numWorkgroupsFn, patterns)))
- return signalPassFailure();
- }
-}
-
-std::unique_ptr<OperationPass<ModuleOp>> createLegalizeNumWorkgroupsFnPass(
- llvm::StringRef numWorkgroupsFnAttrName) {
- return std::make_unique<LegalizeNumWorkgroupsFnPass>(numWorkgroupsFnAttrName);
-}
-
-} // namespace common
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Conversion/Common/Passes.h b/iree/compiler/Conversion/Common/Passes.h
deleted file mode 100644
index 90c918b..0000000
--- a/iree/compiler/Conversion/Common/Passes.h
+++ /dev/null
@@ -1,29 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-namespace mlir {
-namespace iree_compiler {
-namespace common {
-
-std::unique_ptr<OperationPass<ModuleOp>> createDeclareNumWorkgroupsFnPass(
- llvm::StringRef numWorkgroupsFnAttrName);
-
-/// Pass to legalize function that returns number of workgroups to use for
-/// launch to be runnable on the host.
-std::unique_ptr<OperationPass<ModuleOp>> createLegalizeNumWorkgroupsFnPass(
- llvm::StringRef numWorkgroupsFnAttrName);
-
-} // namespace common
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToLLVM/Attributes.h b/iree/compiler/Conversion/LinalgToLLVM/Attributes.h
deleted file mode 100644
index 690da4a..0000000
--- a/iree/compiler/Conversion/LinalgToLLVM/Attributes.h
+++ /dev/null
@@ -1,32 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_COMPILER_CONVERSION_LINALGTOLLVM_ATTRIBUTES_H_
-#define IREE_COMPILER_CONVERSION_LINALGTOLLVM_ATTRIBUTES_H_
-
-#include "llvm/ADT/StringRef.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-/// Attribute on a entry point function that specifies which function computes
-/// the number of workgroups.
-inline llvm::StringRef getNumWorkgroupsFnAttrName() {
- return "llvm.num_workgroups_fn";
-}
-
-} // namespace iree_compiler
-} // namespace mlir
-
-#endif // IREE_COMPILER_CONVERSION_LINALGTOSPIRV_ATTRIBUTES_H_
diff --git a/iree/compiler/Conversion/LinalgToLLVM/BUILD b/iree/compiler/Conversion/LinalgToLLVM/BUILD
index 555e78a..9261ede 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/BUILD
+++ b/iree/compiler/Conversion/LinalgToLLVM/BUILD
@@ -23,19 +23,14 @@
srcs = [
"ConvImg2ColMatmulConversion.cpp",
"ConvertToLLVM.cpp",
- "KernelDispatch.cpp",
- "LinalgTileAndDistributePass.cpp",
"MatMulVectorization.cpp",
"Passes.cpp",
],
hdrs = [
- "Attributes.h",
- "KernelDispatch.h",
"Passes.h",
],
deps = [
"//iree/compiler/Conversion/CodegenUtils",
- "//iree/compiler/Conversion/Common",
"//iree/compiler/Conversion/HLOToHLO",
"//iree/compiler/Conversion/HLOToLinalg",
"//iree/compiler/Dialect/HAL/IR",
@@ -48,7 +43,6 @@
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LLVMDialect",
"@llvm-project//mlir:LLVMTransforms",
- "@llvm-project//mlir:LinalgOps",
"@llvm-project//mlir:LinalgToLLVM",
"@llvm-project//mlir:LinalgTransforms",
"@llvm-project//mlir:Pass",
diff --git a/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt b/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
index 4b807c4..f7b79e0 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
+++ b/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
@@ -18,21 +18,16 @@
NAME
LinalgToLLVM
HDRS
- "Attributes.h"
- "KernelDispatch.h"
"Passes.h"
SRCS
"ConvImg2ColMatmulConversion.cpp"
"ConvertToLLVM.cpp"
- "KernelDispatch.cpp"
- "LinalgTileAndDistributePass.cpp"
"MatMulVectorization.cpp"
"Passes.cpp"
DEPS
MLIRAffineToStandard
MLIRIR
MLIRLLVMIR
- MLIRLinalg
MLIRLinalgToLLVM
MLIRLinalgTransforms
MLIRPass
@@ -45,7 +40,6 @@
MLIRVectorToLLVM
MLIRVectorToSCF
iree::compiler::Conversion::CodegenUtils
- iree::compiler::Conversion::Common
iree::compiler::Conversion::HLOToHLO
iree::compiler::Conversion::HLOToLinalg
iree::compiler::Dialect::HAL::IR
diff --git a/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp b/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp
index 15fd098..f2525d4 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp
+++ b/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp
@@ -12,14 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "iree/compiler/Conversion/CodegenUtils/FunctionUtils.h"
-#include "iree/compiler/Conversion/LinalgToLLVM/Attributes.h"
#include "iree/compiler/Conversion/LinalgToLLVM/Passes.h"
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
-#include "iree/compiler/Dialect/IREE/IR/IREEDialect.h"
#include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
@@ -141,11 +137,10 @@
}
// Change signature of entry function to func
-// clang-format off
-// entry_func(%packed_buffers_arg_ptr: !<llvm.int8**>, thread_idx: !llvm.i32, thread_idy: !llvm.i32, thread_idz: !llvm.i32) and lower IREE and HAL ops to
+// entry_func(%packed_buffers_arg_ptr:
+// !<llvm.int8**>, %push_constant: !<llvm.int64*>) and lower IREE and HAL ops to
// corresponding LLVMIR ops to construct memref descriptors and load
// push_constant values.
-// clang-format on
class ConvertFuncWithHALInterface : public ConvertToLLVMPattern {
public:
explicit ConvertFuncWithHALInterface(MLIRContext *context,
@@ -168,7 +163,6 @@
// Get interface buffers from all the blocks.
SmallVector<IREE::PlaceholderOp, 8> bufferOps;
SmallVector<IREE::HAL::InterfaceLoadConstantOp, 8> loadOps;
- SmallVector<IREE::WorkgroupIdOp, 3> workgroupIdOps;
for (Block &block : funcOp.getBlocks()) {
for (Operation &op : block) {
if (auto phOp = dyn_cast<IREE::PlaceholderOp>(op))
@@ -176,9 +170,6 @@
if (auto phOp = dyn_cast<IREE::HAL::InterfaceLoadConstantOp>(op)) {
loadOps.push_back(phOp);
}
- if (auto threadIdOp = dyn_cast<IREE::WorkgroupIdOp>(op)) {
- workgroupIdOps.push_back(threadIdOp);
- }
}
}
@@ -226,39 +217,22 @@
}
TypeConverter::SignatureConversion signatureConverter(/*numOrigInputs=*/0);
- // clang-format off
- // func foo(%packed_buffer_args: !llvm<i8**>, %push_constant: !llvm<i32*>, thread_idx: i32, thread_idy, thread_idz: i32)
- // clang-format on
+
+ // func foo(%packed_buffer_args: !llvm<i8**>, %push_constant: !llvm<i32*>)
MLIRContext *context = rewriter.getContext();
auto packedBuffersArgsTy =
LLVM::LLVMType::getInt8PtrTy(context).getPointerTo();
auto pushConstantArgTy = LLVM::LLVMType::getInt32Ty(context).getPointerTo();
- auto threadIdXTy = LLVM::LLVMType::getInt32Ty(context);
- auto threadIdYTy = LLVM::LLVMType::getInt32Ty(context);
- auto threadIdZTy = LLVM::LLVMType::getInt32Ty(context);
signatureConverter.addInputs(packedBuffersArgsTy);
signatureConverter.addInputs(pushConstantArgTy);
- signatureConverter.addInputs(threadIdXTy);
- signatureConverter.addInputs(threadIdYTy);
- signatureConverter.addInputs(threadIdZTy);
+ // Create the new function's signature.
Location loc = funcOp.getLoc();
-
- // Construct newFunc with all attributes except return type & symbol name.
- SmallVector<NamedAttribute, 4> funcAttrs;
- for (auto attr : funcOp.getAttrs()) {
- if (attr.first == SymbolTable::getSymbolAttrName() ||
- attr.first == mlir::impl::getTypeAttrName()) {
- continue;
- }
- funcAttrs.push_back(attr);
- }
-
auto newFuncOp = rewriter.create<FuncOp>(
loc, funcOp.getName(),
rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
llvm::None),
- funcAttrs);
+ ArrayRef<NamedAttribute>());
// Move all ops in the old function's region to the new function.
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
@@ -310,26 +284,6 @@
rewriter.replaceOp(loadOp, dimConstantCasted);
}
- // Lower iree.workgroup_idd ops to get indices from function arugments.
- for (auto workgroupCoordOp : workgroupIdOps) {
- auto attr = workgroupCoordOp.getAttrOfType<StringAttr>("dimension");
- int argIndex = -1;
- if (attr.getValue().str() == "x") {
- argIndex = 2;
- } else if (attr.getValue().str() == "y") {
- argIndex = 3;
- } else if (attr.getValue().str() == "z") {
- argIndex = 4;
- } else {
- return rewriter.notifyMatchFailure(
- funcOp,
- "Unable to map to workgroup coordinate : " + attr.getValue().str());
- }
- Value threadXIndex = builder.create<LLVM::ZExtOp>(
- loc, typeConverter.convertType(workgroupCoordOp.getType()),
- newFuncOp.getArgument(argIndex));
- rewriter.replaceOp(workgroupCoordOp, threadXIndex);
- }
rewriter.eraseOp(funcOp);
return success();
}
@@ -398,21 +352,9 @@
RemoveInterfaceOpPattern>(&getContext(), converter);
LLVMConversionTarget target(getContext());
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
-
- // Pass through workspace count calculation. This isn't going to be translated
- // to LLVM.
- // TODO(ataei): Should be handled somewhere else ?
- target.addDynamicallyLegalDialect<ShapeDialect, StandardOpsDialect,
- IREEDialect>(
- Optional<ConversionTarget::DynamicLegalityCallbackFn>([](Operation *op) {
- auto funcOp = dyn_cast<FuncOp>(op->getParentOp());
- if (funcOp && !isEntryPoint(funcOp)) return true;
- return false;
- }));
-
+ target.addIllegalOp<IREE::PlaceholderOp>();
target.addDynamicallyLegalOp<FuncOp>([](FuncOp funcOp) {
bool any = false;
- if (!isEntryPoint(funcOp)) return true;
funcOp.walk([&](IREE::PlaceholderOp placeholderOp) { any = true; });
return any ? false : true;
});
diff --git a/iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.cpp b/iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.cpp
deleted file mode 100644
index 76f59c2..0000000
--- a/iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.cpp
+++ /dev/null
@@ -1,23 +0,0 @@
-
-#include "iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.h"
-
-#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
-#include "mlir/IR/Operation.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-llvm::SmallVector<int64_t, 4> getTileSizesImpl(linalg::MatmulOp op) {
- return {128, 128};
-}
-
-llvm::SmallVector<int64_t, 4> CPUKernelDispatch::getTileSizes(
- Operation* op) const {
- if (isa<linalg::MatmulOp>(op)) {
- return getTileSizesImpl(dyn_cast<linalg::MatmulOp>(op));
- }
- return {1, 1, 1};
-}
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.h b/iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.h
deleted file mode 100644
index 09d6c41..0000000
--- a/iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.h
+++ /dev/null
@@ -1,30 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include <cstdint>
-
-#include "llvm/ADT/SmallVector.h"
-
-namespace mlir {
-class Operation;
-
-namespace iree_compiler {
-
-class CPUKernelDispatch {
- public:
- llvm::SmallVector<int64_t, 4> getTileSizes(Operation* op) const;
-};
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToLLVM/LinalgTileAndDistributePass.cpp b/iree/compiler/Conversion/LinalgToLLVM/LinalgTileAndDistributePass.cpp
deleted file mode 100644
index 620cdc1..0000000
--- a/iree/compiler/Conversion/LinalgToLLVM/LinalgTileAndDistributePass.cpp
+++ /dev/null
@@ -1,220 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/Conversion/CodegenUtils/FunctionUtils.h"
-#include "iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.h"
-#include "iree/compiler/Conversion/CodegenUtils/MarkerUtils.h"
-#include "iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.h"
-#include "iree/compiler/Conversion/LinalgToLLVM/Attributes.h"
-#include "iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.h"
-#include "iree/compiler/Dialect/IREE/IR/IREEDialect.h"
-#include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
-#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Pass/Pass.h"
-
-#define DEBUG_TYPE "iree-linalg-to-llvm-tile-and-distribute"
-
-namespace mlir {
-namespace iree_compiler {
-
-namespace {
-struct LinalgTileAndDistributePass
- : public PassWrapper<LinalgTileAndDistributePass, OperationPass<ModuleOp>> {
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<linalg::LinalgDialect, IREEDialect, AffineDialect,
- scf::SCFDialect>();
- }
- LinalgTileAndDistributePass() = default;
- LinalgTileAndDistributePass(const LinalgTileAndDistributePass &pass) {}
- void runOnOperation() override;
-
- private:
- ListOption<int64_t> tileSizes{
- *this, "tile-sizes", llvm::cl::desc("Set tile sizes to use"),
- llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
-};
-} // namespace
-
-namespace {
-template <typename LinalgOpTy>
-struct TileToCPUThreads : public linalg::LinalgBaseTilingPattern {
- using Base = linalg::LinalgBaseTilingPattern;
- TileToCPUThreads(MLIRContext *context,
- const linalg::LinalgDependenceGraph &dependenceGraph,
- const CPUKernelDispatch &cpuKernelDispatch,
- linalg::LinalgTilingOptions options,
- linalg::LinalgMarker marker, PatternBenefit benefit = 1)
- : Base(LinalgOpTy::getOperationName(), context, options, marker, benefit),
- cpuKernelDispatch(cpuKernelDispatch) {}
-
- LogicalResult matchAndRewrite(Operation *op,
- PatternRewriter &rewriter) const override {
- // Find the parent FuncOp before tiling. If tiling succeeds, the op will be
- // erased.
- FuncOp funcOp = op->getParentOfType<FuncOp>();
- SmallVector<Value, 4> tensorResults;
- if (!funcOp ||
- failed(Base::matchAndRewriteBase(op, rewriter, tensorResults)) ||
- !tensorResults.empty() ||
- (funcOp.getAttr(getNumWorkgroupsFnAttrName()) &&
- failed(utils::createNumWorkgroupsFromResultShape(
- rewriter, cast<linalg::LinalgOp>(op), funcOp,
- getNumWorkgroupsFnAttrName(),
- cpuKernelDispatch.getTileSizes(op))))) {
- return failure();
- }
- rewriter.eraseOp(op);
- return success();
- }
- CPUKernelDispatch cpuKernelDispatch;
-};
-
-template <typename LinalgOpTy>
-struct TileAndFuseToCPUThreads
- : public linalg::LinalgTileAndFusePattern<LinalgOpTy> {
- using Base = linalg::LinalgTileAndFusePattern<LinalgOpTy>;
- TileAndFuseToCPUThreads(MLIRContext *context,
- const linalg::LinalgDependenceGraph &dependenceGraph,
- const CPUKernelDispatch &cpuKernelDispatch,
- linalg::LinalgTilingOptions tilingOptions,
- linalg::LinalgMarker marker,
- PatternBenefit benefit = 1)
- : Base(context, dependenceGraph, tilingOptions,
- linalg::LinalgFusionOptions().setIndicesToFuse({2}), marker,
- marker,
- linalg::LinalgMarker(ArrayRef<Identifier>(),
- Identifier::get(getDeleteMarker(), context)),
- benefit),
- dependenceGraph(dependenceGraph),
- cpuKernelDispatch(cpuKernelDispatch) {}
-
- virtual LogicalResult matchAndRewrite(Operation *op,
- PatternRewriter &rewriter) const {
- FuncOp funcOp = op->getParentOfType<FuncOp>();
- linalg::LinalgOp linalgOp = cast<linalg::LinalgOp>(op);
- if (!funcOp || !dependenceGraph.hasDependentOperations(linalgOp) ||
- failed(Base::matchAndRewrite(op, rewriter)) ||
- failed(utils::createNumWorkgroupsFromResultShape(
- rewriter, cast<linalg::LinalgOp>(op), funcOp,
- getNumWorkgroupsFnAttrName().str(),
- cpuKernelDispatch.getTileSizes(op)))) {
- return failure();
- }
- return success();
- }
-
- const linalg::LinalgDependenceGraph &dependenceGraph;
- CPUKernelDispatch cpuKernelDispatch;
-};
-
-} // namespace
-
-void LinalgTileAndDistributePass::runOnOperation() {
- MLIRContext *context = &getContext();
- ModuleOp module = getOperation();
-
- static linalg::LinalgLoopDistributionOptions workgroupDistributionOptions = {
- [](OpBuilder &builder, Location loc, ArrayRef<Range> parallelLoopRanges) {
- Type indexType = builder.getIndexType();
- auto numParallelDims = parallelLoopRanges.size();
- SmallVector<linalg::ProcInfo, 2> procInfo(numParallelDims);
- for (int dim = 0; dim < numParallelDims; ++dim) {
- std::array<StringRef, 3> dimAttr{"x", "y", "z"};
- StringAttr attr =
- builder.getStringAttr(dimAttr[std::min<unsigned>(dim, 3)]);
- procInfo[numParallelDims - dim - 1] = {
- builder.create<IREE::WorkgroupIdOp>(loc, indexType, attr),
- builder.create<IREE::WorkgroupSizeOp>(loc, indexType, attr)};
- }
- return procInfo;
- },
- {linalg::DistributionMethod::CyclicNumProcsEqNumIters,
- linalg::DistributionMethod::CyclicNumProcsEqNumIters,
- linalg::DistributionMethod::CyclicNumProcsEqNumIters}};
-
- CPUKernelDispatch cpuKernelDispatch;
-
- // Function to compute first level tiling values.
- std::function<SmallVector<Value, 4>(OpBuilder &, Operation *)>
- getOuterTileSizeFn =
- [&cpuKernelDispatch](OpBuilder &builder,
- Operation *operation) -> SmallVector<Value, 4> {
- auto tileSizes = cpuKernelDispatch.getTileSizes(operation);
- if (tileSizes.empty()) return {};
- SmallVector<Value, 4> tileSizesVal;
- tileSizesVal.reserve(tileSizes.size());
- for (auto val : tileSizes) {
- tileSizesVal.push_back(
- builder.create<ConstantIndexOp>(operation->getLoc(), val));
- }
- return tileSizesVal;
- };
-
- for (FuncOp funcOp : module.getOps<FuncOp>()) {
- if (!isEntryPoint(funcOp)) continue;
-
- // Compute the Linalg Dependence Graph.
- linalg::Aliases aliases;
- linalg::LinalgDependenceGraph dependenceGraph =
- linalg::LinalgDependenceGraph::buildDependenceGraph(aliases, funcOp);
-
- OwningRewritePatternList patterns;
-
- auto linalgTilingOptions =
- linalg::LinalgTilingOptions()
- .setDistributionOptions(workgroupDistributionOptions)
- .setLoopType(linalg::LinalgTilingLoopType::ParallelLoops);
- tileSizes.empty()
- ? linalgTilingOptions.setTileSizeComputationFunction(getOuterTileSizeFn)
- : linalgTilingOptions.setTileSizes(ArrayRef<int64_t>(tileSizes));
- patterns.insert<TileAndFuseToCPUThreads<linalg::MatmulOp>,
- TileToCPUThreads<linalg::MatmulOp>>(
- context, dependenceGraph, cpuKernelDispatch, linalgTilingOptions,
- linalg::LinalgMarker(ArrayRef<Identifier>(),
- Identifier::get(getWorkgroupMarker(), context)));
-
- // Tile and distribute to CPU threads.
- applyPatternsAndFoldGreedily(funcOp, patterns);
-
- // Apply canonicalization patterns.
- OwningRewritePatternList canonicalizationPatterns;
- canonicalizationPatterns.insert<AffineMinCanonicalizationPattern>(context);
- AffineApplyOp::getCanonicalizationPatterns(canonicalizationPatterns,
- context);
- AffineMinOp::getCanonicalizationPatterns(canonicalizationPatterns, context);
- SubViewOp::getCanonicalizationPatterns(canonicalizationPatterns, context);
-
- applyPatternsAndFoldGreedily(funcOp, canonicalizationPatterns);
-
- // Delete the ops that are marked for deletion.
- funcOp.walk([](linalg::LinalgOp linalgOp) {
- if (hasMarker(linalgOp.getOperation(), getDeleteMarker()))
- linalgOp.getOperation()->erase();
- });
- }
-}
-
-std::unique_ptr<OperationPass<ModuleOp>> createLinalgTileAndDistributePass() {
- return std::make_unique<LinalgTileAndDistributePass>();
-}
-
-static PassRegistration<LinalgTileAndDistributePass> pass(
- "iree-codegen-llvm-linalg-tile-and-distribute",
- "Tile and distribute Linalg operations on buffers",
- [] { return std::make_unique<LinalgTileAndDistributePass>(); });
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp b/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp
index 9f63a0e..498f4a6 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp
+++ b/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp
@@ -14,9 +14,7 @@
#include "iree/compiler/Conversion/HLOToLinalg/Passes.h"
-#include "iree/compiler/Conversion/Common/Passes.h"
#include "iree/compiler/Conversion/HLOToHLO/Passes.h"
-#include "iree/compiler/Conversion/LinalgToLLVM/Attributes.h"
#include "iree/compiler/Conversion/LinalgToLLVM/Passes.h"
#include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
@@ -34,19 +32,7 @@
"linag.matmul"),
llvm::cl::init(false));
-static llvm::cl::opt<bool> llvmLinalgTileAndDistributePass(
- "iree-codegen-linalg-to-llvm-tile-and-distrobute",
- llvm::cl::desc("Tile and distribute linalg ops among iree threads"),
- llvm::cl::init(false));
-
void addLinalgToLLVMPasses(OpPassManager &passManager) {
- // Distribute linalg op among a 3d grid of parallel threads.
- if (llvmLinalgTileAndDistributePass) {
- passManager.addPass(createLinalgTileAndDistributePass());
- passManager.addPass(common::createLegalizeNumWorkgroupsFnPass(
- getNumWorkgroupsFnAttrName()));
- }
-
// Linalg.ConvOp -> (Img2Col packing + matmul)
if (convImg2ColConversion) {
passManager.addPass(createConvImg2ColMatmulConversionPass());
@@ -71,9 +57,6 @@
}
void buildLLVMTransformPassPipeline(OpPassManager &passManager) {
- passManager.addPass(
- common::createDeclareNumWorkgroupsFnPass(getNumWorkgroupsFnAttrName()));
-
passManager.addPass(createInlinerPass());
// Propagates dynamic shapes computation on tensors.
diff --git a/iree/compiler/Conversion/LinalgToLLVM/Passes.h b/iree/compiler/Conversion/LinalgToLLVM/Passes.h
index 60cedbe..c79c885 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/Passes.h
+++ b/iree/compiler/Conversion/LinalgToLLVM/Passes.h
@@ -20,15 +20,6 @@
namespace mlir {
namespace iree_compiler {
-// Options that can be used to configure LLVMIR codegeneration.
-struct LLVMIRCodegenOptions {
- SmallVector<int64_t, 3> workgroupSize = {};
- SmallVector<int64_t, 3> tileSizes = {};
- bool useWorkgroupMemory = false;
- bool useVectorization = false;
- bool useVectorPass = false;
-};
-
/// Converts linalg::MatmulOp into LLVM dialect
std::unique_ptr<FunctionPass> createMatMulTileAndVectorizePass();
@@ -36,8 +27,6 @@
/// linalg::MatmulOp.
std::unique_ptr<FunctionPass> createConvImg2ColMatmulConversionPass();
-std::unique_ptr<FunctionPass> createLinalgTileAndDistributePass();
-
/// Populates patterns to rewrite linalg::ConvOp into packed img2col operation
/// followed by linalg::MatmulOp.
void populateConvImg2ColMatmulConversionPatterns(
diff --git a/iree/compiler/Conversion/LinalgToLLVM/test/convert_to_llvm.mlir b/iree/compiler/Conversion/LinalgToLLVM/test/convert_to_llvm.mlir
index b6ef3e3..0624289 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/test/convert_to_llvm.mlir
+++ b/iree/compiler/Conversion/LinalgToLLVM/test/convert_to_llvm.mlir
@@ -14,7 +14,7 @@
hal.interface @legacy_io attributes {push_constants = 2 : i32, sym_visibility = "private"} {
hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
}
-// CHECK: llvm.func @convert_dynamic_shape(%[[ARG0:.+]]: !llvm.ptr<ptr<i8>>, %[[ARG1:.+]]: !llvm.ptr<i32>, %[[THREAD_X_ID:.+]]: !llvm.i32, %[[THREAD_Y_ID:.+]]: !llvm.i32, %[[THREAD_Z_ID:.+]]: !llvm.i32)
+// CHECK: llvm.func @convert_dynamic_shape(%[[ARG0:.+]]: !llvm.ptr<ptr<i8>>, %[[ARG1:.+]]: !llvm.ptr<i32>)
// CHECK: %[[PACKED_ARGS_PTR:.+]] = llvm.bitcast %[[ARG0]] : !llvm.ptr<ptr<i8>> to !llvm.ptr<struct<(ptr<float>)>>
// CHECK: %[[PACKED_ARGS:.+]] = llvm.load %[[PACKED_ARGS_PTR]] : !llvm.ptr<struct<(ptr<float>)>>
// CHECK: %[[MEMREF0_DATA_PTR:.+]] = llvm.extractvalue %[[PACKED_ARGS]][0] : !llvm.struct<(ptr<float>)>
@@ -38,8 +38,6 @@
// CHECK: %[[STRIDE_DIM0:.+]] = llvm.mul %[[STRIDE_DIM1]], %[[DIM1_0]] : !llvm.i64
// CHECK: llvm.insertvalue %[[STRIDE_DIM0]], %[[MEMREF0_4]][4, 0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
-// -----
-
// CHECK_LABEL: @convert_dynamic_shape2
func @convert_dynamic_shape2() -> f32 {
%c0 = constant 0 : index
@@ -53,7 +51,8 @@
hal.interface @legacy_io2 attributes {push_constants = 1 : i32, sym_visibility = "private"} {
hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
}
-// CHECK: llvm.func @convert_dynamic_shape2(%[[ARG0:.+]]: !llvm.ptr<ptr<i8>>, %[[ARG1:.+]]: !llvm.ptr<i32>, %[[THREAD_X_ID:.+]]: !llvm.i32, %[[THREAD_Y_ID:.+]]: !llvm.i32, %[[THREAD_Z_ID:.+]]: !llvm.i32)
+
+// CHECK: llvm.func @convert_dynamic_shape2(%[[ARG0:.+]]: !llvm.ptr<ptr<i8>>, %[[ARG1:.+]]: !llvm.ptr<i32>)
// CHECK: %[[PACKED_ARGS_PTR:.+]] = llvm.bitcast %[[ARG0]] : !llvm.ptr<ptr<i8>> to !llvm.ptr<struct<(ptr<float>)>>
// CHECK: %[[PACKED_ARGS:.+]] = llvm.load %[[PACKED_ARGS_PTR]] : !llvm.ptr<struct<(ptr<float>)>>
// CHECK: %[[MEMREF0_DATA_PTR:.+]] = llvm.extractvalue %[[PACKED_ARGS]][0] : !llvm.struct<(ptr<float>)>
@@ -85,17 +84,3 @@
// CHECK: %[[GET_PTR:.+]] = llvm.getelementptr %[[EXTRACT1:.+]][%[[ADD2:.+]]] : (!llvm.ptr<float>, !llvm.i64) -> !llvm.ptr<float>
// CHECK: %[[LOAD:.+]] = llvm.load %[[GET_PTR:.+]] : !llvm.ptr<float>
-// -----
-
-// CHECK_LABEL: @distribute_lookup
-func @distribute_lookup() -> f32 {
- %0 = iree.placeholder for "interface buffer" {binding = @legacy_io3::@arg0} : memref<2x2x2xf32>
- %1 = iree.workgroup_id {dimension = "x"} : index
- %2 = iree.workgroup_id {dimension = "y"} : index
- %3 = iree.workgroup_id {dimension = "z"} : index
- %4 = load %0[%1, %2, %3] : memref<2x2x2xf32>
- return %4 : f32
-}
-hal.interface @legacy_io3 attributes {push_constants = 1 : i32, sym_visibility = "private"} {
- hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
-}
diff --git a/iree/compiler/Conversion/LinalgToLLVM/test/tile_and_distribute.mlir b/iree/compiler/Conversion/LinalgToLLVM/test/tile_and_distribute.mlir
deleted file mode 100644
index baffba1..0000000
--- a/iree/compiler/Conversion/LinalgToLLVM/test/tile_and_distribute.mlir
+++ /dev/null
@@ -1,56 +0,0 @@
-// RUN: iree-opt --iree-codegen-llvm-linalg-tile-and-distribute=tile-sizes=2,4,1 -cse -split-input-file %s | IreeFileCheck %s
-
-func @dynamic_matmul(%lhs: memref<?x?xf32>, %rhs: memref<?x?xf32>, %result: memref<?x?xf32>) {
- linalg.matmul ins(%lhs, %rhs : memref<?x?xf32>, memref<?x?xf32>) outs(%result : memref<?x?xf32>)
- return
-}
-// CHECK: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 2)>
-// CHECK: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (2, s1 - s0 * 2)>
-// CHECK: #[[MAP2:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
-// CHECK: #[[MAP3:.+]] = affine_map<()[s0] -> (s0 * 4)>
-// CHECK: #[[MAP4:.+]] = affine_map<()[s0, s1] -> (4, s1 - s0 * 4)>
-// CHECK: func @dynamic_matmul(%[[LHS:.+]]: memref<?x?xf32>, %[[RHS:.+]]: memref<?x?xf32>, %[[RESULT:.+]]: memref<?x?xf32>)
-// CHECK: %[[CONST_0:.+]] = constant 0 : index
-// CHECK: %[[CONST_1:.+]] = constant 1 : index
-// CHECK: %[[DIM_K:.+]] = dim %[[LHS]], %[[CONST_1]]
-// CHECK: %[[THREAD_X_ID:.+]] = iree.workgroup_id {dimension = "x"} : index
-// CHECK: %[[THREAD_Y_ID:.+]] = iree.workgroup_id {dimension = "y"} : index
-// CHECK: scf.for %[[K:.+]] = %[[CONST_0]] to %[[DIM_K]]
-// CHECK: %[[I:.+]] = affine.apply #[[MAP0]]()[%[[THREAD_Y_ID]]]
-// CHECK: %[[DIM_I:.+]] = dim %[[LHS]], %[[CONST_0]]
-// CHECK: %[[I_OFFSET:.+]] = affine.min #[[MAP1]]()[%[[THREAD_Y_ID]], %[[DIM_I]]]
-// CHECK: %[[LHS_SUBVIEW:.+]] = subview %[[LHS]][%[[I]], %[[K]]] [%[[I_OFFSET]], 1] [1, 1]
-// CHECK: %[[J:.+]] = affine.apply #[[MAP3]]()[%[[THREAD_X_ID]]]
-// CHECK: %[[DIM_J:.+]] = dim %[[RHS]], %[[CONST_1]]
-// CHECK: %[[J_OFFSET:.+]] = affine.min #[[MAP4]]()[%[[THREAD_X_ID]], %[[DIM_J]]]
-// CHECK: %[[RHS_SUBVIEW:.+]] = subview %[[RHS]][%[[K]], %[[J]]] [1, %[[J_OFFSET]]] [1, 1]
-// CHECK: %[[DIM_I:.+]] = dim %[[RESULT]], %[[CONST_0]]
-// CHECK: %[[DIM_I_OFFSET:.+]] = affine.min #[[MAP1]]()[%[[THREAD_Y_ID]], %[[DIM_I]]]
-// CHECK: %[[DIM_J:.+]] = dim %[[RESULT]], %[[CONST_1]]
-// CHECK: %[[DIM_J_OFFSET:.+]] = affine.min #[[MAP4]]()[%[[THREAD_X_ID]], %[[DIM_J]]]
-// CHECK: %[[RESULT_SUBVIEW:.+]] = subview %[[RESULT]][%[[I]], %[[J]]] [%[[DIM_I_OFFSET]], %[[DIM_J_OFFSET]]] [1, 1]
-// CHECK: linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%[[LHS_SUBVIEW]], %[[RHS_SUBVIEW]] : memref<?x1xf32, #[[MAP2]]>, memref<1x?xf32, #[[MAP2]]>) outs(%[[RESULT_SUBVIEW]] : memref<?x?xf32, #[[MAP2]]>)
-
-// -----
-
-func @static_matmul(%lhs: memref<16x4xf32>, %rhs: memref<4x8xf32>, %result: memref<16x8xf32>) {
- linalg.matmul ins(%lhs, %rhs : memref<16x4xf32>, memref<4x8xf32>) outs(%result : memref<16x8xf32>)
- return
-}
-// CHECK: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 2)>
-// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1)[s0] -> (d0 * 4 + s0 + d1)>
-// CHECK: #[[MAP2:.+]] = affine_map<()[s0] -> (s0 * 4)>
-// CHECK: #[[MAP3:.+]] = affine_map<(d0, d1)[s0] -> (d0 * 8 + s0 + d1)>
-// CHECK: func @static_matmul(%[[LHS:.+]]: memref<16x4xf32>, %[[RHS:.+]]: memref<4x8xf32>, %[[RESULT:.+]]: memref<16x8xf32>)
-// CHECK: %[[CONST_0:.+]] = constant 0 : index
-// CHECK: %[[CONST_4:.+]] = constant 4 : index
-// CHECK: %[[CONST_1:.+]] = constant 1 : index
-// CHECK: %[[THREAD_X_ID:.+]] = iree.workgroup_id {dimension = "x"} : index
-// CHECK: %[[THREAD_Y_ID:.+]] = iree.workgroup_id {dimension = "y"} : index
-// CHECK: scf.for %[[K:.+]] = %[[CONST_0]] to %[[CONST_4]] step %[[CONST_1]]
-// CHECK: %[[I:.+]] = affine.apply #[[MAP0]]()[%[[THREAD_Y_ID]]]
-// CHECK: %[[LHS_SUBVIEW:.+]] = subview %[[LHS]][%[[I]], %[[K]]] [2, 1] [1, 1] : memref<16x4xf32> to memref<2x1xf32, #[[MAP1]]>
-// CHECK: %[[J:.+]] = affine.apply #[[MAP2]]()[%[[THREAD_X_ID]]]
-// CHECK: %[[RHS_SUBVIEW:.+]] = subview %[[RHS]][%[[K]], %[[J]]] [1, 4] [1, 1] : memref<4x8xf32> to memref<1x4xf32, #[[MAP3]]>
-// CHECK: %[[RESULT_SUBVIEW:.+]] = subview %[[RESULT]][%[[I]], %[[J]]] [2, 4] [1, 1] : memref<16x8xf32> to memref<2x4xf32, #[[MAP3]]>
-// CHECK: linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%[[LHS_SUBVIEW]], %[[RHS_SUBVIEW]] : memref<2x1xf32, #[[MAP1]]>, memref<1x4xf32, #[[MAP3]]>) outs(%6 : memref<2x4xf32, #[[MAP3]]>)
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/BUILD b/iree/compiler/Conversion/LinalgToSPIRV/BUILD
index fcedafe..f5fe4d9 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/BUILD
+++ b/iree/compiler/Conversion/LinalgToSPIRV/BUILD
@@ -28,6 +28,7 @@
"KernelDispatchUtils.cpp",
"LegalizeNumWorkgroupsFnPass.cpp",
"LinalgTileAndFusePass.cpp",
+ "MarkerUtils.cpp",
"MatMulVectorizationTest.cpp",
"Passes.cpp",
"SplitDispatchFunctionPass.cpp",
@@ -39,6 +40,7 @@
"Attributes.h",
"CooperativeMatrixAnalysis.h",
"KernelDispatchUtils.h",
+ "MarkerUtils.h",
"MemorySpace.h",
"Passes.h",
"Utils.h",
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt b/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt
index d788461..79578dd 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt
+++ b/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt
@@ -21,6 +21,7 @@
"Attributes.h"
"CooperativeMatrixAnalysis.h"
"KernelDispatchUtils.h"
+ "MarkerUtils.h"
"MemorySpace.h"
"Passes.h"
"Utils.h"
@@ -32,6 +33,7 @@
"KernelDispatchUtils.cpp"
"LegalizeNumWorkgroupsFnPass.cpp"
"LinalgTileAndFusePass.cpp"
+ "MarkerUtils.cpp"
"MatMulVectorizationTest.cpp"
"Passes.cpp"
"SplitDispatchFunctionPass.cpp"
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp
index 236820f..5421cc9 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp
@@ -22,9 +22,9 @@
#include <numeric>
#include "iree/compiler/Conversion/CodegenUtils/FunctionUtils.h"
-#include "iree/compiler/Conversion/CodegenUtils/MarkerUtils.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/Attributes.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h"
+#include "iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/MemorySpace.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/Utils.h"
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp
index 5872e76..508d0dc 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp
@@ -21,8 +21,8 @@
//
//===----------------------------------------------------------------------===//
-#include "iree/compiler/Conversion/CodegenUtils/MarkerUtils.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/CooperativeMatrixAnalysis.h"
+#include "iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
#include "llvm/ADT/STLExtras.h"
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
index b74d5e3..022672d 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
@@ -18,10 +18,10 @@
//
//===----------------------------------------------------------------------===//
#include "iree/compiler/Conversion/CodegenUtils/FunctionUtils.h"
-#include "iree/compiler/Conversion/CodegenUtils/MarkerUtils.h"
#include "iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/Attributes.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h"
+#include "iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/MemorySpace.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/Utils.h"
diff --git a/iree/compiler/Conversion/CodegenUtils/MarkerUtils.cpp b/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.cpp
similarity index 96%
rename from iree/compiler/Conversion/CodegenUtils/MarkerUtils.cpp
rename to iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.cpp
index 20f1616..dff5292 100644
--- a/iree/compiler/Conversion/CodegenUtils/MarkerUtils.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.cpp
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "iree/compiler/Conversion/CodegenUtils/MarkerUtils.h"
+#include "iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/IR/Attributes.h"
diff --git a/iree/compiler/Conversion/CodegenUtils/MarkerUtils.h b/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h
similarity index 81%
rename from iree/compiler/Conversion/CodegenUtils/MarkerUtils.h
rename to iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h
index a839807..78d4304 100644
--- a/iree/compiler/Conversion/CodegenUtils/MarkerUtils.h
+++ b/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h
@@ -19,8 +19,8 @@
//
//===----------------------------------------------------------------------===//
-#ifndef IREE_COMPILER_CONVERSION_CODEGENUTILS_MARKERUTILS_H_
-#define IREE_COMPILER_CONVERSION_CODEGENUTILS_MARKERUTILS_H_
+#ifndef IREE_COMPILER_CONVERSION_LINALGTOSPIRV_MARKERUTILS_H_
+#define IREE_COMPILER_CONVERSION_LINALGTOSPIRV_MARKERUTILS_H_
#include "llvm/ADT/ArrayRef.h"
#include "mlir/Support/LLVM.h"
@@ -30,12 +30,11 @@
class Operation;
namespace iree_compiler {
-/// Marker to denote that a linalg operation has been partitioned to
-/// workgroups.
+/// Marker to denote that a linalg operation has been partitioned to workgroups.
StringRef getWorkgroupMarker();
-/// Marker to denote that a linalg operation has been partitioned to
-/// workgroups and operands promoted to scratchspace memory.
+/// Marker to denote that a linalg operation has been partitioned to workgroups
+/// and operands promoted to scratchspace memory.
StringRef getWorkgroupMemoryMarker();
/// Marker for copy operations that are moving data from StorageClass to
@@ -45,10 +44,10 @@
/// Marker for operations that are going to be vectorized.
StringRef getVectorizeMarker();
-/// Marker for tagging an operation for deletion. Tile and fuse pattern does
-/// not delete the original operation to not invalidate the
-/// `linalg::LinalgDependenceGraph` data structure. Instead it is marked with
-/// a marker that can be used later to delete these operations.
+/// Marker for tagging an operation for deletion. Tile and fuse pattern does not
+/// delete the original operation to not invalidate the
+/// `linalg::LinalgDependenceGraph` data structure. Instead it is marked with a
+/// marker that can be used later to delete these operations.
StringRef getDeleteMarker();
/// Returns true if an operation has the specified `marker`. When `marker` is
@@ -61,4 +60,4 @@
} // namespace iree_compiler
} // namespace mlir
-#endif // IREE_COMPILER_CONVERSION_CODEGENUTILS_MARKERUTILS_H_
+#endif // IREE_COMPILER_CONVERSION_LINALGTOSPIRV_MARKERUTILS_H_
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Utils.cpp b/iree/compiler/Conversion/LinalgToSPIRV/Utils.cpp
index ddbf478..2c14351 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/Utils.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/Utils.cpp
@@ -20,7 +20,7 @@
#include "iree/compiler/Conversion/LinalgToSPIRV/Utils.h"
-#include "iree/compiler/Conversion/CodegenUtils/MarkerUtils.h"
+#include "iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/MemorySpace.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp
index 7e93452..eb2a32e 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp
@@ -22,9 +22,9 @@
#include <memory>
#include "iree/compiler/Conversion/CodegenUtils/FunctionUtils.h"
-#include "iree/compiler/Conversion/CodegenUtils/MarkerUtils.h"
#include "iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/CooperativeMatrixAnalysis.h"
+#include "iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/FormatVariadic.h"
diff --git a/iree/compiler/Conversion/init_conversions.h b/iree/compiler/Conversion/init_conversions.h
index 4d4c53c..faf79ab 100644
--- a/iree/compiler/Conversion/init_conversions.h
+++ b/iree/compiler/Conversion/init_conversions.h
@@ -61,9 +61,6 @@
inline void registerLinalgToLLVMPasses() {
static bool init_once = []() {
// LinalgToLLVM
- createConvImg2ColMatmulConversionPass();
- createLinalgTileAndDistributePass();
- createMatMulTileAndVectorizePass();
return true;
}();
(void)init_once;
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/LLVMAOTTarget.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/LLVMAOTTarget.cpp
index 6060600..199daf3 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/LLVMAOTTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/LLVMAOTTarget.cpp
@@ -46,17 +46,7 @@
// multi-threading issues.
llvm::LLVMContext context;
- // Remove all private functions, e.g tile size calcuations.
- SmallVector<FuncOp, 4> nonPublicFn;
- for (auto func : targetOp.getInnerModule().getOps<FuncOp>()) {
- if (SymbolTable::getSymbolVisibility(func) !=
- SymbolTable::Visibility::Public) {
- nonPublicFn.push_back(func);
- }
- }
- for (auto func : nonPublicFn) {
- func.erase();
- }
+ iree::DyLibExecutableDefT dyLibExecutableDef;
// At this moment we are leaving MLIR LLVM dialect land translating module
// into target independent LLVMIR.
@@ -66,7 +56,6 @@
return failure();
}
- iree::DyLibExecutableDefT dyLibExecutableDef;
// Create invocation function an populate entry_points.
auto entryPointOps = targetOp.getBlock().getOps<ExecutableEntryPointOp>();
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/BUILD b/iree/compiler/Dialect/HAL/Target/LLVM/BUILD
index 6458f0b..b644afe 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/BUILD
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/BUILD
@@ -40,7 +40,6 @@
deps = [
":LLVMIRPasses",
":LLVMTargetOptions",
- "//iree/compiler/Conversion/CodegenUtils",
"//iree/compiler/Conversion/LinalgToLLVM",
"//iree/compiler/Dialect/HAL/Target",
"@llvm-project//llvm:Support",
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/LLVM/CMakeLists.txt
index d01c728..ed01729 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/CMakeLists.txt
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/CMakeLists.txt
@@ -34,7 +34,6 @@
MLIRLinalg
MLIRSCF
MLIRVector
- iree::compiler::Conversion::CodegenUtils
iree::compiler::Conversion::LinalgToLLVM
iree::compiler::Dialect::HAL::Target
PUBLIC
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/IR/BUILD b/iree/compiler/Dialect/HAL/Target/LLVM/IR/BUILD
index eed06b7..2194e66 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/IR/BUILD
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/IR/BUILD
@@ -41,7 +41,6 @@
"//iree/compiler/Dialect/HAL/Target/LLVM:LLVMBaseTarget",
"//iree/compiler/Dialect/HAL/Target/LLVM:LLVMIRPasses",
"//iree/compiler/Dialect/HAL/Target/LLVM:LLVMTargetOptions",
- "//iree/compiler/Dialect/Shape/IR",
"//iree/schemas:llvmir_executable_def_cc_fbs",
"@llvm-project//llvm:Core",
"@llvm-project//llvm:Support",
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/IR/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/LLVM/IR/CMakeLists.txt
index c490eb0..345fe5a 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/IR/CMakeLists.txt
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/IR/CMakeLists.txt
@@ -33,7 +33,6 @@
iree::compiler::Dialect::HAL::Target::LLVM::LLVMBaseTarget
iree::compiler::Dialect::HAL::Target::LLVM::LLVMIRPasses
iree::compiler::Dialect::HAL::Target::LLVM::LLVMTargetOptions
- iree::compiler::Dialect::Shape::IR
iree::schemas::llvmir_executable_def_cc_fbs
PUBLIC
)
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/IR/LLVMIRTarget.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/IR/LLVMIRTarget.cpp
index b18ecbf..3a03471 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/IR/LLVMIRTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/IR/LLVMIRTarget.cpp
@@ -42,23 +42,11 @@
// Perform the translation to LLVM in a separate context to avoid
// multi-threading issues.
llvm::LLVMContext context;
- // Remove all private functions, e.g tile size calcuations.
- SmallVector<FuncOp, 4> nonPublicFn;
- for (auto func : targetOp.getInnerModule().getOps<FuncOp>()) {
- if (SymbolTable::getSymbolVisibility(func) !=
- SymbolTable::Visibility::Public) {
- nonPublicFn.push_back(func);
- }
- }
- for (auto func : nonPublicFn) {
- func.erase();
- }
// At this moment we are leaving MLIR LLVM dialect land translating module
// into target independent LLVMIR.
auto llvmModule =
mlir::translateModuleToLLVMIR(targetOp.getInnerModule(), context);
-
if (!llvmModule) {
return targetOp.emitError("Failed to translate executable to LLVM IR");
}
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMBaseTarget.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMBaseTarget.cpp
index 1b4bfe7..a0c82c2 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMBaseTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMBaseTarget.cpp
@@ -14,8 +14,6 @@
#include "iree/compiler/Dialect/HAL/Target/LLVM/LLVMBaseTarget.h"
-#include "iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.h"
-#include "iree/compiler/Conversion/LinalgToLLVM/Attributes.h"
#include "iree/compiler/Conversion/LinalgToLLVM/Passes.h"
#include "iree/compiler/Dialect/HAL/Target/LLVM/LLVMIRPasses.h"
#include "llvm/Support/FormatVariadic.h"
@@ -185,78 +183,13 @@
return success();
}
-LogicalResult LLVMBaseTargetBackend::recordDispatch(
- Location loc, DispatchState dispatchState,
- DeviceSwitchRewriter &switchRewriter) {
- IREE::HAL::ExecutableOp executableOp = dispatchState.executableOp;
- ModuleOp llvmIRModuleOp;
- for (auto executableTargetOp :
- executableOp.getBlock().getOps<IREE::HAL::ExecutableTargetOp>()) {
- if (matchPattern(executableTargetOp.target_backend_filter(),
- filter_pattern())) {
- ModuleOp innerModuleOp = executableTargetOp.getInnerModule();
- llvmIRModuleOp = innerModuleOp;
- break;
- }
- }
- if (!llvmIRModuleOp)
- return executableOp.emitError("unable to find executable llvmIR module");
-
- SmallVector<LLVM::LLVMFuncOp, 2> entryPointFns;
- for (LLVM::LLVMFuncOp funcOp : llvmIRModuleOp.getOps<LLVM::LLVMFuncOp>()) {
- if (SymbolTable::getSymbolVisibility(funcOp) ==
- SymbolTable::Visibility::Public) {
- entryPointFns.push_back(funcOp);
- }
- }
-
- auto *region = switchRewriter.addConditionRegion(
- IREE::HAL::DeviceMatchIDAttr::get(filter_pattern(), loc.getContext()),
- {
- dispatchState.workload,
- dispatchState.commandBuffer,
- });
- auto &entryBlock = region->front();
- ConversionPatternRewriter &rewriter = switchRewriter.getRewriter();
- OpBuilder::InsertionGuard guard(rewriter);
- rewriter.setInsertionPointToEnd(&entryBlock);
-
- auto commandBuffer = entryBlock.getArgument(1);
- for (auto it : llvm::enumerate(entryPointFns)) {
- LLVM::LLVMFuncOp funcOp = it.value();
- FlatSymbolRefAttr numWorkgroupsFnAttr =
- funcOp.getAttrOfType<FlatSymbolRefAttr>(getNumWorkgroupsFnAttrName());
- if (!numWorkgroupsFnAttr) {
- return funcOp.emitError("expected llvm.num_workgroups_fn ");
- }
- std::array<Value, 3> workgroupCount = {nullptr, nullptr, nullptr};
- FuncOp numWorkgroupsFn = dyn_cast<FuncOp>(SymbolTable::lookupSymbolIn(
- funcOp.getParentOfType<ModuleOp>(), numWorkgroupsFnAttr));
- if (!numWorkgroupsFn) {
- return funcOp.emitError("unable to find function ")
- << numWorkgroupsFnAttr
- << " that computes the number of workgroups to use";
- }
- workgroupCount =
- iree_compiler::utils::calculateWorkgroupCountFromNumWorkgroupsFn(
- loc, numWorkgroupsFn,
- dispatchState.executableOp.getFirstInterfaceOp(),
- dispatchState.operands, dispatchState.results, rewriter);
-
- if (llvm::any_of(workgroupCount,
- [](Value v) -> bool { return v == nullptr; })) {
- auto constantOne = rewriter.createOrFold<mlir::ConstantIndexOp>(loc, 1);
- rewriter.create<IREE::HAL::CommandBufferDispatchSymbolOp>(
- loc, commandBuffer, dispatchState.entryPointOp, constantOne,
- constantOne, constantOne);
- } else {
- rewriter.create<IREE::HAL::CommandBufferDispatchSymbolOp>(
- loc, commandBuffer, dispatchState.entryPointOp, workgroupCount[0],
- workgroupCount[1], workgroupCount[2]);
- }
- }
- rewriter.create<IREE::HAL::ReturnOp>(loc);
- return success();
+std::array<Value, 3> LLVMBaseTargetBackend::calculateDispatchWorkgroupCount(
+ Location loc, IREE::HAL::ExecutableOp executableOp,
+ IREE::HAL::ExecutableEntryPointOp entryPointOp, Value workload,
+ OpBuilder &builder) {
+ // For now we are not tiling and just dispatch everything as 1,1,1.
+ auto constantOne = builder.createOrFold<mlir::ConstantIndexOp>(loc, 1);
+ return {constantOne, constantOne, constantOne};
}
} // namespace HAL
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMBaseTarget.h b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMBaseTarget.h
index 4b67aa0..52d0dbd 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMBaseTarget.h
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMBaseTarget.h
@@ -35,8 +35,10 @@
LogicalResult linkExecutables(mlir::ModuleOp moduleOp) override;
- LogicalResult recordDispatch(Location loc, DispatchState dispatchState,
- DeviceSwitchRewriter &switchRewriter) override;
+ std::array<Value, 3> calculateDispatchWorkgroupCount(
+ Location loc, IREE::HAL::ExecutableOp executableOp,
+ IREE::HAL::ExecutableEntryPointOp entryPointOp, Value workload,
+ OpBuilder &builder) override;
protected:
LLVMTargetOptions options_;
diff --git a/iree/compiler/Dialect/HAL/Target/TargetBackend.h b/iree/compiler/Dialect/HAL/Target/TargetBackend.h
index 2d28774..e7b6f6e 100644
--- a/iree/compiler/Dialect/HAL/Target/TargetBackend.h
+++ b/iree/compiler/Dialect/HAL/Target/TargetBackend.h
@@ -345,8 +345,8 @@
}
protected:
- // Calculates the workgroup size (x, y, z). These are the dimension numbers
- // for a single workgroup.
+ // Calculates the workgroup size (x, y, z). Tese are the dimension numbers for
+ // a single workgroup.
virtual std::array<Value, 3> calculateDispatchWorkgroupSize(
Location loc, IREE::HAL::ExecutableOp executableOp,
IREE::HAL::ExecutableEntryPointOp entryPointOp, Value workload,
diff --git a/iree/compiler/Dialect/IREE/IR/IREEOps.td b/iree/compiler/Dialect/IREE/IR/IREEOps.td
index 1f25199..fee13f7 100644
--- a/iree/compiler/Dialect/IREE/IR/IREEOps.td
+++ b/iree/compiler/Dialect/IREE/IR/IREEOps.td
@@ -75,38 +75,6 @@
let assemblyFormat = [{ `for` $purpose attr-dict `:` type($output) }];
}
-def IREE_WorkgroupIdOp : IREE_PureOp<"workgroup_id"> {
- let summary = "Get grid index of an iree workgoup among a specific dimension.";
- let description = [{
- IREE workgroups are logically distributed among a hypergrid, each point sampled
- from the grid corresponds to a logical thread. For example in a 3d grid case
- the op quries (x, y, z) dimensions:
-
- ```mlir
- %0 = iree.workgroup_id {dimension = "x"} : index
- %1 = iree.workgroup_id {dimension = "y"} : index
- %2 = iree.workgroup_id {dimension = "z"} : index
- ```
- }];
-
- let arguments = (ins StrAttr:$dimension);
- let results = (outs Index:$result);
-
- let assemblyFormat = "attr-dict `:` type($result)";
-}
-
-def IREE_WorkgroupSizeOp: IREE_PureOp<"workgoup_size"> {
- let summary = "Get grid size of an iree thread among a specific dimension";
- let description = [{
- Get grid size of an iree workgroups among a specific dimension
- }];
- let arguments = (ins StrAttr:$dimension);
- let results = (outs Index:$result);
-
- let assemblyFormat = "attr-dict `:` type($result)";
-}
-
-
//===----------------------------------------------------------------------===//
// Compiler hints
//===----------------------------------------------------------------------===//
diff --git a/iree/hal/dylib/dylib_executable.cc b/iree/hal/dylib/dylib_executable.cc
index 4bf3c35..10dbf29 100644
--- a/iree/hal/dylib/dylib_executable.cc
+++ b/iree/hal/dylib/dylib_executable.cc
@@ -149,11 +149,10 @@
auto* dispatch_state = static_cast<DyLibDispatchState*>(state);
IREE_TRACE_SCOPE_DYNAMIC(dispatch_state->entry_name);
- auto entry_function = (void (*)(void**, int32_t*, int32_t, int32_t,
- int32_t))dispatch_state->entry_function;
+ auto entry_function =
+ (void (*)(void**, int32_t*))dispatch_state->entry_function;
entry_function(dispatch_state->args.data(),
- dispatch_state->push_constant.data(), workgroup_xyz[0],
- workgroup_xyz[1], workgroup_xyz[2]);
+ dispatch_state->push_constant.data());
return OkStatus();
}
diff --git a/iree/hal/llvmjit/llvmjit_executable.cc b/iree/hal/llvmjit/llvmjit_executable.cc
index 391a217..74f1f9d 100644
--- a/iree/hal/llvmjit/llvmjit_executable.cc
+++ b/iree/hal/llvmjit/llvmjit_executable.cc
@@ -154,10 +154,9 @@
IREE_TRACE_SCOPE0("LLVMJITExecutable::DispatchTile");
auto* dispatch_state = static_cast<LLVMJITDispatchState*>(state);
- auto func_ptr = (void (*)(void**, int32_t*, int32_t, int32_t,
- int32_t))dispatch_state->symbol.getAddress();
- func_ptr(dispatch_state->args.data(), dispatch_state->push_constant.data(),
- workgroup_xyz[0], workgroup_xyz[1], workgroup_xyz[2]);
+ auto func_ptr =
+ (void (*)(void**, int32_t*))dispatch_state->symbol.getAddress();
+ func_ptr(dispatch_state->args.data(), dispatch_state->push_constant.data());
return OkStatus();
}
diff --git a/iree/test/e2e/llvmir_specific/BUILD b/iree/test/e2e/llvmir_specific/BUILD
index a51fcd7..3ab66ea 100644
--- a/iree/test/e2e/llvmir_specific/BUILD
+++ b/iree/test/e2e/llvmir_specific/BUILD
@@ -32,13 +32,3 @@
driver = "llvm",
target_backend = "llvm-ir",
)
-
-iree_check_single_backend_test_suite(
- name = "check_llvm-ir-dot_tile_and_distribute",
- srcs = [
- "dot.mlir",
- ],
- compiler_flags = ["-iree-codegen-linalg-to-llvm-tile-and-distrobute"],
- driver = "llvm",
- target_backend = "llvm-ir",
-)
diff --git a/iree/test/e2e/llvmir_specific/CMakeLists.txt b/iree/test/e2e/llvmir_specific/CMakeLists.txt
index 010407c..88fe959 100644
--- a/iree/test/e2e/llvmir_specific/CMakeLists.txt
+++ b/iree/test/e2e/llvmir_specific/CMakeLists.txt
@@ -26,16 +26,3 @@
COMPILER_FLAGS
"-iree-codegen-linalg-to-llvm-conv-img2col-conversion"
)
-
-iree_check_single_backend_test_suite(
- NAME
- check_llvm-ir-dot_tile_and_distribute
- SRCS
- "dot.mlir"
- TARGET_BACKEND
- llvm-ir
- DRIVER
- llvm
- COMPILER_FLAGS
- "-iree-codegen-linalg-to-llvm-tile-and-distrobute"
-)
diff --git a/iree/test/e2e/llvmir_specific/dot.mlir b/iree/test/e2e/llvmir_specific/dot.mlir
deleted file mode 100644
index a492b62..0000000
--- a/iree/test/e2e/llvmir_specific/dot.mlir
+++ /dev/null
@@ -1,36 +0,0 @@
-func @dot_passthrough() attributes { iree.module.export } {
- %lhs = iree.unfoldable_constant dense<[[0.3, 0.5]]> : tensor<1x2xf32>
- %rhs = iree.unfoldable_constant dense<[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]> : tensor<2x3xf32>
- %res = "mhlo.dot"(%lhs, %rhs) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32>
- check.expect_almost_eq_const(%res, dense<[[0.23, 0.31, 0.39]]> : tensor<1x3xf32>) : tensor<1x3xf32>
- return
-}
-
-func @gemm() attributes { iree.module.export } {
- %lhs = iree.unfoldable_constant dense<[
- [15.0, 14.0, 13.0],
- [12.0, 11.0, 10.0],
- [09.0, 08.0, 07.0],
- [06.0, 05.0, 04.0],
- [03.0, 02.0, 01.0]]> : tensor<5x3xf32>
- %rhs = iree.unfoldable_constant dense<[
- [15.0, 14.0, 13.0, 12.0, 11.0],
- [10.0, 09.0, 08.0, 07.0, 06.0],
- [05.0, 04.0, 03.0, 02.0, 01.0]]> : tensor<3x5xf32>
- %res = "mhlo.dot"(%lhs, %rhs) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<5x3xf32>, tensor<3x5xf32>) -> tensor<5x5xf32>
- check.expect_almost_eq_const(%res, dense<[
- [430.0, 388.0, 346.0, 304.0, 262.0],
- [340.0, 307.0, 274.0, 241.0, 208.0],
- [250.0, 226.0, 202.0, 178.0, 154.0],
- [160.0, 145.0, 130.0, 115.0, 100.0],
- [70.0, 64.0, 58.0, 52.0, 46.0]]> : tensor<5x5xf32>) : tensor<5x5xf32>
- return
-}
-
-func @large_matmul() attributes { iree.module.export } {
- %lhs = iree.unfoldable_constant dense<1.0> : tensor<32x1024xf32>
- %rhs = iree.unfoldable_constant dense<0.4> : tensor<1024x64xf32>
- %res = "mhlo.dot"(%lhs, %rhs) : (tensor<32x1024xf32>, tensor<1024x64xf32>) -> tensor<32x64xf32>
- check.expect_almost_eq_const(%res, dense<409.596> : tensor<32x64xf32>) : tensor<32x64xf32>
- return
-}