Enable use of vector.contract within IREE codegen. (#2986)
This change adds two-level tiling to target subgroups. This is only
being done for matrix-matrix multiplies for now but could be extended
for other ops. Some associated changes. The 2-level tiled code is then
vectorized to get vector.contract operations.
This also changes the way tile size are computed by using a
LaunchConfig class. While this change introduces a default way of
computing this, it could be extended for switching between many
heuristics based on architecture, etc.
Some refactoring of the code to make it easier to see the
first tile for workgroups ->
Promote to use workgroup memory ->
second level tiling for subgroups ->
vectorization
flow.
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
index ea85ec2..01b35f3 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
@@ -14,19 +14,24 @@
//===- KernelDispatchUtils.cpp - Utilities for generating dispatch info ---===//
//
-// This file defines utility functions that can be used to create information
-// the dispatch on the host side needs to execute an entry point function, like
-// the number of workgroups to use for launch, etc.
+// This file defines utility functions that can be used to get the information
+// about tile sizes to use to partition work across workgroups, the workgroup
+// sizes and to create information the dispatch on the host side needs to
+// execute an entry point function (e.g. total number of workgroups).
//
//===----------------------------------------------------------------------===//
+#include "iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h"
+
#include "iree/compiler/Conversion/CodegenUtils/FunctionUtils.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/Attributes.h"
+#include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h"
#include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "llvm/Support/Debug.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/SPIRV/TargetAndABI.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Function.h"
@@ -40,6 +45,10 @@
namespace mlir {
namespace iree_compiler {
+//===----------------------------------------------------------------------===//
+// Number of workgroups computation
+//===----------------------------------------------------------------------===//
+
FuncOp getNumWorkgroupsFn(FuncOp entryPointFn) {
SymbolRefAttr attr =
entryPointFn.getAttrOfType<SymbolRefAttr>(getNumWorkgroupsFnAttrName());
@@ -153,5 +162,271 @@
return success();
}
+//===----------------------------------------------------------------------===//
+// Launch config calculation.
+//===----------------------------------------------------------------------===//
+
+/// Given `nprocs` try to distribute it evenly across 2 logical x and y.
+static std::tuple<int64_t, int64_t> distributeProcs2D(int64_t nprocs) {
+ int64_t nprocs_x = std::max<int64_t>(
+ 1, static_cast<int64_t>(
+ llvm::PowerOf2Ceil(static_cast<uint64_t>(std::sqrt(nprocs)))));
+ return std::make_tuple(nprocs_x, nprocs / nprocs_x);
+}
+
+/// For a given operation `op`, `options` and `resourceLimits` of the hardware
+/// compute the
+/// 1) number of tiling levels and tile sizes to use (updates `tileSizes`),
+/// 2) workgroup size to use (updates `workgroupSize`),
+/// 3) number of subgroups to use if two level tiling is used (updates
+/// `numSubgroups`).
+template <typename T>
+static LogicalResult getOpLaunchConfig(T op, const SPIRVCodegenOptions &options,
+ spirv::ResourceLimitsAttr resourceLimits,
+ TileSizesListType &tileSizes,
+ std::array<int64_t, 3> &workgroupSize,
+ std::array<int64_t, 3> &numSubgroups) {
+ return op.emitError("undefined launch config for tiled operation");
+}
+
+/// Launch config for `linalg.batchmatmul`.
+template <>
+LogicalResult getOpLaunchConfig(linalg::BatchMatmulOp op,
+ const SPIRVCodegenOptions &options,
+ spirv::ResourceLimitsAttr resourceLimits,
+ TileSizesListType &tileSizes,
+ std::array<int64_t, 3> &workgroupSize,
+ std::array<int64_t, 3> &numSubgroups) {
+ unsigned maxWorkgroupSize =
+ resourceLimits.max_compute_workgroup_invocations().getInt();
+ std::tie(workgroupSize[0], workgroupSize[1]) =
+ distributeProcs2D(maxWorkgroupSize);
+ workgroupSize[2] = 1;
+ // TODO(#3131): This is just being hard-wired for now to be minimal viable,
+ // but this can be decided better when we have better estimates of device
+ // charecteristics.
+ const int64_t nRowsPerWorkitem = 1;
+ const int64_t nColsPerWorkitem = 1;
+ const int64_t nBatchesPerWorkitem = 1;
+ int64_t tileSizeK = 0;
+ if (options.useWorkgroupMemory) {
+ // TODO(#3131): This number should be decided based on the amount of
+ // shared memory available (maybe). For now, just hard-wire it.
+ tileSizeK = 32;
+ }
+ assert(tileSizes.empty());
+ SmallVector<int64_t, 4> ts = {nBatchesPerWorkitem,
+ nRowsPerWorkitem * workgroupSize[1],
+ nColsPerWorkitem * workgroupSize[0], tileSizeK};
+ tileSizes.emplace_back(std::move(ts));
+ return success();
+}
+
+/// The size of the co-operative matrix multiply operations on the device.
+// TODO(#3131): This needs to be queried from the device.
+Optional<std::array<int64_t, 3>> getCooperativeMatmulSubgroupSize(
+ Type dataType, Type accumulatorType) {
+ if (dataType.isInteger(8) && accumulatorType.isInteger(32)) {
+ return std::array<int64_t, 3>{8, 8, 32};
+ }
+ if (dataType.isF16() &&
+ (accumulatorType.isF32() || accumulatorType.isF16())) {
+ return std::array<int64_t, 3>{8, 8, 16};
+ }
+ return {};
+}
+
+/// Launch configuration for using spv.CooperativeMatrixMulAddNV
+/// operations. Needs two levels of tiling.
+static LogicalResult getConfigForCooperativeMatmul(
+ linalg::MatmulOp op, spirv::ResourceLimitsAttr resourceLimits,
+ TileSizesListType &tileSizes, std::array<int64_t, 3> &workgroupSize,
+ std::array<int64_t, 3> &numSubgroups) {
+ auto targetEnv = spirv::TargetEnv(spirv::lookupTargetEnv(op));
+ if (!targetEnv.allows(spirv::Capability::CooperativeMatrixNV) ||
+ !targetEnv.allows(spirv::Extension::SPV_NV_cooperative_matrix))
+ return failure();
+
+ ShapedType lhsType = op.getOperand(0).getType().cast<ShapedType>();
+ ArrayRef<int64_t> lhsShape = lhsType.getShape();
+ ShapedType rhsType = op.getOperand(1).getType().cast<ShapedType>();
+ ArrayRef<int64_t> rhsShape = rhsType.getShape();
+ ShapedType outputType = op.getOperand(2).getType().cast<ShapedType>();
+
+ Optional<std::array<int64_t, 3>> coopMatmulSize =
+ getCooperativeMatmulSubgroupSize(lhsType.getElementType(),
+ outputType.getElementType());
+ if (!coopMatmulSize) return failure();
+
+ // Check that the matmul sizes are a multiple of the tilesize.
+ auto isMultipleOf = [](int64_t s, int64_t ts) {
+ return !ShapedType::isDynamic(s) && (s % ts) == 0;
+ };
+ if (!isMultipleOf(lhsShape[0], (*coopMatmulSize)[0]) ||
+ !isMultipleOf(rhsShape[1], (*coopMatmulSize)[1]) ||
+ !isMultipleOf(lhsShape[1], (*coopMatmulSize)[2]) ||
+ !isMultipleOf(rhsShape[0], (*coopMatmulSize)[2]))
+ return failure();
+
+ // TODO(ravishankarm, antiagainst): For now hardwire the subgroup size.
+ const int64_t subgroupSize = 32;
+ unsigned maxWorkgroupSize =
+ resourceLimits.max_compute_workgroup_invocations().getInt();
+ std::tie(numSubgroups[0], numSubgroups[1]) =
+ distributeProcs2D(maxWorkgroupSize / subgroupSize);
+ numSubgroups[2] = 1;
+ // TODO(#3131): This is just being hard-wired for now to be minimal viable,
+ // but this can be decided better when we have better estimates of device
+ // charecteristics.
+ const int64_t numVecMatmulPerSubgroupX = 1;
+ const int64_t numVecMatmulPerSubgroupY = 1;
+ SmallVector<int64_t, 4> ts = {
+ numVecMatmulPerSubgroupY * (*coopMatmulSize)[0] * numSubgroups[1],
+ numVecMatmulPerSubgroupX * (*coopMatmulSize)[1] * numSubgroups[0]};
+ tileSizes.emplace_back(std::move(ts));
+
+ workgroupSize[0] = numSubgroups[0] * subgroupSize;
+ workgroupSize[1] = numSubgroups[1];
+ workgroupSize[2] = 1;
+ // Subgroup tile sizes
+ SmallVector<int64_t, 4> subgroupTs = {
+ numVecMatmulPerSubgroupY * (*coopMatmulSize)[0],
+ numVecMatmulPerSubgroupX * (*coopMatmulSize)[1], (*coopMatmulSize)[2]};
+ tileSizes.emplace_back(std::move(subgroupTs));
+ return success();
+}
+
+template <>
+LogicalResult getOpLaunchConfig(linalg::MatmulOp op,
+ const SPIRVCodegenOptions &options,
+ spirv::ResourceLimitsAttr resourceLimits,
+ TileSizesListType &tileSizes,
+ std::array<int64_t, 3> &workgroupSize,
+ std::array<int64_t, 3> &numSubgroups) {
+ if (options.useVectorization &&
+ succeeded(getConfigForCooperativeMatmul(op, resourceLimits, tileSizes,
+ workgroupSize, numSubgroups))) {
+ return success();
+ }
+ unsigned maxWorkgroupSize =
+ resourceLimits.max_compute_workgroup_invocations().getInt();
+ std::tie(workgroupSize[0], workgroupSize[1]) =
+ distributeProcs2D(maxWorkgroupSize);
+ workgroupSize[2] = 1;
+ const int nRowsPerWorkitem = 1;
+ const int nColsPerWorkitem = 1;
+ int64_t tileSizeK = 0;
+ if (options.useWorkgroupMemory) {
+ // TODO(#3131): This number should be decided based on the amount of shared
+ // memory available (maybe). For now, just hard-wire it.
+ tileSizeK = 32;
+ }
+ assert(tileSizes.empty());
+ SmallVector<int64_t, 4> ts = {nRowsPerWorkitem * workgroupSize[1],
+ nColsPerWorkitem * workgroupSize[0], tileSizeK};
+ tileSizes.emplace_back(std::move(ts));
+ return success();
+}
+
+template <>
+LogicalResult getOpLaunchConfig(linalg::ConvOp op,
+ const SPIRVCodegenOptions &options,
+ spirv::ResourceLimitsAttr resourceLimits,
+ TileSizesListType &tileSizes,
+ std::array<int64_t, 3> &workgroupSize,
+ std::array<int64_t, 3> &numSubgroups) {
+ unsigned maxWorkgroupSize =
+ resourceLimits.max_compute_workgroup_invocations().getInt();
+ const int64_t tileSizeX = 32;
+ int64_t tileSizeY = maxWorkgroupSize / tileSizeX;
+ SmallVector<int64_t, 4> ts = {1, tileSizeY, tileSizeX};
+ tileSizes.emplace_back(std::move(ts));
+ workgroupSize = {tileSizeX, tileSizeY, 1};
+ return success();
+}
+
+static LogicalResult getPoolingOpLaunchConfig(
+ const SPIRVCodegenOptions &options,
+ spirv::ResourceLimitsAttr resourceLimits, TileSizesListType &tileSizes,
+ std::array<int64_t, 3> &workgroupSize,
+ std::array<int64_t, 3> &numSubgroups) {
+ unsigned maxWorkgroupSize =
+ resourceLimits.max_compute_workgroup_invocations().getInt();
+ const int64_t tileSizeX = 32;
+ int64_t tileSizeY = maxWorkgroupSize / tileSizeX;
+ SmallVector<int64_t, 4> ts = {tileSizeY, tileSizeX};
+ tileSizes.emplace_back(std::move(ts));
+ workgroupSize = {tileSizeX, tileSizeY, 1};
+ return success();
+}
+
+#define DEFINE_POOLING_OP_CONFIG(opName) \
+ template <> \
+ LogicalResult getOpLaunchConfig( \
+ opName op, const SPIRVCodegenOptions &options, \
+ spirv::ResourceLimitsAttr resourceLimits, TileSizesListType &tileSizes, \
+ std::array<int64_t, 3> &workgroupSize, \
+ std::array<int64_t, 3> &numSubgroups) { \
+ return getPoolingOpLaunchConfig(options, resourceLimits, tileSizes, \
+ workgroupSize, numSubgroups); \
+ }
+
+DEFINE_POOLING_OP_CONFIG(linalg::PoolingMaxOp)
+DEFINE_POOLING_OP_CONFIG(linalg::PoolingMinOp)
+DEFINE_POOLING_OP_CONFIG(linalg::PoolingSumOp)
+
+#undef DEFINE_POOLINGOP_CONFIG
+
+LogicalResult LaunchConfig::init(const SPIRVCodegenOptions &options,
+ ArrayRef<linalg::LinalgOp> linalgOps) {
+ if (!options.workgroupSize.empty()) {
+ for (linalg::LinalgOp op : linalgOps)
+ tileSizes[op.getOperation()->getName().getStringRef()] = {};
+ workgroupSize = {1, 1, 1};
+ for (unsigned i = 0,
+ e = std::min<unsigned>(3, options.workgroupSize.size());
+ i != e; ++i)
+ workgroupSize[i] = options.workgroupSize[i];
+ return success();
+ }
+
+ if (linalgOps.empty()) return success();
+
+ spirv::ResourceLimitsAttr resourceLimits =
+ spirv::lookupTargetEnv(*linalgOps.begin()).getResourceLimits();
+
+ for (linalg::LinalgOp op : linalgOps) {
+ StringRef key = op.getOperation()->getName().getStringRef();
+ if (tileSizes.count(key)) {
+ return op.emitError("unexpected multiple ")
+ << key << " operations within dispatch region";
+ }
+
+ TileSizesListType &tileSizesInfo = tileSizes[key];
+
+#define DISPATCH(opName) \
+ if (auto lOp = dyn_cast<opName>(op.getOperation())) { \
+ if (failed(getOpLaunchConfig(lOp, options, resourceLimits, tileSizesInfo, \
+ workgroupSize, numSubgroups))) { \
+ return failure(); \
+ } \
+ continue; \
+ }
+
+ DISPATCH(linalg::BatchMatmulOp)
+ DISPATCH(linalg::ConvOp)
+ DISPATCH(linalg::MatmulOp)
+ DISPATCH(linalg::PoolingMaxOp)
+ DISPATCH(linalg::PoolingMinOp)
+ DISPATCH(linalg::PoolingSumOp)
+
+#undef DISPATCH
+ }
+
+ // TODO(ravishankarm): Verify that the set configurations is within the device
+ // limits.
+ return success();
+}
+
} // namespace iree_compiler
} // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h
index 1573c55..d690a0a 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h
+++ b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h
@@ -22,11 +22,17 @@
#ifndef IREE_COMPILER_CONVERSION_LINALGTOSPIRV_KERNELDISPATCHUTILS_H_
#define IREE_COMPILER_CONVERSION_LINALGTOSPIRV_KERNELDISPATCHUTILS_H_
+#include <array>
+
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringMap.h"
+#include "mlir/IR/Operation.h"
#include "mlir/Support/LLVM.h"
namespace mlir {
class FuncOp;
class LogicalResult;
+class Operation;
class PatternRewriter;
class ShapedType;
class Value;
@@ -34,6 +40,9 @@
namespace linalg {
class LinalgOp;
}
+namespace iree_compiler {
+struct SPIRVCodegenOptions;
+}
namespace iree_compiler {
@@ -61,6 +70,68 @@
/// workgroups to use at launch time.
FuncOp getNumWorkgroupsFn(FuncOp entryPointFn);
+/// Store the tile sizes to use at different levels of tiling as a vector of
+/// vectors.
+/// - First level tiling maps to workgroups.
+/// - Second level tiling maps to subgroups.
+using TileSizesListType = SmallVector<SmallVector<int64_t, 4>, 1>;
+
+/// Based on the linalg operations in a dispatch region, the number of levels of
+/// tiling, the tile sizes needed, the workgroup size, etc. need to be
+/// decided. These parameters are called `LaunchConfig`. This class implements
+/// one heuristic to compute these for the different linalg operations on
+/// buffers. This can be adapted later to support multiple configurations that
+/// can be picked based on device information/problem size information. It
+/// exposes the information needed by the codegenerators, and hides the
+/// implementation from the rest of the pipeline.
+class LaunchConfig {
+ public:
+ LaunchConfig() : workgroupSize({1, 1, 1}), numSubgroups({1, 1, 1}) {}
+
+ /// Given the sequence of `linalgOps` (and `options`), decide the launch
+ /// configuration by deciding
+ /// - the number of levels of tiling,
+ /// - tile sizes for each level,
+ /// - the workgroup size, and
+ /// - number of subgroups to use.
+ LogicalResult init(const SPIRVCodegenOptions &options,
+ ArrayRef<linalg::LinalgOp> linalgOps);
+
+ /// Gets the tile size computed for an operation at all levels.
+ TileSizesListType getTileSizes(Operation *op) const {
+ return tileSizes.lookup(op->getName().getStringRef());
+ }
+
+ /// Gets the tile size computed for an operation for an level.
+ ArrayRef<int64_t> getTileSizes(Operation *op, size_t level) const {
+ auto it = tileSizes.find(op->getName().getStringRef());
+ if (it == tileSizes.end() || level >= it->second.size()) return {};
+ return it->second[level];
+ }
+
+ /// Returns the workgroup size to use based on the tile sizes.
+ ArrayRef<int64_t> getWorkgroupSize() const { return workgroupSize; }
+
+ /// Returns the number of subgroups to use.
+ ArrayRef<int64_t> getNumSubgroups() const { return numSubgroups; }
+
+ protected:
+ /// Current tile size configuration per operation.
+
+ // TODO: For now just use the operation name for the mapping. The tile sizes
+ // will be selected only for operations like matmul, conv, pool, etc. and
+ // assume that there is only one such operation per dispatch
+ // region. Eventually this might need to be relaxed, and some name-marker
+ // based mechanism might be needed.
+ llvm::StringMap<TileSizesListType> tileSizes;
+
+ /// Workgroup size to use.
+ std::array<int64_t, 3> workgroupSize;
+
+ /// Number of subgroups that are logically distributed along x, y & z.
+ std::array<int64_t, 3> numSubgroups;
+};
+
} // namespace iree_compiler
} // namespace mlir
#endif // IREE_COMPILER_CONVERSION_LINALGTOSPIRV_DISPATCHUTILS_H_
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
index 4c7417e..135961a 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
@@ -18,6 +18,7 @@
//
//===----------------------------------------------------------------------===//
#include "iree/compiler/Conversion/CodegenUtils/FunctionUtils.h"
+#include "iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/Attributes.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h"
@@ -29,7 +30,6 @@
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
-#include "mlir/Dialect/SPIRV/TargetAndABI.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Identifier.h"
#include "mlir/IR/Matchers.h"
@@ -37,7 +37,7 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/FoldUtils.h"
-#define DEBUG_TYPE "iree-linalg-tile-and-fuse-buffer"
+#define DEBUG_TYPE "iree-linalg-tile-and-fuse"
namespace mlir {
namespace iree_compiler {
@@ -46,18 +46,6 @@
// Utility functions
//===----------------------------------------------------------------------===//
-static ArrayRef<int64_t> dropTrailingOnes(ArrayRef<int64_t> vector) {
- if (vector.empty()) return vector;
- auto numTrailingOnes = 0;
- for (unsigned i = vector.size() - 1; i > 0; --i) {
- if (vector[i] != 1) {
- break;
- }
- numTrailingOnes++;
- }
- return vector.drop_back(numTrailingOnes);
-}
-
/// Returns true if the linalg op has padding attribute, and that it has
/// non-zero entries.
template <typename OpTy>
@@ -68,129 +56,6 @@
[](APInt v) -> bool { return !v.isNullValue(); });
}
-namespace {
-
-/// Computes tile sizes (and workgroup size) to use based on operations within
-/// the function, and resource constraints on the module.
-class TileSizeCalculator {
- public:
- TileSizeCalculator(FuncOp funcOp)
- : resourceLimits(spirv::lookupTargetEnv(funcOp).getResourceLimits()) {
- if (DenseIntElementsAttr attr = spirv::lookupLocalWorkGroupSize(funcOp)) {
- for (auto val : attr.getValues<APInt>())
- workgroupSize.push_back(val.getSExtValue());
- }
- workgroupSize.resize(3, 1);
- }
-
- /// Set tile sizes to use.
- void setTileSizes(ArrayRef<int64_t> sizes) {
- tileSizes.assign(sizes.begin(), sizes.end());
- }
-
- /// Set workgroup size to use.
- void setWorkgroupSize(ArrayRef<int64_t> sizes) {
- workgroupSize.assign(sizes.begin(), sizes.end());
- }
-
- /// Compute the tile sizes based on the Linalg Ops within the dispatch region.
- LogicalResult inferTileAndWorkgroupSize(ArrayRef<linalg::LinalgOp> linalgOps);
-
- /// Get the current tile size computed.
- ArrayRef<int64_t> getTileSizes() const { return tileSizes; }
-
- /// Returns the workgroup size to use based on the tile sizes.
- ArrayRef<int64_t> getWorkgroupSize() const { return workgroupSize; }
-
- private:
- /// Current tile size configuration.
- SmallVector<int64_t, 4> tileSizes;
-
- /// Workgroup size to use.
- SmallVector<int64_t, 3> workgroupSize;
-
- /// Attribute for device constraints.
- spirv::ResourceLimitsAttr resourceLimits;
-};
-} // namespace
-
-LogicalResult TileSizeCalculator::inferTileAndWorkgroupSize(
- ArrayRef<linalg::LinalgOp> linalgOps) {
- tileSizes.clear();
- if (linalgOps.empty()) {
- tileSizes = {1, 1, 1};
- workgroupSize = {1, 1, 1};
- return success();
- }
- // The tile size will be driven by operations like matmul, conv, etc. within
- // the list. So see what operation exists in the list to decide the tile size.
- // If there are two such operations in the list, return error.
- enum OpInfo : uint32_t {
- None = 0x0,
- Convolution = 0x1,
- Matmul = 0x2,
- Pooling = 0x4,
- BatchMatmul = 0x8,
- };
- uint32_t opInfo = OpInfo::None;
- for (linalg::LinalgOp linalgOp : linalgOps) {
- Operation *op = linalgOp.getOperation();
- if (isa<linalg::ConvOp>(op))
- opInfo |= OpInfo::Convolution;
- else if (isa<linalg::MatmulOp>(op))
- opInfo |= OpInfo::Matmul;
- else if (isa<linalg::BatchMatmulOp>(op))
- opInfo |= OpInfo::BatchMatmul;
- else if (isa<linalg::PoolingMaxOp>(op))
- opInfo |= OpInfo::Pooling;
- else if (isa<linalg::PoolingMinOp>(op))
- opInfo |= OpInfo::Pooling;
- else if (isa<linalg::PoolingSumOp>(op))
- opInfo |= OpInfo::Pooling;
- }
- // If there are no tilable ops, there is nothing to do here.
- if (!opInfo) return success();
-
- Operation *linalgOp = *(linalgOps.begin());
- if (llvm::countPopulation(opInfo) != 1)
- return linalgOp->getParentOfType<FuncOp>().emitError(
- "unhandled fusion of ops in dispatch function");
-
- // TODO(ravishanarm, antiagainst): Only the maximum workgroup size is used
- // here for computing tile sizes. In reality we also need the maximum
- // workgroup memory size available (per workgroup) to compute the tile sizes
- // effectively.
- unsigned maxWorkgroupSize =
- resourceLimits.max_compute_workgroup_invocations().getInt();
- if (opInfo & OpInfo::Convolution) {
- int64_t tileSizeX = 32;
- int64_t tileSizeY = maxWorkgroupSize / 32;
- tileSizes = {1, tileSizeY, tileSizeX};
- workgroupSize = {tileSizeX, tileSizeY, 1};
- return success();
- }
- if (opInfo & OpInfo::Matmul) {
- // TODO: For now just hard wire this, but we can do better.
- tileSizes = {8, 8, 4};
- workgroupSize = {8, 8, 1};
- return success();
- }
- if (opInfo & OpInfo::BatchMatmul) {
- tileSizes = {2, 8, 8, 4};
- workgroupSize = {8, 8, 2};
- return success();
- }
- if (opInfo & OpInfo::Pooling) {
- int64_t tileSizeX = 32;
- int64_t tileSizeY = maxWorkgroupSize / 32;
- tileSizes = {tileSizeY, tileSizeX};
- workgroupSize = {tileSizeX, tileSizeY, 1};
- return success();
- }
- return linalgOp->getParentOfType<FuncOp>().emitError(
- "unable to find tile size for ops in this dispatch function");
-}
-
//===----------------------------------------------------------------------===//
// Pass and patterns
//===----------------------------------------------------------------------===//
@@ -199,37 +64,53 @@
/// Function pass that implements tiling and fusion in Linalg on buffers.
struct LinalgTileAndFusePass
: public PassWrapper<LinalgTileAndFusePass, OperationPass<ModuleOp>> {
- LinalgTileAndFusePass(ArrayRef<int64_t> workgroupSize = {},
- ArrayRef<int64_t> tileSizes = {},
- bool useWorkgroupMem = false) {
- this->workgroupSize = workgroupSize;
- this->tileSizes = tileSizes;
- this->useWorkgroupMemory = useWorkgroupMem;
+ LinalgTileAndFusePass() = default;
+ LinalgTileAndFusePass(const SPIRVCodegenOptions &passedOptions) {
+ options = passedOptions;
}
LinalgTileAndFusePass(const LinalgTileAndFusePass &pass) {}
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<AffineDialect, gpu::GPUDialect, linalg::LinalgDialect,
- scf::SCFDialect, ShapeDialect>();
+ scf::SCFDialect, ShapeDialect, vector::VectorDialect>();
}
void runOnOperation() override;
private:
- Option<bool> useWorkgroupMemory{
- *this, "use-workgroup-memory",
- llvm::cl::desc("Promote subviews to use workgroup memory"),
- llvm::cl::init(false)};
+ SPIRVCodegenOptions options;
- ListOption<int64_t> workgroupSize{
- *this, "workgroup-size",
- llvm::cl::desc("Override the default workgroup size"),
- llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
-
+ // TODO: Find a common place to put these options. They are defined three
+ // times, once here, once for the pass pipeline and once for the binary.
ListOption<int64_t> tileSizes{
*this, "tile-sizes", llvm::cl::desc("Set tile sizes to use"),
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
+
+ ListOption<int64_t> workgroupSize{
+ *this, "workgroup-size",
+ llvm::cl::desc(
+ "Number of workgroups to dispatch for the SPIR-V module; at most "
+ "three integers standarding for the x, y, and z dimension; "
+ "additional arguments will be ignored (used only for testing)"),
+ llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
+
+ Option<bool> useWorkgroupMemory{
+ *this, "use-workgroup-memory",
+ llvm::cl::desc(
+ "Enable use of workgroup memory in SPIR-V code generation pipeline"),
+ llvm::cl::init(false)};
+
+ Option<bool> useVectorization{
+ *this, "use-vectorization",
+ llvm::cl::desc(
+ "Enable use of vectorization in SPIR-V code generation pipeline"),
+ llvm::cl::init(false)};
};
+//===----------------------------------------------------------------------===//
+// Patterns to tile computation to map to workgroups.
+//===----------------------------------------------------------------------===//
+
+/// Distribution options for linalg.matmul when targeting workgroups.
static linalg::LinalgLoopDistributionOptions matmulDistributionOptions = {
[](OpBuilder &builder, Location loc,
ArrayRef<SubViewOp::Range> parallelLoopRanges) {
@@ -237,6 +118,7 @@
builder, loc, parallelLoopRanges.size());
},
{linalg::DistributionMethod::CyclicNumProcsEqNumIters,
+ linalg::DistributionMethod::CyclicNumProcsEqNumIters,
linalg::DistributionMethod::CyclicNumProcsEqNumIters}};
/// Pattern for tiling operations. Updates the workgroup size in the surrounding
@@ -245,8 +127,8 @@
struct TileMatmulPattern : public linalg::LinalgBaseTilingPattern {
using Base = linalg::LinalgBaseTilingPattern;
TileMatmulPattern(MLIRContext *context, linalg::LinalgTilingOptions options,
- ArrayRef<int64_t> tileSizes,
- ArrayRef<int64_t> workgroupSize, PatternBenefit benefit = 1)
+ const LaunchConfig &launchConfig,
+ PatternBenefit benefit = 1)
: Base(MatmulOp::getOperationName(), context,
options.setDistributionOptions(matmulDistributionOptions),
linalg::LinalgMarker(
@@ -254,8 +136,7 @@
Identifier::get(getWorkgroupNumItemsGENumItersMarker(),
context)),
benefit),
- tileSizes(tileSizes.begin(), tileSizes.end()),
- workgroupSize(workgroupSize.begin(), workgroupSize.end()) {}
+ launchConfig(launchConfig) {}
virtual LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const {
@@ -263,20 +144,37 @@
// erased.
FuncOp funcOp = op->getParentOfType<FuncOp>();
if (!funcOp || failed(Base::matchAndRewrite(op, rewriter)) ||
- failed(updateWorkGroupSize(funcOp, workgroupSize)) ||
+ failed(updateWorkGroupSize(funcOp, launchConfig.getWorkgroupSize())) ||
(funcOp.getAttr(getNumWorkgroupsFnAttrName()) &&
failed(createNumWorkgroupsFromResultShape(
- rewriter, cast<linalg::LinalgOp>(op), funcOp, tileSizes)))) {
+ rewriter, cast<linalg::LinalgOp>(op), funcOp,
+ launchConfig.getTileSizes(op, 0))))) {
return failure();
}
rewriter.eraseOp(op);
return success();
}
- SmallVector<int64_t, 3> tileSizes;
- SmallVector<int64_t, 3> workgroupSize;
+ const LaunchConfig &launchConfig;
};
+/// Pattern to tile linalg.matmul for subgroups.
+struct TileMatmulSubgroupPattern
+ : public linalg::LinalgTilingPattern<linalg::MatmulOp> {
+ using Base = linalg::LinalgTilingPattern<linalg::MatmulOp>;
+ TileMatmulSubgroupPattern(MLIRContext *context,
+ linalg::LinalgTilingOptions options,
+ PatternBenefit benefit = 1)
+ : Base(context, options,
+ linalg::LinalgMarker(
+ Identifier::get(getWorkgroupNumItemsGENumItersMarker(),
+ context),
+ Identifier::get(getVectorizeMarker(), context)),
+ benefit) {}
+};
+
+/// Distribution options for targeting workgroups for convolution/pooling
+/// operations.
static linalg::LinalgLoopDistributionOptions convPoolDistributionOptions = {
[](OpBuilder &builder, Location loc,
ArrayRef<SubViewOp::Range> parallelLoopRanges) {
@@ -286,14 +184,12 @@
{linalg::DistributionMethod::Cyclic, linalg::DistributionMethod::Cyclic,
linalg::DistributionMethod::Cyclic}};
-/// Pattern for tiling convolution and pooling operations. Currently is just a
-/// way to not tile when the operation has padding.
+/// Pattern for tiling convolution and pooling operations.
template <typename OpTy>
struct TileConvPoolPattern : public linalg::LinalgTilingPattern<OpTy> {
using Base = linalg::LinalgTilingPattern<OpTy>;
TileConvPoolPattern(MLIRContext *context, linalg::LinalgTilingOptions options,
- ArrayRef<int64_t> tileSizes,
- ArrayRef<int64_t> workgroupSize,
+ const LaunchConfig &launchConfig,
PatternBenefit benefit = 1)
: Base(context,
options.setDistributionOptions(convPoolDistributionOptions),
@@ -301,38 +197,58 @@
ArrayRef<Identifier>(),
Identifier::get(getWorkgroupMarker(), context)),
benefit),
- tileSizes(tileSizes.begin(), tileSizes.end()),
- workgroupSize(workgroupSize.begin(), workgroupSize.end()) {}
+ launchConfig(launchConfig) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
if (hasPadding(cast<OpTy>(op))) return failure();
FuncOp funcOp = op->getParentOfType<FuncOp>();
if (!funcOp || failed(Base::matchAndRewrite(op, rewriter)) ||
- failed(updateWorkGroupSize(funcOp, this->workgroupSize)))
+ failed(updateWorkGroupSize(funcOp, launchConfig.getWorkgroupSize())))
return failure();
funcOp.removeAttr(getNumWorkgroupsFnAttrName());
return success();
}
- SmallVector<int64_t, 3> tileSizes;
- SmallVector<int64_t, 3> workgroupSize;
+ const LaunchConfig &launchConfig;
};
+/// Populate patterns for first-level tiling.
+static void populateTilingToWorkgroupPatterns(
+ MLIRContext *context, const LaunchConfig &launchConfig,
+ OwningRewritePatternList &patterns) {
+ // Function to compute first level tiling values.
+ std::function<SmallVector<Value, 4>(OpBuilder &, Operation *)>
+ getOuterTileSizeFn =
+ [&launchConfig](OpBuilder &builder,
+ Operation *operation) -> SmallVector<Value, 4> {
+ ArrayRef<int64_t> tileSizes = launchConfig.getTileSizes(operation, 0);
+ if (tileSizes.empty()) return {};
+ SmallVector<Value, 4> tileSizesVal;
+ tileSizesVal.reserve(tileSizes.size());
+ for (auto val : tileSizes) {
+ tileSizesVal.push_back(
+ builder.create<ConstantIndexOp>(operation->getLoc(), val));
+ }
+ return tileSizesVal;
+ };
+ patterns.insert<TileConvPoolPattern<linalg::ConvOp>,
+ TileMatmulPattern<linalg::MatmulOp>,
+ TileMatmulPattern<linalg::BatchMatmulOp>,
+ TileConvPoolPattern<linalg::PoolingMaxOp>,
+ TileConvPoolPattern<linalg::PoolingMinOp>,
+ TileConvPoolPattern<linalg::PoolingSumOp>>(
+ context,
+ linalg::LinalgTilingOptions()
+ .setTileSizeComputationFunction(getOuterTileSizeFn)
+ .setLoopType(linalg::LinalgTilingLoopType::ParallelLoops),
+ launchConfig);
+}
+
//===----------------------------------------------------------------------===//
// Patterns to promote subviews to workgroup memory
//===----------------------------------------------------------------------===//
-/// Function used as callback for copyin/copyout in promotion pattern used to
-/// promote subviews to workgroup memory when the number of threads is known to
-/// be greater than equal to the number of iteration of loops the copy is
-/// lowered to.
-static LogicalResult copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst) {
- auto copyOp = b.create<linalg::CopyOp>(src.getLoc(), src, dst);
- setMarker(copyOp, getCopyToWorkgroupMemoryMarker());
- return success();
-}
-
/// Pattern to promote matmul operands to workgroup memory.
struct PromoteMatmulSubviewsPattern
: public linalg::LinalgPromotionPattern<linalg::MatmulOp> {
@@ -381,10 +297,114 @@
};
} // namespace
+static void populatePromotionPatterns(MLIRContext *context,
+ OwningRewritePatternList &patterns) {
+ patterns
+ .insert<PromoteMatmulSubviewsPattern, PromoteConvolutionSubviewsPattern>(
+ context,
+ linalg::LinalgPromotionOptions()
+ .setAllocationDeallocationFns(allocateWorkgroupMemory,
+ deallocateWorkgroupMemory)
+ .setCopyInOutFns(copyToWorkgroupMemory, copyToWorkgroupMemory));
+}
+
+//===----------------------------------------------------------------------===//
+// Patterns and methods for subgroup tiling.
+//===----------------------------------------------------------------------===//
+
+/// Computes the Value for subgroupID along each dimension given number of
+/// subgroups `numSubGroups` along each dimension (x-first, y-second, z-third).
+static SmallVector<linalg::ProcInfo, 2> getSubgroupIdsAndCounts(
+ OpBuilder &builder, Location loc, ArrayRef<int64_t> numSubgroups) {
+ Type indexType = builder.getIndexType();
+ Value subgroupId = builder.create<gpu::SubgroupIdOp>(loc, indexType);
+ SmallVector<linalg::ProcInfo, 2> procInfo(numSubgroups.size());
+
+ // subgroupID
+ // = id.z * nsubgroups.y * nsubgroups.x + id.y * nsubgroups.x + id.x
+ using edsc::op::operator%;
+ for (size_t i = 0, e = numSubgroups.size(); i != e; ++i) {
+ Value nprocs = builder.create<ConstantIndexOp>(loc, numSubgroups[i]);
+ Value procId = subgroupId % nprocs;
+ procInfo[e - i - 1] = linalg::ProcInfo{procId, nprocs};
+ subgroupId = builder.create<SignedDivIOp>(loc, subgroupId, nprocs);
+ }
+ return procInfo;
+}
+
+/// Patterns for second level tiling to target subgroups.
+static void populateTilingToSubgroupPatterns(
+ MLIRContext *context, const LaunchConfig &launchConfig,
+ OwningRewritePatternList &patterns) {
+ std::function<SmallVector<Value, 4>(OpBuilder &, Operation *)>
+ getInnerTileSizeFn =
+ [&launchConfig](OpBuilder &builder,
+ Operation *operation) -> SmallVector<Value, 4> {
+ ArrayRef<int64_t> tileSizes = launchConfig.getTileSizes(operation, 1);
+ if (tileSizes.empty()) return {};
+ SmallVector<Value, 4> tileSizesVal;
+ tileSizesVal.reserve(tileSizes.size());
+ for (auto val : tileSizes) {
+ tileSizesVal.push_back(
+ builder.create<ConstantIndexOp>(operation->getLoc(), val));
+ }
+ return tileSizesVal;
+ };
+
+ auto getSubgroupProcInfoFn =
+ [&launchConfig](OpBuilder &builder, Location loc,
+ ArrayRef<SubViewOp::Range> parallelLoopRanges) {
+ ArrayRef<int64_t> numSubgroups =
+ launchConfig.getNumSubgroups().take_front(
+ parallelLoopRanges.size());
+ return getSubgroupIdsAndCounts(builder, loc, numSubgroups);
+ };
+ linalg::LinalgLoopDistributionOptions subgroupDistributionOptions = {
+ getSubgroupProcInfoFn,
+ {linalg::DistributionMethod::CyclicNumProcsEqNumIters,
+ linalg::DistributionMethod::CyclicNumProcsEqNumIters}};
+ patterns.insert<TileMatmulSubgroupPattern>(
+ context, linalg::LinalgTilingOptions()
+ .setLoopType(linalg::LinalgTilingLoopType::ParallelLoops)
+ .setTileSizeComputationFunction(getInnerTileSizeFn)
+ .setDistributionOptions(subgroupDistributionOptions));
+}
+
+//====---------------------------------------------------------------------===//
+// Patterns for vectorization
+//====---------------------------------------------------------------------===//
+
+static void populateVectorizationPatterns(MLIRContext *context,
+ const LaunchConfig &launchConfig,
+ OwningRewritePatternList &patterns) {
+ patterns.insert<linalg::LinalgVectorizationPattern<linalg::MatmulOp>>(
+ context,
+ linalg::LinalgMarker(Identifier::get(getVectorizeMarker(), context)));
+}
+
+/// Apply canonicalizations related to tiling to make promotion/vectorization
+/// easier.
+static void applyCanonicalizationPatterns(MLIRContext *context, Operation *op) {
+ OwningRewritePatternList canonicalizationPatterns;
+ canonicalizationPatterns.insert<AffineMinCanonicalizationPattern>(context);
+ AffineApplyOp::getCanonicalizationPatterns(canonicalizationPatterns, context);
+ AffineMinOp::getCanonicalizationPatterns(canonicalizationPatterns, context);
+ SubViewOp::getCanonicalizationPatterns(canonicalizationPatterns, context);
+ applyPatternsAndFoldGreedily(op, canonicalizationPatterns);
+}
+
void LinalgTileAndFusePass::runOnOperation() {
MLIRContext *context = &getContext();
ModuleOp module = getOperation();
+ // Override options with command line values.
+ if (!tileSizes.empty())
+ options.tileSizes.assign(tileSizes.begin(), tileSizes.end());
+ if (!workgroupSize.empty())
+ options.workgroupSize.assign(workgroupSize.begin(), workgroupSize.end());
+ if (useWorkgroupMemory) options.useWorkgroupMemory = true;
+ if (useVectorization) options.useVectorization = true;
+
LLVM_DEBUG(
llvm::dbgs() << "--- IREE Linalg tile and fuse configuration ---\n";);
for (FuncOp funcOp : module.getOps<FuncOp>()) {
@@ -397,58 +417,62 @@
}
Block &block = body.front();
auto linalgOps = block.getOps<linalg::LinalgOp>();
- if (linalgOps.empty()) return;
+ if (linalgOps.empty()) continue;
- TileSizeCalculator tileSizeCalculator(funcOp);
- if (tileSizes.empty()) {
- // Get the tile sizes to use for the lowering.
- SmallVector<int64_t, 3> tileSizes;
- SmallVector<linalg::LinalgOp, 1> opsVec(linalgOps.begin(),
- linalgOps.end());
- if (failed(tileSizeCalculator.inferTileAndWorkgroupSize(opsVec)))
- return signalPassFailure();
- } else {
- tileSizeCalculator.setTileSizes(tileSizes);
- if (!workgroupSize.empty())
- tileSizeCalculator.setWorkgroupSize(workgroupSize);
+ LaunchConfig launchConfig;
+ SmallVector<linalg::LinalgOp, 4> linalgOpsVec(linalgOps.begin(),
+ linalgOps.end());
+ if (failed(launchConfig.init(options, linalgOpsVec))) {
+ funcOp.emitError("unable to find launch configuration");
+ return signalPassFailure();
}
LLVM_DEBUG({
llvm::dbgs() << "@func " << funcOp.getName() << ": # workgroup sizes: [";
- interleaveComma(tileSizeCalculator.getWorkgroupSize(), llvm::dbgs());
- llvm::dbgs() << "]\ntile sizes: [";
- interleaveComma(tileSizeCalculator.getTileSizes(), llvm::dbgs());
+ interleaveComma(launchConfig.getWorkgroupSize(), llvm::dbgs());
llvm::dbgs() << "]\n";
+ for (auto op : linalgOps) {
+ llvm::dbgs() << "\t" << op.getOperation()->getName() << " : ";
+ TileSizesListType const &tileSizes = launchConfig.getTileSizes(op);
+ llvm::dbgs() << "{";
+ std::string sep = "";
+ for (auto &level : enumerate(tileSizes)) {
+ llvm::dbgs() << sep << level.index() << " : [";
+ sep = ", ";
+ interleaveComma(level.value(), llvm::dbgs());
+ llvm::dbgs() << "]";
+ }
+ llvm::dbgs() << "}\n";
+ }
});
- OwningRewritePatternList tilingPatterns;
- tilingPatterns.insert<TileConvPoolPattern<linalg::ConvOp>,
- TileMatmulPattern<linalg::MatmulOp>,
- TileMatmulPattern<linalg::BatchMatmulOp>,
- TileConvPoolPattern<linalg::PoolingMaxOp>,
- TileConvPoolPattern<linalg::PoolingMinOp>,
- TileConvPoolPattern<linalg::PoolingSumOp>>(
- context,
- linalg::LinalgTilingOptions()
- .setTileSizes(tileSizeCalculator.getTileSizes())
- .setLoopType(linalg::LinalgTilingLoopType::ParallelLoops),
- tileSizeCalculator.getTileSizes(),
- tileSizeCalculator.getWorkgroupSize());
- applyPatternsAndFoldGreedily(funcOp, tilingPatterns);
+ OwningRewritePatternList firstLevelTilingPatterns;
+ populateTilingToWorkgroupPatterns(context, launchConfig,
+ firstLevelTilingPatterns);
+ applyPatternsAndFoldGreedily(funcOp, firstLevelTilingPatterns);
+ applyCanonicalizationPatterns(context, funcOp);
- if (useWorkgroupMemory) {
+ if (options.useWorkgroupMemory) {
// The promotion patterns are put separate from the tiling patterns to
// make sure that the allocated scratchspace memory is constant sizes
// which requires some folding to trigger.
OwningRewritePatternList promotionPatterns;
- promotionPatterns.insert<PromoteMatmulSubviewsPattern,
- PromoteConvolutionSubviewsPattern>(
- context,
- linalg::LinalgPromotionOptions()
- .setAllocationDeallocationFns(allocateWorkgroupMemory,
- deallocateWorkgroupMemory)
- .setCopyInOutFns(copyToWorkgroupMemory, copyToWorkgroupMemory));
+ populatePromotionPatterns(context, promotionPatterns);
applyPatternsAndFoldGreedily(funcOp, promotionPatterns);
+ applyCanonicalizationPatterns(context, funcOp);
+ }
+
+ if (options.useVectorization) {
+ OwningRewritePatternList secondLevelTilingPatterns;
+ populateTilingToSubgroupPatterns(context, launchConfig,
+ secondLevelTilingPatterns);
+ applyPatternsAndFoldGreedily(funcOp, secondLevelTilingPatterns);
+ applyCanonicalizationPatterns(context, funcOp);
+
+ OwningRewritePatternList vectorizationPatterns;
+ populateVectorizationPatterns(context, launchConfig,
+ vectorizationPatterns);
+ applyPatternsAndFoldGreedily(funcOp, vectorizationPatterns);
}
}
}
@@ -458,10 +482,8 @@
//===----------------------------------------------------------------------===//
std::unique_ptr<OperationPass<ModuleOp>> createLinalgTileAndFusePass(
- ArrayRef<int64_t> workgroupSize, ArrayRef<int64_t> tileSizes,
- bool useWorkgroupMemory) {
- return std::make_unique<LinalgTileAndFusePass>(workgroupSize, tileSizes,
- useWorkgroupMemory);
+ const SPIRVCodegenOptions &options) {
+ return std::make_unique<LinalgTileAndFusePass>(options);
}
static PassRegistration<LinalgTileAndFusePass> pass(
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp b/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
index d82d79f..04044a7 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
@@ -24,6 +24,7 @@
#include "iree/compiler/Conversion/HLOToLinalg/Passes.h"
#include "iree/compiler/Conversion/LinalgToVector/Passes.h"
#include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
+#include "llvm/Support/CommandLine.h"
#include "mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.h"
#include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h"
#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h"
@@ -51,22 +52,16 @@
namespace iree_compiler {
namespace {
-/// Command line options for use with SPIR-V code-generation pass pipeline.
-struct SPIRVCodegenClOpts : public PassPipelineOptions<SPIRVCodegenClOpts> {
- ListOption<int64_t> workgroupSize{
- *this, "workgroup-size",
+
+/// Linalg to SPIR-V pass pipeline options. In theory this is a superset of all
+/// options in all passes in the pipeline. Adding those based on need, and they
+/// should be needed only for testing.
+struct LinalgToSPIRVPassPipelineOptions
+ : public PassPipelineOptions<LinalgToSPIRVPassPipelineOptions> {
+ Option<bool> useVectorization{
+ *this, "use-vectorization",
llvm::cl::desc(
- "Number of workgroups to dispatch for the SPIR-V module; at most "
- "three integers standarding for the x, y, and z dimension; "
- "additional arguments will be ignored (used only for testing)"),
- llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
- ListOption<int64_t> tileSizes{
- *this, "tile-sizes", llvm::cl::desc("Set tile sizes to use"),
- llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
- Option<bool> useWorkgroupMemory{
- *this, "use-workgroup-memory",
- llvm::cl::desc(
- "Enable use of workgroup memory in SPIR-V code generation pipeline"),
+ "Enable use of vectorization in SPIR-V code generation pipeline"),
llvm::cl::init(false)};
Option<bool> useVectorPass{
*this, "use-vector-pass",
@@ -107,8 +102,7 @@
// with the second tile and fuse pass.
//===--------------------------------------------------------------------===//
pm.addPass(createSplitDispatchFunctionPass());
- pm.addPass(createLinalgTileAndFusePass(
- options.workgroupSize, options.tileSizes, options.useWorkgroupMemory));
+ pm.addPass(createLinalgTileAndFusePass(options));
if (options.useVectorPass) {
pm.addPass(createLoadStoreVectorizationPass());
}
@@ -261,30 +255,33 @@
addLinalgToSPIRVPasses(pm, options);
}
-static SPIRVCodegenOptions getSPIRVCodegenOptions(
- const SPIRVCodegenClOpts &clOpts) {
- SPIRVCodegenOptions options;
- options.workgroupSize.assign(clOpts.workgroupSize.begin(),
- clOpts.workgroupSize.end());
- options.tileSizes.assign(clOpts.tileSizes.begin(), clOpts.tileSizes.end());
- options.useWorkgroupMemory = clOpts.useWorkgroupMemory;
- return options;
+static SPIRVCodegenOptions initializeCodegenOptions(
+ const LinalgToSPIRVPassPipelineOptions &options) {
+ SPIRVCodegenOptions codegenOptions;
+ codegenOptions.useVectorization = options.useVectorization;
+ return codegenOptions;
}
-static PassPipelineRegistration<SPIRVCodegenClOpts> linalgToSPIRVPipeline(
- "iree-codegen-linalg-to-spirv-pipeline",
- "Runs the progressive lowering pipeline from Linalg to SPIR-V",
- [](OpPassManager &passManager, const SPIRVCodegenClOpts &options) {
- addLinalgToSPIRVPasses(passManager, getSPIRVCodegenOptions(options));
- });
+static PassPipelineRegistration<LinalgToSPIRVPassPipelineOptions>
+ linalgToSPIRVPipeline(
+ "iree-codegen-linalg-to-spirv-pipeline",
+ "Runs the progressive lowering pipeline from Linalg to SPIR-V",
+ [](OpPassManager &passManager,
+ const LinalgToSPIRVPassPipelineOptions &options) {
+ addLinalgToSPIRVPasses(passManager,
+ initializeCodegenOptions(options));
+ });
-static PassPipelineRegistration<SPIRVCodegenClOpts> hloToLinalgSPIRVPipeline(
- "iree-codegen-hlo-to-spirv-pipeline",
- "Runs the progressive lowering pipeline from XLA HLO to Linalg to SPIR-V",
- [](OpPassManager &passManager, const SPIRVCodegenClOpts &options) {
- buildSPIRVTransformPassPipeline(passManager,
- getSPIRVCodegenOptions(options));
- });
+static PassPipelineRegistration<LinalgToSPIRVPassPipelineOptions>
+ hloToLinalgSPIRVPipeline(
+ "iree-codegen-hlo-to-spirv-pipeline",
+ "Runs the progressive lowering pipeline from XLA HLO to Linalg to "
+ "SPIR-V",
+ [](OpPassManager &passManager,
+ const LinalgToSPIRVPassPipelineOptions &options) {
+ buildSPIRVTransformPassPipeline(passManager,
+ initializeCodegenOptions(options));
+ });
} // namespace iree_compiler
} // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Passes.h b/iree/compiler/Conversion/LinalgToSPIRV/Passes.h
index 38483ca..80805e2 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/Passes.h
+++ b/iree/compiler/Conversion/LinalgToSPIRV/Passes.h
@@ -17,6 +17,7 @@
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassOptions.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
@@ -27,6 +28,7 @@
SmallVector<int64_t, 3> workgroupSize = {};
SmallVector<int64_t, 3> tileSizes = {};
bool useWorkgroupMemory = false;
+ bool useVectorization = false;
bool useVectorPass = false;
};
@@ -41,8 +43,7 @@
/// it exists) and along "z" for the next loop (if it exists). The workgroup
/// size is expected to be of size at-most 3.
std::unique_ptr<OperationPass<ModuleOp>> createLinalgTileAndFusePass(
- ArrayRef<int64_t> workGroupSize = {}, ArrayRef<int64_t> tileSizes = {},
- bool useWorkgroupMem = false);
+ const SPIRVCodegenOptions &options);
/// Pass to add the synchronizations and attributes needed to lower from PLoops
/// to GPU dialect.
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Utils.cpp b/iree/compiler/Conversion/LinalgToSPIRV/Utils.cpp
index 47cc37d..2c14351 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/Utils.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/Utils.cpp
@@ -107,6 +107,11 @@
Value blockId = builder.create<gpu::BlockIdOp>(loc, indexType, attr);
Value blockDim = builder.create<gpu::BlockDimOp>(loc, indexType, attr);
Value threadId = builder.create<gpu::ThreadIdOp>(loc, indexType, attr);
+ // TODO(ravishankarm): Using affine_maps here would be beneficial, and we can
+ // do this because the blockDim is constant. But this would lead to an
+ // ordering issue cause it assumes that the workgroup size has already been
+ // set. If using affine_map can help, make sure that the workgroup size is set
+ // before.
return {builder.create<AddIOp>(
loc, builder.create<MulIOp>(loc, blockId, blockDim), threadId),
builder.create<MulIOp>(loc, blockDim, gridDim)};
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Utils.h b/iree/compiler/Conversion/LinalgToSPIRV/Utils.h
index 934728d..682583a 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/Utils.h
+++ b/iree/compiler/Conversion/LinalgToSPIRV/Utils.h
@@ -62,7 +62,6 @@
SmallVector<linalg::ProcInfo, 2> getGPUProcessorIdsAndCounts(OpBuilder &builder,
Location loc,
unsigned numDims);
-
} // namespace iree_compiler
} // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir
index e06eae5..f6604bf 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir
@@ -64,7 +64,6 @@
// CHECK-DAG: %[[ARG0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg0
// CHECK-DAG: %[[ARG1:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg1
// CHECK-DAG: %[[RET0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@ret0
-// CHECK-DAG: %[[C0:.+]] = constant 0
// CHECK-DAG: %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
// CHECK-DAG: %[[NBLOCKSX:.+]] = "gpu.grid_dim"() {dimension = "x"}
// CHECK-DAG: %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
@@ -78,9 +77,9 @@
// CHECK: scf.parallel (%[[IV0:.+]], %[[IV1:.+]], %[[IV2:.+]]) = (%[[BIDZ]], %[[LBY]], %[[LBX]])
// CHECK-SAME: step (%[[NBLOCKSZ]], %[[STEPY]], %[[STEPX]])
// CHECK: %[[VIEW1:.+]] = subview %[[ARG1]]
-// CHECK-SAME: [%[[IV0]], %[[IV1]], %[[IV2]], %[[C0]]]
+// CHECK-SAME: [%[[IV0]], %[[IV1]], %[[IV2]], 0]
// CHECK: %[[VIEW2:.+]] = subview %[[RET0]]
-// CHECK-SAME: [%[[IV0]], %[[IV1]], %[[IV2]], %[[C0]]]
+// CHECK-SAME: [%[[IV0]], %[[IV1]], %[[IV2]], 0]
// CHECK: linalg.conv
// CHECK-SAME: %[[ARG0]], %[[VIEW1]], %[[VIEW2]]
// CHECK-SAME: "workgroup"
@@ -116,37 +115,38 @@
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 8)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<()[s0] -> (s0 * 16)>
// CHECK: func @matmul()
-// CHECK-SAME: local_size = dense<[8, 8, 1]>
+// CHECK-SAME: local_size = dense<[16, 8, 1]>
// CHECK-SAME: vkspv.num_workgroups_fn = @[[NUM_WORKGROUPS_FN:.[a-zA-Z0-9_]+]]
// CHECK-DAG: %[[ARG0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg0
// CHECK-DAG: %[[ARG1:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg1
// CHECK-DAG: %[[RET0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@ret0
-// CHECK-DAG: %[[C0:.+]] = constant 0
-// CHECK-DAG: %[[C4:.+]] = constant 4
// CHECK-DAG: %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
// CHECK-DAG: %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
// CHECK-NOT: scf.parallel
-// CHECK: scf.for %[[IV:.+]] = %[[C0]] to %{{.+}} step %[[C4]]
-// CHECK: %[[LBY:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
-// CHECK: %[[VIEW0:.+]] = subview %[[ARG0]][%[[LBY]], %[[IV]]
-// CHECK: %[[LBX:.+]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
-// CHECK: %[[VIEW1:.+]] = subview %[[ARG1]][%[[IV]], %[[LBX]]]
-// CHECK: %[[LBY_2:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
-// CHECK: %[[LBX_2:.+]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
-// CHECK: %[[VIEW2:.+]] = subview %[[RET0]][%[[LBY_2]], %[[LBX_2]]]
-// CHECK: linalg.matmul
-// CHECK-SAME: "workgroup_numprocs_ge_numiters"
-// CHECK-SAME: %[[VIEW0]], %[[VIEW1]], %[[VIEW2]]
+// CHECK-NOT: scf.for
+// CHECK: %[[LBY:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
+// CHECK: %[[VIEW0:.+]] = subview %[[ARG0]][%[[LBY]], 0]
+// CHECK: %[[LBX:.+]] = affine.apply #[[MAP3]]()[%[[BIDX]]]
+// CHECK: %[[VIEW1:.+]] = subview %[[ARG1]][0, %[[LBX]]]
+// CHECK: %[[LBY_2:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
+// CHECK: %[[LBX_2:.+]] = affine.apply #[[MAP3]]()[%[[BIDX]]]
+// CHECK: %[[VIEW2:.+]] = subview %[[RET0]][%[[LBY_2]], %[[LBX_2]]]
+// CHECK: linalg.matmul
+// CHECK-SAME: "workgroup_numprocs_ge_numiters"
+// CHECK-SAME: %[[VIEW0]], %[[VIEW1]], %[[VIEW2]]
// CHECK: func @[[NUM_WORKGROUPS_FN]]
// CHECK-DAG: %[[C8:.+]] = constant 8 : index
// CHECK-DAG: %[[C7:.+]] = constant 7 : index
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[C16:.+]] = constant 16 : index
+// CHECK-DAG: %[[C15:.+]] = constant 15 : index
// CHECK: %[[DIM0:.+]] = dim %{{.*}}, %[[C0]]
// CHECK: %[[DIM1:.+]] = dim %{{.*}}, %[[C1]]
-// CHECK: %[[T0:.+]] = addi %[[DIM1]], %[[C7]]
-// CHECK: %[[T1:.+]] = divi_signed %[[T0]], %[[C8]]
+// CHECK: %[[T0:.+]] = addi %[[DIM1]], %[[C15]]
+// CHECK: %[[T1:.+]] = divi_signed %[[T0]], %[[C16]]
// CHECK: %[[T2:.+]] = addi %[[DIM0]], %[[C7]]
// CHECK: %[[T3:.+]] = divi_signed %[[T2]], %[[C8]]
// CHECK: return %[[T1]], %[[T3]], %[[C1]]
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_subgroup.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_subgroup.mlir
new file mode 100644
index 0000000..8421f81
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_subgroup.mlir
@@ -0,0 +1,77 @@
+// RUN: iree-opt -split-input-file -iree-codegen-linalg-tile-and-fuse=use-vectorization %s | IreeFileCheck %s
+
+module attributes {
+ spv.target_env =
+ #spv.target_env<#spv.vce<v1.3,
+ [Shader, CooperativeMatrixNV],
+ [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>,
+ {max_compute_workgroup_invocations = 512 : i32,
+ max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
+ func @matmul_static_shape()
+ attributes {vkspv.num_workgroups_fn = @matmul_static_shape__num_workgroups__} {
+ %arg0 = iree.placeholder for "interface buffer"
+ {binding = @legacy_io::@arg0, operand_result_num = 0 : i32} : memref<128x64xf16>
+ %arg1 = iree.placeholder for "interface buffer"
+ {binding = @legacy_io::@arg1, operand_result_num = 1 : i32} : memref<64x256xf16>
+ %ret0 = iree.placeholder for "interface buffer"
+ {binding = @legacy_io::@ret0, operand_result_num = 2 : i32} : memref<128x256xf16>
+ linalg.matmul %arg0, %arg1, %ret0 :
+ (memref<128x64xf16>, memref<64x256xf16>, memref<128x256xf16>)
+ return
+ }
+ func @matmul_static_shape__num_workgroups__
+ (!shapex.ranked_shape<[128, 64]>, !shapex.ranked_shape<[64, 256]>,
+ !shapex.ranked_shape<[128, 256]>) -> (index, index, index)
+ attributes {sym_visibility = "private"}
+ hal.interface @legacy_io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
+ }
+}
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 32)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 4) * 32)>
+// CHECK: func @matmul_static_shape
+// CHECK-DAG: %[[ARG0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg0
+// CHECK-DAG: %[[ARG1:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg1
+// CHECK-DAG: %[[RET0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@ret0
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[CST:.+]] = constant 0.0
+// CHECK-DAG: %[[C4:.+]] = constant 4 : index
+// CHECK: %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
+// CHECK: %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
+// CHECK: %[[BOFFSET_Y:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
+// CHECK: %[[SUBVIEW_LHS:.+]] = subview %[[ARG0]]
+// CHECK-SAME: [%[[BOFFSET_Y]], 0] [32, 64]
+// CHECK: %[[BOFFSET_X:.+]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
+// CHECK: %[[SUBVIEW_RHS:.+]] = subview %[[ARG1]]
+// CHECK-SAME: [0, %[[BOFFSET_X]]] [64, 32]
+// CHECK: %[[BOFFSET_Y_2:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
+// CHECK: %[[BOFFSET_X_2:.+]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
+// CHECK: %[[SUBVIEW_RESULT:.+]] = subview %[[RET0]]
+// CHECK-SAME: [%[[BOFFSET_Y_2]], %[[BOFFSET_X_2]]] [32, 32]
+// CHECK: %[[SGID:.+]] = gpu.subgroup_id
+// CHECK: %[[SGID_Y:.+]] = divi_signed %[[SGID]], %[[C4]]
+// CHECK: scf.for %[[IV2:.+]] =
+// CHECK: %[[SGOFFSET_Y:.+]] = affine.apply #[[MAP3]]()[%[[SGID_Y]]]
+// CHECK: %[[SUBVIEW2_LHS:.+]] = subview %[[SUBVIEW_LHS]]
+// CHECK-SAME: [%[[SGOFFSET_Y]], %[[IV2]]] [8, 16]
+// CHECK: %[[SGOFFSET_X:.+]] = affine.apply #[[MAP3]]()[%[[SGID]]]
+// CHECK: %[[SUBVIEW2_RHS:.+]] = subview %[[SUBVIEW_RHS]]
+// CHECK-SAME: [%[[IV2]], %[[SGOFFSET_X]]] [16, 8]
+// CHECK: %[[SGOFFSET_Y_2:.+]] = affine.apply #[[MAP3]]()[%[[SGID_Y]]]
+// CHECK: %[[SGOFFSET_X_2:.+]] = affine.apply #[[MAP3]]()[%[[SGID]]]
+// CHECK: %[[SUBVIEW2_RESULT:.+]] = subview %[[SUBVIEW_RESULT]]
+// CHECK-SAME: [%[[SGOFFSET_Y_2]], %[[SGOFFSET_X_2]]] [8, 8]
+// CHECK: %[[VTR_LHS:.+]] = vector.transfer_read %[[SUBVIEW2_LHS]]
+// CHECK-SAME: [%[[C0]], %[[C0]]], %[[CST]] {masked = [false, false]}
+// CHECK: %[[VTR_RHS:.+]] = vector.transfer_read %[[SUBVIEW2_RHS]]
+// CHECK-SAME: [%[[C0]], %[[C0]]], %[[CST]] {masked = [false, false]}
+// CHECK: %[[VTR_RESULT:.+]] = vector.transfer_read %[[SUBVIEW2_RESULT]]
+// CHECK-SAME: [%[[C0]], %[[C0]]], %[[CST]] {masked = [false, false]}
+// CHECK: %[[VECTOR_CONTRACT:.+]] = vector.contract
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
+// CHECK-SAME: vector<8x16xf16>, vector<16x8xf16> into vector<8x8xf16>
+// CHECK: vector.transfer_write %[[VECTOR_CONTRACT]], %[[SUBVIEW2_RESULT]]
+// CHECK-SAME: masked = [false, false]
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test.mlir
new file mode 100644
index 0000000..ddd2530
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test.mlir
@@ -0,0 +1,37 @@
+// RUN: iree-opt -split-input-file -iree-codegen-linalg-to-spirv-pipeline=use-vectorization %s | IreeFileCheck %s
+
+module attributes {
+ spv.target_env =
+ #spv.target_env<#spv.vce<v1.3,
+ [Float16, Shader, CooperativeMatrixNV],
+ [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>,
+ {max_compute_workgroup_invocations = 512 : i32,
+ max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
+ func @matmul_static_shape()
+ attributes {vkspv.num_workgroups_fn = @matmul_static_shape__num_workgroups__} {
+ %0 = iree.placeholder for "interface buffer"
+ {binding = @legacy_io::@arg0, operand_result_num = 0} : memref<128x64xf16>
+ %1 = iree.placeholder for "interface buffer"
+ {binding = @legacy_io::@arg1, operand_result_num = 1} : memref<64x256xf16>
+ %2 = iree.placeholder for "interface buffer"
+ {binding = @legacy_io::@ret0, operand_result_num = 2} : memref<128x256xf16>
+ linalg.matmul %0, %1, %2 :
+ (memref<128x64xf16>, memref<64x256xf16>, memref<128x256xf16>)
+ return
+ }
+ func @matmul_static_shape__num_workgroups__
+ (!shapex.ranked_shape<[128, 64]>, !shapex.ranked_shape<[64, 256]>,
+ !shapex.ranked_shape<[128, 256]>) -> (index, index, index)
+ attributes {sym_visibility = "private"}
+ hal.interface @legacy_io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+ }
+}
+// CHECK-LABEL: spv.func @matmul_static_shape
+// CHECK: spv.CooperativeMatrixLoadNV
+// CHECK: spv.CooperativeMatrixLoadNV
+// CHECK: spv.CooperativeMatrixLoadNV
+// CHECK: spv.CooperativeMatrixMulAddNV
+// CHECK: spv.CooperativeMatrixStoreNV
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir
index 9ff21ba..5b1d440 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir
@@ -5,8 +5,7 @@
#spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
{max_compute_workgroup_invocations = 128 : i32,
max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
- func @matmul_tile()
- attributes {signature = (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>)} {
+ func @matmul_tile() attributes {vkspv.num_workgroups_fn = @matmul_tile__num_workgroups__} {
%0 = iree.placeholder for "interace buffer"
{binding = @legacy_io::@arg0, operand_result_index = 0 : i32} : memref<?x?xf32>
%1 = iree.placeholder for "interace buffer"
@@ -16,6 +15,10 @@
linalg.matmul %0, %1, %2 : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
return
}
+ func @matmul_tile__num_workgroups__
+ (!shapex.ranked_shape<[?,?]>, !shapex.ranked_shape<[?,?]>,
+ !shapex.ranked_shape<[?,?]>) -> (index, index, index)
+ attributes {sym_visibility = "private"}
hal.interface @legacy_io attributes {sym_visibility = "private"} {
hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
@@ -30,9 +33,9 @@
// CHECK: %[[ARG0SV:.+]] = subview %[[ARG0]]
// CHECK: %[[ARG1SV:.+]] = subview %[[ARG1]]
// CHECK: %[[RET0SV:.+]] = subview %[[RET0]]
-// CHECK: %[[ALLOC1:.+]] = alloc() : memref<8x4xf32, 3>
+// CHECK: %[[ALLOC1:.+]] = alloc() : memref<8x32xf32, 3>
// CHECK: %[[SUBVIEW1:.+]] = subview %[[ALLOC1]]
-// CHECK: %[[ALLOC2:.+]] = alloc() : memref<4x8xf32, 3>
+// CHECK: %[[ALLOC2:.+]] = alloc() : memref<32x16xf32, 3>
// CHECK: %[[SUBVIEW2:.+]] = subview %[[ALLOC2]]
// CHECK: linalg.copy(%[[ARG0SV]], %[[SUBVIEW1]])
// CHECK-SAME: "copy_to_workgroup_memory"
@@ -41,8 +44,8 @@
// CHECK: linalg.matmul
// CHECK-SAME: "workgroup_memory_numprocs_ge_numiters"
// CHECK-SAME: %[[SUBVIEW1]], %[[SUBVIEW2]], %[[RET0SV]]
-// CHECK-DAG: dealloc %[[ALLOC1]] : memref<8x4xf32, 3>
-// CHECK-DAG: dealloc %[[ALLOC2]] : memref<4x8xf32, 3>
+// CHECK-DAG: dealloc %[[ALLOC1]] : memref<8x32xf32, 3>
+// CHECK-DAG: dealloc %[[ALLOC2]] : memref<32x16xf32, 3>
// -----
@@ -51,8 +54,7 @@
#spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
{max_compute_workgroup_invocations = 128 : i32,
max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
- func @conv_no_padding_tile()
- attributes {signature = (tensor<3x4x3x2xf32>, tensor<?x?x?x3xf32>) -> (tensor<?x?x?x2xf32>)} {
+ func @conv_no_padding_tile() {
%0 = iree.placeholder for "interace buffer"
{binding = @legacy_io::@arg0, operand_result_index = 0 : i32} : memref<3x4x3x2xf32>
%1 = iree.placeholder for "interace buffer"
@@ -63,6 +65,11 @@
: memref<3x4x3x2xf32>, memref<?x?x?x3xf32>, memref<?x?x?x2xf32>
return
}
+ hal.interface @legacy_io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
+ }
}
// CHECK: func @conv_no_padding_tile()
// CHECK-DAG: %[[ARG0:.+]] = iree.placeholder for "interace buffer" {binding = @legacy_io::@arg0
diff --git a/iree/compiler/Conversion/init_conversions.h b/iree/compiler/Conversion/init_conversions.h
index a5c42e7..c06e2d8 100644
--- a/iree/compiler/Conversion/init_conversions.h
+++ b/iree/compiler/Conversion/init_conversions.h
@@ -46,7 +46,7 @@
static bool init_once = []() {
// LinalgToSPIRV
createConvertToGPUPass();
- createLinalgTileAndFusePass();
+ createLinalgTileAndFusePass(SPIRVCodegenOptions());
createSplitDispatchFunctionPass();
createVectorToGPUPass();
createMatMulTileAndVectorizeGPUPass();