Fix and resumbit #3585
Fix ODR issue and share common code to resubmit: #3585
diff --git a/iree/compiler/Conversion/CodegenUtils/BUILD b/iree/compiler/Conversion/CodegenUtils/BUILD
index 1c346f6..4b58a54 100644
--- a/iree/compiler/Conversion/CodegenUtils/BUILD
+++ b/iree/compiler/Conversion/CodegenUtils/BUILD
@@ -25,14 +25,22 @@
srcs = [
"ForOpCanonicalization.cpp",
"FunctionUtils.cpp",
+ "GetNumWorkgroups.cpp",
+ "MarkerUtils.cpp",
"MatmulCodegenStrategy.cpp",
],
hdrs = [
"ForOpCanonicalization.h",
"FunctionUtils.h",
+ "GetNumWorkgroups.h",
+ "MarkerUtils.h",
"MatmulCodegenStrategy.h",
],
deps = [
+ "//iree/compiler/Dialect/HAL/IR",
+ "//iree/compiler/Dialect/HAL/IR:HALDialect",
+ "//iree/compiler/Dialect/HAL/Utils",
+ "//iree/compiler/Dialect/Shape/IR",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Affine",
"@llvm-project//mlir:Analysis",
diff --git a/iree/compiler/Conversion/CodegenUtils/CMakeLists.txt b/iree/compiler/Conversion/CodegenUtils/CMakeLists.txt
index 35c2030..bf8ddd7 100644
--- a/iree/compiler/Conversion/CodegenUtils/CMakeLists.txt
+++ b/iree/compiler/Conversion/CodegenUtils/CMakeLists.txt
@@ -20,10 +20,14 @@
HDRS
"ForOpCanonicalization.h"
"FunctionUtils.h"
+ "GetNumWorkgroups.h"
+ "MarkerUtils.h"
"MatmulCodegenStrategy.h"
SRCS
"ForOpCanonicalization.cpp"
"FunctionUtils.cpp"
+ "GetNumWorkgroups.cpp"
+ "MarkerUtils.cpp"
"MatmulCodegenStrategy.cpp"
DEPS
LLVMSupport
@@ -39,5 +43,9 @@
MLIRTransforms
MLIRVector
MLIRVectorToSCF
+ iree::compiler::Dialect::HAL::IR
+ iree::compiler::Dialect::HAL::IR::HALDialect
+ iree::compiler::Dialect::HAL::Utils
+ iree::compiler::Dialect::Shape::IR
PUBLIC
)
diff --git a/iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.cpp b/iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.cpp
new file mode 100644
index 0000000..62a446d
--- /dev/null
+++ b/iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.cpp
@@ -0,0 +1,222 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "iree/compiler/Dialect/HAL/Utils/TypeUtils.h"
+#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/Support/Debug.h"
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/PatternMatch.h"
+
+#define DEBUG_TYPE "workgroup-calculation"
+
+namespace mlir {
+namespace iree_compiler {
+FuncOp getNumWorkgroupsFn(FuncOp entryPointFn,
+ llvm::StringRef numWorkgroupsFnAttr) {
+ SymbolRefAttr attr =
+ entryPointFn.getAttrOfType<SymbolRefAttr>(numWorkgroupsFnAttr);
+ if (!attr) {
+ entryPointFn.emitError("missing attribute '") << numWorkgroupsFnAttr << "'";
+ return nullptr;
+ }
+ FuncOp numWorkgroupsFn = dyn_cast_or_null<FuncOp>(SymbolTable::lookupSymbolIn(
+ entryPointFn.getParentOfType<ModuleOp>(), attr));
+ if (!numWorkgroupsFn) {
+ entryPointFn.emitError("unable to find num workgroups fn ") << attr;
+ return nullptr;
+ }
+ return numWorkgroupsFn;
+}
+
+/// Computes the bounds of the parallel loops partitioned across workgroups.
+static Optional<SmallVector<Value, 2>> getParallelLoopRange(
+ PatternRewriter &rewriter, FuncOp numWorkgroupsFn, Location loc,
+ linalg::LinalgOp linalgOp) {
+ if (!numWorkgroupsFn.empty()) {
+ numWorkgroupsFn.emitError("num workgroups fn expected to be empty");
+ return {};
+ }
+ LLVM_DEBUG({
+ llvm::dbgs() << "Found num workgroups function : "
+ << numWorkgroupsFn.getName();
+ });
+
+ rewriter.setInsertionPointToEnd(numWorkgroupsFn.addEntryBlock());
+ llvm::SetVector<Operation *> slice;
+ getBackwardSlice(linalgOp, &slice);
+ BlockAndValueMapping mapper;
+ for (Operation *op : slice) {
+ rewriter.clone(*op, mapper);
+ }
+ // Clone the linalg operation just to compute the loop bounds.
+ linalg::LinalgOp clonedLinalgOp =
+ rewriter.clone(*linalgOp.getOperation(), mapper);
+ Optional<SmallVector<Value, 4>> bounds =
+ getLoopRanges(rewriter, clonedLinalgOp);
+ unsigned numParallelLoops = linalgOp.iterator_types()
+ .getValue()
+ .take_while([](Attribute attr) -> bool {
+ return attr.cast<StringAttr>().getValue() ==
+ getParallelIteratorTypeName();
+ })
+ .size();
+ SmallVector<Value, 2> returnVals(
+ bounds->begin(), std::next(bounds->begin(), numParallelLoops));
+ rewriter.eraseOp(clonedLinalgOp);
+ return returnVals;
+}
+
+/// Utility method to build IR that computes ceil(`numerator` / `denominator`)
+static Value buildCeilDiv(PatternRewriter &rewriter, Location loc,
+ Value numerator, Value denominator) {
+ Value one = rewriter.create<ConstantIndexOp>(loc, 1);
+ Value t = rewriter.create<AddIOp>(
+ loc, numerator, rewriter.create<SubIOp>(loc, denominator, one));
+ return rewriter.create<SignedDivIOp>(loc, t, denominator);
+}
+
+/// Utility method to build IR that computes ceil(`numerator` / `denominator`)
+/// when denominator is a constant.
+static Value buildCeilDivConstDenominator(PatternRewriter &rewriter,
+ Location loc, Value numerator,
+ int64_t denominator) {
+ return buildCeilDiv(rewriter, loc, numerator,
+ rewriter.create<ConstantIndexOp>(loc, denominator));
+}
+
+LogicalResult createNumWorkgroupsFromResultShape(
+ PatternRewriter &rewriter, linalg::LinalgOp linalgOp, FuncOp entryPointFn,
+ llvm::StringRef numWorkgroupsFnAttr, ArrayRef<int64_t> tileSizes) {
+ FuncOp numWorkgroupsFn = getNumWorkgroupsFn(
+ linalgOp.getParentOfType<FuncOp>(), numWorkgroupsFnAttr);
+ if (!numWorkgroupsFn) return failure();
+
+ Location loc = linalgOp.getLoc();
+ OpBuilder::InsertionGuard gaurd(rewriter);
+ Optional<SmallVector<Value, 2>> parallelLoopRange =
+ getParallelLoopRange(rewriter, numWorkgroupsFn, loc, linalgOp);
+ if (!parallelLoopRange) return failure();
+ Value one = rewriter.create<ConstantIndexOp>(loc, 1);
+ SmallVector<Value, 3> returnValues(3, one);
+ for (size_t i = 0, e = std::min<size_t>(parallelLoopRange->size(), 3); i != e;
+ ++i) {
+ if (tileSizes[e - i - 1] != 0) {
+ returnValues[i] = buildCeilDivConstDenominator(
+ rewriter, loc, (*parallelLoopRange)[e - i - 1], tileSizes[e - i - 1]);
+ }
+ }
+ rewriter.create<mlir::ReturnOp>(loc, returnValues);
+ return success();
+}
+
+LogicalResult createNumWorkgroupsFromLinearizedResultShape(
+ PatternRewriter &rewriter, linalg::LinalgOp linalgOp, FuncOp entryPointFn,
+ llvm::StringRef numWorkgroupsFnAttr, int64_t workgroupSizeX) {
+ FuncOp numWorkgroupsFn = getNumWorkgroupsFn(
+ linalgOp.getParentOfType<FuncOp>(), numWorkgroupsFnAttr);
+ if (!numWorkgroupsFn) return failure();
+ if (!numWorkgroupsFn.empty()) {
+ // TODO(ravishankarm): We can end up with multiple linalg operations
+ // (typically linalg.generic operations) that have the same workload in a
+ // dispatch region. In that case, the first linalg.generic creates the body
+ // of number of workgroups. For now, just returning if the body is not empty
+ // assuming that it is correct for all the ops in the dispatch region. This
+ // needs to be enforced somehow.
+ return success();
+ }
+
+ Location loc = linalgOp.getLoc();
+ OpBuilder::InsertionGuard gaurd(rewriter);
+ Optional<SmallVector<Value, 2>> parallelLoopRange =
+ getParallelLoopRange(rewriter, numWorkgroupsFn, loc, linalgOp);
+ if (!parallelLoopRange) return failure();
+ Value one = rewriter.create<ConstantIndexOp>(loc, 1);
+ SmallVector<Value, 3> returnValues(3, one);
+ for (auto range : *parallelLoopRange) {
+ returnValues[0] = rewriter.create<MulIOp>(loc, range, returnValues[0]);
+ }
+ returnValues[0] = buildCeilDivConstDenominator(rewriter, loc, returnValues[0],
+ workgroupSizeX);
+ rewriter.create<mlir::ReturnOp>(loc, returnValues);
+ return success();
+}
+
+/// The codegeneration emits a function `numWorkgroupsFn` for each entry point
+/// function. This function has arguments the !shapex.ranked_shape for all the
+/// input and output shaped types. Using this the function returns the number of
+/// workgroups to use. To use this function on the host side, generate the
+/// !shapex.ranked_shape values that describe the shape of the inputs and
+/// outputs of the dispatch region and "inline" the function body.
+std::array<Value, 3> calculateWorkgroupCountFromNumWorkgroupsFn(
+ Location loc, FuncOp numWorkgroupsFn, IREE::HAL::InterfaceOp interface,
+ ArrayRef<Optional<IREE::HAL::TensorRewriteAdaptor>> operands,
+ ArrayRef<Optional<IREE::HAL::TensorRewriteAdaptor>> results,
+ ConversionPatternRewriter &rewriter) {
+ std::array<Value, 3> returnValue = {nullptr, nullptr, nullptr};
+ // TODO: This is really just inlining a function. For now assume that the
+ // `numWorkgroupsFn` has a single block to make inlining easier.
+ if (!numWorkgroupsFn || !llvm::hasSingleElement(numWorkgroupsFn))
+ return returnValue;
+ SmallVector<SmallVector<Value, 4>, 4> shapeValues;
+ shapeValues.reserve(operands.size() + results.size());
+ auto getShapeValuesFn =
+ [&](ArrayRef<Optional<IREE::HAL::TensorRewriteAdaptor>> values)
+ -> LogicalResult {
+ for (auto val : values) {
+ if (!val) continue;
+ Optional<SmallVector<Value, 4>> shape = val->getShapeDims(rewriter);
+ if (!shape) return emitError(loc, "shape computation for operand failed");
+ shapeValues.push_back(shape.getValue());
+ }
+ return success();
+ };
+ if (failed(getShapeValuesFn(operands)) || failed(getShapeValuesFn(results)))
+ return returnValue;
+ BlockAndValueMapping mapper;
+ for (Operation &op : numWorkgroupsFn.front()) {
+ if (isa<mlir::ReturnOp>(op)) {
+ for (unsigned i = 0, e = std::min<unsigned>(3, op.getNumOperands());
+ i != e; ++i) {
+ returnValue[i] = mapper.lookupOrNull(op.getOperand(i));
+ }
+ break;
+ }
+ if (auto shapeOp = dyn_cast<Shape::RankedDimOp>(op)) {
+ if (BlockArgument arg = shapeOp.shape().dyn_cast<BlockArgument>()) {
+ auto &dimValues = shapeValues[arg.getArgNumber()];
+ mapper.map(shapeOp.result(), dimValues[shapeOp.getIndex()]);
+ continue;
+ }
+ return returnValue;
+ }
+ // If all its operands are mapped, clone it.
+ if (llvm::all_of(op.getOperands(), [&mapper](Value operand) {
+ return mapper.contains(operand);
+ })) {
+ rewriter.clone(op, mapper);
+ continue;
+ }
+ }
+ return returnValue;
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.h b/iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.h
new file mode 100644
index 0000000..0ce710f
--- /dev/null
+++ b/iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.h
@@ -0,0 +1,97 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef MLIR_EDGE_BENCHMARKS_STRATEGIES_WORKGROUPCALULCATION_H_
+#define MLIR_EDGE_BENCHMARKS_STRATEGIES_WORKGROUPCALULCATION_H_
+
+#include <cstdint>
+
+namespace llvm {
+class StringRef;
+template <typename T>
+class ArrayRef;
+template <typename T>
+class Optional;
+} // namespace llvm
+
+namespace mlir {
+class Location;
+class FuncOp;
+class LogicalResult;
+class PatternRewriter;
+class ConversionPatternRewriter;
+class Value;
+namespace linalg {
+class LinalgOp;
+} // namespace linalg
+
+namespace iree_compiler {
+namespace IREE {
+namespace HAL {
+class InterfaceOp;
+class TensorRewriteAdaptor;
+} // namespace HAL
+} // namespace IREE
+} // namespace iree_compiler
+namespace iree_compiler {
+
+/// Generates a function that computes the number of workgroups as
+/// [ceil(`parallelLoopRange`[2] / `tileSizes`[2]),
+/// ceil(`parallelLoopRange`[1] / `tileSizes`[1]),
+/// ceil(`parallelLoopRange`[0] / `tileSizes`[0])]
+/// where `parallelLoopRange` is the ranges of the parallel loops of `linalgOp`
+/// distributed across workgroups.
+LogicalResult createNumWorkgroupsFromResultShape(
+ PatternRewriter &rewriter, linalg::LinalgOp linalgOp, FuncOp entryPointFn,
+ llvm::StringRef numWorkgroupsFnAttr, llvm::ArrayRef<int64_t> tileSizes);
+
+/// Generates a function that computes the number of workgroups as
+/// ceil(`parallelLoopRange`[0] * `parallelLoopRange`[1] * ... *
+/// `parallelLoopRange`[n-1] / `workgroupSizeX`)
+/// where `parallelLoopRange` is the ranges of the parallel loops of `linalgOp`
+/// distributed across workgroups.
+LogicalResult createNumWorkgroupsFromLinearizedResultShape(
+ PatternRewriter &rewriter, linalg::LinalgOp linalgOp, FuncOp entryPointFn,
+ llvm::StringRef numWorkgroupsFnAttr, int64_t workgroupSizeX);
+
+/// For a given `entryPointFn` return the function that computes the number of
+/// workgroups to use at launch time.
+FuncOp getNumWorkgroupsFn(FuncOp entryPointFn,
+ llvm::StringRef numWorkgroupsFnAttr);
+
+LogicalResult createNumWorkgroupsFromLinearizedResultShape(
+ PatternRewriter &rewriter, linalg::LinalgOp linalgOp, FuncOp entryPointFn,
+ llvm::StringRef numWorkgroupsFnAttr, int64_t workgroupSizeX);
+
+/// The codegeneration emits a function `numWorkgroupsFn` for each entry point
+/// function. This function has arguments the !shapex.ranked_shape for all the
+/// input and output shaped types. Using this the function returns the number of
+/// workgroups to use. To use this function on the host side, generate the
+/// !shapex.ranked_shape values that describe the shape of the inputs and
+/// outputs of the dispatch region and "inline" the function body.
+std::array<Value, 3> calculateWorkgroupCountFromNumWorkgroupsFn(
+ Location loc, FuncOp numWorkgroupsFn,
+ mlir::iree_compiler::IREE::HAL::InterfaceOp interface,
+ llvm::ArrayRef<
+ llvm::Optional<mlir::iree_compiler::IREE::HAL::TensorRewriteAdaptor>>
+ operands,
+ llvm::ArrayRef<
+ llvm::Optional<mlir::iree_compiler::IREE::HAL::TensorRewriteAdaptor>>
+ results,
+ ConversionPatternRewriter &rewriter);
+
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // MLIR_EDGE_BENCHMARKS_STRATEGIES_WORKGROUPCALULCATION_H_
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.cpp b/iree/compiler/Conversion/CodegenUtils/MarkerUtils.cpp
similarity index 96%
rename from iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.cpp
rename to iree/compiler/Conversion/CodegenUtils/MarkerUtils.cpp
index dff5292..20f1616 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.cpp
+++ b/iree/compiler/Conversion/CodegenUtils/MarkerUtils.cpp
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h"
+#include "iree/compiler/Conversion/CodegenUtils/MarkerUtils.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/IR/Attributes.h"
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h b/iree/compiler/Conversion/CodegenUtils/MarkerUtils.h
similarity index 81%
rename from iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h
rename to iree/compiler/Conversion/CodegenUtils/MarkerUtils.h
index 78d4304..a839807 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h
+++ b/iree/compiler/Conversion/CodegenUtils/MarkerUtils.h
@@ -19,8 +19,8 @@
//
//===----------------------------------------------------------------------===//
-#ifndef IREE_COMPILER_CONVERSION_LINALGTOSPIRV_MARKERUTILS_H_
-#define IREE_COMPILER_CONVERSION_LINALGTOSPIRV_MARKERUTILS_H_
+#ifndef IREE_COMPILER_CONVERSION_CODEGENUTILS_MARKERUTILS_H_
+#define IREE_COMPILER_CONVERSION_CODEGENUTILS_MARKERUTILS_H_
#include "llvm/ADT/ArrayRef.h"
#include "mlir/Support/LLVM.h"
@@ -30,11 +30,12 @@
class Operation;
namespace iree_compiler {
-/// Marker to denote that a linalg operation has been partitioned to workgroups.
+/// Marker to denote that a linalg operation has been partitioned to
+/// workgroups.
StringRef getWorkgroupMarker();
-/// Marker to denote that a linalg operation has been partitioned to workgroups
-/// and operands promoted to scratchspace memory.
+/// Marker to denote that a linalg operation has been partitioned to
+/// workgroups and operands promoted to scratchspace memory.
StringRef getWorkgroupMemoryMarker();
/// Marker for copy operations that are moving data from StorageClass to
@@ -44,10 +45,10 @@
/// Marker for operations that are going to be vectorized.
StringRef getVectorizeMarker();
-/// Marker for tagging an operation for deletion. Tile and fuse pattern does not
-/// delete the original operation to not invalidate the
-/// `linalg::LinalgDependenceGraph` data structure. Instead it is marked with a
-/// marker that can be used later to delete these operations.
+/// Marker for tagging an operation for deletion. Tile and fuse pattern does
+/// not delete the original operation to not invalidate the
+/// `linalg::LinalgDependenceGraph` data structure. Instead it is marked with
+/// a marker that can be used later to delete these operations.
StringRef getDeleteMarker();
/// Returns true if an operation has the specified `marker`. When `marker` is
@@ -60,4 +61,4 @@
} // namespace iree_compiler
} // namespace mlir
-#endif // IREE_COMPILER_CONVERSION_LINALGTOSPIRV_MARKERUTILS_H_
+#endif // IREE_COMPILER_CONVERSION_CODEGENUTILS_MARKERUTILS_H_
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Attributes.h b/iree/compiler/Conversion/Common/Attributes.h
similarity index 94%
rename from iree/compiler/Conversion/LinalgToSPIRV/Attributes.h
rename to iree/compiler/Conversion/Common/Attributes.h
index 84d6393..c0d1012 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/Attributes.h
+++ b/iree/compiler/Conversion/Common/Attributes.h
@@ -23,13 +23,13 @@
/// Attribute on a module op to denote the scheduling order of entry points.
/// The attribute value is expected to be an array of entry point name strings.
inline llvm::StringRef getEntryPointScheduleAttrName() {
- return "vkspv.entry_point_schedule";
+ return "hal.entry_point_schedule";
}
/// Attribute on a entry point function that specifies which function computes
/// the number of workgroups.
inline llvm::StringRef getNumWorkgroupsFnAttrName() {
- return "vkspv.num_workgroups_fn";
+ return "hal.num_workgroups_fn";
}
} // namespace iree_compiler
diff --git a/iree/compiler/Conversion/Common/BUILD b/iree/compiler/Conversion/Common/BUILD
new file mode 100644
index 0000000..b8c5b5a
--- /dev/null
+++ b/iree/compiler/Conversion/Common/BUILD
@@ -0,0 +1,43 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "Common",
+ srcs = [
+ "DeclareNumWorkgroupsFnPass.cpp",
+ "LegalizeNumWorkgroupsFnPass.cpp",
+ ],
+ hdrs = [
+ "Attributes.h",
+ "Passes.h",
+ ],
+ deps = [
+ "//iree/compiler/Conversion/CodegenUtils",
+ "//iree/compiler/Dialect/HAL/IR",
+ "//iree/compiler/Dialect/IREE/IR",
+ "//iree/compiler/Dialect/Shape/IR",
+ "@llvm-project//mlir:CFGTransforms",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:StandardOps",
+ "@llvm-project//mlir:Transforms",
+ "@org_tensorflow//tensorflow/compiler/mlir/hlo",
+ ],
+)
diff --git a/iree/compiler/Conversion/Common/CMakeLists.txt b/iree/compiler/Conversion/Common/CMakeLists.txt
new file mode 100644
index 0000000..b233805
--- /dev/null
+++ b/iree/compiler/Conversion/Common/CMakeLists.txt
@@ -0,0 +1,38 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+iree_add_all_subdirs()
+
+iree_cc_library(
+ NAME
+ Common
+ HDRS
+ "Attributes.h"
+ "Passes.h"
+ SRCS
+ "DeclareNumWorkgroupsFnPass.cpp"
+ "LegalizeNumWorkgroupsFnPass.cpp"
+ DEPS
+ MLIRIR
+ MLIRPass
+ MLIRSCFToStandard
+ MLIRStandard
+ MLIRTransforms
+ iree::compiler::Conversion::CodegenUtils
+ iree::compiler::Dialect::HAL::IR
+ iree::compiler::Dialect::IREE::IR
+ iree::compiler::Dialect::Shape::IR
+ tensorflow::mlir_hlo
+ PUBLIC
+)
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/DeclareNumWorkgroupsFnPass.cpp b/iree/compiler/Conversion/Common/DeclareNumWorkgroupsFnPass.cpp
similarity index 98%
rename from iree/compiler/Conversion/LinalgToSPIRV/DeclareNumWorkgroupsFnPass.cpp
rename to iree/compiler/Conversion/Common/DeclareNumWorkgroupsFnPass.cpp
index 0cc1637..0f46788 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/DeclareNumWorkgroupsFnPass.cpp
+++ b/iree/compiler/Conversion/Common/DeclareNumWorkgroupsFnPass.cpp
@@ -20,7 +20,7 @@
//
//===----------------------------------------------------------------------===//
#include "iree/compiler/Conversion/CodegenUtils/FunctionUtils.h"
-#include "iree/compiler/Conversion/LinalgToSPIRV/Attributes.h"
+#include "iree/compiler/Conversion/Common/Attributes.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/LegalizeNumWorkgroupsFnPass.cpp b/iree/compiler/Conversion/Common/LegalizeNumWorkgroupsFnPass.cpp
similarity index 98%
rename from iree/compiler/Conversion/LinalgToSPIRV/LegalizeNumWorkgroupsFnPass.cpp
rename to iree/compiler/Conversion/Common/LegalizeNumWorkgroupsFnPass.cpp
index f118b78..ada7efe 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/LegalizeNumWorkgroupsFnPass.cpp
+++ b/iree/compiler/Conversion/Common/LegalizeNumWorkgroupsFnPass.cpp
@@ -20,7 +20,7 @@
//
//===----------------------------------------------------------------------===//
#include "iree/compiler/Conversion/CodegenUtils/FunctionUtils.h"
-#include "iree/compiler/Conversion/LinalgToSPIRV/Attributes.h"
+#include "iree/compiler/Conversion/Common/Attributes.h"
#include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "mlir/IR/Module.h"
diff --git a/iree/compiler/Conversion/Common/Passes.h b/iree/compiler/Conversion/Common/Passes.h
new file mode 100644
index 0000000..00a61e0
--- /dev/null
+++ b/iree/compiler/Conversion/Common/Passes.h
@@ -0,0 +1,27 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+namespace mlir {
+namespace iree_compiler {
+
+/// Pass to legalize function that returns number of workgroups to use for
+/// launch to be runnable on the host.
+std::unique_ptr<OperationPass<ModuleOp>> createLegalizeNumWorkgroupsFnPass();
+
+/// Pass to initialize the function that computes the number of workgroups for
+/// each entry point function. The function is defined, but is populated later.
+std::unique_ptr<OperationPass<ModuleOp>> createDeclareNumWorkgroupsFnPass();
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToLLVM/BUILD b/iree/compiler/Conversion/LinalgToLLVM/BUILD
index 9261ede..e7eb509 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/BUILD
+++ b/iree/compiler/Conversion/LinalgToLLVM/BUILD
@@ -23,14 +23,18 @@
srcs = [
"ConvImg2ColMatmulConversion.cpp",
"ConvertToLLVM.cpp",
+ "KernelDispatch.cpp",
+ "LinalgTileAndDistributePass.cpp",
"MatMulVectorization.cpp",
"Passes.cpp",
],
hdrs = [
+ "KernelDispatch.h",
"Passes.h",
],
deps = [
"//iree/compiler/Conversion/CodegenUtils",
+ "//iree/compiler/Conversion/Common",
"//iree/compiler/Conversion/HLOToHLO",
"//iree/compiler/Conversion/HLOToLinalg",
"//iree/compiler/Dialect/HAL/IR",
@@ -43,6 +47,7 @@
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LLVMDialect",
"@llvm-project//mlir:LLVMTransforms",
+ "@llvm-project//mlir:LinalgOps",
"@llvm-project//mlir:LinalgToLLVM",
"@llvm-project//mlir:LinalgTransforms",
"@llvm-project//mlir:Pass",
diff --git a/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt b/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
index f7b79e0..a98bb34 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
+++ b/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
@@ -18,16 +18,20 @@
NAME
LinalgToLLVM
HDRS
+ "KernelDispatch.h"
"Passes.h"
SRCS
"ConvImg2ColMatmulConversion.cpp"
"ConvertToLLVM.cpp"
+ "KernelDispatch.cpp"
+ "LinalgTileAndDistributePass.cpp"
"MatMulVectorization.cpp"
"Passes.cpp"
DEPS
MLIRAffineToStandard
MLIRIR
MLIRLLVMIR
+ MLIRLinalg
MLIRLinalgToLLVM
MLIRLinalgTransforms
MLIRPass
@@ -40,6 +44,7 @@
MLIRVectorToLLVM
MLIRVectorToSCF
iree::compiler::Conversion::CodegenUtils
+ iree::compiler::Conversion::Common
iree::compiler::Conversion::HLOToHLO
iree::compiler::Conversion::HLOToLinalg
iree::compiler::Dialect::HAL::IR
diff --git a/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp b/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp
index 16c399c..85ce689 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp
+++ b/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp
@@ -12,10 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include "iree/compiler/Conversion/CodegenUtils/FunctionUtils.h"
#include "iree/compiler/Conversion/LinalgToLLVM/Passes.h"
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "iree/compiler/Dialect/IREE/IR/IREEDialect.h"
#include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
+#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
@@ -138,10 +141,11 @@
}
// Change signature of entry function to func
-// entry_func(%packed_buffers_arg_ptr:
-// !<llvm.int8**>, %push_constant: !<llvm.int64*>) and lower IREE and HAL ops to
+// clang-format off
+// entry_func(%packed_buffers_arg_ptr: !<llvm.int8**>, thread_idx: !llvm.i32, thread_idy: !llvm.i32, thread_idz: !llvm.i32) and lower IREE and HAL ops to
// corresponding LLVMIR ops to construct memref descriptors and load
// push_constant values.
+// clang-format on
class ConvertFuncWithHALInterface : public ConvertToLLVMPattern {
public:
explicit ConvertFuncWithHALInterface(MLIRContext *context,
@@ -164,6 +168,7 @@
// Get interface buffers from all the blocks.
SmallVector<IREE::PlaceholderOp, 8> bufferOps;
SmallVector<IREE::HAL::InterfaceLoadConstantOp, 8> loadOps;
+ SmallVector<IREE::WorkgroupIdOp, 3> workgroupIdOps;
for (Block &block : funcOp.getBlocks()) {
for (Operation &op : block) {
if (auto phOp = dyn_cast<IREE::PlaceholderOp>(op))
@@ -171,6 +176,9 @@
if (auto phOp = dyn_cast<IREE::HAL::InterfaceLoadConstantOp>(op)) {
loadOps.push_back(phOp);
}
+ if (auto threadIdOp = dyn_cast<IREE::WorkgroupIdOp>(op)) {
+ workgroupIdOps.push_back(threadIdOp);
+ }
}
}
@@ -218,22 +226,39 @@
}
TypeConverter::SignatureConversion signatureConverter(/*numOrigInputs=*/0);
-
- // func foo(%packed_buffer_args: !llvm<i8**>, %push_constant: !llvm<i32*>)
+ // clang-format off
+ // func foo(%packed_buffer_args: !llvm<i8**>, %push_constant: !llvm<i32*>, thread_idx: i32, thread_idy, thread_idz: i32)
+ // clang-format on
MLIRContext *context = rewriter.getContext();
auto packedBuffersArgsTy =
LLVM::LLVMType::getInt8PtrTy(context).getPointerTo();
auto pushConstantArgTy = LLVM::LLVMType::getInt32Ty(context).getPointerTo();
+ auto threadIdXTy = LLVM::LLVMType::getInt32Ty(context);
+ auto threadIdYTy = LLVM::LLVMType::getInt32Ty(context);
+ auto threadIdZTy = LLVM::LLVMType::getInt32Ty(context);
signatureConverter.addInputs(packedBuffersArgsTy);
signatureConverter.addInputs(pushConstantArgTy);
+ signatureConverter.addInputs(threadIdXTy);
+ signatureConverter.addInputs(threadIdYTy);
+ signatureConverter.addInputs(threadIdZTy);
- // Create the new function's signature.
Location loc = funcOp.getLoc();
+
+ // Construct newFunc with all attributes except return type & symbol name.
+ SmallVector<NamedAttribute, 4> funcAttrs;
+ for (auto attr : funcOp.getAttrs()) {
+ if (attr.first == SymbolTable::getSymbolAttrName() ||
+ attr.first == mlir::impl::getTypeAttrName()) {
+ continue;
+ }
+ funcAttrs.push_back(attr);
+ }
+
auto newFuncOp = rewriter.create<FuncOp>(
loc, funcOp.getName(),
rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
llvm::None),
- ArrayRef<NamedAttribute>());
+ funcAttrs);
// Move all ops in the old function's region to the new function.
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
@@ -285,6 +310,26 @@
rewriter.replaceOp(loadOp, dimConstantCasted);
}
+ // Lower iree.workgroup_idd ops to get indices from function arugments.
+ for (auto workgroupCoordOp : workgroupIdOps) {
+ auto attr = workgroupCoordOp.getAttrOfType<StringAttr>("dimension");
+ int argIndex = -1;
+ if (attr.getValue().str() == "x") {
+ argIndex = 2;
+ } else if (attr.getValue().str() == "y") {
+ argIndex = 3;
+ } else if (attr.getValue().str() == "z") {
+ argIndex = 4;
+ } else {
+ return rewriter.notifyMatchFailure(
+ funcOp,
+ "Unable to map to workgroup coordinate : " + attr.getValue().str());
+ }
+ Value threadXIndex = builder.create<LLVM::ZExtOp>(
+ loc, typeConverter.convertType(workgroupCoordOp.getType()),
+ newFuncOp.getArgument(argIndex));
+ rewriter.replaceOp(workgroupCoordOp, threadXIndex);
+ }
rewriter.eraseOp(funcOp);
return success();
}
@@ -353,9 +398,21 @@
RemoveInterfaceOpPattern>(&getContext(), converter);
LLVMConversionTarget target(getContext());
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
- target.addIllegalOp<IREE::PlaceholderOp>();
+
+ // Pass through workspace count calculation. This isn't going to be translated
+ // to LLVM.
+ // TODO(ataei): Should be handled somewhere else ?
+ target.addDynamicallyLegalDialect<ShapeDialect, StandardOpsDialect,
+ IREEDialect>(
+ Optional<ConversionTarget::DynamicLegalityCallbackFn>([](Operation *op) {
+ auto funcOp = dyn_cast<FuncOp>(op->getParentOp());
+ if (funcOp && !isEntryPoint(funcOp)) return true;
+ return false;
+ }));
+
target.addDynamicallyLegalOp<FuncOp>([](FuncOp funcOp) {
bool any = false;
+ if (!isEntryPoint(funcOp)) return true;
funcOp.walk([&](IREE::PlaceholderOp placeholderOp) { any = true; });
return any ? false : true;
});
diff --git a/iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.cpp b/iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.cpp
new file mode 100644
index 0000000..c44d388
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.cpp
@@ -0,0 +1,37 @@
+
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.h"
+
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/IR/Operation.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+llvm::SmallVector<int64_t, 4> getTileSizesImpl(linalg::MatmulOp op) {
+ return {128, 128};
+}
+
+llvm::SmallVector<int64_t, 4> CPUKernelDispatch::getTileSizes(
+ Operation* op) const {
+ if (isa<linalg::MatmulOp>(op)) {
+ return getTileSizesImpl(dyn_cast<linalg::MatmulOp>(op));
+ }
+ return {1, 1, 1};
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.h b/iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.h
new file mode 100644
index 0000000..09d6c41
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.h
@@ -0,0 +1,30 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <cstdint>
+
+#include "llvm/ADT/SmallVector.h"
+
+namespace mlir {
+class Operation;
+
+namespace iree_compiler {
+
+class CPUKernelDispatch {
+ public:
+ llvm::SmallVector<int64_t, 4> getTileSizes(Operation* op) const;
+};
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToLLVM/LinalgTileAndDistributePass.cpp b/iree/compiler/Conversion/LinalgToLLVM/LinalgTileAndDistributePass.cpp
new file mode 100644
index 0000000..2cc4a71
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToLLVM/LinalgTileAndDistributePass.cpp
@@ -0,0 +1,221 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/compiler/Conversion/CodegenUtils/FunctionUtils.h"
+#include "iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.h"
+#include "iree/compiler/Conversion/CodegenUtils/MarkerUtils.h"
+#include "iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.h"
+#include "iree/compiler/Conversion/Common/Attributes.h"
+#include "iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.h"
+#include "iree/compiler/Dialect/IREE/IR/IREEDialect.h"
+#include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "iree-linalg-to-llvm-tile-and-distribute"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+struct LinalgTileAndDistributePass
+ : public PassWrapper<LinalgTileAndDistributePass, OperationPass<ModuleOp>> {
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<linalg::LinalgDialect, IREEDialect, AffineDialect,
+ scf::SCFDialect>();
+ }
+ LinalgTileAndDistributePass() = default;
+ LinalgTileAndDistributePass(const LinalgTileAndDistributePass &pass) {}
+ void runOnOperation() override;
+
+ private:
+ ListOption<int64_t> tileSizes{
+ *this, "tile-sizes", llvm::cl::desc("Set tile sizes to use"),
+ llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
+};
+} // namespace
+
+namespace {
+template <typename LinalgOpTy>
+struct TileToCPUThreads : public linalg::LinalgBaseTilingPattern {
+ using Base = linalg::LinalgBaseTilingPattern;
+ TileToCPUThreads(MLIRContext *context,
+ const linalg::LinalgDependenceGraph &dependenceGraph,
+ const CPUKernelDispatch &cpuKernelDispatch,
+ linalg::LinalgTilingOptions options,
+ linalg::LinalgMarker marker, PatternBenefit benefit = 1)
+ : Base(LinalgOpTy::getOperationName(), context, options, marker, benefit),
+ cpuKernelDispatch(cpuKernelDispatch) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ // Find the parent FuncOp before tiling. If tiling succeeds, the op will be
+ // erased.
+ FuncOp funcOp = op->getParentOfType<FuncOp>();
+ SmallVector<Value, 4> tensorResults;
+ if (!funcOp ||
+ failed(Base::matchAndRewriteBase(op, rewriter, tensorResults)) ||
+ !tensorResults.empty() ||
+ (funcOp.getAttr(getNumWorkgroupsFnAttrName()) &&
+ failed(createNumWorkgroupsFromResultShape(
+ rewriter, cast<linalg::LinalgOp>(op), funcOp,
+ getNumWorkgroupsFnAttrName(),
+ cpuKernelDispatch.getTileSizes(op))))) {
+ return failure();
+ }
+ rewriter.eraseOp(op);
+ return success();
+ }
+ CPUKernelDispatch cpuKernelDispatch;
+};
+
+template <typename LinalgOpTy>
+struct TileAndFuseToCPUThreads
+ : public linalg::LinalgTileAndFusePattern<LinalgOpTy> {
+ using Base = linalg::LinalgTileAndFusePattern<LinalgOpTy>;
+ TileAndFuseToCPUThreads(MLIRContext *context,
+ const linalg::LinalgDependenceGraph &dependenceGraph,
+ const CPUKernelDispatch &cpuKernelDispatch,
+ linalg::LinalgTilingOptions tilingOptions,
+ linalg::LinalgMarker marker,
+ PatternBenefit benefit = 1)
+ : Base(context, dependenceGraph, tilingOptions,
+ linalg::LinalgFusionOptions().setIndicesToFuse({2}), marker,
+ marker,
+ linalg::LinalgMarker(ArrayRef<Identifier>(),
+ Identifier::get(getDeleteMarker(), context)),
+ benefit),
+ dependenceGraph(dependenceGraph),
+ cpuKernelDispatch(cpuKernelDispatch) {}
+
+ virtual LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const {
+ FuncOp funcOp = op->getParentOfType<FuncOp>();
+ linalg::LinalgOp linalgOp = cast<linalg::LinalgOp>(op);
+ if (!funcOp || !dependenceGraph.hasDependentOperations(linalgOp) ||
+ failed(Base::matchAndRewrite(op, rewriter)) ||
+ failed(createNumWorkgroupsFromResultShape(
+ rewriter, cast<linalg::LinalgOp>(op), funcOp,
+ getNumWorkgroupsFnAttrName().str(),
+ cpuKernelDispatch.getTileSizes(op)))) {
+ return failure();
+ }
+ return success();
+ }
+
+ const linalg::LinalgDependenceGraph &dependenceGraph;
+ CPUKernelDispatch cpuKernelDispatch;
+};
+
+} // namespace
+
+void LinalgTileAndDistributePass::runOnOperation() {
+ MLIRContext *context = &getContext();
+ ModuleOp module = getOperation();
+
+ static linalg::LinalgLoopDistributionOptions workgroupDistributionOptions = {
+ [](OpBuilder &builder, Location loc, ArrayRef<Range> parallelLoopRanges) {
+ Type indexType = builder.getIndexType();
+ auto numParallelDims = parallelLoopRanges.size();
+ SmallVector<linalg::ProcInfo, 2> procInfo(numParallelDims);
+ for (int dim = 0; dim < numParallelDims; ++dim) {
+ std::array<StringRef, 3> dimAttr{"x", "y", "z"};
+ StringAttr attr =
+ builder.getStringAttr(dimAttr[std::min<unsigned>(dim, 3)]);
+ procInfo[numParallelDims - dim - 1] = {
+ builder.create<IREE::WorkgroupIdOp>(loc, indexType, attr),
+ builder.create<IREE::WorkgroupSizeOp>(loc, indexType, attr)};
+ }
+ return procInfo;
+ },
+ {linalg::DistributionMethod::CyclicNumProcsEqNumIters,
+ linalg::DistributionMethod::CyclicNumProcsEqNumIters,
+ linalg::DistributionMethod::CyclicNumProcsEqNumIters}};
+
+ CPUKernelDispatch cpuKernelDispatch;
+
+ // Function to compute first level tiling values.
+ std::function<SmallVector<Value, 4>(OpBuilder &, Operation *)>
+ getOuterTileSizeFn =
+ [&cpuKernelDispatch](OpBuilder &builder,
+ Operation *operation) -> SmallVector<Value, 4> {
+ auto tileSizes = cpuKernelDispatch.getTileSizes(operation);
+ if (tileSizes.empty()) return {};
+ SmallVector<Value, 4> tileSizesVal;
+ tileSizesVal.reserve(tileSizes.size());
+ for (auto val : tileSizes) {
+ tileSizesVal.push_back(
+ builder.create<ConstantIndexOp>(operation->getLoc(), val));
+ }
+ return tileSizesVal;
+ };
+
+ for (FuncOp funcOp : module.getOps<FuncOp>()) {
+ if (!isEntryPoint(funcOp)) continue;
+
+ // Compute the Linalg Dependence Graph.
+ linalg::Aliases aliases;
+ linalg::LinalgDependenceGraph dependenceGraph =
+ linalg::LinalgDependenceGraph::buildDependenceGraph(aliases, funcOp);
+
+ OwningRewritePatternList patterns;
+
+ auto linalgTilingOptions =
+ linalg::LinalgTilingOptions()
+ .setDistributionOptions(workgroupDistributionOptions)
+ .setLoopType(linalg::LinalgTilingLoopType::ParallelLoops);
+ tileSizes.empty()
+ ? linalgTilingOptions.setTileSizeComputationFunction(getOuterTileSizeFn)
+ : linalgTilingOptions.setTileSizes(ArrayRef<int64_t>(tileSizes));
+ patterns.insert<TileAndFuseToCPUThreads<linalg::MatmulOp>,
+ TileToCPUThreads<linalg::MatmulOp>>(
+ context, dependenceGraph, cpuKernelDispatch, linalgTilingOptions,
+ linalg::LinalgMarker(ArrayRef<Identifier>(),
+ Identifier::get(getWorkgroupMarker(), context)));
+
+ // Tile and distribute to CPU threads.
+ applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+
+ // Apply canonicalization patterns.
+ OwningRewritePatternList canonicalizationPatterns;
+ canonicalizationPatterns.insert<AffineMinCanonicalizationPattern>(context);
+ AffineApplyOp::getCanonicalizationPatterns(canonicalizationPatterns,
+ context);
+ AffineMinOp::getCanonicalizationPatterns(canonicalizationPatterns, context);
+ SubViewOp::getCanonicalizationPatterns(canonicalizationPatterns, context);
+
+ applyPatternsAndFoldGreedily(funcOp, std::move(canonicalizationPatterns));
+
+ // Delete the ops that are marked for deletion.
+ funcOp.walk([](linalg::LinalgOp linalgOp) {
+ if (hasMarker(linalgOp.getOperation(), getDeleteMarker()))
+ linalgOp.getOperation()->erase();
+ });
+ }
+}
+
+std::unique_ptr<OperationPass<ModuleOp>> createLinalgTileAndDistributePass() {
+ return std::make_unique<LinalgTileAndDistributePass>();
+}
+
+static PassRegistration<LinalgTileAndDistributePass> pass(
+ "iree-codegen-llvm-linalg-tile-and-distribute",
+ "Tile and distribute Linalg operations on buffers",
+ [] { return std::make_unique<LinalgTileAndDistributePass>(); });
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp b/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp
index 498f4a6..96ad2a9 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp
+++ b/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp
@@ -14,6 +14,8 @@
#include "iree/compiler/Conversion/HLOToLinalg/Passes.h"
+#include "iree/compiler/Conversion/Common/Attributes.h"
+#include "iree/compiler/Conversion/Common/Passes.h"
#include "iree/compiler/Conversion/HLOToHLO/Passes.h"
#include "iree/compiler/Conversion/LinalgToLLVM/Passes.h"
#include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
@@ -32,7 +34,18 @@
"linag.matmul"),
llvm::cl::init(false));
+static llvm::cl::opt<bool> llvmLinalgTileAndDistributePass(
+ "iree-codegen-linalg-to-llvm-tile-and-distrobute",
+ llvm::cl::desc("Tile and distribute linalg ops among iree threads"),
+ llvm::cl::init(false));
+
void addLinalgToLLVMPasses(OpPassManager &passManager) {
+ // Distribute linalg op among a 3d grid of parallel threads.
+ if (llvmLinalgTileAndDistributePass) {
+ passManager.addPass(createLinalgTileAndDistributePass());
+ passManager.addPass(createLegalizeNumWorkgroupsFnPass());
+ }
+
// Linalg.ConvOp -> (Img2Col packing + matmul)
if (convImg2ColConversion) {
passManager.addPass(createConvImg2ColMatmulConversionPass());
@@ -57,6 +70,8 @@
}
void buildLLVMTransformPassPipeline(OpPassManager &passManager) {
+ passManager.addPass(createDeclareNumWorkgroupsFnPass());
+
passManager.addPass(createInlinerPass());
// Propagates dynamic shapes computation on tensors.
diff --git a/iree/compiler/Conversion/LinalgToLLVM/Passes.h b/iree/compiler/Conversion/LinalgToLLVM/Passes.h
index c79c885..60cedbe 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/Passes.h
+++ b/iree/compiler/Conversion/LinalgToLLVM/Passes.h
@@ -20,6 +20,15 @@
namespace mlir {
namespace iree_compiler {
+// Options that can be used to configure LLVMIR codegeneration.
+struct LLVMIRCodegenOptions {
+ SmallVector<int64_t, 3> workgroupSize = {};
+ SmallVector<int64_t, 3> tileSizes = {};
+ bool useWorkgroupMemory = false;
+ bool useVectorization = false;
+ bool useVectorPass = false;
+};
+
/// Converts linalg::MatmulOp into LLVM dialect
std::unique_ptr<FunctionPass> createMatMulTileAndVectorizePass();
@@ -27,6 +36,8 @@
/// linalg::MatmulOp.
std::unique_ptr<FunctionPass> createConvImg2ColMatmulConversionPass();
+std::unique_ptr<FunctionPass> createLinalgTileAndDistributePass();
+
/// Populates patterns to rewrite linalg::ConvOp into packed img2col operation
/// followed by linalg::MatmulOp.
void populateConvImg2ColMatmulConversionPatterns(
diff --git a/iree/compiler/Conversion/LinalgToLLVM/test/convert_to_llvm.mlir b/iree/compiler/Conversion/LinalgToLLVM/test/convert_to_llvm.mlir
index 0624289..b6ef3e3 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/test/convert_to_llvm.mlir
+++ b/iree/compiler/Conversion/LinalgToLLVM/test/convert_to_llvm.mlir
@@ -14,7 +14,7 @@
hal.interface @legacy_io attributes {push_constants = 2 : i32, sym_visibility = "private"} {
hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
}
-// CHECK: llvm.func @convert_dynamic_shape(%[[ARG0:.+]]: !llvm.ptr<ptr<i8>>, %[[ARG1:.+]]: !llvm.ptr<i32>)
+// CHECK: llvm.func @convert_dynamic_shape(%[[ARG0:.+]]: !llvm.ptr<ptr<i8>>, %[[ARG1:.+]]: !llvm.ptr<i32>, %[[THREAD_X_ID:.+]]: !llvm.i32, %[[THREAD_Y_ID:.+]]: !llvm.i32, %[[THREAD_Z_ID:.+]]: !llvm.i32)
// CHECK: %[[PACKED_ARGS_PTR:.+]] = llvm.bitcast %[[ARG0]] : !llvm.ptr<ptr<i8>> to !llvm.ptr<struct<(ptr<float>)>>
// CHECK: %[[PACKED_ARGS:.+]] = llvm.load %[[PACKED_ARGS_PTR]] : !llvm.ptr<struct<(ptr<float>)>>
// CHECK: %[[MEMREF0_DATA_PTR:.+]] = llvm.extractvalue %[[PACKED_ARGS]][0] : !llvm.struct<(ptr<float>)>
@@ -38,6 +38,8 @@
// CHECK: %[[STRIDE_DIM0:.+]] = llvm.mul %[[STRIDE_DIM1]], %[[DIM1_0]] : !llvm.i64
// CHECK: llvm.insertvalue %[[STRIDE_DIM0]], %[[MEMREF0_4]][4, 0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
+// -----
+
// CHECK_LABEL: @convert_dynamic_shape2
func @convert_dynamic_shape2() -> f32 {
%c0 = constant 0 : index
@@ -51,8 +53,7 @@
hal.interface @legacy_io2 attributes {push_constants = 1 : i32, sym_visibility = "private"} {
hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
}
-
-// CHECK: llvm.func @convert_dynamic_shape2(%[[ARG0:.+]]: !llvm.ptr<ptr<i8>>, %[[ARG1:.+]]: !llvm.ptr<i32>)
+// CHECK: llvm.func @convert_dynamic_shape2(%[[ARG0:.+]]: !llvm.ptr<ptr<i8>>, %[[ARG1:.+]]: !llvm.ptr<i32>, %[[THREAD_X_ID:.+]]: !llvm.i32, %[[THREAD_Y_ID:.+]]: !llvm.i32, %[[THREAD_Z_ID:.+]]: !llvm.i32)
// CHECK: %[[PACKED_ARGS_PTR:.+]] = llvm.bitcast %[[ARG0]] : !llvm.ptr<ptr<i8>> to !llvm.ptr<struct<(ptr<float>)>>
// CHECK: %[[PACKED_ARGS:.+]] = llvm.load %[[PACKED_ARGS_PTR]] : !llvm.ptr<struct<(ptr<float>)>>
// CHECK: %[[MEMREF0_DATA_PTR:.+]] = llvm.extractvalue %[[PACKED_ARGS]][0] : !llvm.struct<(ptr<float>)>
@@ -84,3 +85,17 @@
// CHECK: %[[GET_PTR:.+]] = llvm.getelementptr %[[EXTRACT1:.+]][%[[ADD2:.+]]] : (!llvm.ptr<float>, !llvm.i64) -> !llvm.ptr<float>
// CHECK: %[[LOAD:.+]] = llvm.load %[[GET_PTR:.+]] : !llvm.ptr<float>
+// -----
+
+// CHECK_LABEL: @distribute_lookup
+func @distribute_lookup() -> f32 {
+ %0 = iree.placeholder for "interface buffer" {binding = @legacy_io3::@arg0} : memref<2x2x2xf32>
+ %1 = iree.workgroup_id {dimension = "x"} : index
+ %2 = iree.workgroup_id {dimension = "y"} : index
+ %3 = iree.workgroup_id {dimension = "z"} : index
+ %4 = load %0[%1, %2, %3] : memref<2x2x2xf32>
+ return %4 : f32
+}
+hal.interface @legacy_io3 attributes {push_constants = 1 : i32, sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+}
diff --git a/iree/compiler/Conversion/LinalgToLLVM/test/tile_and_distribute.mlir b/iree/compiler/Conversion/LinalgToLLVM/test/tile_and_distribute.mlir
new file mode 100644
index 0000000..baffba1
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToLLVM/test/tile_and_distribute.mlir
@@ -0,0 +1,56 @@
+// RUN: iree-opt --iree-codegen-llvm-linalg-tile-and-distribute=tile-sizes=2,4,1 -cse -split-input-file %s | IreeFileCheck %s
+
+func @dynamic_matmul(%lhs: memref<?x?xf32>, %rhs: memref<?x?xf32>, %result: memref<?x?xf32>) {
+ linalg.matmul ins(%lhs, %rhs : memref<?x?xf32>, memref<?x?xf32>) outs(%result : memref<?x?xf32>)
+ return
+}
+// CHECK: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 2)>
+// CHECK: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (2, s1 - s0 * 2)>
+// CHECK: #[[MAP2:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
+// CHECK: #[[MAP3:.+]] = affine_map<()[s0] -> (s0 * 4)>
+// CHECK: #[[MAP4:.+]] = affine_map<()[s0, s1] -> (4, s1 - s0 * 4)>
+// CHECK: func @dynamic_matmul(%[[LHS:.+]]: memref<?x?xf32>, %[[RHS:.+]]: memref<?x?xf32>, %[[RESULT:.+]]: memref<?x?xf32>)
+// CHECK: %[[CONST_0:.+]] = constant 0 : index
+// CHECK: %[[CONST_1:.+]] = constant 1 : index
+// CHECK: %[[DIM_K:.+]] = dim %[[LHS]], %[[CONST_1]]
+// CHECK: %[[THREAD_X_ID:.+]] = iree.workgroup_id {dimension = "x"} : index
+// CHECK: %[[THREAD_Y_ID:.+]] = iree.workgroup_id {dimension = "y"} : index
+// CHECK: scf.for %[[K:.+]] = %[[CONST_0]] to %[[DIM_K]]
+// CHECK: %[[I:.+]] = affine.apply #[[MAP0]]()[%[[THREAD_Y_ID]]]
+// CHECK: %[[DIM_I:.+]] = dim %[[LHS]], %[[CONST_0]]
+// CHECK: %[[I_OFFSET:.+]] = affine.min #[[MAP1]]()[%[[THREAD_Y_ID]], %[[DIM_I]]]
+// CHECK: %[[LHS_SUBVIEW:.+]] = subview %[[LHS]][%[[I]], %[[K]]] [%[[I_OFFSET]], 1] [1, 1]
+// CHECK: %[[J:.+]] = affine.apply #[[MAP3]]()[%[[THREAD_X_ID]]]
+// CHECK: %[[DIM_J:.+]] = dim %[[RHS]], %[[CONST_1]]
+// CHECK: %[[J_OFFSET:.+]] = affine.min #[[MAP4]]()[%[[THREAD_X_ID]], %[[DIM_J]]]
+// CHECK: %[[RHS_SUBVIEW:.+]] = subview %[[RHS]][%[[K]], %[[J]]] [1, %[[J_OFFSET]]] [1, 1]
+// CHECK: %[[DIM_I:.+]] = dim %[[RESULT]], %[[CONST_0]]
+// CHECK: %[[DIM_I_OFFSET:.+]] = affine.min #[[MAP1]]()[%[[THREAD_Y_ID]], %[[DIM_I]]]
+// CHECK: %[[DIM_J:.+]] = dim %[[RESULT]], %[[CONST_1]]
+// CHECK: %[[DIM_J_OFFSET:.+]] = affine.min #[[MAP4]]()[%[[THREAD_X_ID]], %[[DIM_J]]]
+// CHECK: %[[RESULT_SUBVIEW:.+]] = subview %[[RESULT]][%[[I]], %[[J]]] [%[[DIM_I_OFFSET]], %[[DIM_J_OFFSET]]] [1, 1]
+// CHECK: linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%[[LHS_SUBVIEW]], %[[RHS_SUBVIEW]] : memref<?x1xf32, #[[MAP2]]>, memref<1x?xf32, #[[MAP2]]>) outs(%[[RESULT_SUBVIEW]] : memref<?x?xf32, #[[MAP2]]>)
+
+// -----
+
+func @static_matmul(%lhs: memref<16x4xf32>, %rhs: memref<4x8xf32>, %result: memref<16x8xf32>) {
+ linalg.matmul ins(%lhs, %rhs : memref<16x4xf32>, memref<4x8xf32>) outs(%result : memref<16x8xf32>)
+ return
+}
+// CHECK: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 2)>
+// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1)[s0] -> (d0 * 4 + s0 + d1)>
+// CHECK: #[[MAP2:.+]] = affine_map<()[s0] -> (s0 * 4)>
+// CHECK: #[[MAP3:.+]] = affine_map<(d0, d1)[s0] -> (d0 * 8 + s0 + d1)>
+// CHECK: func @static_matmul(%[[LHS:.+]]: memref<16x4xf32>, %[[RHS:.+]]: memref<4x8xf32>, %[[RESULT:.+]]: memref<16x8xf32>)
+// CHECK: %[[CONST_0:.+]] = constant 0 : index
+// CHECK: %[[CONST_4:.+]] = constant 4 : index
+// CHECK: %[[CONST_1:.+]] = constant 1 : index
+// CHECK: %[[THREAD_X_ID:.+]] = iree.workgroup_id {dimension = "x"} : index
+// CHECK: %[[THREAD_Y_ID:.+]] = iree.workgroup_id {dimension = "y"} : index
+// CHECK: scf.for %[[K:.+]] = %[[CONST_0]] to %[[CONST_4]] step %[[CONST_1]]
+// CHECK: %[[I:.+]] = affine.apply #[[MAP0]]()[%[[THREAD_Y_ID]]]
+// CHECK: %[[LHS_SUBVIEW:.+]] = subview %[[LHS]][%[[I]], %[[K]]] [2, 1] [1, 1] : memref<16x4xf32> to memref<2x1xf32, #[[MAP1]]>
+// CHECK: %[[J:.+]] = affine.apply #[[MAP2]]()[%[[THREAD_X_ID]]]
+// CHECK: %[[RHS_SUBVIEW:.+]] = subview %[[RHS]][%[[K]], %[[J]]] [1, 4] [1, 1] : memref<4x8xf32> to memref<1x4xf32, #[[MAP3]]>
+// CHECK: %[[RESULT_SUBVIEW:.+]] = subview %[[RESULT]][%[[I]], %[[J]]] [2, 4] [1, 1] : memref<16x8xf32> to memref<2x4xf32, #[[MAP3]]>
+// CHECK: linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%[[LHS_SUBVIEW]], %[[RHS_SUBVIEW]] : memref<2x1xf32, #[[MAP1]]>, memref<1x4xf32, #[[MAP3]]>) outs(%6 : memref<2x4xf32, #[[MAP3]]>)
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/BUILD b/iree/compiler/Conversion/LinalgToSPIRV/BUILD
index f5fe4d9..6e3ec22 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/BUILD
+++ b/iree/compiler/Conversion/LinalgToSPIRV/BUILD
@@ -24,11 +24,8 @@
"ConvertToGPUPass.cpp",
"ConvertToSPIRVPass.cpp",
"CooperativeMatrixAnalysis.cpp",
- "DeclareNumWorkgroupsFnPass.cpp",
"KernelDispatchUtils.cpp",
- "LegalizeNumWorkgroupsFnPass.cpp",
"LinalgTileAndFusePass.cpp",
- "MarkerUtils.cpp",
"MatMulVectorizationTest.cpp",
"Passes.cpp",
"SplitDispatchFunctionPass.cpp",
@@ -37,16 +34,15 @@
"VectorizeMemref.cpp",
],
hdrs = [
- "Attributes.h",
"CooperativeMatrixAnalysis.h",
"KernelDispatchUtils.h",
- "MarkerUtils.h",
"MemorySpace.h",
"Passes.h",
"Utils.h",
],
deps = [
"//iree/compiler/Conversion/CodegenUtils",
+ "//iree/compiler/Conversion/Common",
"//iree/compiler/Conversion/HLOToHLO",
"//iree/compiler/Conversion/HLOToLinalg",
"//iree/compiler/Conversion/LinalgToVector",
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt b/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt
index 79578dd..be9e0c5 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt
+++ b/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt
@@ -18,10 +18,8 @@
NAME
LinalgToSPIRV
HDRS
- "Attributes.h"
"CooperativeMatrixAnalysis.h"
"KernelDispatchUtils.h"
- "MarkerUtils.h"
"MemorySpace.h"
"Passes.h"
"Utils.h"
@@ -29,11 +27,8 @@
"ConvertToGPUPass.cpp"
"ConvertToSPIRVPass.cpp"
"CooperativeMatrixAnalysis.cpp"
- "DeclareNumWorkgroupsFnPass.cpp"
"KernelDispatchUtils.cpp"
- "LegalizeNumWorkgroupsFnPass.cpp"
"LinalgTileAndFusePass.cpp"
- "MarkerUtils.cpp"
"MatMulVectorizationTest.cpp"
"Passes.cpp"
"SplitDispatchFunctionPass.cpp"
@@ -64,6 +59,7 @@
MLIRVector
MLIRVectorToSPIRV
iree::compiler::Conversion::CodegenUtils
+ iree::compiler::Conversion::Common
iree::compiler::Conversion::HLOToHLO
iree::compiler::Conversion::HLOToLinalg
iree::compiler::Conversion::LinalgToVector
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp
index b5af0c5..d281d79 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp
@@ -22,9 +22,10 @@
#include <numeric>
#include "iree/compiler/Conversion/CodegenUtils/FunctionUtils.h"
-#include "iree/compiler/Conversion/LinalgToSPIRV/Attributes.h"
+#include "iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.h"
+#include "iree/compiler/Conversion/CodegenUtils/MarkerUtils.h"
+#include "iree/compiler/Conversion/Common/Attributes.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h"
-#include "iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/MemorySpace.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/Utils.h"
@@ -653,7 +654,7 @@
(funcOp.getAttr(getNumWorkgroupsFnAttrName()) &&
failed(createNumWorkgroupsFromLinearizedResultShape(
rewriter, cast<linalg::LinalgOp>(linalgOp.getOperation()), funcOp,
- workgroupSize[0])))) {
+ getNumWorkgroupsFnAttrName(), workgroupSize[0])))) {
return failure();
}
rewriter.eraseOp(linalgOp);
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp
index 4bb9d8d..40ffa42 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp
@@ -21,8 +21,8 @@
//
//===----------------------------------------------------------------------===//
+#include "iree/compiler/Conversion/CodegenUtils/MarkerUtils.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/CooperativeMatrixAnalysis.h"
-#include "iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
#include "llvm/ADT/STLExtras.h"
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
index 7673cf9..58106b0 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
@@ -23,7 +23,7 @@
#include "iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h"
#include "iree/compiler/Conversion/CodegenUtils/FunctionUtils.h"
-#include "iree/compiler/Conversion/LinalgToSPIRV/Attributes.h"
+#include "iree/compiler/Conversion/Common/Attributes.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h"
#include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
@@ -46,145 +46,6 @@
namespace mlir {
namespace iree_compiler {
-//===----------------------------------------------------------------------===//
-// Number of workgroups computation
-//===----------------------------------------------------------------------===//
-
-FuncOp getNumWorkgroupsFn(FuncOp entryPointFn) {
- SymbolRefAttr attr =
- entryPointFn.getAttrOfType<SymbolRefAttr>(getNumWorkgroupsFnAttrName());
- if (!attr) {
- entryPointFn.emitError("missing attribute '")
- << getNumWorkgroupsFnAttrName() << "'";
- return nullptr;
- }
- FuncOp numWorkgroupsFn = dyn_cast_or_null<FuncOp>(SymbolTable::lookupSymbolIn(
- entryPointFn.getParentOfType<ModuleOp>(), attr));
- if (!numWorkgroupsFn) {
- entryPointFn.emitError("unable to find num workgroups fn ") << attr;
- return nullptr;
- }
- return numWorkgroupsFn;
-}
-
-/// Computes the bounds of the parallel loops partitioned across workgroups.
-static Optional<SmallVector<Value, 2>> getParallelLoopRange(
- PatternRewriter &rewriter, FuncOp numWorkgroupsFn, Location loc,
- linalg::LinalgOp linalgOp) {
- if (!numWorkgroupsFn.empty()) {
- numWorkgroupsFn.emitError("num workgroups fn expected to be empty");
- return {};
- }
- LLVM_DEBUG({
- llvm::dbgs() << "Found num workgroups function : "
- << numWorkgroupsFn.getName();
- });
-
- rewriter.setInsertionPointToEnd(numWorkgroupsFn.addEntryBlock());
- llvm::SetVector<Operation *> slice;
- getBackwardSlice(linalgOp, &slice);
- BlockAndValueMapping mapper;
- for (Operation *op : slice) {
- rewriter.clone(*op, mapper);
- }
- // Clone the linalg operation just to compute the loop bounds.
- linalg::LinalgOp clonedLinalgOp =
- rewriter.clone(*linalgOp.getOperation(), mapper);
- Optional<SmallVector<Value, 4>> bounds =
- getLoopRanges(rewriter, clonedLinalgOp);
- unsigned numParallelLoops = linalgOp.iterator_types()
- .getValue()
- .take_while([](Attribute attr) -> bool {
- return attr.cast<StringAttr>().getValue() ==
- getParallelIteratorTypeName();
- })
- .size();
- SmallVector<Value, 2> returnVals(
- bounds->begin(), std::next(bounds->begin(), numParallelLoops));
- rewriter.eraseOp(clonedLinalgOp);
- return returnVals;
-}
-
-/// Utility method to build IR that computes ceil(`numerator` / `denominator`)
-static Value buildCeilDiv(PatternRewriter &rewriter, Location loc,
- Value numerator, Value denominator) {
- Value one = rewriter.create<ConstantIndexOp>(loc, 1);
- Value t = rewriter.create<AddIOp>(
- loc, numerator, rewriter.create<SubIOp>(loc, denominator, one));
- return rewriter.create<SignedDivIOp>(loc, t, denominator);
-}
-
-/// Utility method to build IR that computes ceil(`numerator` / `denominator`)
-/// when denominator is a constant.
-static Value buildCeilDivConstDenominator(PatternRewriter &rewriter,
- Location loc, Value numerator,
- int64_t denominator) {
- return buildCeilDiv(rewriter, loc, numerator,
- rewriter.create<ConstantIndexOp>(loc, denominator));
-}
-
-LogicalResult createNumWorkgroupsFromResultShape(PatternRewriter &rewriter,
- linalg::LinalgOp linalgOp,
- FuncOp entryPointFn,
- ArrayRef<int64_t> tileSizes) {
- FuncOp numWorkgroupsFn =
- getNumWorkgroupsFn(linalgOp.getParentOfType<FuncOp>());
- if (!numWorkgroupsFn) return failure();
-
- Location loc = linalgOp.getLoc();
- OpBuilder::InsertionGuard gaurd(rewriter);
- Optional<SmallVector<Value, 2>> parallelLoopRange =
- getParallelLoopRange(rewriter, numWorkgroupsFn, loc, linalgOp);
- if (!parallelLoopRange) return failure();
- Value one = rewriter.create<ConstantIndexOp>(loc, 1);
- SmallVector<Value, 3> returnValues(3, one);
- for (size_t i = 0, e = std::min<size_t>(parallelLoopRange->size(), 3); i != e;
- ++i) {
- if (tileSizes[e - i - 1] != 0) {
- returnValues[i] = buildCeilDivConstDenominator(
- rewriter, loc, (*parallelLoopRange)[e - i - 1], tileSizes[e - i - 1]);
- }
- }
- rewriter.create<mlir::ReturnOp>(loc, returnValues);
- return success();
-}
-
-LogicalResult createNumWorkgroupsFromLinearizedResultShape(
- PatternRewriter &rewriter, linalg::LinalgOp linalgOp, FuncOp entryPointFn,
- int64_t workgroupSizeX) {
- FuncOp numWorkgroupsFn =
- getNumWorkgroupsFn(linalgOp.getParentOfType<FuncOp>());
- if (!numWorkgroupsFn) return failure();
- if (!numWorkgroupsFn.empty()) {
- // TODO(ravishankarm): We can end up with multiple linalg operations
- // (typically linalg.generic operations) that have the same workload in a
- // dispatch region. In that case, the first linalg.generic creates the body
- // of number of workgroups. For now, just returning if the body is not empty
- // assuming that it is correct for all the ops in the dispatch region. This
- // needs to be enforced somehow.
- return success();
- }
-
- Location loc = linalgOp.getLoc();
- OpBuilder::InsertionGuard gaurd(rewriter);
- Optional<SmallVector<Value, 2>> parallelLoopRange =
- getParallelLoopRange(rewriter, numWorkgroupsFn, loc, linalgOp);
- if (!parallelLoopRange) return failure();
- Value one = rewriter.create<ConstantIndexOp>(loc, 1);
- SmallVector<Value, 3> returnValues(3, one);
- for (auto range : *parallelLoopRange) {
- returnValues[0] = rewriter.create<MulIOp>(loc, range, returnValues[0]);
- }
- returnValues[0] = buildCeilDivConstDenominator(rewriter, loc, returnValues[0],
- workgroupSizeX);
- rewriter.create<mlir::ReturnOp>(loc, returnValues);
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// Launch config calculation.
-//===----------------------------------------------------------------------===//
-
/// Name of the StrAttr that can be used to get the key to access the tile size
/// information.
static const char kLaunchInfoKey[] = "launch_info_key";
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h
index 918cd6b..8bf7afc 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h
+++ b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h
@@ -47,30 +47,6 @@
namespace iree_compiler {
-/// Generates a function that computes the number of workgroups as
-/// [ceil(`parallelLoopRange`[2] / `tileSizes`[2]),
-/// ceil(`parallelLoopRange`[1] / `tileSizes`[1]),
-/// ceil(`parallelLoopRange`[0] / `tileSizes`[0])]
-/// where `parallelLoopRange` is the ranges of the parallel loops of `linalgOp`
-/// distributed across workgroups.
-LogicalResult createNumWorkgroupsFromResultShape(PatternRewriter &rewriter,
- linalg::LinalgOp linalgOp,
- FuncOp entryPointFn,
- ArrayRef<int64_t> tileSizes);
-
-/// Generates a function that computes the number of workgroups as
-/// ceil(`parallelLoopRange`[0] * `parallelLoopRange`[1] * ... *
-/// `parallelLoopRange`[n-1] / `workgroupSizeX`)
-/// where `parallelLoopRange` is the ranges of the parallel loops of `linalgOp`
-/// distributed across workgroups.
-LogicalResult createNumWorkgroupsFromLinearizedResultShape(
- PatternRewriter &rewriter, linalg::LinalgOp linalgOp, FuncOp entryPointFn,
- int64_t workgroupSizeX);
-
-/// For a given `entryPointFn` return the function that computes the number of
-/// workgroups to use at launch time.
-FuncOp getNumWorkgroupsFn(FuncOp entryPointFn);
-
/// Store the tile sizes to use at different levels of tiling as a vector of
/// vectors.
/// - First level tiling maps to workgroups.
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
index 8ee8657..a78343d 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
@@ -19,10 +19,11 @@
//===----------------------------------------------------------------------===//
#include "iree/compiler/Conversion/CodegenUtils/FunctionUtils.h"
+#include "iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.h"
+#include "iree/compiler/Conversion/CodegenUtils/MarkerUtils.h"
#include "iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.h"
-#include "iree/compiler/Conversion/LinalgToSPIRV/Attributes.h"
+#include "iree/compiler/Conversion/Common/Attributes.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h"
-#include "iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/MemorySpace.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/Utils.h"
@@ -173,7 +174,8 @@
failed(updateWorkGroupSize(funcOp, launchConfig.getWorkgroupSize())) ||
(funcOp.getAttr(getNumWorkgroupsFnAttrName()) &&
failed(createNumWorkgroupsFromResultShape(
- rewriter, linalgOp, funcOp, launchConfig.getTileSizes(op, 0))))) {
+ rewriter, linalgOp, funcOp, getNumWorkgroupsFnAttrName(),
+ launchConfig.getTileSizes(op, 0))))) {
return failure();
}
setMarker(op, getDeleteMarker());
@@ -215,7 +217,8 @@
failed(updateWorkGroupSize(funcOp, launchConfig.getWorkgroupSize())) ||
(funcOp.getAttr(getNumWorkgroupsFnAttrName()) &&
failed(createNumWorkgroupsFromResultShape(
- rewriter, linalgOp, funcOp, launchConfig.getTileSizes(op, 0))))) {
+ rewriter, linalgOp, funcOp, getNumWorkgroupsFnAttrName(),
+ launchConfig.getTileSizes(op, 0))))) {
return failure();
}
return success();
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp b/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
index d413df5..b09f909 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
@@ -20,6 +20,7 @@
#include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h"
+#include "iree/compiler/Conversion/Common/Passes.h"
#include "iree/compiler/Conversion/HLOToHLO/Passes.h"
#include "iree/compiler/Conversion/HLOToLinalg/Passes.h"
#include "iree/compiler/Conversion/LinalgToVector/Passes.h"
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Passes.h b/iree/compiler/Conversion/LinalgToSPIRV/Passes.h
index 30067d7..6a661af 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/Passes.h
+++ b/iree/compiler/Conversion/LinalgToSPIRV/Passes.h
@@ -32,10 +32,6 @@
bool useVectorizeMemrefPass = false;
};
-/// Pass to initialize the function that computes the number of workgroups for
-/// each entry point function. The function is defined, but is populated later.
-std::unique_ptr<OperationPass<ModuleOp>> createDeclareNumWorkgroupsFnPass();
-
/// Pass to tile and fuse linalg operations on buffers. The pass takes as
/// argument the `workgroupSize` that the tiling should use. Note that the
/// tile-sizes are the reverse of the workgroup size. So workgroup size along
@@ -49,10 +45,6 @@
/// to GPU dialect.
std::unique_ptr<OperationPass<ModuleOp>> createConvertToGPUPass();
-/// Pass to legalize function that returns number of workgroups to use for
-/// launch to be runnable on the host.
-std::unique_ptr<OperationPass<ModuleOp>> createLegalizeNumWorkgroupsFnPass();
-
/// Pass to perform the final conversion to SPIR-V dialect.
/// This pass converts remaining interface ops into SPIR-V global variables,
/// GPU processor ID ops into SPIR-V global variables, loop/standard ops into
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp
index 7e8d1aa..09ceb74 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp
@@ -27,7 +27,8 @@
#include <iterator>
#include "iree/compiler/Conversion/CodegenUtils/FunctionUtils.h"
-#include "iree/compiler/Conversion/LinalgToSPIRV/Attributes.h"
+#include "iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.h"
+#include "iree/compiler/Conversion/Common/Attributes.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h"
#include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
@@ -257,7 +258,8 @@
newFn.setAttr(namedAttr.first, namedAttr.second);
}
// Need special handling for the number of workgroups function.
- if (FuncOp numWorkgroupsFn = getNumWorkgroupsFn(oldFn)) {
+ if (FuncOp numWorkgroupsFn =
+ getNumWorkgroupsFn(oldFn, getNumWorkgroupsFnAttrName())) {
FuncOp newNumWorkgroupsFn =
builder.create<FuncOp>(loc, newFnName.str() + "__num_workgroups__",
numWorkgroupsFn.getType());
@@ -296,7 +298,8 @@
moduleOp.setAttr(getEntryPointScheduleAttrName(),
builder.getArrayAttr(entryPoints));
- if (FuncOp numWorkgroupsFn = getNumWorkgroupsFn(oldFn)) {
+ if (FuncOp numWorkgroupsFn =
+ getNumWorkgroupsFn(oldFn, getNumWorkgroupsFnAttrName())) {
LLVM_DEBUG({
llvm::dbgs() << "Erased num workgroups fn func @"
<< numWorkgroupsFn.getName() << " for func @"
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Utils.cpp b/iree/compiler/Conversion/LinalgToSPIRV/Utils.cpp
index 2c14351..ddbf478 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/Utils.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/Utils.cpp
@@ -20,7 +20,7 @@
#include "iree/compiler/Conversion/LinalgToSPIRV/Utils.h"
-#include "iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h"
+#include "iree/compiler/Conversion/CodegenUtils/MarkerUtils.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/MemorySpace.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp
index a48460a..bf032b7 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp
@@ -21,9 +21,9 @@
#include <memory>
#include "iree/compiler/Conversion/CodegenUtils/FunctionUtils.h"
+#include "iree/compiler/Conversion/CodegenUtils/MarkerUtils.h"
#include "iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/CooperativeMatrixAnalysis.h"
-#include "iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/FormatVariadic.h"
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir
index edc26a5..c4419d8 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir
@@ -74,7 +74,7 @@
[Shader], [SPV_KHR_storage_buffer_storage_class]>,
{max_compute_workgroup_invocations = 128 : i32,
max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
- func @parallel_4D_static() attributes {vkspv.num_workgroups_fn = @parallel_4D_static__num_workgroups__} {
+ func @parallel_4D_static() attributes {hal.num_workgroups_fn = @parallel_4D_static__num_workgroups__} {
%arg0 = iree.placeholder for "interace buffer"
{binding = @legacy_io::@arg0, operand_result_index = 0 : i32} : memref<3x4x5x6xf32>
%arg1 = iree.placeholder for "interace buffer"
@@ -103,8 +103,8 @@
}
}
// CHECK-LABEL: func @parallel_4D_static()
+// CHECK-SAME: hal.num_workgroups_fn = @[[NUM_WORKGROUPS_FN:[a-zA-Z0-9_]+]]
// CHECK-SAME: local_size = dense<[32, 1, 1]>
-// CHECK-SAME: vkspv.num_workgroups_fn = @[[NUM_WORKGROUPS_FN:[a-zA-Z0-9_]+]]
// CHECK-DAG: %[[C360:.+]] = constant 360 : index
// CHECK-DAG: %[[C120:.+]] = constant 120 : index
// CHECK-DAG: %[[C30:.+]] = constant 30 : index
@@ -145,7 +145,7 @@
[Shader], [SPV_KHR_storage_buffer_storage_class]>,
{max_compute_workgroup_invocations = 128 : i32,
max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
- func @scalar_add() attributes {vkspv.num_workgroups_fn = @scalar_add__num_workgroups__} {
+ func @scalar_add() attributes {hal.num_workgroups_fn = @scalar_add__num_workgroups__} {
%arg0 = iree.placeholder for "interace buffer"
{binding = @legacy_io::@arg0, operand_result_index = 0 : i32} : memref<f32>
%arg1 = iree.placeholder for "interace buffer"
@@ -172,7 +172,7 @@
}
}
// CHECK-LABEL: func @scalar_add()
-// CHECK-SAME: vkspv.num_workgroups_fn = @[[NUM_WORKGROUPS_FN:[a-zA-Z0-9_]+]]
+// CHECK-SAME: hal.num_workgroups_fn = @[[NUM_WORKGROUPS_FN:[a-zA-Z0-9_]+]]
// CHECK: load
// CHECK-NEXT: load
// CHECK-NEXT: addf
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir
index 580d32a..1f27ed9 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir
@@ -41,7 +41,7 @@
{max_compute_workgroup_invocations = 128 : i32,
max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
func @conv_no_padding()
- attributes {vkspv.num_workgroups_fn = @conv_no_padding__num_workgroups__} {
+ attributes {hal.num_workgroups_fn = @conv_no_padding__num_workgroups__} {
%0 = iree.placeholder for "interace buffer"
{binding = @legacy_io::@arg0, operand_result_index = 0 : i32} : memref<?x?x?x?xf32>
%1 = iree.placeholder for "interace buffer"
@@ -65,8 +65,8 @@
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 4)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 32)>
// CHECK: func @conv_no_padding()
+// CHECK-SAME: hal.num_workgroups_fn = @[[NUM_WORKGROUPS_FN:[a-zA-Z0-9_]+]]
// CHECK-SAME: local_size = dense<[32, 4, 1]>
-// CHECK-SAME: vkspv.num_workgroups_fn = @[[NUM_WORKGROUPS_FN:[a-zA-Z0-9_]+]]
// CHECK-DAG: %[[ARG0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg0
// CHECK-DAG: %[[ARG1:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg1
// CHECK-DAG: %[[RET0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@ret0
@@ -113,7 +113,7 @@
[Shader], [SPV_KHR_storage_buffer_storage_class]>,
{max_compute_workgroup_invocations = 128 : i32,
max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
- func @matmul() attributes {vkspv.num_workgroups_fn = @matmul__num_workgroups__} {
+ func @matmul() attributes {hal.num_workgroups_fn = @matmul__num_workgroups__} {
%0 = iree.placeholder for "interace buffer"
{binding = @legacy_io::@arg0, operand_result_index = 0 : i32} : memref<?x?xf32>
%1 = iree.placeholder for "interace buffer"
@@ -138,8 +138,8 @@
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 8)>
// CHECK-DAG: #[[MAP3:.+]] = affine_map<()[s0] -> (s0 * 16)>
// CHECK: func @matmul()
+// CHECK-SAME: hal.num_workgroups_fn = @[[NUM_WORKGROUPS_FN:[a-zA-Z0-9_]+]]
// CHECK-SAME: local_size = dense<[16, 8, 1]>
-// CHECK-SAME: vkspv.num_workgroups_fn = @[[NUM_WORKGROUPS_FN:[a-zA-Z0-9_]+]]
// CHECK-DAG: %[[ARG0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg0
// CHECK-DAG: %[[ARG1:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg1
// CHECK-DAG: %[[RET0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@ret0
@@ -182,7 +182,7 @@
{max_compute_workgroup_invocations = 128 : i32,
max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
func @pooling_sum_no_padding()
- attributes {vkspv.num_workgroups_fn = @pooling_sum_no_padding__num_workgroups__} {
+ attributes {hal.num_workgroups_fn = @pooling_sum_no_padding__num_workgroups__} {
%0 = iree.placeholder for "interace buffer"
{binding = @legacy_io::@arg0, operand_result_index = 0 : i32} : memref<?x?xf32>
%1 = iree.placeholder for "interace buffer"
@@ -206,8 +206,8 @@
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 4)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 32)>
// CHECK: func @pooling_sum_no_padding()
+// CHECK-SAME: hal.num_workgroups_fn = @[[NUM_WORKGROUPS_FN:[a-zA-Z0-9_]+]]
// CHECK-SAME: local_size = dense<[32, 4, 1]>
-// CHECK-SAME: vkspv.num_workgroups_fn = @[[NUM_WORKGROUPS_FN:[a-zA-Z0-9_]+]]
// CHECK-DAG: %[[ARG0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg0
// CHECK-DAG: %[[ARG1:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg1
// CHECK-DAG: %[[RET0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@ret0
@@ -247,7 +247,7 @@
{max_compute_workgroup_invocations = 128 : i32,
max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
func @pooling_max_4D()
- attributes {vkspv.num_workgroups_fn = @pooling_max_4D__num_workgroups__} {
+ attributes {hal.num_workgroups_fn = @pooling_max_4D__num_workgroups__} {
%0 = iree.placeholder for "interace buffer"
{binding = @legacy_io::@arg0, operand_result_index = 0 : i32} : memref<?x?x?x?xf32>
%1 = iree.placeholder for "interace buffer"
@@ -272,8 +272,8 @@
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 4)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0] -> (s0 * 32)>
// CHECK: func @pooling_max_4D()
+// CHECK-SAME: hal.num_workgroups_fn = @[[NUM_WORKGROUPS_FN:[a-zA-Z0-9_]+]]
// CHECK-SAME: local_size = dense<[32, 4, 1]>
-// CHECK-SAME: vkspv.num_workgroups_fn = @[[NUM_WORKGROUPS_FN:[a-zA-Z0-9_]+]]
// CHECK-DAG: %[[ARG0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg0
// CHECK-DAG: %[[ARG1:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg1
// CHECK-DAG: %[[RET0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@ret0
@@ -311,7 +311,7 @@
[Shader], [SPV_KHR_storage_buffer_storage_class]>,
{max_compute_workgroup_invocations = 128 : i32,
max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
- func @matmul_fusion() attributes {vkspv.num_workgroups_fn = @matmul_fusion__num_workgroups__} {
+ func @matmul_fusion() attributes {hal.num_workgroups_fn = @matmul_fusion__num_workgroups__} {
%0 = iree.placeholder for "interace buffer"
{binding = @legacy_io::@arg0, operand_result_index = 0 : i32} : memref<?x?xf32>
%1 = iree.placeholder for "interace buffer"
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir
index e2e8d7e..a6a77b4 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir
@@ -6,7 +6,7 @@
// CHECK: linalg.conv
func @kernel_fusable_fill_conv_ops()
- attributes {vkspv.num_workgroups_fn = @kernel_fusable_fill_conv_ops_num_workgroups__} {
+ attributes {hal.num_workgroups_fn = @kernel_fusable_fill_conv_ops_num_workgroups__} {
%cst = constant 0.000000e+00 : f32
%dim = hal.interface.load.constant offset = 0 : index
%shape1 = shapex.make_ranked_shape %dim : (index) -> !shapex.ranked_shape<[?,2,2,512]>
@@ -40,7 +40,7 @@
// CHECK: linalg.matmul
func @kernel_fusable_fill_matmul_ops()
- attributes {vkspv.num_workgroups_fn = @kernel_fusable_fill_matmul_ops_num_workgroups__} {
+ attributes {hal.num_workgroups_fn = @kernel_fusable_fill_matmul_ops_num_workgroups__} {
%cst = constant 0.000000e+00 : f32
%dimM = hal.interface.load.constant offset = 0 : index
%dimN = hal.interface.load.constant offset = 1 : index
@@ -76,7 +76,7 @@
// CHECK: func @kernel_fusable_pooling()
// CHECK: linalg.fill
// CHECK: linalg.pooling
- func @kernel_fusable_pooling() attributes {vkspv.num_workgroups_fn = @kernel_fusable_pooling__num_workgroups__} {
+ func @kernel_fusable_pooling() attributes {hal.num_workgroups_fn = @kernel_fusable_pooling__num_workgroups__} {
%cst = constant 0.000000e+00 : f32
%0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x?xf32>
%1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<?x?xf32>
@@ -100,7 +100,7 @@
// -----
-// CHECK: module attributes {vkspv.entry_point_schedule = ["kernel_dispatch_0", "kernel_dispatch_1"]}
+// CHECK: module attributes {hal.entry_point_schedule = ["kernel_dispatch_0", "kernel_dispatch_1"]}
module {
// CHECK: func @kernel_dispatch_1()
// CHECK: %[[ZERO:.+]] = constant
@@ -123,7 +123,7 @@
// CHECK: linalg.conv(%[[IN2]], %[[TS1]], %[[TS2]])
// CHECK: return
- func @kernel() attributes {vkspv.num_workgroups_fn = @kernel__num_workgroups__} {
+ func @kernel() attributes {hal.num_workgroups_fn = @kernel__num_workgroups__} {
%cst = constant 0.000000e+00 : f32
%dim = hal.interface.load.constant offset = 0 : index
%shape1 = shapex.make_ranked_shape %dim : (index) -> !shapex.ranked_shape<[?,2,2,512]>
@@ -151,10 +151,10 @@
// -----
-// CHECK: module attributes {vkspv.entry_point_schedule = ["kernel_dispatch_0", "kernel_dispatch_1", "kernel_dispatch_2"]}
+// CHECK: module attributes {hal.entry_point_schedule = ["kernel_dispatch_0", "kernel_dispatch_1", "kernel_dispatch_2"]}
module {
// CHECK: func @kernel_dispatch_2()
-// CHECK-SAME: {vkspv.num_workgroups_fn = @[[NUM_WORKGROUPS_FN2:.+]]}
+// CHECK-SAME: {hal.num_workgroups_fn = @[[NUM_WORKGROUPS_FN2:.+]]}
// CHECK: %[[DIM:.+]] = hal.interface.load.constant
// CHECK: %[[SHAPE1:.+]] = shapex.make_ranked_shape %[[DIM]]
// CHECK: %[[SHAPE2:.+]] = shapex.make_ranked_shape %[[DIM]]
@@ -169,7 +169,7 @@
// CHECK: func @[[NUM_WORKGROUPS_FN2]]
// CHECK: func @kernel_dispatch_1()
-// CHECK-SAME: {vkspv.num_workgroups_fn = @[[NUM_WORKGROUPS_FN1:.+]]}
+// CHECK-SAME: {hal.num_workgroups_fn = @[[NUM_WORKGROUPS_FN1:.+]]}
// CHECK: %[[C0:.+]] = constant 0 : index
// CHECK: %[[C1:.+]] = constant 1 : index
// CHECK: scf.parallel (%{{.*}}) = (%[[C0]]) to (%[[C1]]) step (%[[C1]])
@@ -179,7 +179,7 @@
// CHECK: func @[[NUM_WORKGROUPS_FN1]]
// CHECK: func @kernel_dispatch_0()
-// CHECK-SAME: {vkspv.num_workgroups_fn = @[[NUM_WORKGROUPS_FN0:.+]]}
+// CHECK-SAME: {hal.num_workgroups_fn = @[[NUM_WORKGROUPS_FN0:.+]]}
// CHECK: %[[ZERO:.+]] = constant
// CHECK: %[[DIM:.+]] = hal.interface.load.constant
// CHECK: %[[SHAPE:.+]] = shapex.make_ranked_shape %[[DIM]]
@@ -190,7 +190,7 @@
// CHECK: func @[[NUM_WORKGROUPS_FN0]]
- func @kernel() attributes {vkspv.num_workgroups_fn = @kernel__num_workgroups__} {
+ func @kernel() attributes {hal.num_workgroups_fn = @kernel__num_workgroups__} {
%cst = constant 0.000000e+00 : f32
%c0 = constant 0 : index
%c1 = constant 1 : index
@@ -226,10 +226,10 @@
// Nothing to do if there is just one Linalg op.
-// CHECK-NOT: vkspv.entry_point_schedule
+// CHECK-NOT: hal.entry_point_schedule
module {
// CHECK-LABEL: @kernel()
- func @kernel() attributes {vkspv.num_workgroups_fn = @kernel__num_workgroups__} {
+ func @kernel() attributes {hal.num_workgroups_fn = @kernel__num_workgroups__} {
%cst = constant 0.000000e+00 : f32
%0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<1x2x2x512xf32>
%1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<3x3x512x1xf32>
@@ -275,7 +275,7 @@
module {
func @subview_interleaved()
- attributes {vkspv.num_workgroups_fn = @subview_interleaved__num_workgroups__} {
+ attributes {hal.num_workgroups_fn = @subview_interleaved__num_workgroups__} {
%cst = constant 0.000000e+00 : f32
%0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<18x12xf32>
%1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<12x4xf32>
@@ -295,7 +295,7 @@
}
// CHECK: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0 * 12 + d1 + 53)>
-// CHECK: module attributes {vkspv.entry_point_schedule =
+// CHECK: module attributes {hal.entry_point_schedule =
// CHECK-SAME: ["subview_interleaved_dispatch_0",
// CHECK-SAME: "subview_interleaved_dispatch_1"]}
// CHECK: func @subview_interleaved_dispatch_1()
@@ -318,7 +318,7 @@
module {
func @reshape_interleaved()
- attributes {vkspv.num_workgroups_fn = @reshape_interleaved__num_workgroups__} {
+ attributes {hal.num_workgroups_fn = @reshape_interleaved__num_workgroups__} {
%0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<2x4xf32>
%1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret1} : memref<1x2x4xf32>
%2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<2x4xf32>
@@ -349,7 +349,7 @@
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d2)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
-// CHECK: module attributes {vkspv.entry_point_schedule =
+// CHECK: module attributes {hal.entry_point_schedule =
// CHECK-SAME: ["reshape_interleaved_dispatch_0",
// CHECK-SAME: "reshape_interleaved_dispatch_1"]}
// CHECK: func @reshape_interleaved_dispatch_1()
@@ -369,7 +369,7 @@
module {
func @predict_ex_dispatch_0()
- attributes {vkspv.num_workgroups_fn = @predict_ex_dispatch_0__num_workgroups__} {
+ attributes {hal.num_workgroups_fn = @predict_ex_dispatch_0__num_workgroups__} {
%0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<1x512x1xf32>
%1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret1} : memref<4x8x16xf32>
%2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<1x512x1xf32>
@@ -397,7 +397,7 @@
hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
}
}
-// CHECK: module attributes {vkspv.entry_point_schedule =
+// CHECK: module attributes {hal.entry_point_schedule =
// CHECK-SAME: ["predict_ex_dispatch_0_dispatch_0",
// CHECK-SAME: "predict_ex_dispatch_0_dispatch_1"]}
// CHECK: func @predict_ex_dispatch_0_dispatch_1
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir
index ff1ec02..0697db6 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir
@@ -5,7 +5,7 @@
#spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
{max_compute_workgroup_invocations = 128 : i32,
max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
- func @matmul_tile() attributes {vkspv.num_workgroups_fn = @matmul_tile__num_workgroups__} {
+ func @matmul_tile() attributes {hal.num_workgroups_fn = @matmul_tile__num_workgroups__} {
%0 = iree.placeholder for "interace buffer"
{binding = @legacy_io::@arg0, operand_result_index = 0 : i32} : memref<?x?xf32>
%1 = iree.placeholder for "interace buffer"
@@ -57,7 +57,7 @@
{max_compute_workgroup_invocations = 128 : i32,
max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
func @conv_no_padding_tile()
- attributes {vkspv.num_workgroups_fn = @conv_no_padding_tile__num_workgroups__} {
+ attributes {hal.num_workgroups_fn = @conv_no_padding_tile__num_workgroups__} {
%0 = iree.placeholder for "interace buffer"
{binding = @legacy_io::@arg0, operand_result_index = 0 : i32} : memref<3x4x3x2xf32>
%1 = iree.placeholder for "interace buffer"
diff --git a/iree/compiler/Conversion/init_conversions.h b/iree/compiler/Conversion/init_conversions.h
index faf79ab..4d4c53c 100644
--- a/iree/compiler/Conversion/init_conversions.h
+++ b/iree/compiler/Conversion/init_conversions.h
@@ -61,6 +61,9 @@
inline void registerLinalgToLLVMPasses() {
static bool init_once = []() {
// LinalgToLLVM
+ createConvImg2ColMatmulConversionPass();
+ createLinalgTileAndDistributePass();
+ createMatMulTileAndVectorizePass();
return true;
}();
(void)init_once;
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/LLVMAOTTarget.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/LLVMAOTTarget.cpp
index 199daf3..6060600 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/LLVMAOTTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/LLVMAOTTarget.cpp
@@ -46,7 +46,17 @@
// multi-threading issues.
llvm::LLVMContext context;
- iree::DyLibExecutableDefT dyLibExecutableDef;
+ // Remove all private functions, e.g tile size calcuations.
+ SmallVector<FuncOp, 4> nonPublicFn;
+ for (auto func : targetOp.getInnerModule().getOps<FuncOp>()) {
+ if (SymbolTable::getSymbolVisibility(func) !=
+ SymbolTable::Visibility::Public) {
+ nonPublicFn.push_back(func);
+ }
+ }
+ for (auto func : nonPublicFn) {
+ func.erase();
+ }
// At this moment we are leaving MLIR LLVM dialect land translating module
// into target independent LLVMIR.
@@ -56,6 +66,7 @@
return failure();
}
+ iree::DyLibExecutableDefT dyLibExecutableDef;
// Create invocation function an populate entry_points.
auto entryPointOps = targetOp.getBlock().getOps<ExecutableEntryPointOp>();
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/BUILD b/iree/compiler/Dialect/HAL/Target/LLVM/BUILD
index b644afe..9c608c2 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/BUILD
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/BUILD
@@ -40,6 +40,8 @@
deps = [
":LLVMIRPasses",
":LLVMTargetOptions",
+ "//iree/compiler/Conversion/CodegenUtils",
+ "//iree/compiler/Conversion/Common",
"//iree/compiler/Conversion/LinalgToLLVM",
"//iree/compiler/Dialect/HAL/Target",
"@llvm-project//llvm:Support",
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/LLVM/CMakeLists.txt
index ed01729..1f0f77f 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/CMakeLists.txt
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/CMakeLists.txt
@@ -34,6 +34,8 @@
MLIRLinalg
MLIRSCF
MLIRVector
+ iree::compiler::Conversion::CodegenUtils
+ iree::compiler::Conversion::Common
iree::compiler::Conversion::LinalgToLLVM
iree::compiler::Dialect::HAL::Target
PUBLIC
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/IR/BUILD b/iree/compiler/Dialect/HAL/Target/LLVM/IR/BUILD
index 2194e66..eed06b7 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/IR/BUILD
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/IR/BUILD
@@ -41,6 +41,7 @@
"//iree/compiler/Dialect/HAL/Target/LLVM:LLVMBaseTarget",
"//iree/compiler/Dialect/HAL/Target/LLVM:LLVMIRPasses",
"//iree/compiler/Dialect/HAL/Target/LLVM:LLVMTargetOptions",
+ "//iree/compiler/Dialect/Shape/IR",
"//iree/schemas:llvmir_executable_def_cc_fbs",
"@llvm-project//llvm:Core",
"@llvm-project//llvm:Support",
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/IR/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/LLVM/IR/CMakeLists.txt
index 345fe5a..c490eb0 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/IR/CMakeLists.txt
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/IR/CMakeLists.txt
@@ -33,6 +33,7 @@
iree::compiler::Dialect::HAL::Target::LLVM::LLVMBaseTarget
iree::compiler::Dialect::HAL::Target::LLVM::LLVMIRPasses
iree::compiler::Dialect::HAL::Target::LLVM::LLVMTargetOptions
+ iree::compiler::Dialect::Shape::IR
iree::schemas::llvmir_executable_def_cc_fbs
PUBLIC
)
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/IR/LLVMIRTarget.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/IR/LLVMIRTarget.cpp
index 3a03471..b18ecbf 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/IR/LLVMIRTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/IR/LLVMIRTarget.cpp
@@ -42,11 +42,23 @@
// Perform the translation to LLVM in a separate context to avoid
// multi-threading issues.
llvm::LLVMContext context;
+ // Remove all private functions, e.g tile size calcuations.
+ SmallVector<FuncOp, 4> nonPublicFn;
+ for (auto func : targetOp.getInnerModule().getOps<FuncOp>()) {
+ if (SymbolTable::getSymbolVisibility(func) !=
+ SymbolTable::Visibility::Public) {
+ nonPublicFn.push_back(func);
+ }
+ }
+ for (auto func : nonPublicFn) {
+ func.erase();
+ }
// At this moment we are leaving MLIR LLVM dialect land translating module
// into target independent LLVMIR.
auto llvmModule =
mlir::translateModuleToLLVMIR(targetOp.getInnerModule(), context);
+
if (!llvmModule) {
return targetOp.emitError("Failed to translate executable to LLVM IR");
}
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMBaseTarget.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMBaseTarget.cpp
index a0c82c2..27dc57d 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMBaseTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMBaseTarget.cpp
@@ -14,6 +14,8 @@
#include "iree/compiler/Dialect/HAL/Target/LLVM/LLVMBaseTarget.h"
+#include "iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.h"
+#include "iree/compiler/Conversion/Common/Attributes.h"
#include "iree/compiler/Conversion/LinalgToLLVM/Passes.h"
#include "iree/compiler/Dialect/HAL/Target/LLVM/LLVMIRPasses.h"
#include "llvm/Support/FormatVariadic.h"
@@ -183,13 +185,76 @@
return success();
}
-std::array<Value, 3> LLVMBaseTargetBackend::calculateDispatchWorkgroupCount(
- Location loc, IREE::HAL::ExecutableOp executableOp,
- IREE::HAL::ExecutableEntryPointOp entryPointOp, Value workload,
- OpBuilder &builder) {
- // For now we are not tiling and just dispatch everything as 1,1,1.
- auto constantOne = builder.createOrFold<mlir::ConstantIndexOp>(loc, 1);
- return {constantOne, constantOne, constantOne};
+LogicalResult LLVMBaseTargetBackend::recordDispatch(
+ Location loc, DispatchState dispatchState,
+ DeviceSwitchRewriter &switchRewriter) {
+ IREE::HAL::ExecutableOp executableOp = dispatchState.executableOp;
+ ModuleOp llvmIRModuleOp;
+ for (auto executableTargetOp :
+ executableOp.getBlock().getOps<IREE::HAL::ExecutableTargetOp>()) {
+ if (matchPattern(executableTargetOp.target_backend_filter(),
+ filter_pattern())) {
+ ModuleOp innerModuleOp = executableTargetOp.getInnerModule();
+ llvmIRModuleOp = innerModuleOp;
+ break;
+ }
+ }
+ if (!llvmIRModuleOp)
+ return executableOp.emitError("unable to find executable llvmIR module");
+
+ SmallVector<LLVM::LLVMFuncOp, 2> entryPointFns;
+ for (LLVM::LLVMFuncOp funcOp : llvmIRModuleOp.getOps<LLVM::LLVMFuncOp>()) {
+ if (SymbolTable::getSymbolVisibility(funcOp) ==
+ SymbolTable::Visibility::Public) {
+ entryPointFns.push_back(funcOp);
+ }
+ }
+
+ auto *region = switchRewriter.addConditionRegion(
+ IREE::HAL::DeviceMatchIDAttr::get(filter_pattern(), loc.getContext()),
+ {
+ dispatchState.workload,
+ dispatchState.commandBuffer,
+ });
+ auto &entryBlock = region->front();
+ ConversionPatternRewriter &rewriter = switchRewriter.getRewriter();
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointToEnd(&entryBlock);
+
+ auto commandBuffer = entryBlock.getArgument(1);
+ for (auto it : llvm::enumerate(entryPointFns)) {
+ LLVM::LLVMFuncOp funcOp = it.value();
+ FlatSymbolRefAttr numWorkgroupsFnAttr =
+ funcOp.getAttrOfType<FlatSymbolRefAttr>(getNumWorkgroupsFnAttrName());
+ if (!numWorkgroupsFnAttr) {
+ return funcOp.emitError("expected llvm.num_workgroups_fn ");
+ }
+ std::array<Value, 3> workgroupCount = {nullptr, nullptr, nullptr};
+ FuncOp numWorkgroupsFn = dyn_cast<FuncOp>(SymbolTable::lookupSymbolIn(
+ funcOp.getParentOfType<ModuleOp>(), numWorkgroupsFnAttr));
+ if (!numWorkgroupsFn) {
+ return funcOp.emitError("unable to find function ")
+ << numWorkgroupsFnAttr
+ << " that computes the number of workgroups to use";
+ }
+ workgroupCount = iree_compiler::calculateWorkgroupCountFromNumWorkgroupsFn(
+ loc, numWorkgroupsFn, dispatchState.executableOp.getFirstInterfaceOp(),
+ dispatchState.operands, dispatchState.results, rewriter);
+
+ if (llvm::any_of(workgroupCount,
+ [](Value v) -> bool { return v == nullptr; })) {
+ auto constantOne = rewriter.createOrFold<mlir::ConstantIndexOp>(loc, 1);
+ rewriter.create<IREE::HAL::CommandBufferDispatchSymbolOp>(
+ loc, commandBuffer, dispatchState.entryPointOp, constantOne,
+ constantOne, constantOne);
+ } else {
+ rewriter.create<IREE::HAL::CommandBufferDispatchSymbolOp>(
+ loc, commandBuffer, dispatchState.entryPointOp, workgroupCount[0],
+ workgroupCount[1], workgroupCount[2]);
+ }
+ }
+ rewriter.create<IREE::HAL::ReturnOp>(loc);
+ return success();
}
} // namespace HAL
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMBaseTarget.h b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMBaseTarget.h
index 52d0dbd..4b67aa0 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMBaseTarget.h
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMBaseTarget.h
@@ -35,10 +35,8 @@
LogicalResult linkExecutables(mlir::ModuleOp moduleOp) override;
- std::array<Value, 3> calculateDispatchWorkgroupCount(
- Location loc, IREE::HAL::ExecutableOp executableOp,
- IREE::HAL::ExecutableEntryPointOp entryPointOp, Value workload,
- OpBuilder &builder) override;
+ LogicalResult recordDispatch(Location loc, DispatchState dispatchState,
+ DeviceSwitchRewriter &switchRewriter) override;
protected:
LLVMTargetOptions options_;
diff --git a/iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.cpp b/iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.cpp
index 4f5d2b4..a4a58ef 100644
--- a/iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.cpp
@@ -15,7 +15,7 @@
#include "iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.h"
#include "flatbuffers/flatbuffers.h"
-#include "iree/compiler/Conversion/LinalgToSPIRV/Attributes.h"
+#include "iree/compiler/Conversion/Common/Attributes.h"
#include "iree/compiler/Dialect/HAL/Target/MetalSPIRV/SPIRVToMSL.h"
#include "iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.h"
#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
diff --git a/iree/compiler/Dialect/HAL/Target/SPIRVCommon/BUILD b/iree/compiler/Dialect/HAL/Target/SPIRVCommon/BUILD
index ae5d0d2..d8e7419 100644
--- a/iree/compiler/Dialect/HAL/Target/SPIRVCommon/BUILD
+++ b/iree/compiler/Dialect/HAL/Target/SPIRVCommon/BUILD
@@ -37,6 +37,8 @@
"SPIRVTarget.h",
],
deps = [
+ "//iree/compiler/Conversion/CodegenUtils",
+ "//iree/compiler/Conversion/Common",
"//iree/compiler/Conversion/LinalgToSPIRV",
"//iree/compiler/Dialect/Flow/IR",
"//iree/compiler/Dialect/HAL/Target",
diff --git a/iree/compiler/Dialect/HAL/Target/SPIRVCommon/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/SPIRVCommon/CMakeLists.txt
index 45d0b21..ba08998 100644
--- a/iree/compiler/Dialect/HAL/Target/SPIRVCommon/CMakeLists.txt
+++ b/iree/compiler/Dialect/HAL/Target/SPIRVCommon/CMakeLists.txt
@@ -32,6 +32,8 @@
MLIRSPIRVSerialization
MLIRSPIRVTransforms
MLIRSupport
+ iree::compiler::Conversion::CodegenUtils
+ iree::compiler::Conversion::Common
iree::compiler::Conversion::LinalgToSPIRV
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::HAL::Target
diff --git a/iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.cpp b/iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.cpp
index 25cea52..87dddbd 100644
--- a/iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.cpp
@@ -14,7 +14,8 @@
#include "iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.h"
-#include "iree/compiler/Conversion/LinalgToSPIRV/Attributes.h"
+#include "iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.h"
+#include "iree/compiler/Conversion/Common/Attributes.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
@@ -41,64 +42,6 @@
ArrayRef<Value>{memoryBarrier}, ArrayRef<Value>{});
}
-/// The codegeneration emits a function `numWorkgroupsFn` for each entry point
-/// function. This function has arguments the !shapex.ranked_shape for all the
-/// input and output shaped types. Using this the function returns the number of
-/// workgroups to use. To use this function on the host side, generate the
-/// !shapex.ranked_shape values that describe the shape of the inputs and
-/// outputs of the dispatch region and "inline" the function body.
-static std::array<Value, 3> calculateWorkgroupCountFromNumWorkgroupsFn(
- Location loc, FuncOp numWorkgroupsFn, IREE::HAL::InterfaceOp interface,
- ArrayRef<Optional<TensorRewriteAdaptor>> operands,
- ArrayRef<Optional<TensorRewriteAdaptor>> results,
- ConversionPatternRewriter &rewriter) {
- std::array<Value, 3> returnValue = {nullptr, nullptr, nullptr};
- // TODO: This is really just inlining a function. For now assume that the
- // `numWorkgroupsFn` has a single block to make inlining easier.
- if (!numWorkgroupsFn || !llvm::hasSingleElement(numWorkgroupsFn))
- return returnValue;
- SmallVector<SmallVector<Value, 4>, 4> shapeValues;
- shapeValues.reserve(operands.size() + results.size());
- auto getShapeValuesFn =
- [&](ArrayRef<Optional<TensorRewriteAdaptor>> values) -> LogicalResult {
- for (auto val : values) {
- if (!val) continue;
- Optional<SmallVector<Value, 4>> shape = val->getShapeDims(rewriter);
- if (!shape) return emitError(loc, "shape computation for operand failed");
- shapeValues.push_back(shape.getValue());
- }
- return success();
- };
- if (failed(getShapeValuesFn(operands)) || failed(getShapeValuesFn(results)))
- return returnValue;
- BlockAndValueMapping mapper;
- for (Operation &op : numWorkgroupsFn.front()) {
- if (isa<mlir::ReturnOp>(op)) {
- for (unsigned i = 0, e = std::min<unsigned>(3, op.getNumOperands());
- i != e; ++i) {
- returnValue[i] = mapper.lookupOrNull(op.getOperand(i));
- }
- break;
- }
- if (auto shapeOp = dyn_cast<Shape::RankedDimOp>(op)) {
- if (BlockArgument arg = shapeOp.shape().dyn_cast<BlockArgument>()) {
- auto &dimValues = shapeValues[arg.getArgNumber()];
- mapper.map(shapeOp.result(), dimValues[shapeOp.getIndex()]);
- continue;
- }
- return returnValue;
- }
- // If all its operands are mapped, clone it.
- if (llvm::all_of(op.getOperands(), [&mapper](Value operand) {
- return mapper.contains(operand);
- })) {
- rewriter.clone(op, mapper);
- continue;
- }
- }
- return returnValue;
-}
-
SPIRVTargetBackend::SPIRVTargetBackend(SPIRVCodegenOptions options)
: spvCodeGenOptions_(std::move(options)) {}
diff --git a/iree/compiler/Dialect/HAL/Target/TargetBackend.h b/iree/compiler/Dialect/HAL/Target/TargetBackend.h
index e7b6f6e..2d28774 100644
--- a/iree/compiler/Dialect/HAL/Target/TargetBackend.h
+++ b/iree/compiler/Dialect/HAL/Target/TargetBackend.h
@@ -345,8 +345,8 @@
}
protected:
- // Calculates the workgroup size (x, y, z). Tese are the dimension numbers for
- // a single workgroup.
+ // Calculates the workgroup size (x, y, z). These are the dimension numbers
+ // for a single workgroup.
virtual std::array<Value, 3> calculateDispatchWorkgroupSize(
Location loc, IREE::HAL::ExecutableOp executableOp,
IREE::HAL::ExecutableEntryPointOp entryPointOp, Value workload,
diff --git a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/BUILD b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/BUILD
index 97d6bc2..dcf132b 100644
--- a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/BUILD
+++ b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/BUILD
@@ -37,6 +37,7 @@
"VulkanSPIRVTarget.h",
],
deps = [
+ "//iree/compiler/Conversion/Common",
"//iree/compiler/Conversion/LinalgToSPIRV",
"//iree/compiler/Dialect/Flow/IR",
"//iree/compiler/Dialect/HAL/Target",
diff --git a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/CMakeLists.txt
index 4cba78a..b7ed83c 100644
--- a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/CMakeLists.txt
+++ b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/CMakeLists.txt
@@ -37,6 +37,7 @@
MLIRSupport
MLIRVector
flatbuffers
+ iree::compiler::Conversion::Common
iree::compiler::Conversion::LinalgToSPIRV
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::HAL::Target
diff --git a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
index f2787b6..4bc29f1 100644
--- a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
@@ -15,7 +15,7 @@
#include "iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.h"
#include "flatbuffers/flatbuffers.h"
-#include "iree/compiler/Conversion/LinalgToSPIRV/Attributes.h"
+#include "iree/compiler/Conversion/Common/Attributes.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.h"
#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
diff --git a/iree/compiler/Dialect/IREE/IR/IREEOps.td b/iree/compiler/Dialect/IREE/IR/IREEOps.td
index 40b685a..8fcb71d 100644
--- a/iree/compiler/Dialect/IREE/IR/IREEOps.td
+++ b/iree/compiler/Dialect/IREE/IR/IREEOps.td
@@ -75,6 +75,38 @@
let assemblyFormat = [{ `for` $purpose attr-dict `:` type($output) }];
}
+def IREE_WorkgroupIdOp : IREE_PureOp<"workgroup_id"> {
+ let summary = "Get grid index of an iree workgoup among a specific dimension.";
+ let description = [{
+ IREE workgroups are logically distributed among a hypergrid, each point sampled
+ from the grid corresponds to a logical thread. For example in a 3d grid case
+ the op quries (x, y, z) dimensions:
+
+ ```mlir
+ %0 = iree.workgroup_id {dimension = "x"} : index
+ %1 = iree.workgroup_id {dimension = "y"} : index
+ %2 = iree.workgroup_id {dimension = "z"} : index
+ ```
+ }];
+
+ let arguments = (ins StrAttr:$dimension);
+ let results = (outs Index:$result);
+
+ let assemblyFormat = "attr-dict `:` type($result)";
+}
+
+def IREE_WorkgroupSizeOp: IREE_PureOp<"workgoup_size"> {
+ let summary = "Get grid size of an iree thread among a specific dimension";
+ let description = [{
+ Get grid size of an iree workgroups among a specific dimension
+ }];
+ let arguments = (ins StrAttr:$dimension);
+ let results = (outs Index:$result);
+
+ let assemblyFormat = "attr-dict `:` type($result)";
+}
+
+
//===----------------------------------------------------------------------===//
// Compiler hints
//===----------------------------------------------------------------------===//
diff --git a/iree/hal/dylib/dylib_executable.cc b/iree/hal/dylib/dylib_executable.cc
index 10dbf29..4bf3c35 100644
--- a/iree/hal/dylib/dylib_executable.cc
+++ b/iree/hal/dylib/dylib_executable.cc
@@ -149,10 +149,11 @@
auto* dispatch_state = static_cast<DyLibDispatchState*>(state);
IREE_TRACE_SCOPE_DYNAMIC(dispatch_state->entry_name);
- auto entry_function =
- (void (*)(void**, int32_t*))dispatch_state->entry_function;
+ auto entry_function = (void (*)(void**, int32_t*, int32_t, int32_t,
+ int32_t))dispatch_state->entry_function;
entry_function(dispatch_state->args.data(),
- dispatch_state->push_constant.data());
+ dispatch_state->push_constant.data(), workgroup_xyz[0],
+ workgroup_xyz[1], workgroup_xyz[2]);
return OkStatus();
}
diff --git a/iree/hal/llvmjit/llvmjit_executable.cc b/iree/hal/llvmjit/llvmjit_executable.cc
index 74f1f9d..391a217 100644
--- a/iree/hal/llvmjit/llvmjit_executable.cc
+++ b/iree/hal/llvmjit/llvmjit_executable.cc
@@ -154,9 +154,10 @@
IREE_TRACE_SCOPE0("LLVMJITExecutable::DispatchTile");
auto* dispatch_state = static_cast<LLVMJITDispatchState*>(state);
- auto func_ptr =
- (void (*)(void**, int32_t*))dispatch_state->symbol.getAddress();
- func_ptr(dispatch_state->args.data(), dispatch_state->push_constant.data());
+ auto func_ptr = (void (*)(void**, int32_t*, int32_t, int32_t,
+ int32_t))dispatch_state->symbol.getAddress();
+ func_ptr(dispatch_state->args.data(), dispatch_state->push_constant.data(),
+ workgroup_xyz[0], workgroup_xyz[1], workgroup_xyz[2]);
return OkStatus();
}
diff --git a/iree/test/e2e/llvmir_specific/BUILD b/iree/test/e2e/llvmir_specific/BUILD
index 3ab66ea..a51fcd7 100644
--- a/iree/test/e2e/llvmir_specific/BUILD
+++ b/iree/test/e2e/llvmir_specific/BUILD
@@ -32,3 +32,13 @@
driver = "llvm",
target_backend = "llvm-ir",
)
+
+iree_check_single_backend_test_suite(
+ name = "check_llvm-ir-dot_tile_and_distribute",
+ srcs = [
+ "dot.mlir",
+ ],
+ compiler_flags = ["-iree-codegen-linalg-to-llvm-tile-and-distrobute"],
+ driver = "llvm",
+ target_backend = "llvm-ir",
+)
diff --git a/iree/test/e2e/llvmir_specific/CMakeLists.txt b/iree/test/e2e/llvmir_specific/CMakeLists.txt
index 88fe959..010407c 100644
--- a/iree/test/e2e/llvmir_specific/CMakeLists.txt
+++ b/iree/test/e2e/llvmir_specific/CMakeLists.txt
@@ -26,3 +26,16 @@
COMPILER_FLAGS
"-iree-codegen-linalg-to-llvm-conv-img2col-conversion"
)
+
+iree_check_single_backend_test_suite(
+ NAME
+ check_llvm-ir-dot_tile_and_distribute
+ SRCS
+ "dot.mlir"
+ TARGET_BACKEND
+ llvm-ir
+ DRIVER
+ llvm
+ COMPILER_FLAGS
+ "-iree-codegen-linalg-to-llvm-tile-and-distrobute"
+)
diff --git a/iree/test/e2e/llvmir_specific/dot.mlir b/iree/test/e2e/llvmir_specific/dot.mlir
new file mode 100644
index 0000000..a492b62
--- /dev/null
+++ b/iree/test/e2e/llvmir_specific/dot.mlir
@@ -0,0 +1,36 @@
+func @dot_passthrough() attributes { iree.module.export } {
+ %lhs = iree.unfoldable_constant dense<[[0.3, 0.5]]> : tensor<1x2xf32>
+ %rhs = iree.unfoldable_constant dense<[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]> : tensor<2x3xf32>
+ %res = "mhlo.dot"(%lhs, %rhs) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32>
+ check.expect_almost_eq_const(%res, dense<[[0.23, 0.31, 0.39]]> : tensor<1x3xf32>) : tensor<1x3xf32>
+ return
+}
+
+func @gemm() attributes { iree.module.export } {
+ %lhs = iree.unfoldable_constant dense<[
+ [15.0, 14.0, 13.0],
+ [12.0, 11.0, 10.0],
+ [09.0, 08.0, 07.0],
+ [06.0, 05.0, 04.0],
+ [03.0, 02.0, 01.0]]> : tensor<5x3xf32>
+ %rhs = iree.unfoldable_constant dense<[
+ [15.0, 14.0, 13.0, 12.0, 11.0],
+ [10.0, 09.0, 08.0, 07.0, 06.0],
+ [05.0, 04.0, 03.0, 02.0, 01.0]]> : tensor<3x5xf32>
+ %res = "mhlo.dot"(%lhs, %rhs) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<5x3xf32>, tensor<3x5xf32>) -> tensor<5x5xf32>
+ check.expect_almost_eq_const(%res, dense<[
+ [430.0, 388.0, 346.0, 304.0, 262.0],
+ [340.0, 307.0, 274.0, 241.0, 208.0],
+ [250.0, 226.0, 202.0, 178.0, 154.0],
+ [160.0, 145.0, 130.0, 115.0, 100.0],
+ [70.0, 64.0, 58.0, 52.0, 46.0]]> : tensor<5x5xf32>) : tensor<5x5xf32>
+ return
+}
+
+func @large_matmul() attributes { iree.module.export } {
+ %lhs = iree.unfoldable_constant dense<1.0> : tensor<32x1024xf32>
+ %rhs = iree.unfoldable_constant dense<0.4> : tensor<1024x64xf32>
+ %res = "mhlo.dot"(%lhs, %rhs) : (tensor<32x1024xf32>, tensor<1024x64xf32>) -> tensor<32x64xf32>
+ check.expect_almost_eq_const(%res, dense<409.596> : tensor<32x64xf32>) : tensor<32x64xf32>
+ return
+}