Merge pull request #3172 from ScottTodd:main-to-google
PiperOrigin-RevId: 332123717
diff --git a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
index 19a0fd3..7c0f678 100644
--- a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
+++ b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
@@ -32,6 +32,7 @@
"@llvm-project//mlir:Affine": ["MLIRAffineOps"],
"@llvm-project//mlir:AffineToStandardTransforms": ["MLIRAffineToStandard"],
"@llvm-project//mlir:CFGTransforms": ["MLIRSCFToStandard"],
+ "@llvm-project//mlir:DialectUtils": [""],
"@llvm-project//mlir:ExecutionEngineUtils": ["MLIRExecutionEngine"],
"@llvm-project//mlir:GPUDialect": ["MLIRGPU"],
"@llvm-project//mlir:GPUTransforms": ["MLIRGPU"],
diff --git a/experimental/ModelBuilder/test/BUILD b/experimental/ModelBuilder/test/BUILD
index 64a88a9..9c6c084 100644
--- a/experimental/ModelBuilder/test/BUILD
+++ b/experimental/ModelBuilder/test/BUILD
@@ -27,7 +27,10 @@
name = "lit",
srcs = glob(
["*.cpp"],
- exclude = ["Bench*"],
+ exclude = [
+ "Bench*",
+ "TestMatMulVulkan.cpp", # b/168746423
+ ],
),
data = [
# runtime libraries
diff --git a/experimental/ModelBuilder/test/TestMatMulVulkan.cpp b/experimental/ModelBuilder/test/TestMatMulVulkan.cpp
index 4252271..61a01c1 100644
--- a/experimental/ModelBuilder/test/TestMatMulVulkan.cpp
+++ b/experimental/ModelBuilder/test/TestMatMulVulkan.cpp
@@ -97,13 +97,14 @@
ModelRunner runner(modelBuilder.getModuleRef(),
ModelRunner::Target::GPUTarget);
CompilationOptions options;
+ mlir::iree_compiler::SPIRVCodegenOptions codegenOptions;
SmallVector<Type, 3> args = {typeA, typeB, typeC};
- SmallVector<int64_t, 4> vWorkgroupSizes(workgroupSize.begin(),
- workgroupSize.end());
- SmallVector<int64_t, 4> vTileSizes(tileSizes.begin(), tileSizes.end());
+ codegenOptions.workgroupSize.assign(workgroupSize.begin(),
+ workgroupSize.end());
+ codegenOptions.tileSizes.assign(tileSizes.begin(), tileSizes.end());
auto lowering = [&](mlir::PassManager &pm) {
- pm.addPass(mlir::iree_compiler::createLinalgTileAndFusePass(
- vWorkgroupSizes, vTileSizes, useWorkgroupMemory));
+ pm.addPass(
+ mlir::iree_compiler::createLinalgTileAndFusePass(codegenOptions));
pm.addPass(mlir::iree_compiler::createConvertToGPUPass());
pm.addPass(mlir::createLowerAffinePass());
pm.addPass(mlir::createLegalizeStdOpsForSPIRVLoweringPass());
@@ -119,14 +120,14 @@
spirvModulePM.addPass(
mlir::spirv::createUpdateVersionCapabilityExtensionPass());
- int numWorkgroupX =
- vWorkgroupSizes.empty()
- ? 1
- : (width + vWorkgroupSizes[0] - 1) / vWorkgroupSizes[0];
- int numWorkgroupY =
- vWorkgroupSizes.size() < 2
- ? 1
- : (height + vWorkgroupSizes[1] - 1) / vWorkgroupSizes[1];
+ int numWorkgroupX = codegenOptions.tileSizes.empty()
+ ? 1
+ : (width + codegenOptions.tileSizes[0] - 1) /
+ codegenOptions.tileSizes[0];
+ int numWorkgroupY = codegenOptions.tileSizes.size() < 2
+ ? 1
+ : (height + codegenOptions.tileSizes[1] - 1) /
+ codegenOptions.tileSizes[1];
pm.addPass(mlir::createAddVulkanLaunchWrapperPass(
{numWorkgroupX, numWorkgroupY, 1}, args));
mlir::LowerToLLVMOptions llvmOptions = {
diff --git a/iree/compiler/Conversion/BUILD b/iree/compiler/Conversion/BUILD
index 35f424e..bfa7317 100644
--- a/iree/compiler/Conversion/BUILD
+++ b/iree/compiler/Conversion/BUILD
@@ -27,5 +27,6 @@
"//iree/compiler/Conversion/HLOToLinalg",
"//iree/compiler/Conversion/LinalgToLLVM",
"//iree/compiler/Conversion/LinalgToSPIRV",
+ "//iree/compiler/Conversion/LinalgToVector",
],
)
diff --git a/iree/compiler/Conversion/CMakeLists.txt b/iree/compiler/Conversion/CMakeLists.txt
index 6a022f2..79ead57 100644
--- a/iree/compiler/Conversion/CMakeLists.txt
+++ b/iree/compiler/Conversion/CMakeLists.txt
@@ -23,5 +23,6 @@
iree::compiler::Conversion::HLOToLinalg
iree::compiler::Conversion::LinalgToLLVM
iree::compiler::Conversion::LinalgToSPIRV
+ iree::compiler::Conversion::LinalgToVector
PUBLIC
)
diff --git a/iree/compiler/Conversion/HLOToLinalg/BUILD b/iree/compiler/Conversion/HLOToLinalg/BUILD
index da3a73e..a1f9de8 100644
--- a/iree/compiler/Conversion/HLOToLinalg/BUILD
+++ b/iree/compiler/Conversion/HLOToLinalg/BUILD
@@ -38,6 +38,7 @@
"//iree/compiler/Dialect/IREE/IR",
"//iree/compiler/Dialect/Shape/IR",
"@llvm-project//llvm:Support",
+ "@llvm-project//mlir:DialectUtils",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgOps",
"@llvm-project//mlir:LinalgTransforms",
diff --git a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
index 87ed137..be68c3a 100644
--- a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
+++ b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
@@ -36,6 +36,7 @@
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Function.h"
@@ -79,13 +80,11 @@
/// Returns an ArrayAttr that contains `nLoops` attributes. All the attributes
/// are "parallel" except the last `nReduction` elements, where are "reduction"
/// attributes.
-// TODO(hanchung): Use helpers in StructuredOpsUtils.h instead of hardcoded
-// strings once the build system is set up.
static ArrayAttr getParallelAndReductionIterAttrs(Builder b, unsigned nLoops,
unsigned nReduction) {
- SmallVector<Attribute, 3> attrs(nLoops - nReduction,
- b.getStringAttr("parallel"));
- attrs.append(nReduction, b.getStringAttr("reduction"));
+ SmallVector<Attribute, 3> attrs(
+ nLoops - nReduction, b.getStringAttr(getParallelIteratorTypeName()));
+ attrs.append(nReduction, b.getStringAttr(getReductionIteratorTypeName()));
return b.getArrayAttr(attrs);
}
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/BUILD b/iree/compiler/Conversion/LinalgToSPIRV/BUILD
index dd6717a..cd0d823 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/BUILD
+++ b/iree/compiler/Conversion/LinalgToSPIRV/BUILD
@@ -49,6 +49,7 @@
"//iree/compiler/Conversion/CodegenUtils",
"//iree/compiler/Conversion/HLOToHLO",
"//iree/compiler/Conversion/HLOToLinalg",
+ "//iree/compiler/Conversion/LinalgToVector",
"//iree/compiler/Dialect/HAL/IR",
"//iree/compiler/Dialect/IREE/IR",
"//iree/compiler/Dialect/Shape/IR",
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt b/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt
index ce9a6cb..42538e3 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt
+++ b/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt
@@ -65,6 +65,7 @@
iree::compiler::Conversion::CodegenUtils
iree::compiler::Conversion::HLOToHLO
iree::compiler::Conversion::HLOToLinalg
+ iree::compiler::Conversion::LinalgToVector
iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::IREE::IR
iree::compiler::Dialect::Shape::IR
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
index ea85ec2..01b35f3 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
@@ -14,19 +14,24 @@
//===- KernelDispatchUtils.cpp - Utilities for generating dispatch info ---===//
//
-// This file defines utility functions that can be used to create information
-// the dispatch on the host side needs to execute an entry point function, like
-// the number of workgroups to use for launch, etc.
+// This file defines utility functions that can be used to get the information
+// about tile sizes to use to partition work across workgroups, the workgroup
+// sizes and to create information the dispatch on the host side needs to
+// execute an entry point function (e.g. total number of workgroups).
//
//===----------------------------------------------------------------------===//
+#include "iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h"
+
#include "iree/compiler/Conversion/CodegenUtils/FunctionUtils.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/Attributes.h"
+#include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h"
#include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "llvm/Support/Debug.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/SPIRV/TargetAndABI.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Function.h"
@@ -40,6 +45,10 @@
namespace mlir {
namespace iree_compiler {
+//===----------------------------------------------------------------------===//
+// Number of workgroups computation
+//===----------------------------------------------------------------------===//
+
FuncOp getNumWorkgroupsFn(FuncOp entryPointFn) {
SymbolRefAttr attr =
entryPointFn.getAttrOfType<SymbolRefAttr>(getNumWorkgroupsFnAttrName());
@@ -153,5 +162,271 @@
return success();
}
+//===----------------------------------------------------------------------===//
+// Launch config calculation.
+//===----------------------------------------------------------------------===//
+
+/// Given `nprocs` try to distribute it evenly across 2 logical x and y.
+static std::tuple<int64_t, int64_t> distributeProcs2D(int64_t nprocs) {
+ int64_t nprocs_x = std::max<int64_t>(
+ 1, static_cast<int64_t>(
+ llvm::PowerOf2Ceil(static_cast<uint64_t>(std::sqrt(nprocs)))));
+ return std::make_tuple(nprocs_x, nprocs / nprocs_x);
+}
+
+/// For a given operation `op`, `options` and `resourceLimits` of the hardware
+/// compute the
+/// 1) number of tiling levels and tile sizes to use (updates `tileSizes`),
+/// 2) workgroup size to use (updates `workgroupSize`),
+/// 3) number of subgroups to use if two level tiling is used (updates
+/// `numSubgroups`).
+template <typename T>
+static LogicalResult getOpLaunchConfig(T op, const SPIRVCodegenOptions &options,
+ spirv::ResourceLimitsAttr resourceLimits,
+ TileSizesListType &tileSizes,
+ std::array<int64_t, 3> &workgroupSize,
+ std::array<int64_t, 3> &numSubgroups) {
+ return op.emitError("undefined launch config for tiled operation");
+}
+
+/// Launch config for `linalg.batchmatmul`.
+template <>
+LogicalResult getOpLaunchConfig(linalg::BatchMatmulOp op,
+ const SPIRVCodegenOptions &options,
+ spirv::ResourceLimitsAttr resourceLimits,
+ TileSizesListType &tileSizes,
+ std::array<int64_t, 3> &workgroupSize,
+ std::array<int64_t, 3> &numSubgroups) {
+ unsigned maxWorkgroupSize =
+ resourceLimits.max_compute_workgroup_invocations().getInt();
+ std::tie(workgroupSize[0], workgroupSize[1]) =
+ distributeProcs2D(maxWorkgroupSize);
+ workgroupSize[2] = 1;
+ // TODO(#3131): This is just being hard-wired for now to be minimal viable,
+ // but this can be decided better when we have better estimates of device
+ // charecteristics.
+ const int64_t nRowsPerWorkitem = 1;
+ const int64_t nColsPerWorkitem = 1;
+ const int64_t nBatchesPerWorkitem = 1;
+ int64_t tileSizeK = 0;
+ if (options.useWorkgroupMemory) {
+ // TODO(#3131): This number should be decided based on the amount of
+ // shared memory available (maybe). For now, just hard-wire it.
+ tileSizeK = 32;
+ }
+ assert(tileSizes.empty());
+ SmallVector<int64_t, 4> ts = {nBatchesPerWorkitem,
+ nRowsPerWorkitem * workgroupSize[1],
+ nColsPerWorkitem * workgroupSize[0], tileSizeK};
+ tileSizes.emplace_back(std::move(ts));
+ return success();
+}
+
+/// The size of the co-operative matrix multiply operations on the device.
+// TODO(#3131): This needs to be queried from the device.
+Optional<std::array<int64_t, 3>> getCooperativeMatmulSubgroupSize(
+ Type dataType, Type accumulatorType) {
+ if (dataType.isInteger(8) && accumulatorType.isInteger(32)) {
+ return std::array<int64_t, 3>{8, 8, 32};
+ }
+ if (dataType.isF16() &&
+ (accumulatorType.isF32() || accumulatorType.isF16())) {
+ return std::array<int64_t, 3>{8, 8, 16};
+ }
+ return {};
+}
+
+/// Launch configuration for using spv.CooperativeMatrixMulAddNV
+/// operations. Needs two levels of tiling.
+static LogicalResult getConfigForCooperativeMatmul(
+ linalg::MatmulOp op, spirv::ResourceLimitsAttr resourceLimits,
+ TileSizesListType &tileSizes, std::array<int64_t, 3> &workgroupSize,
+ std::array<int64_t, 3> &numSubgroups) {
+ auto targetEnv = spirv::TargetEnv(spirv::lookupTargetEnv(op));
+ if (!targetEnv.allows(spirv::Capability::CooperativeMatrixNV) ||
+ !targetEnv.allows(spirv::Extension::SPV_NV_cooperative_matrix))
+ return failure();
+
+ ShapedType lhsType = op.getOperand(0).getType().cast<ShapedType>();
+ ArrayRef<int64_t> lhsShape = lhsType.getShape();
+ ShapedType rhsType = op.getOperand(1).getType().cast<ShapedType>();
+ ArrayRef<int64_t> rhsShape = rhsType.getShape();
+ ShapedType outputType = op.getOperand(2).getType().cast<ShapedType>();
+
+ Optional<std::array<int64_t, 3>> coopMatmulSize =
+ getCooperativeMatmulSubgroupSize(lhsType.getElementType(),
+ outputType.getElementType());
+ if (!coopMatmulSize) return failure();
+
+ // Check that the matmul sizes are a multiple of the tilesize.
+ auto isMultipleOf = [](int64_t s, int64_t ts) {
+ return !ShapedType::isDynamic(s) && (s % ts) == 0;
+ };
+ if (!isMultipleOf(lhsShape[0], (*coopMatmulSize)[0]) ||
+ !isMultipleOf(rhsShape[1], (*coopMatmulSize)[1]) ||
+ !isMultipleOf(lhsShape[1], (*coopMatmulSize)[2]) ||
+ !isMultipleOf(rhsShape[0], (*coopMatmulSize)[2]))
+ return failure();
+
+ // TODO(ravishankarm, antiagainst): For now hardwire the subgroup size.
+ const int64_t subgroupSize = 32;
+ unsigned maxWorkgroupSize =
+ resourceLimits.max_compute_workgroup_invocations().getInt();
+ std::tie(numSubgroups[0], numSubgroups[1]) =
+ distributeProcs2D(maxWorkgroupSize / subgroupSize);
+ numSubgroups[2] = 1;
+ // TODO(#3131): This is just being hard-wired for now to be minimal viable,
+ // but this can be decided better when we have better estimates of device
+ // charecteristics.
+ const int64_t numVecMatmulPerSubgroupX = 1;
+ const int64_t numVecMatmulPerSubgroupY = 1;
+ SmallVector<int64_t, 4> ts = {
+ numVecMatmulPerSubgroupY * (*coopMatmulSize)[0] * numSubgroups[1],
+ numVecMatmulPerSubgroupX * (*coopMatmulSize)[1] * numSubgroups[0]};
+ tileSizes.emplace_back(std::move(ts));
+
+ workgroupSize[0] = numSubgroups[0] * subgroupSize;
+ workgroupSize[1] = numSubgroups[1];
+ workgroupSize[2] = 1;
+ // Subgroup tile sizes
+ SmallVector<int64_t, 4> subgroupTs = {
+ numVecMatmulPerSubgroupY * (*coopMatmulSize)[0],
+ numVecMatmulPerSubgroupX * (*coopMatmulSize)[1], (*coopMatmulSize)[2]};
+ tileSizes.emplace_back(std::move(subgroupTs));
+ return success();
+}
+
+template <>
+LogicalResult getOpLaunchConfig(linalg::MatmulOp op,
+ const SPIRVCodegenOptions &options,
+ spirv::ResourceLimitsAttr resourceLimits,
+ TileSizesListType &tileSizes,
+ std::array<int64_t, 3> &workgroupSize,
+ std::array<int64_t, 3> &numSubgroups) {
+ if (options.useVectorization &&
+ succeeded(getConfigForCooperativeMatmul(op, resourceLimits, tileSizes,
+ workgroupSize, numSubgroups))) {
+ return success();
+ }
+ unsigned maxWorkgroupSize =
+ resourceLimits.max_compute_workgroup_invocations().getInt();
+ std::tie(workgroupSize[0], workgroupSize[1]) =
+ distributeProcs2D(maxWorkgroupSize);
+ workgroupSize[2] = 1;
+ const int nRowsPerWorkitem = 1;
+ const int nColsPerWorkitem = 1;
+ int64_t tileSizeK = 0;
+ if (options.useWorkgroupMemory) {
+ // TODO(#3131): This number should be decided based on the amount of shared
+ // memory available (maybe). For now, just hard-wire it.
+ tileSizeK = 32;
+ }
+ assert(tileSizes.empty());
+ SmallVector<int64_t, 4> ts = {nRowsPerWorkitem * workgroupSize[1],
+ nColsPerWorkitem * workgroupSize[0], tileSizeK};
+ tileSizes.emplace_back(std::move(ts));
+ return success();
+}
+
+template <>
+LogicalResult getOpLaunchConfig(linalg::ConvOp op,
+ const SPIRVCodegenOptions &options,
+ spirv::ResourceLimitsAttr resourceLimits,
+ TileSizesListType &tileSizes,
+ std::array<int64_t, 3> &workgroupSize,
+ std::array<int64_t, 3> &numSubgroups) {
+ unsigned maxWorkgroupSize =
+ resourceLimits.max_compute_workgroup_invocations().getInt();
+ const int64_t tileSizeX = 32;
+ int64_t tileSizeY = maxWorkgroupSize / tileSizeX;
+ SmallVector<int64_t, 4> ts = {1, tileSizeY, tileSizeX};
+ tileSizes.emplace_back(std::move(ts));
+ workgroupSize = {tileSizeX, tileSizeY, 1};
+ return success();
+}
+
+static LogicalResult getPoolingOpLaunchConfig(
+ const SPIRVCodegenOptions &options,
+ spirv::ResourceLimitsAttr resourceLimits, TileSizesListType &tileSizes,
+ std::array<int64_t, 3> &workgroupSize,
+ std::array<int64_t, 3> &numSubgroups) {
+ unsigned maxWorkgroupSize =
+ resourceLimits.max_compute_workgroup_invocations().getInt();
+ const int64_t tileSizeX = 32;
+ int64_t tileSizeY = maxWorkgroupSize / tileSizeX;
+ SmallVector<int64_t, 4> ts = {tileSizeY, tileSizeX};
+ tileSizes.emplace_back(std::move(ts));
+ workgroupSize = {tileSizeX, tileSizeY, 1};
+ return success();
+}
+
+#define DEFINE_POOLING_OP_CONFIG(opName) \
+ template <> \
+ LogicalResult getOpLaunchConfig( \
+ opName op, const SPIRVCodegenOptions &options, \
+ spirv::ResourceLimitsAttr resourceLimits, TileSizesListType &tileSizes, \
+ std::array<int64_t, 3> &workgroupSize, \
+ std::array<int64_t, 3> &numSubgroups) { \
+ return getPoolingOpLaunchConfig(options, resourceLimits, tileSizes, \
+ workgroupSize, numSubgroups); \
+ }
+
+DEFINE_POOLING_OP_CONFIG(linalg::PoolingMaxOp)
+DEFINE_POOLING_OP_CONFIG(linalg::PoolingMinOp)
+DEFINE_POOLING_OP_CONFIG(linalg::PoolingSumOp)
+
+#undef DEFINE_POOLINGOP_CONFIG
+
+LogicalResult LaunchConfig::init(const SPIRVCodegenOptions &options,
+ ArrayRef<linalg::LinalgOp> linalgOps) {
+ if (!options.workgroupSize.empty()) {
+ for (linalg::LinalgOp op : linalgOps)
+ tileSizes[op.getOperation()->getName().getStringRef()] = {};
+ workgroupSize = {1, 1, 1};
+ for (unsigned i = 0,
+ e = std::min<unsigned>(3, options.workgroupSize.size());
+ i != e; ++i)
+ workgroupSize[i] = options.workgroupSize[i];
+ return success();
+ }
+
+ if (linalgOps.empty()) return success();
+
+ spirv::ResourceLimitsAttr resourceLimits =
+ spirv::lookupTargetEnv(*linalgOps.begin()).getResourceLimits();
+
+ for (linalg::LinalgOp op : linalgOps) {
+ StringRef key = op.getOperation()->getName().getStringRef();
+ if (tileSizes.count(key)) {
+ return op.emitError("unexpected multiple ")
+ << key << " operations within dispatch region";
+ }
+
+ TileSizesListType &tileSizesInfo = tileSizes[key];
+
+#define DISPATCH(opName) \
+ if (auto lOp = dyn_cast<opName>(op.getOperation())) { \
+ if (failed(getOpLaunchConfig(lOp, options, resourceLimits, tileSizesInfo, \
+ workgroupSize, numSubgroups))) { \
+ return failure(); \
+ } \
+ continue; \
+ }
+
+ DISPATCH(linalg::BatchMatmulOp)
+ DISPATCH(linalg::ConvOp)
+ DISPATCH(linalg::MatmulOp)
+ DISPATCH(linalg::PoolingMaxOp)
+ DISPATCH(linalg::PoolingMinOp)
+ DISPATCH(linalg::PoolingSumOp)
+
+#undef DISPATCH
+ }
+
+ // TODO(ravishankarm): Verify that the set configurations is within the device
+ // limits.
+ return success();
+}
+
} // namespace iree_compiler
} // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h
index 1573c55..d690a0a 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h
+++ b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h
@@ -22,11 +22,17 @@
#ifndef IREE_COMPILER_CONVERSION_LINALGTOSPIRV_KERNELDISPATCHUTILS_H_
#define IREE_COMPILER_CONVERSION_LINALGTOSPIRV_KERNELDISPATCHUTILS_H_
+#include <array>
+
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringMap.h"
+#include "mlir/IR/Operation.h"
#include "mlir/Support/LLVM.h"
namespace mlir {
class FuncOp;
class LogicalResult;
+class Operation;
class PatternRewriter;
class ShapedType;
class Value;
@@ -34,6 +40,9 @@
namespace linalg {
class LinalgOp;
}
+namespace iree_compiler {
+struct SPIRVCodegenOptions;
+}
namespace iree_compiler {
@@ -61,6 +70,68 @@
/// workgroups to use at launch time.
FuncOp getNumWorkgroupsFn(FuncOp entryPointFn);
+/// Store the tile sizes to use at different levels of tiling as a vector of
+/// vectors.
+/// - First level tiling maps to workgroups.
+/// - Second level tiling maps to subgroups.
+using TileSizesListType = SmallVector<SmallVector<int64_t, 4>, 1>;
+
+/// Based on the linalg operations in a dispatch region, the number of levels of
+/// tiling, the tile sizes needed, the workgroup size, etc. need to be
+/// decided. These parameters are called `LaunchConfig`. This class implements
+/// one heuristic to compute these for the different linalg operations on
+/// buffers. This can be adapted later to support multiple configurations that
+/// can be picked based on device information/problem size information. It
+/// exposes the information needed by the codegenerators, and hides the
+/// implementation from the rest of the pipeline.
+class LaunchConfig {
+ public:
+ LaunchConfig() : workgroupSize({1, 1, 1}), numSubgroups({1, 1, 1}) {}
+
+ /// Given the sequence of `linalgOps` (and `options`), decide the launch
+ /// configuration by deciding
+ /// - the number of levels of tiling,
+ /// - tile sizes for each level,
+ /// - the workgroup size, and
+ /// - number of subgroups to use.
+ LogicalResult init(const SPIRVCodegenOptions &options,
+ ArrayRef<linalg::LinalgOp> linalgOps);
+
+ /// Gets the tile size computed for an operation at all levels.
+ TileSizesListType getTileSizes(Operation *op) const {
+ return tileSizes.lookup(op->getName().getStringRef());
+ }
+
+ /// Gets the tile size computed for an operation for an level.
+ ArrayRef<int64_t> getTileSizes(Operation *op, size_t level) const {
+ auto it = tileSizes.find(op->getName().getStringRef());
+ if (it == tileSizes.end() || level >= it->second.size()) return {};
+ return it->second[level];
+ }
+
+ /// Returns the workgroup size to use based on the tile sizes.
+ ArrayRef<int64_t> getWorkgroupSize() const { return workgroupSize; }
+
+ /// Returns the number of subgroups to use.
+ ArrayRef<int64_t> getNumSubgroups() const { return numSubgroups; }
+
+ protected:
+ /// Current tile size configuration per operation.
+
+ // TODO: For now just use the operation name for the mapping. The tile sizes
+ // will be selected only for operations like matmul, conv, pool, etc. and
+ // assume that there is only one such operation per dispatch
+ // region. Eventually this might need to be relaxed, and some name-marker
+ // based mechanism might be needed.
+ llvm::StringMap<TileSizesListType> tileSizes;
+
+ /// Workgroup size to use.
+ std::array<int64_t, 3> workgroupSize;
+
+ /// Number of subgroups that are logically distributed along x, y & z.
+ std::array<int64_t, 3> numSubgroups;
+};
+
} // namespace iree_compiler
} // namespace mlir
#endif // IREE_COMPILER_CONVERSION_LINALGTOSPIRV_DISPATCHUTILS_H_
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
index 4c7417e..135961a 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
@@ -18,6 +18,7 @@
//
//===----------------------------------------------------------------------===//
#include "iree/compiler/Conversion/CodegenUtils/FunctionUtils.h"
+#include "iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/Attributes.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h"
@@ -29,7 +30,6 @@
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
-#include "mlir/Dialect/SPIRV/TargetAndABI.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Identifier.h"
#include "mlir/IR/Matchers.h"
@@ -37,7 +37,7 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/FoldUtils.h"
-#define DEBUG_TYPE "iree-linalg-tile-and-fuse-buffer"
+#define DEBUG_TYPE "iree-linalg-tile-and-fuse"
namespace mlir {
namespace iree_compiler {
@@ -46,18 +46,6 @@
// Utility functions
//===----------------------------------------------------------------------===//
-static ArrayRef<int64_t> dropTrailingOnes(ArrayRef<int64_t> vector) {
- if (vector.empty()) return vector;
- auto numTrailingOnes = 0;
- for (unsigned i = vector.size() - 1; i > 0; --i) {
- if (vector[i] != 1) {
- break;
- }
- numTrailingOnes++;
- }
- return vector.drop_back(numTrailingOnes);
-}
-
/// Returns true if the linalg op has padding attribute, and that it has
/// non-zero entries.
template <typename OpTy>
@@ -68,129 +56,6 @@
[](APInt v) -> bool { return !v.isNullValue(); });
}
-namespace {
-
-/// Computes tile sizes (and workgroup size) to use based on operations within
-/// the function, and resource constraints on the module.
-class TileSizeCalculator {
- public:
- TileSizeCalculator(FuncOp funcOp)
- : resourceLimits(spirv::lookupTargetEnv(funcOp).getResourceLimits()) {
- if (DenseIntElementsAttr attr = spirv::lookupLocalWorkGroupSize(funcOp)) {
- for (auto val : attr.getValues<APInt>())
- workgroupSize.push_back(val.getSExtValue());
- }
- workgroupSize.resize(3, 1);
- }
-
- /// Set tile sizes to use.
- void setTileSizes(ArrayRef<int64_t> sizes) {
- tileSizes.assign(sizes.begin(), sizes.end());
- }
-
- /// Set workgroup size to use.
- void setWorkgroupSize(ArrayRef<int64_t> sizes) {
- workgroupSize.assign(sizes.begin(), sizes.end());
- }
-
- /// Compute the tile sizes based on the Linalg Ops within the dispatch region.
- LogicalResult inferTileAndWorkgroupSize(ArrayRef<linalg::LinalgOp> linalgOps);
-
- /// Get the current tile size computed.
- ArrayRef<int64_t> getTileSizes() const { return tileSizes; }
-
- /// Returns the workgroup size to use based on the tile sizes.
- ArrayRef<int64_t> getWorkgroupSize() const { return workgroupSize; }
-
- private:
- /// Current tile size configuration.
- SmallVector<int64_t, 4> tileSizes;
-
- /// Workgroup size to use.
- SmallVector<int64_t, 3> workgroupSize;
-
- /// Attribute for device constraints.
- spirv::ResourceLimitsAttr resourceLimits;
-};
-} // namespace
-
-LogicalResult TileSizeCalculator::inferTileAndWorkgroupSize(
- ArrayRef<linalg::LinalgOp> linalgOps) {
- tileSizes.clear();
- if (linalgOps.empty()) {
- tileSizes = {1, 1, 1};
- workgroupSize = {1, 1, 1};
- return success();
- }
- // The tile size will be driven by operations like matmul, conv, etc. within
- // the list. So see what operation exists in the list to decide the tile size.
- // If there are two such operations in the list, return error.
- enum OpInfo : uint32_t {
- None = 0x0,
- Convolution = 0x1,
- Matmul = 0x2,
- Pooling = 0x4,
- BatchMatmul = 0x8,
- };
- uint32_t opInfo = OpInfo::None;
- for (linalg::LinalgOp linalgOp : linalgOps) {
- Operation *op = linalgOp.getOperation();
- if (isa<linalg::ConvOp>(op))
- opInfo |= OpInfo::Convolution;
- else if (isa<linalg::MatmulOp>(op))
- opInfo |= OpInfo::Matmul;
- else if (isa<linalg::BatchMatmulOp>(op))
- opInfo |= OpInfo::BatchMatmul;
- else if (isa<linalg::PoolingMaxOp>(op))
- opInfo |= OpInfo::Pooling;
- else if (isa<linalg::PoolingMinOp>(op))
- opInfo |= OpInfo::Pooling;
- else if (isa<linalg::PoolingSumOp>(op))
- opInfo |= OpInfo::Pooling;
- }
- // If there are no tilable ops, there is nothing to do here.
- if (!opInfo) return success();
-
- Operation *linalgOp = *(linalgOps.begin());
- if (llvm::countPopulation(opInfo) != 1)
- return linalgOp->getParentOfType<FuncOp>().emitError(
- "unhandled fusion of ops in dispatch function");
-
- // TODO(ravishanarm, antiagainst): Only the maximum workgroup size is used
- // here for computing tile sizes. In reality we also need the maximum
- // workgroup memory size available (per workgroup) to compute the tile sizes
- // effectively.
- unsigned maxWorkgroupSize =
- resourceLimits.max_compute_workgroup_invocations().getInt();
- if (opInfo & OpInfo::Convolution) {
- int64_t tileSizeX = 32;
- int64_t tileSizeY = maxWorkgroupSize / 32;
- tileSizes = {1, tileSizeY, tileSizeX};
- workgroupSize = {tileSizeX, tileSizeY, 1};
- return success();
- }
- if (opInfo & OpInfo::Matmul) {
- // TODO: For now just hard wire this, but we can do better.
- tileSizes = {8, 8, 4};
- workgroupSize = {8, 8, 1};
- return success();
- }
- if (opInfo & OpInfo::BatchMatmul) {
- tileSizes = {2, 8, 8, 4};
- workgroupSize = {8, 8, 2};
- return success();
- }
- if (opInfo & OpInfo::Pooling) {
- int64_t tileSizeX = 32;
- int64_t tileSizeY = maxWorkgroupSize / 32;
- tileSizes = {tileSizeY, tileSizeX};
- workgroupSize = {tileSizeX, tileSizeY, 1};
- return success();
- }
- return linalgOp->getParentOfType<FuncOp>().emitError(
- "unable to find tile size for ops in this dispatch function");
-}
-
//===----------------------------------------------------------------------===//
// Pass and patterns
//===----------------------------------------------------------------------===//
@@ -199,37 +64,53 @@
/// Function pass that implements tiling and fusion in Linalg on buffers.
struct LinalgTileAndFusePass
: public PassWrapper<LinalgTileAndFusePass, OperationPass<ModuleOp>> {
- LinalgTileAndFusePass(ArrayRef<int64_t> workgroupSize = {},
- ArrayRef<int64_t> tileSizes = {},
- bool useWorkgroupMem = false) {
- this->workgroupSize = workgroupSize;
- this->tileSizes = tileSizes;
- this->useWorkgroupMemory = useWorkgroupMem;
+ LinalgTileAndFusePass() = default;
+ LinalgTileAndFusePass(const SPIRVCodegenOptions &passedOptions) {
+ options = passedOptions;
}
LinalgTileAndFusePass(const LinalgTileAndFusePass &pass) {}
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<AffineDialect, gpu::GPUDialect, linalg::LinalgDialect,
- scf::SCFDialect, ShapeDialect>();
+ scf::SCFDialect, ShapeDialect, vector::VectorDialect>();
}
void runOnOperation() override;
private:
- Option<bool> useWorkgroupMemory{
- *this, "use-workgroup-memory",
- llvm::cl::desc("Promote subviews to use workgroup memory"),
- llvm::cl::init(false)};
+ SPIRVCodegenOptions options;
- ListOption<int64_t> workgroupSize{
- *this, "workgroup-size",
- llvm::cl::desc("Override the default workgroup size"),
- llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
-
+ // TODO: Find a common place to put these options. They are defined three
+ // times, once here, once for the pass pipeline and once for the binary.
ListOption<int64_t> tileSizes{
*this, "tile-sizes", llvm::cl::desc("Set tile sizes to use"),
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
+
+ ListOption<int64_t> workgroupSize{
+ *this, "workgroup-size",
+ llvm::cl::desc(
+ "Number of workgroups to dispatch for the SPIR-V module; at most "
+ "three integers standarding for the x, y, and z dimension; "
+ "additional arguments will be ignored (used only for testing)"),
+ llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
+
+ Option<bool> useWorkgroupMemory{
+ *this, "use-workgroup-memory",
+ llvm::cl::desc(
+ "Enable use of workgroup memory in SPIR-V code generation pipeline"),
+ llvm::cl::init(false)};
+
+ Option<bool> useVectorization{
+ *this, "use-vectorization",
+ llvm::cl::desc(
+ "Enable use of vectorization in SPIR-V code generation pipeline"),
+ llvm::cl::init(false)};
};
+//===----------------------------------------------------------------------===//
+// Patterns to tile computation to map to workgroups.
+//===----------------------------------------------------------------------===//
+
+/// Distribution options for linalg.matmul when targeting workgroups.
static linalg::LinalgLoopDistributionOptions matmulDistributionOptions = {
[](OpBuilder &builder, Location loc,
ArrayRef<SubViewOp::Range> parallelLoopRanges) {
@@ -237,6 +118,7 @@
builder, loc, parallelLoopRanges.size());
},
{linalg::DistributionMethod::CyclicNumProcsEqNumIters,
+ linalg::DistributionMethod::CyclicNumProcsEqNumIters,
linalg::DistributionMethod::CyclicNumProcsEqNumIters}};
/// Pattern for tiling operations. Updates the workgroup size in the surrounding
@@ -245,8 +127,8 @@
struct TileMatmulPattern : public linalg::LinalgBaseTilingPattern {
using Base = linalg::LinalgBaseTilingPattern;
TileMatmulPattern(MLIRContext *context, linalg::LinalgTilingOptions options,
- ArrayRef<int64_t> tileSizes,
- ArrayRef<int64_t> workgroupSize, PatternBenefit benefit = 1)
+ const LaunchConfig &launchConfig,
+ PatternBenefit benefit = 1)
: Base(MatmulOp::getOperationName(), context,
options.setDistributionOptions(matmulDistributionOptions),
linalg::LinalgMarker(
@@ -254,8 +136,7 @@
Identifier::get(getWorkgroupNumItemsGENumItersMarker(),
context)),
benefit),
- tileSizes(tileSizes.begin(), tileSizes.end()),
- workgroupSize(workgroupSize.begin(), workgroupSize.end()) {}
+ launchConfig(launchConfig) {}
virtual LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const {
@@ -263,20 +144,37 @@
// erased.
FuncOp funcOp = op->getParentOfType<FuncOp>();
if (!funcOp || failed(Base::matchAndRewrite(op, rewriter)) ||
- failed(updateWorkGroupSize(funcOp, workgroupSize)) ||
+ failed(updateWorkGroupSize(funcOp, launchConfig.getWorkgroupSize())) ||
(funcOp.getAttr(getNumWorkgroupsFnAttrName()) &&
failed(createNumWorkgroupsFromResultShape(
- rewriter, cast<linalg::LinalgOp>(op), funcOp, tileSizes)))) {
+ rewriter, cast<linalg::LinalgOp>(op), funcOp,
+ launchConfig.getTileSizes(op, 0))))) {
return failure();
}
rewriter.eraseOp(op);
return success();
}
- SmallVector<int64_t, 3> tileSizes;
- SmallVector<int64_t, 3> workgroupSize;
+ const LaunchConfig &launchConfig;
};
+/// Pattern to tile linalg.matmul for subgroups.
+struct TileMatmulSubgroupPattern
+ : public linalg::LinalgTilingPattern<linalg::MatmulOp> {
+ using Base = linalg::LinalgTilingPattern<linalg::MatmulOp>;
+ TileMatmulSubgroupPattern(MLIRContext *context,
+ linalg::LinalgTilingOptions options,
+ PatternBenefit benefit = 1)
+ : Base(context, options,
+ linalg::LinalgMarker(
+ Identifier::get(getWorkgroupNumItemsGENumItersMarker(),
+ context),
+ Identifier::get(getVectorizeMarker(), context)),
+ benefit) {}
+};
+
+/// Distribution options for targeting workgroups for convolution/pooling
+/// operations.
static linalg::LinalgLoopDistributionOptions convPoolDistributionOptions = {
[](OpBuilder &builder, Location loc,
ArrayRef<SubViewOp::Range> parallelLoopRanges) {
@@ -286,14 +184,12 @@
{linalg::DistributionMethod::Cyclic, linalg::DistributionMethod::Cyclic,
linalg::DistributionMethod::Cyclic}};
-/// Pattern for tiling convolution and pooling operations. Currently is just a
-/// way to not tile when the operation has padding.
+/// Pattern for tiling convolution and pooling operations.
template <typename OpTy>
struct TileConvPoolPattern : public linalg::LinalgTilingPattern<OpTy> {
using Base = linalg::LinalgTilingPattern<OpTy>;
TileConvPoolPattern(MLIRContext *context, linalg::LinalgTilingOptions options,
- ArrayRef<int64_t> tileSizes,
- ArrayRef<int64_t> workgroupSize,
+ const LaunchConfig &launchConfig,
PatternBenefit benefit = 1)
: Base(context,
options.setDistributionOptions(convPoolDistributionOptions),
@@ -301,38 +197,58 @@
ArrayRef<Identifier>(),
Identifier::get(getWorkgroupMarker(), context)),
benefit),
- tileSizes(tileSizes.begin(), tileSizes.end()),
- workgroupSize(workgroupSize.begin(), workgroupSize.end()) {}
+ launchConfig(launchConfig) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
if (hasPadding(cast<OpTy>(op))) return failure();
FuncOp funcOp = op->getParentOfType<FuncOp>();
if (!funcOp || failed(Base::matchAndRewrite(op, rewriter)) ||
- failed(updateWorkGroupSize(funcOp, this->workgroupSize)))
+ failed(updateWorkGroupSize(funcOp, launchConfig.getWorkgroupSize())))
return failure();
funcOp.removeAttr(getNumWorkgroupsFnAttrName());
return success();
}
- SmallVector<int64_t, 3> tileSizes;
- SmallVector<int64_t, 3> workgroupSize;
+ const LaunchConfig &launchConfig;
};
+/// Populate patterns for first-level tiling.
+static void populateTilingToWorkgroupPatterns(
+ MLIRContext *context, const LaunchConfig &launchConfig,
+ OwningRewritePatternList &patterns) {
+ // Function to compute first level tiling values.
+ std::function<SmallVector<Value, 4>(OpBuilder &, Operation *)>
+ getOuterTileSizeFn =
+ [&launchConfig](OpBuilder &builder,
+ Operation *operation) -> SmallVector<Value, 4> {
+ ArrayRef<int64_t> tileSizes = launchConfig.getTileSizes(operation, 0);
+ if (tileSizes.empty()) return {};
+ SmallVector<Value, 4> tileSizesVal;
+ tileSizesVal.reserve(tileSizes.size());
+ for (auto val : tileSizes) {
+ tileSizesVal.push_back(
+ builder.create<ConstantIndexOp>(operation->getLoc(), val));
+ }
+ return tileSizesVal;
+ };
+ patterns.insert<TileConvPoolPattern<linalg::ConvOp>,
+ TileMatmulPattern<linalg::MatmulOp>,
+ TileMatmulPattern<linalg::BatchMatmulOp>,
+ TileConvPoolPattern<linalg::PoolingMaxOp>,
+ TileConvPoolPattern<linalg::PoolingMinOp>,
+ TileConvPoolPattern<linalg::PoolingSumOp>>(
+ context,
+ linalg::LinalgTilingOptions()
+ .setTileSizeComputationFunction(getOuterTileSizeFn)
+ .setLoopType(linalg::LinalgTilingLoopType::ParallelLoops),
+ launchConfig);
+}
+
//===----------------------------------------------------------------------===//
// Patterns to promote subviews to workgroup memory
//===----------------------------------------------------------------------===//
-/// Function used as callback for copyin/copyout in promotion pattern used to
-/// promote subviews to workgroup memory when the number of threads is known to
-/// be greater than equal to the number of iteration of loops the copy is
-/// lowered to.
-static LogicalResult copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst) {
- auto copyOp = b.create<linalg::CopyOp>(src.getLoc(), src, dst);
- setMarker(copyOp, getCopyToWorkgroupMemoryMarker());
- return success();
-}
-
/// Pattern to promote matmul operands to workgroup memory.
struct PromoteMatmulSubviewsPattern
: public linalg::LinalgPromotionPattern<linalg::MatmulOp> {
@@ -381,10 +297,114 @@
};
} // namespace
+static void populatePromotionPatterns(MLIRContext *context,
+ OwningRewritePatternList &patterns) {
+ patterns
+ .insert<PromoteMatmulSubviewsPattern, PromoteConvolutionSubviewsPattern>(
+ context,
+ linalg::LinalgPromotionOptions()
+ .setAllocationDeallocationFns(allocateWorkgroupMemory,
+ deallocateWorkgroupMemory)
+ .setCopyInOutFns(copyToWorkgroupMemory, copyToWorkgroupMemory));
+}
+
+//===----------------------------------------------------------------------===//
+// Patterns and methods for subgroup tiling.
+//===----------------------------------------------------------------------===//
+
+/// Computes the Value for subgroupID along each dimension given number of
+/// subgroups `numSubGroups` along each dimension (x-first, y-second, z-third).
+static SmallVector<linalg::ProcInfo, 2> getSubgroupIdsAndCounts(
+ OpBuilder &builder, Location loc, ArrayRef<int64_t> numSubgroups) {
+ Type indexType = builder.getIndexType();
+ Value subgroupId = builder.create<gpu::SubgroupIdOp>(loc, indexType);
+ SmallVector<linalg::ProcInfo, 2> procInfo(numSubgroups.size());
+
+ // subgroupID
+ // = id.z * nsubgroups.y * nsubgroups.x + id.y * nsubgroups.x + id.x
+ using edsc::op::operator%;
+ for (size_t i = 0, e = numSubgroups.size(); i != e; ++i) {
+ Value nprocs = builder.create<ConstantIndexOp>(loc, numSubgroups[i]);
+ Value procId = subgroupId % nprocs;
+ procInfo[e - i - 1] = linalg::ProcInfo{procId, nprocs};
+ subgroupId = builder.create<SignedDivIOp>(loc, subgroupId, nprocs);
+ }
+ return procInfo;
+}
+
+/// Patterns for second level tiling to target subgroups.
+static void populateTilingToSubgroupPatterns(
+ MLIRContext *context, const LaunchConfig &launchConfig,
+ OwningRewritePatternList &patterns) {
+ std::function<SmallVector<Value, 4>(OpBuilder &, Operation *)>
+ getInnerTileSizeFn =
+ [&launchConfig](OpBuilder &builder,
+ Operation *operation) -> SmallVector<Value, 4> {
+ ArrayRef<int64_t> tileSizes = launchConfig.getTileSizes(operation, 1);
+ if (tileSizes.empty()) return {};
+ SmallVector<Value, 4> tileSizesVal;
+ tileSizesVal.reserve(tileSizes.size());
+ for (auto val : tileSizes) {
+ tileSizesVal.push_back(
+ builder.create<ConstantIndexOp>(operation->getLoc(), val));
+ }
+ return tileSizesVal;
+ };
+
+ auto getSubgroupProcInfoFn =
+ [&launchConfig](OpBuilder &builder, Location loc,
+ ArrayRef<SubViewOp::Range> parallelLoopRanges) {
+ ArrayRef<int64_t> numSubgroups =
+ launchConfig.getNumSubgroups().take_front(
+ parallelLoopRanges.size());
+ return getSubgroupIdsAndCounts(builder, loc, numSubgroups);
+ };
+ linalg::LinalgLoopDistributionOptions subgroupDistributionOptions = {
+ getSubgroupProcInfoFn,
+ {linalg::DistributionMethod::CyclicNumProcsEqNumIters,
+ linalg::DistributionMethod::CyclicNumProcsEqNumIters}};
+ patterns.insert<TileMatmulSubgroupPattern>(
+ context, linalg::LinalgTilingOptions()
+ .setLoopType(linalg::LinalgTilingLoopType::ParallelLoops)
+ .setTileSizeComputationFunction(getInnerTileSizeFn)
+ .setDistributionOptions(subgroupDistributionOptions));
+}
+
+//====---------------------------------------------------------------------===//
+// Patterns for vectorization
+//====---------------------------------------------------------------------===//
+
+static void populateVectorizationPatterns(MLIRContext *context,
+ const LaunchConfig &launchConfig,
+ OwningRewritePatternList &patterns) {
+ patterns.insert<linalg::LinalgVectorizationPattern<linalg::MatmulOp>>(
+ context,
+ linalg::LinalgMarker(Identifier::get(getVectorizeMarker(), context)));
+}
+
+/// Apply canonicalizations related to tiling to make promotion/vectorization
+/// easier.
+static void applyCanonicalizationPatterns(MLIRContext *context, Operation *op) {
+ OwningRewritePatternList canonicalizationPatterns;
+ canonicalizationPatterns.insert<AffineMinCanonicalizationPattern>(context);
+ AffineApplyOp::getCanonicalizationPatterns(canonicalizationPatterns, context);
+ AffineMinOp::getCanonicalizationPatterns(canonicalizationPatterns, context);
+ SubViewOp::getCanonicalizationPatterns(canonicalizationPatterns, context);
+ applyPatternsAndFoldGreedily(op, canonicalizationPatterns);
+}
+
void LinalgTileAndFusePass::runOnOperation() {
MLIRContext *context = &getContext();
ModuleOp module = getOperation();
+ // Override options with command line values.
+ if (!tileSizes.empty())
+ options.tileSizes.assign(tileSizes.begin(), tileSizes.end());
+ if (!workgroupSize.empty())
+ options.workgroupSize.assign(workgroupSize.begin(), workgroupSize.end());
+ if (useWorkgroupMemory) options.useWorkgroupMemory = true;
+ if (useVectorization) options.useVectorization = true;
+
LLVM_DEBUG(
llvm::dbgs() << "--- IREE Linalg tile and fuse configuration ---\n";);
for (FuncOp funcOp : module.getOps<FuncOp>()) {
@@ -397,58 +417,62 @@
}
Block &block = body.front();
auto linalgOps = block.getOps<linalg::LinalgOp>();
- if (linalgOps.empty()) return;
+ if (linalgOps.empty()) continue;
- TileSizeCalculator tileSizeCalculator(funcOp);
- if (tileSizes.empty()) {
- // Get the tile sizes to use for the lowering.
- SmallVector<int64_t, 3> tileSizes;
- SmallVector<linalg::LinalgOp, 1> opsVec(linalgOps.begin(),
- linalgOps.end());
- if (failed(tileSizeCalculator.inferTileAndWorkgroupSize(opsVec)))
- return signalPassFailure();
- } else {
- tileSizeCalculator.setTileSizes(tileSizes);
- if (!workgroupSize.empty())
- tileSizeCalculator.setWorkgroupSize(workgroupSize);
+ LaunchConfig launchConfig;
+ SmallVector<linalg::LinalgOp, 4> linalgOpsVec(linalgOps.begin(),
+ linalgOps.end());
+ if (failed(launchConfig.init(options, linalgOpsVec))) {
+ funcOp.emitError("unable to find launch configuration");
+ return signalPassFailure();
}
LLVM_DEBUG({
llvm::dbgs() << "@func " << funcOp.getName() << ": # workgroup sizes: [";
- interleaveComma(tileSizeCalculator.getWorkgroupSize(), llvm::dbgs());
- llvm::dbgs() << "]\ntile sizes: [";
- interleaveComma(tileSizeCalculator.getTileSizes(), llvm::dbgs());
+ interleaveComma(launchConfig.getWorkgroupSize(), llvm::dbgs());
llvm::dbgs() << "]\n";
+ for (auto op : linalgOps) {
+ llvm::dbgs() << "\t" << op.getOperation()->getName() << " : ";
+ TileSizesListType const &tileSizes = launchConfig.getTileSizes(op);
+ llvm::dbgs() << "{";
+ std::string sep = "";
+ for (auto &level : enumerate(tileSizes)) {
+ llvm::dbgs() << sep << level.index() << " : [";
+ sep = ", ";
+ interleaveComma(level.value(), llvm::dbgs());
+ llvm::dbgs() << "]";
+ }
+ llvm::dbgs() << "}\n";
+ }
});
- OwningRewritePatternList tilingPatterns;
- tilingPatterns.insert<TileConvPoolPattern<linalg::ConvOp>,
- TileMatmulPattern<linalg::MatmulOp>,
- TileMatmulPattern<linalg::BatchMatmulOp>,
- TileConvPoolPattern<linalg::PoolingMaxOp>,
- TileConvPoolPattern<linalg::PoolingMinOp>,
- TileConvPoolPattern<linalg::PoolingSumOp>>(
- context,
- linalg::LinalgTilingOptions()
- .setTileSizes(tileSizeCalculator.getTileSizes())
- .setLoopType(linalg::LinalgTilingLoopType::ParallelLoops),
- tileSizeCalculator.getTileSizes(),
- tileSizeCalculator.getWorkgroupSize());
- applyPatternsAndFoldGreedily(funcOp, tilingPatterns);
+ OwningRewritePatternList firstLevelTilingPatterns;
+ populateTilingToWorkgroupPatterns(context, launchConfig,
+ firstLevelTilingPatterns);
+ applyPatternsAndFoldGreedily(funcOp, firstLevelTilingPatterns);
+ applyCanonicalizationPatterns(context, funcOp);
- if (useWorkgroupMemory) {
+ if (options.useWorkgroupMemory) {
// The promotion patterns are put separate from the tiling patterns to
// make sure that the allocated scratchspace memory is constant sizes
// which requires some folding to trigger.
OwningRewritePatternList promotionPatterns;
- promotionPatterns.insert<PromoteMatmulSubviewsPattern,
- PromoteConvolutionSubviewsPattern>(
- context,
- linalg::LinalgPromotionOptions()
- .setAllocationDeallocationFns(allocateWorkgroupMemory,
- deallocateWorkgroupMemory)
- .setCopyInOutFns(copyToWorkgroupMemory, copyToWorkgroupMemory));
+ populatePromotionPatterns(context, promotionPatterns);
applyPatternsAndFoldGreedily(funcOp, promotionPatterns);
+ applyCanonicalizationPatterns(context, funcOp);
+ }
+
+ if (options.useVectorization) {
+ OwningRewritePatternList secondLevelTilingPatterns;
+ populateTilingToSubgroupPatterns(context, launchConfig,
+ secondLevelTilingPatterns);
+ applyPatternsAndFoldGreedily(funcOp, secondLevelTilingPatterns);
+ applyCanonicalizationPatterns(context, funcOp);
+
+ OwningRewritePatternList vectorizationPatterns;
+ populateVectorizationPatterns(context, launchConfig,
+ vectorizationPatterns);
+ applyPatternsAndFoldGreedily(funcOp, vectorizationPatterns);
}
}
}
@@ -458,10 +482,8 @@
//===----------------------------------------------------------------------===//
std::unique_ptr<OperationPass<ModuleOp>> createLinalgTileAndFusePass(
- ArrayRef<int64_t> workgroupSize, ArrayRef<int64_t> tileSizes,
- bool useWorkgroupMemory) {
- return std::make_unique<LinalgTileAndFusePass>(workgroupSize, tileSizes,
- useWorkgroupMemory);
+ const SPIRVCodegenOptions &options) {
+ return std::make_unique<LinalgTileAndFusePass>(options);
}
static PassRegistration<LinalgTileAndFusePass> pass(
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp b/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
index cc90ccc..04044a7 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
@@ -22,7 +22,9 @@
#include "iree/compiler/Conversion/HLOToHLO/Passes.h"
#include "iree/compiler/Conversion/HLOToLinalg/Passes.h"
+#include "iree/compiler/Conversion/LinalgToVector/Passes.h"
#include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
+#include "llvm/Support/CommandLine.h"
#include "mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.h"
#include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h"
#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h"
@@ -50,22 +52,21 @@
namespace iree_compiler {
namespace {
-/// Command line options for use with SPIR-V code-generation pass pipeline.
-struct SPIRVCodegenClOpts : public PassPipelineOptions<SPIRVCodegenClOpts> {
- ListOption<int64_t> workgroupSize{
- *this, "workgroup-size",
+
+/// Linalg to SPIR-V pass pipeline options. In theory this is a superset of all
+/// options in all passes in the pipeline. Adding those based on need, and they
+/// should be needed only for testing.
+struct LinalgToSPIRVPassPipelineOptions
+ : public PassPipelineOptions<LinalgToSPIRVPassPipelineOptions> {
+ Option<bool> useVectorization{
+ *this, "use-vectorization",
llvm::cl::desc(
- "Number of workgroups to dispatch for the SPIR-V module; at most "
- "three integers standarding for the x, y, and z dimension; "
- "additional arguments will be ignored (used only for testing)"),
- llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
- ListOption<int64_t> tileSizes{
- *this, "tile-sizes", llvm::cl::desc("Set tile sizes to use"),
- llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
- Option<bool> useWorkgroupMemory{
- *this, "use-workgroup-memory",
- llvm::cl::desc(
- "Enable use of workgroup memory in SPIR-V code generation pipeline"),
+ "Enable use of vectorization in SPIR-V code generation pipeline"),
+ llvm::cl::init(false)};
+ Option<bool> useVectorPass{
+ *this, "use-vector-pass",
+ llvm::cl::desc("Enable use of Linalg vectorization in SPIR-V code "
+ "generation pipeline"),
llvm::cl::init(false)};
};
@@ -101,8 +102,10 @@
// with the second tile and fuse pass.
//===--------------------------------------------------------------------===//
pm.addPass(createSplitDispatchFunctionPass());
- pm.addPass(createLinalgTileAndFusePass(
- options.workgroupSize, options.tileSizes, options.useWorkgroupMemory));
+ pm.addPass(createLinalgTileAndFusePass(options));
+ if (options.useVectorPass) {
+ pm.addPass(createLoadStoreVectorizationPass());
+ }
pm.addPass(createCanonicalizerPass());
//===--------------------------------------------------------------------===//
@@ -252,30 +255,33 @@
addLinalgToSPIRVPasses(pm, options);
}
-static SPIRVCodegenOptions getSPIRVCodegenOptions(
- const SPIRVCodegenClOpts &clOpts) {
- SPIRVCodegenOptions options;
- options.workgroupSize.assign(clOpts.workgroupSize.begin(),
- clOpts.workgroupSize.end());
- options.tileSizes.assign(clOpts.tileSizes.begin(), clOpts.tileSizes.end());
- options.useWorkgroupMemory = clOpts.useWorkgroupMemory;
- return options;
+static SPIRVCodegenOptions initializeCodegenOptions(
+ const LinalgToSPIRVPassPipelineOptions &options) {
+ SPIRVCodegenOptions codegenOptions;
+ codegenOptions.useVectorization = options.useVectorization;
+ return codegenOptions;
}
-static PassPipelineRegistration<SPIRVCodegenClOpts> linalgToSPIRVPipeline(
- "iree-codegen-linalg-to-spirv-pipeline",
- "Runs the progressive lowering pipeline from Linalg to SPIR-V",
- [](OpPassManager &passManager, const SPIRVCodegenClOpts &options) {
- addLinalgToSPIRVPasses(passManager, getSPIRVCodegenOptions(options));
- });
+static PassPipelineRegistration<LinalgToSPIRVPassPipelineOptions>
+ linalgToSPIRVPipeline(
+ "iree-codegen-linalg-to-spirv-pipeline",
+ "Runs the progressive lowering pipeline from Linalg to SPIR-V",
+ [](OpPassManager &passManager,
+ const LinalgToSPIRVPassPipelineOptions &options) {
+ addLinalgToSPIRVPasses(passManager,
+ initializeCodegenOptions(options));
+ });
-static PassPipelineRegistration<SPIRVCodegenClOpts> hloToLinalgSPIRVPipeline(
- "iree-codegen-hlo-to-spirv-pipeline",
- "Runs the progressive lowering pipeline from XLA HLO to Linalg to SPIR-V",
- [](OpPassManager &passManager, const SPIRVCodegenClOpts &options) {
- buildSPIRVTransformPassPipeline(passManager,
- getSPIRVCodegenOptions(options));
- });
+static PassPipelineRegistration<LinalgToSPIRVPassPipelineOptions>
+ hloToLinalgSPIRVPipeline(
+ "iree-codegen-hlo-to-spirv-pipeline",
+ "Runs the progressive lowering pipeline from XLA HLO to Linalg to "
+ "SPIR-V",
+ [](OpPassManager &passManager,
+ const LinalgToSPIRVPassPipelineOptions &options) {
+ buildSPIRVTransformPassPipeline(passManager,
+ initializeCodegenOptions(options));
+ });
} // namespace iree_compiler
} // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Passes.h b/iree/compiler/Conversion/LinalgToSPIRV/Passes.h
index 6e07cb0..80805e2 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/Passes.h
+++ b/iree/compiler/Conversion/LinalgToSPIRV/Passes.h
@@ -17,6 +17,7 @@
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassOptions.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
@@ -27,6 +28,8 @@
SmallVector<int64_t, 3> workgroupSize = {};
SmallVector<int64_t, 3> tileSizes = {};
bool useWorkgroupMemory = false;
+ bool useVectorization = false;
+ bool useVectorPass = false;
};
/// Pass to initialize the function that computes the number of workgroups for
@@ -40,8 +43,7 @@
/// it exists) and along "z" for the next loop (if it exists). The workgroup
/// size is expected to be of size at-most 3.
std::unique_ptr<OperationPass<ModuleOp>> createLinalgTileAndFusePass(
- ArrayRef<int64_t> workGroupSize = {}, ArrayRef<int64_t> tileSizes = {},
- bool useWorkgroupMem = false);
+ const SPIRVCodegenOptions &options);
/// Pass to add the synchronizations and attributes needed to lower from PLoops
/// to GPU dialect.
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Utils.cpp b/iree/compiler/Conversion/LinalgToSPIRV/Utils.cpp
index 47cc37d..2c14351 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/Utils.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/Utils.cpp
@@ -107,6 +107,11 @@
Value blockId = builder.create<gpu::BlockIdOp>(loc, indexType, attr);
Value blockDim = builder.create<gpu::BlockDimOp>(loc, indexType, attr);
Value threadId = builder.create<gpu::ThreadIdOp>(loc, indexType, attr);
+ // TODO(ravishankarm): Using affine_maps here would be beneficial, and we can
+ // do this because the blockDim is constant. But this would lead to an
+ // ordering issue cause it assumes that the workgroup size has already been
+ // set. If using affine_map can help, make sure that the workgroup size is set
+ // before.
return {builder.create<AddIOp>(
loc, builder.create<MulIOp>(loc, blockId, blockDim), threadId),
builder.create<MulIOp>(loc, blockDim, gridDim)};
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Utils.h b/iree/compiler/Conversion/LinalgToSPIRV/Utils.h
index 934728d..682583a 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/Utils.h
+++ b/iree/compiler/Conversion/LinalgToSPIRV/Utils.h
@@ -62,7 +62,6 @@
SmallVector<linalg::ProcInfo, 2> getGPUProcessorIdsAndCounts(OpBuilder &builder,
Location loc,
unsigned numDims);
-
} // namespace iree_compiler
} // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir
index e06eae5..f6604bf 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir
@@ -64,7 +64,6 @@
// CHECK-DAG: %[[ARG0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg0
// CHECK-DAG: %[[ARG1:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg1
// CHECK-DAG: %[[RET0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@ret0
-// CHECK-DAG: %[[C0:.+]] = constant 0
// CHECK-DAG: %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
// CHECK-DAG: %[[NBLOCKSX:.+]] = "gpu.grid_dim"() {dimension = "x"}
// CHECK-DAG: %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
@@ -78,9 +77,9 @@
// CHECK: scf.parallel (%[[IV0:.+]], %[[IV1:.+]], %[[IV2:.+]]) = (%[[BIDZ]], %[[LBY]], %[[LBX]])
// CHECK-SAME: step (%[[NBLOCKSZ]], %[[STEPY]], %[[STEPX]])
// CHECK: %[[VIEW1:.+]] = subview %[[ARG1]]
-// CHECK-SAME: [%[[IV0]], %[[IV1]], %[[IV2]], %[[C0]]]
+// CHECK-SAME: [%[[IV0]], %[[IV1]], %[[IV2]], 0]
// CHECK: %[[VIEW2:.+]] = subview %[[RET0]]
-// CHECK-SAME: [%[[IV0]], %[[IV1]], %[[IV2]], %[[C0]]]
+// CHECK-SAME: [%[[IV0]], %[[IV1]], %[[IV2]], 0]
// CHECK: linalg.conv
// CHECK-SAME: %[[ARG0]], %[[VIEW1]], %[[VIEW2]]
// CHECK-SAME: "workgroup"
@@ -116,37 +115,38 @@
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 8)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<()[s0] -> (s0 * 16)>
// CHECK: func @matmul()
-// CHECK-SAME: local_size = dense<[8, 8, 1]>
+// CHECK-SAME: local_size = dense<[16, 8, 1]>
// CHECK-SAME: vkspv.num_workgroups_fn = @[[NUM_WORKGROUPS_FN:.[a-zA-Z0-9_]+]]
// CHECK-DAG: %[[ARG0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg0
// CHECK-DAG: %[[ARG1:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg1
// CHECK-DAG: %[[RET0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@ret0
-// CHECK-DAG: %[[C0:.+]] = constant 0
-// CHECK-DAG: %[[C4:.+]] = constant 4
// CHECK-DAG: %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
// CHECK-DAG: %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
// CHECK-NOT: scf.parallel
-// CHECK: scf.for %[[IV:.+]] = %[[C0]] to %{{.+}} step %[[C4]]
-// CHECK: %[[LBY:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
-// CHECK: %[[VIEW0:.+]] = subview %[[ARG0]][%[[LBY]], %[[IV]]
-// CHECK: %[[LBX:.+]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
-// CHECK: %[[VIEW1:.+]] = subview %[[ARG1]][%[[IV]], %[[LBX]]]
-// CHECK: %[[LBY_2:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
-// CHECK: %[[LBX_2:.+]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
-// CHECK: %[[VIEW2:.+]] = subview %[[RET0]][%[[LBY_2]], %[[LBX_2]]]
-// CHECK: linalg.matmul
-// CHECK-SAME: "workgroup_numprocs_ge_numiters"
-// CHECK-SAME: %[[VIEW0]], %[[VIEW1]], %[[VIEW2]]
+// CHECK-NOT: scf.for
+// CHECK: %[[LBY:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
+// CHECK: %[[VIEW0:.+]] = subview %[[ARG0]][%[[LBY]], 0]
+// CHECK: %[[LBX:.+]] = affine.apply #[[MAP3]]()[%[[BIDX]]]
+// CHECK: %[[VIEW1:.+]] = subview %[[ARG1]][0, %[[LBX]]]
+// CHECK: %[[LBY_2:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
+// CHECK: %[[LBX_2:.+]] = affine.apply #[[MAP3]]()[%[[BIDX]]]
+// CHECK: %[[VIEW2:.+]] = subview %[[RET0]][%[[LBY_2]], %[[LBX_2]]]
+// CHECK: linalg.matmul
+// CHECK-SAME: "workgroup_numprocs_ge_numiters"
+// CHECK-SAME: %[[VIEW0]], %[[VIEW1]], %[[VIEW2]]
// CHECK: func @[[NUM_WORKGROUPS_FN]]
// CHECK-DAG: %[[C8:.+]] = constant 8 : index
// CHECK-DAG: %[[C7:.+]] = constant 7 : index
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[C16:.+]] = constant 16 : index
+// CHECK-DAG: %[[C15:.+]] = constant 15 : index
// CHECK: %[[DIM0:.+]] = dim %{{.*}}, %[[C0]]
// CHECK: %[[DIM1:.+]] = dim %{{.*}}, %[[C1]]
-// CHECK: %[[T0:.+]] = addi %[[DIM1]], %[[C7]]
-// CHECK: %[[T1:.+]] = divi_signed %[[T0]], %[[C8]]
+// CHECK: %[[T0:.+]] = addi %[[DIM1]], %[[C15]]
+// CHECK: %[[T1:.+]] = divi_signed %[[T0]], %[[C16]]
// CHECK: %[[T2:.+]] = addi %[[DIM0]], %[[C7]]
// CHECK: %[[T3:.+]] = divi_signed %[[T2]], %[[C8]]
// CHECK: return %[[T1]], %[[T3]], %[[C1]]
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_subgroup.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_subgroup.mlir
new file mode 100644
index 0000000..8421f81
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_subgroup.mlir
@@ -0,0 +1,77 @@
+// RUN: iree-opt -split-input-file -iree-codegen-linalg-tile-and-fuse=use-vectorization %s | IreeFileCheck %s
+
+module attributes {
+ spv.target_env =
+ #spv.target_env<#spv.vce<v1.3,
+ [Shader, CooperativeMatrixNV],
+ [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>,
+ {max_compute_workgroup_invocations = 512 : i32,
+ max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
+ func @matmul_static_shape()
+ attributes {vkspv.num_workgroups_fn = @matmul_static_shape__num_workgroups__} {
+ %arg0 = iree.placeholder for "interface buffer"
+ {binding = @legacy_io::@arg0, operand_result_num = 0 : i32} : memref<128x64xf16>
+ %arg1 = iree.placeholder for "interface buffer"
+ {binding = @legacy_io::@arg1, operand_result_num = 1 : i32} : memref<64x256xf16>
+ %ret0 = iree.placeholder for "interface buffer"
+ {binding = @legacy_io::@ret0, operand_result_num = 2 : i32} : memref<128x256xf16>
+ linalg.matmul %arg0, %arg1, %ret0 :
+ (memref<128x64xf16>, memref<64x256xf16>, memref<128x256xf16>)
+ return
+ }
+ func @matmul_static_shape__num_workgroups__
+ (!shapex.ranked_shape<[128, 64]>, !shapex.ranked_shape<[64, 256]>,
+ !shapex.ranked_shape<[128, 256]>) -> (index, index, index)
+ attributes {sym_visibility = "private"}
+ hal.interface @legacy_io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
+ }
+}
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 32)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 4) * 32)>
+// CHECK: func @matmul_static_shape
+// CHECK-DAG: %[[ARG0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg0
+// CHECK-DAG: %[[ARG1:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg1
+// CHECK-DAG: %[[RET0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@ret0
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[CST:.+]] = constant 0.0
+// CHECK-DAG: %[[C4:.+]] = constant 4 : index
+// CHECK: %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
+// CHECK: %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
+// CHECK: %[[BOFFSET_Y:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
+// CHECK: %[[SUBVIEW_LHS:.+]] = subview %[[ARG0]]
+// CHECK-SAME: [%[[BOFFSET_Y]], 0] [32, 64]
+// CHECK: %[[BOFFSET_X:.+]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
+// CHECK: %[[SUBVIEW_RHS:.+]] = subview %[[ARG1]]
+// CHECK-SAME: [0, %[[BOFFSET_X]]] [64, 32]
+// CHECK: %[[BOFFSET_Y_2:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
+// CHECK: %[[BOFFSET_X_2:.+]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
+// CHECK: %[[SUBVIEW_RESULT:.+]] = subview %[[RET0]]
+// CHECK-SAME: [%[[BOFFSET_Y_2]], %[[BOFFSET_X_2]]] [32, 32]
+// CHECK: %[[SGID:.+]] = gpu.subgroup_id
+// CHECK: %[[SGID_Y:.+]] = divi_signed %[[SGID]], %[[C4]]
+// CHECK: scf.for %[[IV2:.+]] =
+// CHECK: %[[SGOFFSET_Y:.+]] = affine.apply #[[MAP3]]()[%[[SGID_Y]]]
+// CHECK: %[[SUBVIEW2_LHS:.+]] = subview %[[SUBVIEW_LHS]]
+// CHECK-SAME: [%[[SGOFFSET_Y]], %[[IV2]]] [8, 16]
+// CHECK: %[[SGOFFSET_X:.+]] = affine.apply #[[MAP3]]()[%[[SGID]]]
+// CHECK: %[[SUBVIEW2_RHS:.+]] = subview %[[SUBVIEW_RHS]]
+// CHECK-SAME: [%[[IV2]], %[[SGOFFSET_X]]] [16, 8]
+// CHECK: %[[SGOFFSET_Y_2:.+]] = affine.apply #[[MAP3]]()[%[[SGID_Y]]]
+// CHECK: %[[SGOFFSET_X_2:.+]] = affine.apply #[[MAP3]]()[%[[SGID]]]
+// CHECK: %[[SUBVIEW2_RESULT:.+]] = subview %[[SUBVIEW_RESULT]]
+// CHECK-SAME: [%[[SGOFFSET_Y_2]], %[[SGOFFSET_X_2]]] [8, 8]
+// CHECK: %[[VTR_LHS:.+]] = vector.transfer_read %[[SUBVIEW2_LHS]]
+// CHECK-SAME: [%[[C0]], %[[C0]]], %[[CST]] {masked = [false, false]}
+// CHECK: %[[VTR_RHS:.+]] = vector.transfer_read %[[SUBVIEW2_RHS]]
+// CHECK-SAME: [%[[C0]], %[[C0]]], %[[CST]] {masked = [false, false]}
+// CHECK: %[[VTR_RESULT:.+]] = vector.transfer_read %[[SUBVIEW2_RESULT]]
+// CHECK-SAME: [%[[C0]], %[[C0]]], %[[CST]] {masked = [false, false]}
+// CHECK: %[[VECTOR_CONTRACT:.+]] = vector.contract
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
+// CHECK-SAME: vector<8x16xf16>, vector<16x8xf16> into vector<8x8xf16>
+// CHECK: vector.transfer_write %[[VECTOR_CONTRACT]], %[[SUBVIEW2_RESULT]]
+// CHECK-SAME: masked = [false, false]
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test.mlir
new file mode 100644
index 0000000..ddd2530
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test.mlir
@@ -0,0 +1,37 @@
+// RUN: iree-opt -split-input-file -iree-codegen-linalg-to-spirv-pipeline=use-vectorization %s | IreeFileCheck %s
+
+module attributes {
+ spv.target_env =
+ #spv.target_env<#spv.vce<v1.3,
+ [Float16, Shader, CooperativeMatrixNV],
+ [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>,
+ {max_compute_workgroup_invocations = 512 : i32,
+ max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
+ func @matmul_static_shape()
+ attributes {vkspv.num_workgroups_fn = @matmul_static_shape__num_workgroups__} {
+ %0 = iree.placeholder for "interface buffer"
+ {binding = @legacy_io::@arg0, operand_result_num = 0} : memref<128x64xf16>
+ %1 = iree.placeholder for "interface buffer"
+ {binding = @legacy_io::@arg1, operand_result_num = 1} : memref<64x256xf16>
+ %2 = iree.placeholder for "interface buffer"
+ {binding = @legacy_io::@ret0, operand_result_num = 2} : memref<128x256xf16>
+ linalg.matmul %0, %1, %2 :
+ (memref<128x64xf16>, memref<64x256xf16>, memref<128x256xf16>)
+ return
+ }
+ func @matmul_static_shape__num_workgroups__
+ (!shapex.ranked_shape<[128, 64]>, !shapex.ranked_shape<[64, 256]>,
+ !shapex.ranked_shape<[128, 256]>) -> (index, index, index)
+ attributes {sym_visibility = "private"}
+ hal.interface @legacy_io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+ }
+}
+// CHECK-LABEL: spv.func @matmul_static_shape
+// CHECK: spv.CooperativeMatrixLoadNV
+// CHECK: spv.CooperativeMatrixLoadNV
+// CHECK: spv.CooperativeMatrixLoadNV
+// CHECK: spv.CooperativeMatrixMulAddNV
+// CHECK: spv.CooperativeMatrixStoreNV
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir
index 9ff21ba..5b1d440 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir
@@ -5,8 +5,7 @@
#spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
{max_compute_workgroup_invocations = 128 : i32,
max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
- func @matmul_tile()
- attributes {signature = (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>)} {
+ func @matmul_tile() attributes {vkspv.num_workgroups_fn = @matmul_tile__num_workgroups__} {
%0 = iree.placeholder for "interace buffer"
{binding = @legacy_io::@arg0, operand_result_index = 0 : i32} : memref<?x?xf32>
%1 = iree.placeholder for "interace buffer"
@@ -16,6 +15,10 @@
linalg.matmul %0, %1, %2 : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
return
}
+ func @matmul_tile__num_workgroups__
+ (!shapex.ranked_shape<[?,?]>, !shapex.ranked_shape<[?,?]>,
+ !shapex.ranked_shape<[?,?]>) -> (index, index, index)
+ attributes {sym_visibility = "private"}
hal.interface @legacy_io attributes {sym_visibility = "private"} {
hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
@@ -30,9 +33,9 @@
// CHECK: %[[ARG0SV:.+]] = subview %[[ARG0]]
// CHECK: %[[ARG1SV:.+]] = subview %[[ARG1]]
// CHECK: %[[RET0SV:.+]] = subview %[[RET0]]
-// CHECK: %[[ALLOC1:.+]] = alloc() : memref<8x4xf32, 3>
+// CHECK: %[[ALLOC1:.+]] = alloc() : memref<8x32xf32, 3>
// CHECK: %[[SUBVIEW1:.+]] = subview %[[ALLOC1]]
-// CHECK: %[[ALLOC2:.+]] = alloc() : memref<4x8xf32, 3>
+// CHECK: %[[ALLOC2:.+]] = alloc() : memref<32x16xf32, 3>
// CHECK: %[[SUBVIEW2:.+]] = subview %[[ALLOC2]]
// CHECK: linalg.copy(%[[ARG0SV]], %[[SUBVIEW1]])
// CHECK-SAME: "copy_to_workgroup_memory"
@@ -41,8 +44,8 @@
// CHECK: linalg.matmul
// CHECK-SAME: "workgroup_memory_numprocs_ge_numiters"
// CHECK-SAME: %[[SUBVIEW1]], %[[SUBVIEW2]], %[[RET0SV]]
-// CHECK-DAG: dealloc %[[ALLOC1]] : memref<8x4xf32, 3>
-// CHECK-DAG: dealloc %[[ALLOC2]] : memref<4x8xf32, 3>
+// CHECK-DAG: dealloc %[[ALLOC1]] : memref<8x32xf32, 3>
+// CHECK-DAG: dealloc %[[ALLOC2]] : memref<32x16xf32, 3>
// -----
@@ -51,8 +54,7 @@
#spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
{max_compute_workgroup_invocations = 128 : i32,
max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
- func @conv_no_padding_tile()
- attributes {signature = (tensor<3x4x3x2xf32>, tensor<?x?x?x3xf32>) -> (tensor<?x?x?x2xf32>)} {
+ func @conv_no_padding_tile() {
%0 = iree.placeholder for "interace buffer"
{binding = @legacy_io::@arg0, operand_result_index = 0 : i32} : memref<3x4x3x2xf32>
%1 = iree.placeholder for "interace buffer"
@@ -63,6 +65,11 @@
: memref<3x4x3x2xf32>, memref<?x?x?x3xf32>, memref<?x?x?x2xf32>
return
}
+ hal.interface @legacy_io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
+ }
}
// CHECK: func @conv_no_padding_tile()
// CHECK-DAG: %[[ARG0:.+]] = iree.placeholder for "interace buffer" {binding = @legacy_io::@arg0
diff --git a/iree/compiler/Conversion/LinalgToVector/BUILD b/iree/compiler/Conversion/LinalgToVector/BUILD
new file mode 100644
index 0000000..95676fb
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToVector/BUILD
@@ -0,0 +1,44 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "LinalgToVector",
+ srcs = [
+ "LoadStoreVectorization.cpp",
+ ],
+ hdrs = [
+ "Passes.h",
+ ],
+ deps = [
+ "//iree/compiler/Conversion/CodegenUtils",
+ "//iree/compiler/Dialect/IREE/IR",
+ "//iree/compiler/Dialect/Shape/IR",
+ "//iree/compiler/Dialect/Shape/Transforms",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:LinalgOps",
+ "@llvm-project//mlir:LinalgTransforms",
+ "@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:StandardOps",
+ "@llvm-project//mlir:Support",
+ "@llvm-project//mlir:Transforms",
+ "@llvm-project//mlir:VectorOps",
+ ],
+)
diff --git a/iree/compiler/Conversion/LinalgToVector/CMakeLists.txt b/iree/compiler/Conversion/LinalgToVector/CMakeLists.txt
new file mode 100644
index 0000000..1c7e227
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToVector/CMakeLists.txt
@@ -0,0 +1,39 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+iree_add_all_subdirs()
+
+iree_cc_library(
+ NAME
+ LinalgToVector
+ HDRS
+ "Passes.h"
+ SRCS
+ "LoadStoreVectorization.cpp"
+ DEPS
+ LLVMSupport
+ MLIRIR
+ MLIRLinalgOps
+ MLIRLinalgTransforms
+ MLIRPass
+ MLIRStandardOps
+ MLIRSupport
+ MLIRTransforms
+ MLIRVector
+ iree::compiler::Conversion::CodegenUtils
+ iree::compiler::Dialect::IREE::IR
+ iree::compiler::Dialect::Shape::IR
+ iree::compiler::Dialect::Shape::Transforms
+ PUBLIC
+)
diff --git a/iree/compiler/Conversion/LinalgToVector/LoadStoreVectorization.cpp b/iree/compiler/Conversion/LinalgToVector/LoadStoreVectorization.cpp
new file mode 100644
index 0000000..65c3da7
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToVector/LoadStoreVectorization.cpp
@@ -0,0 +1,328 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/compiler/Conversion/LinalgToVector/Passes.h"
+#include "iree/compiler/Dialect/IREE/IR/IREEDialect.h"
+#include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+/// Returns the bitwidth of a scalar or vector type.
+static Optional<unsigned> getBitWidth(Type type) {
+ if (type.isIntOrFloat()) {
+ return type.getIntOrFloatBitWidth();
+ } else if (type.isa<VectorType>()) {
+ auto vecType = type.cast<VectorType>();
+ auto elementType = vecType.getElementType();
+ return elementType.getIntOrFloatBitWidth() * vecType.getNumElements();
+ }
+ return {};
+}
+
+constexpr int kVectorizationSizeInBits = 128;
+constexpr int kVecSize = kVectorizationSizeInBits / (sizeof(float) * 8);
+
+/// Returns a VectorType in `kVectorizationSizeInBits` bits if `t` is a scalar.
+static VectorType getVecType(OpBuilder &builder, Type t) {
+ if (!t.isa<IntegerType, FloatType>()) return {};
+ if (t.getIntOrFloatBitWidth() != 32) return {};
+ Type newElemType = t.isa<IntegerType>() ? builder.getI32Type().cast<Type>()
+ : builder.getF32Type().cast<Type>();
+ return VectorType::get(kVecSize, newElemType);
+}
+
+/// Returns the memref of vector converted from `type`.
+static MemRefType getVectorizedMemRefType(OpBuilder &builder, MemRefType type) {
+ Type elemType = type.getElementType();
+ VectorType vecType = getVecType(builder, elemType);
+ if (!vecType) return {};
+ unsigned elemSize = elemType.getIntOrFloatBitWidth();
+ unsigned vecSize = kVectorizationSizeInBits / elemSize;
+ SmallVector<int64_t, 2> newShape(type.getShape().begin(),
+ type.getShape().end());
+ if (newShape.empty()) return {};
+ if (newShape.back() % vecSize != 0) return {};
+ newShape.back() = newShape.back() / vecSize;
+ return MemRefType::get(newShape, vecType, {}, type.getMemorySpace());
+}
+
+/// Returns a vectorized `val`, ie, the result type is a VectorType.
+static Value legalizeToVectorType(OpBuilder &builder, Value val) {
+ Type type = val.getType();
+ if (type.isa<VectorType>()) {
+ return val;
+ } else if (type.isIntOrFloat()) {
+ auto vecType = getVecType(builder, type);
+ if (!vecType) return nullptr;
+ // TODO(hanchung): Add a folder on vector::BroadcastOp so we don't need to
+ // create manually.
+ if (auto cst = val.getDefiningOp<ConstantOp>()) {
+ auto cstVecValue = DenseElementsAttr::get(vecType, cst.value());
+ return builder.create<ConstantOp>(val.getLoc(), vecType, cstVecValue)
+ .getResult();
+ }
+ return builder.create<vector::BroadcastOp>(val.getLoc(), vecType, val)
+ .getResult();
+ }
+ return nullptr;
+}
+
+/// Base class to vectorize std ops. If a generic op is vectorized, all the std
+/// ops in the region should be vectorized as well.
+///
+/// This base class handles the check on operands and vectorization for all the
+/// operands.
+///
+/// All derived classes implement a static apply method with the following
+/// signature:
+///
+/// ```c++
+/// LogicalResult apply(SrcOpTy op, ArrayRef<Value> args,
+/// ConversionPatternRewriter& rewriter) const;
+/// ```
+template <typename DerivedTy, typename SrcOpTy>
+struct VectorizeOpBase : public OpConversionPattern<SrcOpTy> {
+ using OpConversionPattern<SrcOpTy>::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ SrcOpTy op, ArrayRef<Value> args,
+ ConversionPatternRewriter &rewriter) const override {
+ if (llvm::all_of(args, [](Value arg) {
+ return arg.getType().isIntOrIndexOrFloat();
+ })) {
+ return failure();
+ }
+ SmallVector<Value, 4> vecArgs;
+ for (Value arg : args) {
+ Value val = legalizeToVectorType(rewriter, arg);
+ if (!val) return failure();
+ vecArgs.push_back(val);
+ }
+ return static_cast<DerivedTy const *>(this)->apply(op, vecArgs, rewriter);
+ }
+};
+
+template <typename OpTy>
+struct VectorizeElementwiseOp
+ : public VectorizeOpBase<VectorizeElementwiseOp<OpTy>, OpTy> {
+ using VectorizeOpBase<VectorizeElementwiseOp<OpTy>, OpTy>::VectorizeOpBase;
+ LogicalResult apply(OpTy op, ArrayRef<Value> args,
+ ConversionPatternRewriter &rewriter) const {
+ auto vecType = getVecType(rewriter, op.getResult().getType());
+ if (!vecType) return failure();
+ auto newOp = rewriter.create<OpTy>(op.getLoc(), vecType, args);
+ rewriter.replaceOp(op, newOp.getOperation()->getResults());
+ return success();
+ }
+};
+
+template <typename OpTy>
+struct VectorizeCmpOp : public VectorizeOpBase<VectorizeCmpOp<OpTy>, OpTy> {
+ using VectorizeOpBase<VectorizeCmpOp<OpTy>, OpTy>::VectorizeOpBase;
+ LogicalResult apply(OpTy op, ArrayRef<Value> args,
+ ConversionPatternRewriter &rewriter) const {
+ auto newOp =
+ rewriter.create<OpTy>(op.getLoc(), op.predicate(), args[0], args[1]);
+ rewriter.replaceOp(op, newOp.getResult());
+ return success();
+ }
+};
+
+struct VectorizeSelectOp
+ : public VectorizeOpBase<VectorizeSelectOp, mlir::SelectOp> {
+ using VectorizeOpBase<VectorizeSelectOp, mlir::SelectOp>::VectorizeOpBase;
+ LogicalResult apply(mlir::SelectOp op, ArrayRef<Value> args,
+ ConversionPatternRewriter &rewriter) const {
+ auto newOp =
+ rewriter.create<SelectOp>(op.getLoc(), args[0], args[1], args[2]);
+ rewriter.replaceOp(op, newOp.getResult());
+ return success();
+ }
+};
+
+struct VectorizeGenericOp : public OpConversionPattern<linalg::GenericOp> {
+ using OpConversionPattern<linalg::GenericOp>::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ linalg::GenericOp genericOp, ArrayRef<Value> args,
+ ConversionPatternRewriter &rewriter) const override {
+ if (llvm::any_of(genericOp.iterator_types(), [](Attribute attr) {
+ return attr.cast<StringAttr>().getValue() !=
+ getParallelIteratorTypeName();
+ })) {
+ return failure();
+ }
+
+ // Do not vectorize if one of the operand is 0-D or one of the operand is
+ // not iterated on contiguous memory.
+ for (auto map : genericOp.getIndexingMaps()) {
+ if (map.getNumResults() == 0) return failure();
+ AffineDimExpr innerMostExpr =
+ map.getResults().back().dyn_cast<AffineDimExpr>();
+ if (!innerMostExpr ||
+ innerMostExpr.getPosition() != map.getNumDims() - 1) {
+ return failure();
+ }
+ }
+
+ SmallVector<IREE::PlaceholderOp, 4> operands;
+ SmallVector<MemRefType, 4> vecMemRefs;
+ for (auto operand : args) {
+ auto op = operand.getDefiningOp<IREE::PlaceholderOp>();
+ if (!op) return failure();
+ if (!op.getOperation()->hasOneUse()) return failure();
+ auto memrefType = op.getResult().getType().dyn_cast<MemRefType>();
+ if (!memrefType) return failure();
+ auto vecMemRef = getVectorizedMemRefType(rewriter, memrefType);
+ if (!vecMemRef) return failure();
+ operands.push_back(op);
+ vecMemRefs.push_back(vecMemRef);
+ }
+
+ SmallVector<Value, 4> newArgs;
+ for (auto it : llvm::zip(operands, vecMemRefs)) {
+ IREE::PlaceholderOp placeholder = std::get<0>(it);
+ MemRefType vecMemRef = std::get<1>(it);
+ auto arg = rewriter.create<IREE::PlaceholderOp>(placeholder.getLoc(),
+ vecMemRef, ValueRange{},
+ placeholder.getAttrs());
+ rewriter.replaceOp(placeholder, arg.getResult());
+ newArgs.push_back(arg.getResult());
+ }
+
+ auto newOp = rewriter.create<linalg::GenericOp>(
+ genericOp.getLoc(), genericOp.getResultTypes(), newArgs,
+ rewriter.getI64IntegerAttr(genericOp.getNumInputs()),
+ rewriter.getI64IntegerAttr(genericOp.getNumOutputs()),
+ genericOp.indexing_mapsAttr(), genericOp.iterator_types(),
+ /*doc=*/nullptr,
+ /*library_call=*/nullptr,
+ /*symbol_source=*/nullptr);
+
+ Region &newRegion = newOp.region();
+ rewriter.inlineRegionBefore(genericOp.getRegion(), newRegion,
+ newRegion.end());
+ Block &newBlock = newOp.region().front();
+ TypeConverter::SignatureConversion signatureConverter(
+ newBlock.getNumArguments());
+ for (auto arg : llvm::enumerate(vecMemRefs)) {
+ signatureConverter.addInputs(arg.index(), arg.value().getElementType());
+ }
+ rewriter.applySignatureConversion(&newOp.region(), signatureConverter);
+ rewriter.replaceOp(genericOp, newOp.getResults());
+ return success();
+ }
+};
+
+struct LoadStoreVectorizationPass
+ : public PassWrapper<LoadStoreVectorizationPass, OperationPass<FuncOp>> {
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<linalg::LinalgDialect, vector::VectorDialect>();
+ }
+
+ void runOnOperation() override {
+ MLIRContext *context = &getContext();
+ OwningRewritePatternList patterns;
+ // clang-format off
+ patterns.insert<
+ VectorizeGenericOp,
+ VectorizeCmpOp<CmpFOp>,
+ VectorizeCmpOp<CmpIOp>,
+ VectorizeSelectOp,
+ VectorizeElementwiseOp<AbsFOp>,
+ VectorizeElementwiseOp<AndOp>,
+ VectorizeElementwiseOp<OrOp>,
+ VectorizeElementwiseOp<XOrOp>,
+ VectorizeElementwiseOp<AddFOp>,
+ VectorizeElementwiseOp<AddIOp>,
+ VectorizeElementwiseOp<CeilFOp>,
+ VectorizeElementwiseOp<CosOp>,
+ VectorizeElementwiseOp<DivFOp>,
+ VectorizeElementwiseOp<ExpOp>,
+ VectorizeElementwiseOp<FPExtOp>,
+ VectorizeElementwiseOp<FPToSIOp>,
+ VectorizeElementwiseOp<FPTruncOp>,
+ VectorizeElementwiseOp<FloorFOp>,
+ VectorizeElementwiseOp<LogOp>,
+ VectorizeElementwiseOp<MulFOp>,
+ VectorizeElementwiseOp<MulIOp>,
+ VectorizeElementwiseOp<NegFOp>,
+ VectorizeElementwiseOp<RemFOp>,
+ VectorizeElementwiseOp<RsqrtOp>,
+ VectorizeElementwiseOp<SIToFPOp>,
+ VectorizeElementwiseOp<ShiftLeftOp>,
+ VectorizeElementwiseOp<SignExtendIOp>,
+ VectorizeElementwiseOp<SignedDivIOp>,
+ VectorizeElementwiseOp<SignedShiftRightOp>,
+ VectorizeElementwiseOp<SinOp>,
+ VectorizeElementwiseOp<SqrtOp>,
+ VectorizeElementwiseOp<SubFOp>,
+ VectorizeElementwiseOp<SubIOp>,
+ VectorizeElementwiseOp<TanhOp>,
+ VectorizeElementwiseOp<TruncateIOp>,
+ VectorizeElementwiseOp<UnsignedDivIOp>,
+ VectorizeElementwiseOp<UnsignedRemIOp>,
+ VectorizeElementwiseOp<UnsignedShiftRightOp>>(context);
+ // clang-format on
+
+ ConversionTarget target(*context);
+ // Mark vector dialect and plancholder op legal.
+ target.addLegalDialect<vector::VectorDialect>();
+ target.addLegalOp<IREE::PlaceholderOp>();
+
+ // If a generic op is vectorized, it is legal.
+ target.addDynamicallyLegalOp<linalg::GenericOp>([](linalg::GenericOp op) {
+ if (!op.hasBufferSemantics()) return false;
+ for (auto arg : op.getOperands()) {
+ if (arg.getType()
+ .cast<MemRefType>()
+ .getElementType()
+ .isSignlessIntOrFloat())
+ return false;
+ }
+ return true;
+ });
+
+ // Mark all standard ops legal if they are operating on vector types.
+ target.addDynamicallyLegalDialect<mlir::StandardOpsDialect>(
+ Optional<ConversionTarget::DynamicLegalityCallbackFn>(
+ [](Operation *op) {
+ auto isVectorType = [](Type t) { return t.isa<VectorType>(); };
+ return llvm::any_of(op->getOperandTypes(), isVectorType) ||
+ llvm::any_of(op->getResultTypes(), isVectorType);
+ }));
+ if (failed(applyPartialConversion(getOperation(), target, patterns)))
+ return signalPassFailure();
+ }
+};
+} // namespace
+
+std::unique_ptr<Pass> createLoadStoreVectorizationPass() {
+ return std::make_unique<LoadStoreVectorizationPass>();
+}
+
+static PassRegistration<LoadStoreVectorizationPass> pass(
+ "iree-codegen-vectorize-linalg-ops", "Vectorize Linalg operations");
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToVector/Passes.h b/iree/compiler/Conversion/LinalgToVector/Passes.h
new file mode 100644
index 0000000..98f07bd
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToVector/Passes.h
@@ -0,0 +1,29 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_COMPILER_CONVERSION_LINALGTOVECTOR_PASSES_H_
+#define IREE_COMPILER_CONVERSION_LINALGTOVECTOR_PASSES_H_
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+/// Creates a pass to vectorize Linalg operations.
+std::unique_ptr<Pass> createLoadStoreVectorizationPass();
+
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_CONVERSION_LINALGTOVECTOR_PASSES_H_
diff --git a/iree/compiler/Conversion/LinalgToVector/test/BUILD b/iree/compiler/Conversion/LinalgToVector/test/BUILD
new file mode 100644
index 0000000..e124cf5
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToVector/test/BUILD
@@ -0,0 +1,32 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Tests for common transforms.
+
+load("//iree:lit_test.bzl", "iree_lit_test_suite")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_lit_test_suite(
+ name = "lit",
+ srcs = glob(["*.mlir"]),
+ data = [
+ "//iree/tools:IreeFileCheck",
+ "//iree/tools:iree-opt",
+ ],
+)
diff --git a/iree/compiler/Conversion/LinalgToVector/test/CMakeLists.txt b/iree/compiler/Conversion/LinalgToVector/test/CMakeLists.txt
new file mode 100644
index 0000000..fcc538b
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToVector/test/CMakeLists.txt
@@ -0,0 +1,26 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+iree_add_all_subdirs()
+
+file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir)
+iree_lit_test_suite(
+ NAME
+ lit
+ SRCS
+ "${_GLOB_X_MLIR}"
+ DATA
+ iree::tools::IreeFileCheck
+ iree::tools::iree-opt
+)
diff --git a/iree/compiler/Conversion/LinalgToVector/test/vectorize_linalg_ops.mlir b/iree/compiler/Conversion/LinalgToVector/test/vectorize_linalg_ops.mlir
new file mode 100644
index 0000000..8b72941
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToVector/test/vectorize_linalg_ops.mlir
@@ -0,0 +1,115 @@
+// RUN: iree-opt -split-input-file -iree-codegen-vectorize-linalg-ops -canonicalize -cse %s | IreeFileCheck %s
+
+func @broadcast_add() {
+ %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<4xf32>
+ %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<3x4xf32>
+ %2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<3x4xf32>
+ linalg.generic {args_in = 2 : i64,
+ args_out = 1 : i64,
+ indexing_maps = [affine_map<(d0, d1) -> (d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]
+ } %0, %1, %2 {
+ ^bb0(%arg0: f32, %arg1: f32, %arg2: f32): // no predecessors
+ %3 = addf %arg0, %arg1 : f32
+ linalg.yield %3 : f32
+ }: memref<4xf32>, memref<3x4xf32>, memref<3x4xf32>
+ return
+}
+// CHECK-LABEL: func @broadcast_add
+// CHECK-DAG: %[[BUF0:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<1xvector<4xf32>>
+// CHECK-DAG: %[[BUF1:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<3x1xvector<4xf32>>
+// CHECK-DAG: %[[BUF2:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<3x1xvector<4xf32>>
+// CHECK: linalg.generic
+// CHECK-SAME: %[[BUF0]], %[[BUF1]], %[[BUF2]]
+// CHECK: ^bb0(%[[ARG0:.+]]: vector<4xf32>, %[[ARG1:.+]]: vector<4xf32>, %[[ARG2:.+]]: vector<4xf32>)
+// CHECK: %[[RES:.+]] = addf %[[ARG0]], %[[ARG1]] : vector<4xf32>
+// CHECK: linalg.yield %[[RES]] : vector<4xf32>
+
+// -----
+
+func @log_plus_one() {
+ %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<4xf32>
+ %c0 = constant 0 : index
+ %cst = constant 1.000000e+00 : f32
+ %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<4xf32>
+ linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} %1, %0 {
+ ^bb0(%arg0: f32, %arg1: f32): // no predecessors
+ %2 = addf %arg0, %cst : f32
+ %3 = log %2 : f32
+ linalg.yield %3 : f32
+ }: memref<4xf32>, memref<4xf32>
+ return
+}
+// CHECK-LABEL: func @log_plus_one
+// CHECK-DAG: %[[BUF0:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<1xvector<4xf32>>
+// CHECK-DAG: %[[BUF1:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<1xvector<4xf32>>
+// CHECK-DAG: %[[CST:.+]] = constant dense<1.000000e+00> : vector<4xf32>
+// CHECK: linalg.generic
+// CHECK-SAME: %[[BUF0]], %[[BUF1]]
+// CHECK: ^bb0(%[[ARG0:.+]]: vector<4xf32>, %[[ARG1:.+]]: vector<4xf32>)
+// CHECK: %[[T1:.+]] = addf %[[ARG0]], %[[CST]] : vector<4xf32>
+// CHECK: %[[T2:.+]] = log %[[T1]] : vector<4xf32>
+// CHECK: linalg.yield %[[T2]] : vector<4xf32>
+
+// -----
+
+func @cmp_and_select() {
+ %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<4xi32>
+ %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<4xi32>
+ %2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<4xi32>
+ linalg.generic {args_in = 2 : i64, args_out = 1 : i64, indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} %1, %2, %0 {
+ ^bb0(%arg0: i32, %arg1: i32, %arg2: i32): // no predecessors
+ %3 = cmpi "sgt", %arg0, %arg1 : i32
+ %4 = select %3, %arg0, %arg1 : i32
+ linalg.yield %4 : i32
+ }: memref<4xi32>, memref<4xi32>, memref<4xi32>
+ return
+}
+// CHECK-LABEL: func @cmp_and_select
+// CHECK-DAG: %[[BUF0:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<1xvector<4xi32>>
+// CHECK-DAG: %[[BUF1:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<1xvector<4xi32>>
+// CHECK-DAG: %[[BUF2:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<1xvector<4xi32>>
+// CHECK: linalg.generic
+// CHECK-SAME: %[[BUF0]], %[[BUF1]], %[[BUF2]]
+// CHECK: ^bb0(%[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>, %[[ARG2:.+]]: vector<4xi32>)
+// CHECK: %[[T1:.+]] = cmpi "sgt", %[[ARG0]], %[[ARG1]] : vector<4xi32>
+// CHECK: %[[T2:.+]] = select %[[T1]], %[[ARG0]], %[[ARG1]] : vector<4xi1>, vector<4xi32>
+// CHECK: linalg.yield %[[T2]] : vector<4xi32>
+
+// -----
+
+func @not_contiguous() {
+ %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<4x4xf32>
+ %c0 = constant 0 : index
+ %cst = constant 1.000000e+00 : f32
+ %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<4x4xf32>
+ linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} %1, %0 {
+ ^bb0(%arg0: f32, %arg1: f32): // no predecessors
+ %2 = addf %arg0, %cst : f32
+ linalg.yield %2 : f32
+ }: memref<4x4xf32>, memref<4x4xf32>
+ return
+}
+// CHECK-LABEL: func @not_contiguous
+// CHECK-DAG: iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<4x4xf32>
+// CHECK-DAG: iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<4x4xf32>
+
+// -----
+
+func @not_4s() {
+ %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<4x3xf32>
+ %c0 = constant 0 : index
+ %cst = constant 1.000000e+00 : f32
+ %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<4x3xf32>
+ linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} %1, %0 {
+ ^bb0(%arg0: f32, %arg1: f32): // no predecessors
+ %2 = addf %arg0, %cst : f32
+ linalg.yield %2 : f32
+ }: memref<4x3xf32>, memref<4x3xf32>
+ return
+}
+// CHECK-LABEL: func @not_4s
+// CHECK-DAG: iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<4x3xf32>
+// CHECK-DAG: iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<4x3xf32>
diff --git a/iree/compiler/Conversion/init_conversions.h b/iree/compiler/Conversion/init_conversions.h
index 2978c69..c06e2d8 100644
--- a/iree/compiler/Conversion/init_conversions.h
+++ b/iree/compiler/Conversion/init_conversions.h
@@ -19,6 +19,7 @@
#include "iree/compiler/Conversion/HLOToLinalg/Passes.h"
#include "iree/compiler/Conversion/LinalgToLLVM/Passes.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h"
+#include "iree/compiler/Conversion/LinalgToVector/Passes.h"
namespace mlir {
namespace iree_compiler {
@@ -33,11 +34,19 @@
createHLOToLinalgOnTensorsPass();
}
+inline void registerLinalgToVectorPasses() {
+ static bool init_once = []() {
+ createLoadStoreVectorizationPass();
+ return true;
+ }();
+ (void)init_once;
+}
+
inline void registerLinalgToSPIRVPasses() {
static bool init_once = []() {
// LinalgToSPIRV
createConvertToGPUPass();
- createLinalgTileAndFusePass();
+ createLinalgTileAndFusePass(SPIRVCodegenOptions());
createSplitDispatchFunctionPass();
createVectorToGPUPass();
createMatMulTileAndVectorizeGPUPass();
diff --git a/iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.cpp b/iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.cpp
index 2017d60..1226491 100644
--- a/iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.cpp
@@ -221,10 +221,11 @@
// Ordinals are fixed based on the precomputed schedule, so use
// CommandBufferDispatchOp instead of CommandBufferDispatchSymbolOp.
+ int32_t entryPointOrdinal = it.index();
builder.create<IREE::HAL::CommandBufferDispatchOp>(
loc, commandBuffer, executable,
- builder.getI32IntegerAttr(/*entryPointOrdinal=*/it.index()),
- workgroupCount[0], workgroupCount[1], workgroupCount[2]);
+ builder.getI32IntegerAttr(entryPointOrdinal), workgroupCount[0],
+ workgroupCount[1], workgroupCount[2]);
if (it.index() + 1 != spvEntryPointFns.size()) {
recordFullExecutionBarrier(commandBuffer, loc, builder);
}
diff --git a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
index 4a9ac26..582c165 100644
--- a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
@@ -48,6 +48,12 @@
// llvm::cl::OptionCategory halVulkanSPIRVOptionsCategory(
// "IREE Vulkan/SPIR-V backend options");
+ static llvm::cl::opt<bool> clUseVectorPass(
+ "iree-spirv-use-vector-pass",
+ llvm::cl::desc(
+ "Enable use of Linalg vectorization in SPIR-V code generation"),
+ llvm::cl::init(false));
+
static llvm::cl::opt<bool> clUseWorkgroupMemory(
"iree-spirv-use-workgroup-memory",
llvm::cl::desc(
@@ -77,6 +83,7 @@
targetOptions.codegenOptions.tileSizes.assign(clTileSizes.begin(),
clTileSizes.end());
targetOptions.codegenOptions.useWorkgroupMemory = clUseWorkgroupMemory;
+ targetOptions.codegenOptions.useVectorPass = clUseVectorPass;
targetOptions.vulkanTargetEnv = clVulkanTargetEnv;
return targetOptions;
}
diff --git a/iree/compiler/Dialect/VM/Target/C/CMakeLists.txt b/iree/compiler/Dialect/VM/Target/C/CMakeLists.txt
index ae197cc..0b0b4da 100644
--- a/iree/compiler/Dialect/VM/Target/C/CMakeLists.txt
+++ b/iree/compiler/Dialect/VM/Target/C/CMakeLists.txt
@@ -15,7 +15,7 @@
if(${IREE_ENABLE_EMITC})
iree_add_all_subdirs()
-
+
iree_cc_library(
NAME
C
diff --git a/iree/test/e2e/vulkan_specific/BUILD b/iree/test/e2e/vulkan_specific/BUILD
index 9686098..f4741ac 100644
--- a/iree/test/e2e/vulkan_specific/BUILD
+++ b/iree/test/e2e/vulkan_specific/BUILD
@@ -45,3 +45,15 @@
driver = "vulkan",
target_backend = "vulkan-spirv",
)
+
+iree_check_single_backend_test_suite(
+ name = "check_vulkan-spirv_vulkan_vector",
+ srcs = [
+ "compare.mlir",
+ "log_plus_one.mlir",
+ "pw_add_multiwg.mlir",
+ ],
+ compiler_flags = ["-iree-spirv-use-vector-pass"],
+ driver = "vulkan",
+ target_backend = "vulkan-spirv",
+)
diff --git a/iree/test/e2e/vulkan_specific/CMakeLists.txt b/iree/test/e2e/vulkan_specific/CMakeLists.txt
index 4260b07..d5bd481 100644
--- a/iree/test/e2e/vulkan_specific/CMakeLists.txt
+++ b/iree/test/e2e/vulkan_specific/CMakeLists.txt
@@ -41,3 +41,18 @@
COMPILER_FLAGS
"-iree-spirv-use-workgroup-memory"
)
+
+iree_check_single_backend_test_suite(
+ NAME
+ check_vulkan-spirv_vulkan_vector
+ SRCS
+ "compare.mlir"
+ "log_plus_one.mlir"
+ "pw_add_multiwg.mlir"
+ TARGET_BACKEND
+ vulkan-spirv
+ DRIVER
+ vulkan
+ COMPILER_FLAGS
+ "-iree-spirv-use-vector-pass"
+)
diff --git a/iree/test/e2e/vulkan_specific/compare.mlir b/iree/test/e2e/vulkan_specific/compare.mlir
new file mode 100644
index 0000000..099670a
--- /dev/null
+++ b/iree/test/e2e/vulkan_specific/compare.mlir
@@ -0,0 +1,164 @@
+func @compare_tensor() attributes { iree.module.export } {
+ %lhs = iree.unfoldable_constant dense<[1, 2, 7, 4]> : tensor<4xi32>
+ %rhs = iree.unfoldable_constant dense<[5, 2, 3, 4]> : tensor<4xi32>
+ %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+ %c0 = iree.unfoldable_constant dense<0> : tensor<4xi8>
+ %c1 = iree.unfoldable_constant dense<1> : tensor<4xi8>
+ %output = "mhlo.select"(%result, %c1, %c0) : (tensor<4xi1>, tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8>
+ check.expect_eq_const(%output, dense<[0, 1, 0, 1]> : tensor<4xi8>) : tensor<4xi8>
+ return
+}
+
+func @compare_scalar() attributes { iree.module.export } {
+ %lhs = iree.unfoldable_constant dense<1> : tensor<i32>
+ %rhs = iree.unfoldable_constant dense<5> : tensor<i32>
+ %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ %c0 = iree.unfoldable_constant dense<0> : tensor<i8>
+ %c1 = iree.unfoldable_constant dense<1> : tensor<i8>
+ %output = "mhlo.select"(%result, %c1, %c0) : (tensor<i1>, tensor<i8>, tensor<i8>) -> tensor<i8>
+ check.expect_eq_const(%output, dense<0> : tensor<i8>) : tensor<i8>
+ return
+}
+
+func @compare_i8() attributes { iree.module.export } {
+ %lhs = iree.unfoldable_constant dense<1> : tensor<i8>
+ %rhs = iree.unfoldable_constant dense<5> : tensor<i8>
+ %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"} : (tensor<i8>, tensor<i8>) -> tensor<i1>
+ %c0 = iree.unfoldable_constant dense<0> : tensor<i8>
+ %c1 = iree.unfoldable_constant dense<1> : tensor<i8>
+ %output = "mhlo.select"(%result, %c1, %c0) : (tensor<i1>, tensor<i8>, tensor<i8>) -> tensor<i8>
+ check.expect_eq_const(%output, dense<0> : tensor<i8>) : tensor<i8>
+ return
+}
+
+func @compare_i16() attributes { iree.module.export } {
+ %lhs = iree.unfoldable_constant dense<1> : tensor<i16>
+ %rhs = iree.unfoldable_constant dense<5> : tensor<i16>
+ %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"} : (tensor<i16>, tensor<i16>) -> tensor<i1>
+ %c0 = iree.unfoldable_constant dense<0> : tensor<i8>
+ %c1 = iree.unfoldable_constant dense<1> : tensor<i8>
+ %output = "mhlo.select"(%result, %c1, %c0) : (tensor<i1>, tensor<i8>, tensor<i8>) -> tensor<i8>
+ check.expect_eq_const(%output, dense<0> : tensor<i8>) : tensor<i8>
+ return
+}
+
+func @compare_i32() attributes { iree.module.export } {
+ %lhs = iree.unfoldable_constant dense<1> : tensor<i32>
+ %rhs = iree.unfoldable_constant dense<5> : tensor<i32>
+ %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ %c0 = iree.unfoldable_constant dense<0> : tensor<i8>
+ %c1 = iree.unfoldable_constant dense<1> : tensor<i8>
+ %output = "mhlo.select"(%result, %c1, %c0) : (tensor<i1>, tensor<i8>, tensor<i8>) -> tensor<i8>
+ check.expect_eq_const(%output, dense<0> : tensor<i8>) : tensor<i8>
+ return
+}
+
+func @compare_i64() attributes { iree.module.export } {
+ %lhs = iree.unfoldable_constant dense<1> : tensor<i64>
+ %rhs = iree.unfoldable_constant dense<5> : tensor<i64>
+ %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
+ %c0 = iree.unfoldable_constant dense<0> : tensor<i8>
+ %c1 = iree.unfoldable_constant dense<1> : tensor<i8>
+ %output = "mhlo.select"(%result, %c1, %c0) : (tensor<i1>, tensor<i8>, tensor<i8>) -> tensor<i8>
+ check.expect_eq_const(%output, dense<0> : tensor<i8>) : tensor<i8>
+ return
+}
+
+func @compare_f32() attributes { iree.module.export } {
+ %lhs = iree.unfoldable_constant dense<1.0> : tensor<f32>
+ %rhs = iree.unfoldable_constant dense<5.0> : tensor<f32>
+ %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
+ %c0 = iree.unfoldable_constant dense<0> : tensor<i8>
+ %c1 = iree.unfoldable_constant dense<1> : tensor<i8>
+ %output = "mhlo.select"(%result, %c1, %c0) : (tensor<i1>, tensor<i8>, tensor<i8>) -> tensor<i8>
+ check.expect_eq_const(%output, dense<0> : tensor<i8>) : tensor<i8>
+ return
+}
+
+func @compare_f64() attributes { iree.module.export } {
+ %lhs = iree.unfoldable_constant dense<1.0> : tensor<f64>
+ %rhs = iree.unfoldable_constant dense<5.0> : tensor<f64>
+ %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"} : (tensor<f64>, tensor<f64>) -> tensor<i1>
+ %c0 = iree.unfoldable_constant dense<0> : tensor<i8>
+ %c1 = iree.unfoldable_constant dense<1> : tensor<i8>
+ %output = "mhlo.select"(%result, %c1, %c0) : (tensor<i1>, tensor<i8>, tensor<i8>) -> tensor<i8>
+ check.expect_eq_const(%output, dense<0> : tensor<i8>) : tensor<i8>
+ return
+}
+
+func @compare_tensor_odd_length() attributes { iree.module.export } {
+ %lhs = iree.unfoldable_constant dense<[1, 2, 7]> : tensor<3xi32>
+ %rhs = iree.unfoldable_constant dense<[5, 2, 3]> : tensor<3xi32>
+ %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1>
+ %c0 = iree.unfoldable_constant dense<0> : tensor<3xi8>
+ %c1 = iree.unfoldable_constant dense<1> : tensor<3xi8>
+ %output = "mhlo.select"(%result, %c1, %c0) : (tensor<3xi1>, tensor<3xi8>, tensor<3xi8>) -> tensor<3xi8>
+ check.expect_eq_const(%output, dense<[0, 1, 0]> : tensor<3xi8>) : tensor<3xi8>
+ return
+}
+
+func @compare_eq() attributes { iree.module.export } {
+ %lhs = iree.unfoldable_constant dense<[1, 2, 7, 4]> : tensor<4xi32>
+ %rhs = iree.unfoldable_constant dense<[5, 2, 3, 4]> : tensor<4xi32>
+ %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+ %c0 = iree.unfoldable_constant dense<0> : tensor<4xi8>
+ %c1 = iree.unfoldable_constant dense<1> : tensor<4xi8>
+ %output = "mhlo.select"(%result, %c1, %c0) : (tensor<4xi1>, tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8>
+ check.expect_eq_const(%output, dense<[0, 1, 0, 1]> : tensor<4xi8>) : tensor<4xi8>
+ return
+}
+
+func @compare_ne() attributes { iree.module.export } {
+ %lhs = iree.unfoldable_constant dense<[1, 2, 7, 4]> : tensor<4xi32>
+ %rhs = iree.unfoldable_constant dense<[5, 2, 3, 4]> : tensor<4xi32>
+ %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "NE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+ %c0 = iree.unfoldable_constant dense<0> : tensor<4xi8>
+ %c1 = iree.unfoldable_constant dense<1> : tensor<4xi8>
+ %output = "mhlo.select"(%result, %c1, %c0) : (tensor<4xi1>, tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8>
+ check.expect_eq_const(%output, dense<[1, 0, 1, 0]> : tensor<4xi8>) : tensor<4xi8>
+ return
+}
+
+func @compare_lt() attributes { iree.module.export } {
+ %lhs = iree.unfoldable_constant dense<[1, 2, 7, 4]> : tensor<4xi32>
+ %rhs = iree.unfoldable_constant dense<[5, 2, 3, 4]> : tensor<4xi32>
+ %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "LT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+ %c0 = iree.unfoldable_constant dense<0> : tensor<4xi8>
+ %c1 = iree.unfoldable_constant dense<1> : tensor<4xi8>
+ %output = "mhlo.select"(%result, %c1, %c0) : (tensor<4xi1>, tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8>
+ check.expect_eq_const(%output, dense<[1, 0, 0, 0]> : tensor<4xi8>) : tensor<4xi8>
+ return
+}
+
+func @compare_le() attributes { iree.module.export } {
+ %lhs = iree.unfoldable_constant dense<[1, 2, 7, 4]> : tensor<4xi32>
+ %rhs = iree.unfoldable_constant dense<[5, 2, 3, 4]> : tensor<4xi32>
+ %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "LE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+ %c0 = iree.unfoldable_constant dense<0> : tensor<4xi8>
+ %c1 = iree.unfoldable_constant dense<1> : tensor<4xi8>
+ %output = "mhlo.select"(%result, %c1, %c0) : (tensor<4xi1>, tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8>
+ check.expect_eq_const(%output, dense<[1, 1, 0, 1]> : tensor<4xi8>) : tensor<4xi8>
+ return
+}
+
+func @compare_gt() attributes { iree.module.export } {
+ %lhs = iree.unfoldable_constant dense<[1, 2, 7, 4]> : tensor<4xi32>
+ %rhs = iree.unfoldable_constant dense<[5, 2, 3, 4]> : tensor<4xi32>
+ %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "GT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+ %c0 = iree.unfoldable_constant dense<0> : tensor<4xi8>
+ %c1 = iree.unfoldable_constant dense<1> : tensor<4xi8>
+ %output = "mhlo.select"(%result, %c1, %c0) : (tensor<4xi1>, tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8>
+ check.expect_eq_const(%output, dense<[0, 0, 1, 0]> : tensor<4xi8>) : tensor<4xi8>
+ return
+}
+
+func @compare_ge() attributes { iree.module.export } {
+ %lhs = iree.unfoldable_constant dense<[1, 2, 7, 4]> : tensor<4xi32>
+ %rhs = iree.unfoldable_constant dense<[5, 2, 3, 4]> : tensor<4xi32>
+ %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "GE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+ %c0 = iree.unfoldable_constant dense<0> : tensor<4xi8>
+ %c1 = iree.unfoldable_constant dense<1> : tensor<4xi8>
+ %output = "mhlo.select"(%result, %c1, %c0) : (tensor<4xi1>, tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8>
+ check.expect_eq_const(%output, dense<[0, 1, 1, 1]> : tensor<4xi8>) : tensor<4xi8>
+ return
+}
diff --git a/iree/test/e2e/vulkan_specific/log_plus_one.mlir b/iree/test/e2e/vulkan_specific/log_plus_one.mlir
new file mode 100644
index 0000000..3bba4a9
--- /dev/null
+++ b/iree/test/e2e/vulkan_specific/log_plus_one.mlir
@@ -0,0 +1,6 @@
+func @log_plus_one() attributes { iree.module.export } {
+ %input = iree.unfoldable_constant dense<[0.0, 0.5, 1.0, 5.0]> : tensor<4xf32>
+ %result = "mhlo.log_plus_one"(%input) : (tensor<4xf32>) -> tensor<4xf32>
+ check.expect_almost_eq_const(%result, dense<[0.0, 0.4054651, 0.6931472, 1.7917595]> : tensor<4xf32>) : tensor<4xf32>
+ return
+}