Merge pull request #2351 from GMNGeoffrey:main-to-google
PiperOrigin-RevId: 318861045
diff --git a/SUBMODULE_VERSIONS b/SUBMODULE_VERSIONS
index 9a89b24..137a304 100644
--- a/SUBMODULE_VERSIONS
+++ b/SUBMODULE_VERSIONS
@@ -3,7 +3,7 @@
4c13807b7d43ff0946b7ffea0ae3aee9e611d778 third_party/dear_imgui
a5d9d0f7d368054fd1691aedf1db4116efcc233e third_party/flatbuffers
f2fb48c3b3d79a75a88a99fba6576b25d42ec528 third_party/googletest
-1becd298b82ed2f1a8ba5e61c5ad2ce7fe32d812 third_party/llvm-project
+e34523c87c3f1cfabcf741568dede026bbb12d3a third_party/llvm-project
17b12a4481daa150e2d1ea3ada086b551b856707 third_party/marl
67f3ccebee84f3488b46a8d3ac005178c52ff264 third_party/mlir-emitc
80d452484c5409444b0ec19383faa84bb7a4d351 third_party/pybind11
diff --git a/docs/GetStarted/getting_started_python.md b/docs/GetStarted/getting_started_python.md
index 1a0452b..3aec4e2 100644
--- a/docs/GetStarted/getting_started_python.md
+++ b/docs/GetStarted/getting_started_python.md
@@ -62,8 +62,7 @@
See
[start_colab_kernel.py](https://github.com/google/iree/blob/main/colab/start_colab_kernel.py)
-and
-[Using Colab](https://github.com/google/iree/blob/main/docs/using_colab.md)
+and [Using Colab](https://github.com/google/iree/blob/main/docs/using_colab.md)
for setup instructions, then take a look through the
[Colab directory](https://github.com/google/iree/tree/main/colab) for some
sample notebooks.
diff --git a/experimental/ModelBuilder/test/BUILD b/experimental/ModelBuilder/test/BUILD
index 7f6d34d..ad27616 100644
--- a/experimental/ModelBuilder/test/BUILD
+++ b/experimental/ModelBuilder/test/BUILD
@@ -128,6 +128,7 @@
"//experimental/ModelBuilder:ModelRunner",
"//experimental/ModelBuilder:VulkanLaunchWrapper",
"//iree/base:initializer",
+ "//iree/compiler/Conversion/CodegenUtils",
"//iree/compiler/Conversion/LinalgToSPIRV",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AllPassesAndDialects",
diff --git a/experimental/ModelBuilder/test/TestVectorToGPU.cpp b/experimental/ModelBuilder/test/TestVectorToGPU.cpp
index 9c11fae..7b50f12 100644
--- a/experimental/ModelBuilder/test/TestVectorToGPU.cpp
+++ b/experimental/ModelBuilder/test/TestVectorToGPU.cpp
@@ -48,7 +48,7 @@
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/Passes.h"
-#include "iree/compiler/Conversion/CodegenUtils/MarkerUtils.h"
+#include "iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h"
using namespace mlir; // NOLINT
using namespace mlir::edsc; // NOLINT
diff --git a/integrations/tensorflow/e2e/keras/vision_model_test.py b/integrations/tensorflow/e2e/keras/vision_model_test.py
index 9164d9b..0a1e812 100644
--- a/integrations/tensorflow/e2e/keras/vision_model_test.py
+++ b/integrations/tensorflow/e2e/keras/vision_model_test.py
@@ -136,7 +136,7 @@
np.float32)
input_data = input_data.reshape(input_shape)
self.modules.applications.all.predict(input_data).print().assert_all_close(
- atol=1e-6)
+ atol=3e-5)
if __name__ == '__main__':
diff --git a/iree/compiler/Conversion/CodegenUtils/BUILD b/iree/compiler/Conversion/CodegenUtils/BUILD
index a414146..3696239 100644
--- a/iree/compiler/Conversion/CodegenUtils/BUILD
+++ b/iree/compiler/Conversion/CodegenUtils/BUILD
@@ -23,12 +23,10 @@
name = "CodegenUtils",
srcs = [
"FunctionUtils.cpp",
- "MarkerUtils.cpp",
"MatmulCodegenStrategy.cpp",
],
hdrs = [
"FunctionUtils.h",
- "MarkerUtils.h",
"MatmulCodegenStrategy.h",
],
deps = [
diff --git a/iree/compiler/Conversion/CodegenUtils/CMakeLists.txt b/iree/compiler/Conversion/CodegenUtils/CMakeLists.txt
index 08a7b14..08dfe7c 100644
--- a/iree/compiler/Conversion/CodegenUtils/CMakeLists.txt
+++ b/iree/compiler/Conversion/CodegenUtils/CMakeLists.txt
@@ -19,11 +19,9 @@
CodegenUtils
HDRS
"FunctionUtils.h"
- "MarkerUtils.h"
"MatmulCodegenStrategy.h"
SRCS
"FunctionUtils.cpp"
- "MarkerUtils.cpp"
"MatmulCodegenStrategy.cpp"
DEPS
LLVMSupport
diff --git a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
index 6a54529..01c0cba 100644
--- a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
+++ b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
@@ -22,7 +22,6 @@
#include <cstddef>
-#include "iree/compiler/Conversion/CodegenUtils/MarkerUtils.h"
#include "iree/compiler/Conversion/HLOToLinalg/Passes.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
@@ -612,8 +611,6 @@
cond ? rewriter.create<SelectOp>(loc, cond, inputVal, paddingVal)
: inputVal;
rewriter.create<linalg::YieldOp>(loc, result);
-
- setNoTileMarker(linalgOp);
return success();
}
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/BUILD b/iree/compiler/Conversion/LinalgToSPIRV/BUILD
index d2308b4..70df4d6 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/BUILD
+++ b/iree/compiler/Conversion/LinalgToSPIRV/BUILD
@@ -23,14 +23,18 @@
"ConvertToGPUPass.cpp",
"ConvertToSPIRVPass.cpp",
"LinalgTileAndFusePass.cpp",
+ "MarkerUtils.cpp",
"Passes.cpp",
"SplitDispatchFunctionPass.cpp",
+ "Utils.cpp",
"VectorToGPUPass.cpp",
],
hdrs = [
"Attributes.h",
+ "MarkerUtils.h",
"MemorySpace.h",
"Passes.h",
+ "Utils.h",
],
deps = [
"//iree/compiler/Conversion/CodegenUtils",
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt b/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt
index ccc694c..b8821e2 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt
+++ b/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt
@@ -19,14 +19,18 @@
LinalgToSPIRV
HDRS
"Attributes.h"
+ "MarkerUtils.h"
"MemorySpace.h"
"Passes.h"
+ "Utils.h"
SRCS
"ConvertToGPUPass.cpp"
"ConvertToSPIRVPass.cpp"
"LinalgTileAndFusePass.cpp"
+ "MarkerUtils.cpp"
"Passes.cpp"
"SplitDispatchFunctionPass.cpp"
+ "Utils.cpp"
"VectorToGPUPass.cpp"
DEPS
LLVMSupport
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp
index 8a85b6a..9009441 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp
@@ -17,8 +17,9 @@
// Partition computation within dispatch function to workgroups/workitems.
//
//===----------------------------------------------------------------------===//
-#include "iree/compiler/Conversion/CodegenUtils/MarkerUtils.h"
+#include "iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h"
+#include "iree/compiler/Conversion/LinalgToSPIRV/Utils.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
@@ -36,11 +37,21 @@
namespace mlir {
namespace iree_compiler {
+// TODO(#2134): Remove this flag/set it to false always when issue with
+// convolution is resolved (see bug for more details).
+// TODO(#2346): Make this a pass specific option.
+llvm::cl::opt<bool> useLegacyConvLowering{
+ "iree-codegen-use-legacy-conv-lowering",
+ llvm::cl::desc("Use conv lowering that does not assume 1:1 mapping "
+ "between threads within a block and iterations of "
+ "parallel loops distributed to the block"),
+ llvm::cl::init(true)};
+
//===----------------------------------------------------------------------===//
// Loop utilities
//===----------------------------------------------------------------------===//
-/// Builds an empty loop.for operation. The default builder adds an entry basic
+/// Builds an empty scf.for operation. The default builder adds an entry basic
/// block which needs to be avoided here.
static scf::ForOp buildEmptyForOp(Location loc, OpBuilder &builder, Value lb,
Value ub, Value step) {
@@ -50,6 +61,15 @@
return cast<scf::ForOp>(builder.createOperation(state));
}
+/// Builds an empty scf.if operation without the then and else blocks.
+static scf::IfOp buildEmptyIfOp(Location loc, OpBuilder &builder, Value cond) {
+ OperationState state(loc, scf::IfOp::getOperationName());
+ state.addOperands(cond);
+ state.addRegion();
+ state.addRegion();
+ return cast<scf::IfOp>(builder.createOperation(state));
+}
+
namespace {
struct LoopBounds {
Value lb;
@@ -58,10 +78,10 @@
};
} // namespace
-/// Replaces a loop.parallelOp with an optional loop.parallel op and nested
-/// loop.for operations. To create the loop.parallel op as the outermost loop,
+/// Replaces a scf.parallelOp with an optional scf.parallel op and nested
+/// scf.for operations. To create the scf.parallel op as the outermost loop,
/// pass the lower bound, upper bound and steps in `newPLoopLbs`, `newPLoopUbs`,
-/// and `newPLoopStep` respectively. The bounds of the inner loop.for operations
+/// and `newPLoopStep` respectively. The bounds of the inner scf.for operations
/// to be created are passed in `forLbs`, `forUbs`, and `forStep`. The
/// `permutation` vector contains a mapping from the original loop order, to the
/// loop order to be generated.
@@ -70,21 +90,21 @@
ArrayRef<LoopBounds> newPLoopBounds,
ArrayRef<LoopBounds> forBounds,
ArrayRef<unsigned> permutation) {
- assert(!forBounds.empty() && "unhandled case of no loop.for created");
+ assert(!forBounds.empty() && "unhandled case of no scf.for created");
unsigned numLoops = pLoopOp.getNumLoops();
Location loc = pLoopOp.getLoc();
assert(forBounds.size() + newPLoopBounds.size() == numLoops &&
- "cannot drop loops when splitting loop.parallel operation");
+ "cannot drop loops when splitting scf.parallel operation");
assert(permutation.size() == numLoops);
OpBuilder::InsertionGuard guard(rewriter);
- // Need a signature conversion for the body of the loop.parallel operation,
+ // Need a signature conversion for the body of the scf.parallel operation,
// before can it can be used as the body of the innermost loop created here.
TypeConverter::SignatureConversion signatureConverter(numLoops);
Operation *outermostLoop = nullptr;
auto permuteIt = permutation.begin();
- // Create the loop.parallel operation as the outermost loop, if specified.
+ // Create the scf.parallel operation as the outermost loop, if specified.
if (!newPLoopBounds.empty()) {
auto lbs = llvm::to_vector<2>(llvm::map_range(
newPLoopBounds, [](LoopBounds bounds) -> Value { return bounds.lb; }));
@@ -101,7 +121,7 @@
outermostLoop = newPLoop.getOperation();
}
- // Generate the nested loop.for operations with the bounds passed.
+ // Generate the nested scf.for operations with the bounds passed.
for (auto it : enumerate(forBounds)) {
Value lb = it.value().lb, ub = it.value().ub, step = it.value().step;
if (it.index() != forBounds.size() - 1) {
@@ -110,7 +130,7 @@
signatureConverter.remapInput(*permuteIt, forOp.getInductionVar());
rewriter.setInsertionPointToStart(forOp.getBody());
} else {
- // For the last loop, move the body of the loop.parallel op as the body of
+ // For the last loop, move the body of the scf.parallel op as the body of
// the loop after signature conversion.
auto forOp = buildEmptyForOp(loc, rewriter, lb, ub, step);
if (!outermostLoop) outermostLoop = forOp.getOperation();
@@ -127,8 +147,8 @@
return outermostLoop;
}
-/// Serializes the dimensions of the loop.parallel specified in
-/// `serializedDimensions`, by creating an nested loop.for operation for each
+/// Serializes the dimensions of the scf.parallel specified in
+/// `serializedDimensions`, by creating an nested scf.for operation for each
/// dimension.
// TODO(ravishankarm): Move this into LoopUtils.h in MLIR.
static Operation *serializeDimensions(ConversionPatternRewriter &rewriter,
@@ -141,7 +161,7 @@
serializedDimSet.insert(serializedDimensions.begin(),
serializedDimensions.end());
assert(serializedDimSet.size() == serializedDimensions.size() &&
- "cannot repeat dimensions during serialization of loop.parallel");
+ "cannot repeat dimensions during serialization of scf.parallel");
SmallVector<LoopBounds, 2> newPLoopBounds, forBounds;
SmallVector<unsigned, 2> permutation;
auto lbs = pLoopOp.lowerBound();
@@ -174,16 +194,85 @@
return serializeDimensions(rewriter, pLoopOp, serializedDimensions);
}
+/// Collapses all loops in a scf.parallel into one scf.parallel operation. This
+/// is done by
+/// 1) Normalize the loop bounds to be [0, (ub - lb) / step)
+/// 2) Compute the total number of iterations.
+/// 3) From the induction variable of the modified loop, compute the values of
+/// the original induction variables by de-linearization.
+scf::ParallelOp collapseParallelLoops(ConversionPatternRewriter &rewriter,
+ scf::ParallelOp pLoopOp) {
+ if (pLoopOp.getNumReductions()) return nullptr;
+
+ unsigned numLoops = pLoopOp.getNumLoops();
+ if (numLoops == 1) return pLoopOp;
+
+ // Compute the number of iterations of each loops starting from the innermost.
+ Location loc = pLoopOp.getLoc();
+ Value totalNumIterations = rewriter.create<ConstantIndexOp>(loc, 1);
+
+ // Track the "stride" of each loop, i.e. product of the total number of
+ // iterations of the inner loops.
+ SmallVector<Value, 2> iterationStride;
+ iterationStride.resize(pLoopOp.getNumLoops());
+ auto lbs = pLoopOp.lowerBound();
+ auto ubs = pLoopOp.upperBound();
+ auto steps = pLoopOp.step();
+ for (int i = numLoops - 1; i >= 0; --i) {
+ Value lb = lbs[i], ub = ubs[i], step = steps[i];
+ Value iterCount = rewriter.create<SignedDivIOp>(
+ loc, rewriter.create<SubIOp>(loc, ub, lb), step);
+ iterationStride[i] = totalNumIterations;
+ totalNumIterations =
+ rewriter.create<MulIOp>(loc, totalNumIterations, iterCount);
+ }
+
+ // Create the collapsed parallel loop op with lowerbound 0, step 1 and upper
+ // bound being the totalNumIterations.
+ Value newLb = rewriter.create<ConstantIndexOp>(loc, 0);
+ Value newStep = rewriter.create<ConstantIndexOp>(loc, 1);
+ scf::ParallelOp newPLoopOp =
+ rewriter.create<scf::ParallelOp>(loc, newLb, totalNumIterations, newStep);
+
+ // Build the body of the collapsed loop by cloning the original loop body. The
+ // replacement value of the induction variables of the original loop body,
+ // from the induction variable of the new loop, using
+ // origLoopIv[i] = loopIv / iterationStride[i]
+ // loopIv = loopIv % iterationStride[i]
+ OpBuilder::InsertionGuard guard(rewriter);
+ Block &pLoopBody = pLoopOp.getLoopBody().front();
+ rewriter.setInsertionPointToStart(&newPLoopOp.getLoopBody().front());
+ Value loopIv = *newPLoopOp.getInductionVars().begin();
+ BlockAndValueMapping map;
+ for (int i : llvm::seq<int>(0, numLoops)) {
+ Value iterNum =
+ rewriter.create<SignedDivIOp>(loc, loopIv, iterationStride[i]);
+ Value newIv = rewriter.create<AddIOp>(
+ loc, lbs[i], rewriter.create<MulIOp>(loc, iterNum, steps[i]));
+ map.map(pLoopBody.getArgument(i), newIv);
+ loopIv = rewriter.create<SignedRemIOp>(loc, loopIv, iterationStride[i]);
+ }
+ for (Operation &op : pLoopBody.without_terminator()) {
+ rewriter.clone(op, map);
+ }
+ rewriter.eraseOp(pLoopOp);
+ return newPLoopOp;
+}
+
//===----------------------------------------------------------------------===//
// GPU processor ID mapping utilities
//===----------------------------------------------------------------------===//
-/// Distribute loop.parallel to processors with the processors logically
+/// Distributes scf.parallel to processors with the processors logically
/// arranged with same dimensionality as the number of loops, i.e. a
-/// loop.parallel with 2 loops to a 2D grid of processors. `processorIDs` and
+/// scf.parallel with 2 loops to a 2D grid of processors. `processorIDs` and
/// `numProcessors` must be of same size as the number of loops and are the
/// values to use for process ID and number of processors along each dimension
/// in the distributed code.
+/// This method accounts for the case where the number of processors is not
+/// enough to execute the entire iteration space with one iteration mapped to
+/// each processor. So implements a block-cyclic distribution with each block
+/// size being equal to the number of processors.
static LogicalResult mapToProcessors(ConversionPatternRewriter &rewriter,
scf::ParallelOp pLoopOp,
ArrayRef<Value> processorIDs,
@@ -212,6 +301,43 @@
return success();
}
+/// Distributes scf.parallel to processors with the processors logically
+/// arranged with same dimensionality as the number of loops, i.e. a
+/// scf.parallel with 2 loops to a 2D grid of processors. `processorIDs` must be
+/// of same size as the number of loops and are the values to use for process ID
+/// and number of processors along each dimension in the distributed code. This
+/// method assumes that the number of processors is greater than or equal to the
+/// number of iterations. So just generates an if statement to mask of
+/// processors with no work.
+static LogicalResult mapToProcessorsAndGuard(
+ ConversionPatternRewriter &rewriter, scf::ParallelOp pLoopOp,
+ ArrayRef<Value> processorIDs) {
+ unsigned numLoops = pLoopOp.getNumLoops();
+ Location loc = pLoopOp.getLoc();
+ assert(numLoops == processorIDs.size() &&
+ "expected as many ids as number of loops");
+ Value cond = nullptr;
+ TypeConverter::SignatureConversion signatureConverter(numLoops);
+ auto lbs = pLoopOp.lowerBound();
+ auto step = pLoopOp.step();
+ auto ubs = pLoopOp.upperBound();
+ for (unsigned i : llvm::seq<unsigned>(0, numLoops)) {
+ Value iterValue = rewriter.create<AddIOp>(
+ loc, lbs[i], rewriter.create<MulIOp>(loc, processorIDs[i], step[i]));
+ Value cmp =
+ rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, iterValue, ubs[i]);
+ cond = (cond ? rewriter.create<AndOp>(loc, cond, cmp) : cmp);
+ signatureConverter.remapInput(i, iterValue);
+ }
+ scf::IfOp ifOp = buildEmptyIfOp(loc, rewriter, cond);
+ Region &pLoopOpRegion = pLoopOp.getLoopBody();
+ rewriter.applySignatureConversion(&pLoopOpRegion, signatureConverter);
+ Region &ifOpRegion = ifOp.getRegion(0);
+ rewriter.inlineRegionBefore(pLoopOpRegion, ifOpRegion, ifOpRegion.begin());
+ rewriter.eraseOp(pLoopOp);
+ return success();
+}
+
namespace {
struct ProcessorIdAndCount {
Value id;
@@ -251,7 +377,24 @@
rewriter.create<MulIOp>(loc, blockDim, gridDim)};
}
-/// Distribute loop.parallel to processors where `IdOp` is used to get the
+template <typename GPUIdOp, typename GPUCountOp>
+static void getGPUProcessorIdsAndCounts(Location loc,
+ ConversionPatternRewriter &rewriter,
+ unsigned numDims,
+ MutableArrayRef<Value> id,
+ MutableArrayRef<Value> count) {
+ ArrayRef<StringRef> dims = {"x", "y", "z"};
+ assert(id.size() == numDims);
+ assert(count.size() == numDims);
+ for (unsigned i = 0; i < numDims; ++i) {
+ ProcessorIdAndCount idAndCount =
+ getGPUProcessorIdAndCount<GPUIdOp, GPUCountOp>(loc, dims[i], rewriter);
+ id[numDims - 1 - i] = idAndCount.id;
+ count[numDims - 1 - i] = idAndCount.count;
+ }
+}
+
+/// Distributes scf.parallel to processors where `IdOp` is used to get the
/// processor ID and `DimOp` is used to get the number of processors along a
/// dimension.
template <typename GPUIdOp, typename GPUCountOp>
@@ -263,38 +406,51 @@
cast<scf::ParallelOp>(serializeDimensionsFrom(rewriter, pLoopOp, 3));
numLoops = 3;
}
- SmallVector<Value, 2> id, count;
- id.reserve(numLoops);
- count.reserve(numLoops);
- ArrayRef<StringRef> dims = {"x", "y", "z"};
- Location loc = pLoopOp.getLoc();
- for (unsigned i = 0; i < numLoops; ++i) {
- ProcessorIdAndCount idAndCount =
- getGPUProcessorIdAndCount<GPUIdOp, GPUCountOp>(loc, dims[i], rewriter);
- id.insert(id.begin(), idAndCount.id);
- count.insert(count.begin(), idAndCount.count);
- }
+ SmallVector<Value, 2> id(numLoops), count(numLoops);
+ getGPUProcessorIdsAndCounts<GPUIdOp, GPUCountOp>(pLoopOp.getLoc(), rewriter,
+ numLoops, id, count);
return mapToProcessors(rewriter, pLoopOp, id, count);
}
-/// Distribute the loop.parallel to workgroups.
+/// Distributes scf.parallel to processors where `IdOp` is used to get the
+/// processor ID and `DimOp` is used to get the number of processors along a
+/// dimension. Assumes that the number of processors will be less than equal to
+/// the number of iterations of the pLoopOp along all dimensions.
+template <typename GPUIdOp, typename GPUCountOp>
+static LogicalResult mapToProcessorsAndGuard(
+ ConversionPatternRewriter &rewriter, scf::ParallelOp pLoopOp) {
+ unsigned numLoops = pLoopOp.getNumLoops();
+ if (numLoops > 3) {
+ pLoopOp =
+ cast<scf::ParallelOp>(serializeDimensionsFrom(rewriter, pLoopOp, 3));
+ numLoops = 3;
+ }
+ SmallVector<Value, 2> id(numLoops), count(numLoops);
+ getGPUProcessorIdsAndCounts<GPUIdOp, GPUCountOp>(pLoopOp.getLoc(), rewriter,
+ numLoops, id, count);
+ return mapToProcessorsAndGuard(rewriter, pLoopOp, id);
+}
+
+/// Distribute the scf.parallel to workgroups.
static LogicalResult mapToWorkgroups(ConversionPatternRewriter &rewriter,
scf::ParallelOp pLoopOp) {
return mapToProcessor<gpu::BlockIdOp, gpu::GridDimOp>(rewriter, pLoopOp);
}
-/// Distribute loop.parallel to workitems using local invocation ID.
+/// Distributes scf.parallel to workitems using local invocation ID.
static LogicalResult mapToLocalInvocationId(ConversionPatternRewriter &rewriter,
scf::ParallelOp pLoopOp) {
- return mapToProcessor<gpu::ThreadIdOp, gpu::BlockDimOp>(rewriter, pLoopOp);
+ return mapToProcessorsAndGuard<gpu::ThreadIdOp, gpu::BlockDimOp>(rewriter,
+ pLoopOp);
}
-/// Distribute loop.parallel to workitems using global invocation ID. The GPU
+/// Distributes scf.parallel to workitems using global invocation ID. The GPU
/// dialect doesn't have a direct operation to do this. This could be done using
/// id = blockIdx * blockDim + gridIdx. count = blockDim * gridDim.
static LogicalResult mapToGlobalInvocationId(
ConversionPatternRewriter &rewriter, scf::ParallelOp pLoopOp) {
- return mapToProcessor<GPUGlobalId, GPUGlobalCount>(rewriter, pLoopOp);
+ return mapToProcessorsAndGuard<GPUGlobalId, GPUGlobalCount>(rewriter,
+ pLoopOp);
}
//===----------------------------------------------------------------------===//
@@ -304,10 +460,13 @@
namespace {
/// Pass to convert from tiled and fused linalg ops into gpu.func.
struct ConvertToGPUPass : public PassWrapper<ConvertToGPUPass, FunctionPass> {
+ ConvertToGPUPass() = default;
+ ConvertToGPUPass(const ConvertToGPUPass &pass) {}
+
void runOnFunction() override;
};
-/// Pattern to map loop.parallel to workgroups.
+/// Pattern to map scf.parallel to workgroups.
struct PartitionPLoopToWorkgroups
: public OpConversionPattern<scf::ParallelOp> {
using OpConversionPattern<scf::ParallelOp>::OpConversionPattern;
@@ -318,7 +477,7 @@
}
};
-/// Map tiled linalg op to workitems by lowering it to loop.parallel and
+/// Map tiled linalg op to workitems by lowering it to scf.parallel and
/// partitioning it to workitems.
template <typename LinalgOpTy>
struct MapLinalgOpToLocalInvocationId : public OpConversionPattern<LinalgOpTy> {
@@ -342,6 +501,29 @@
}
};
+/// Legacy path for lowering tiled conv/pooling op to loops.
+// TODO(#2134): Remove this pattern. The default path of using
+// `MapLinalgOpToLocalInvocationId` seems to have a bug. It only shows up
+// currently on Resnet50. Remove this pattern after the bug is triaged/fixed.
+template <typename LinalgOpTy>
+struct MapConvPoolToLocalInvocationId : public OpConversionPattern<LinalgOpTy> {
+ using OpConversionPattern<LinalgOpTy>::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ LinalgOpTy linalgOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ if (!hasWorkItemMarker(linalgOp)) return failure();
+ Optional<linalg::LinalgLoops> loops =
+ linalg::linalgLowerOpToLoops<scf::ParallelOp>(rewriter, linalgOp);
+ if (!loops) return failure();
+ scf::ParallelOp pLoopOp = cast<scf::ParallelOp>(loops.getValue()[0]);
+ if (failed(mapToProcessor<gpu::ThreadIdOp, gpu::BlockDimOp>(rewriter,
+ pLoopOp)))
+ return failure();
+ rewriter.eraseOp(linalgOp);
+ return success();
+ }
+};
+
/// Map linalg operation to execute on GPU in parallel by mapping the parallel
/// loops to "GlobalInvocationId".
template <typename LinalgOpTy>
@@ -351,19 +533,29 @@
LogicalResult matchAndRewrite(
LinalgOpTy linalgOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- // If marker exists and its not no-tile, do nothing.
- if (hasMarker(linalgOp) && !hasNoTileMarker(linalgOp)) return failure();
+ // If marker exists do nothing.
+ if (hasMarker(linalgOp)) return failure();
Optional<linalg::LinalgLoops> loops =
linalg::linalgLowerOpToLoops<scf::ParallelOp>(rewriter, linalgOp);
if (!loops) return failure();
+
+ SmallVector<int64_t, 3> workgroupSize(3, 1);
if (!loops.getValue().empty()) {
scf::ParallelOp pLoopOp = dyn_cast<scf::ParallelOp>(loops.getValue()[0]);
// If there are parallel loops partition them to threads using global
// invocation ID.
- if (pLoopOp && failed(mapToGlobalInvocationId(rewriter, pLoopOp)))
- return failure();
+ if (pLoopOp) {
+ pLoopOp = collapseParallelLoops(rewriter, pLoopOp);
+ if (!pLoopOp) return failure();
+ if (failed(mapToGlobalInvocationId(rewriter, pLoopOp)))
+ return rewriter.notifyMatchFailure(
+ linalgOp, "mapping to GlobalInvocationID failed");
+ workgroupSize = {32, 1, 1};
+ }
}
rewriter.eraseOp(linalgOp);
+ FuncOp funcOp = linalgOp.template getParentOfType<FuncOp>();
+ if (funcOp) updateWorkGroupSize(funcOp, workgroupSize);
return success();
}
};
@@ -392,7 +584,7 @@
MLIRContext *context = &getContext();
ConversionTarget target(*context);
- // After this pass Linalg and loop.parallel ops should be gone.
+ // After this pass Linalg and scf.parallel ops should be gone.
target.addIllegalOp<scf::ParallelOp>();
target.addIllegalDialect<linalg::LinalgDialect>();
// Reshape ops are treated legal since they just change the way the underlying
@@ -404,27 +596,40 @@
OwningRewritePatternList patterns;
- patterns.insert<
// clang-format off
-#define ADD_ALL_LINALG_PATTERNS(OP_NAME) \
- MapLinalgOpToGlobalInvocationId<OP_NAME>, \
- MapLinalgOpToLocalInvocationId<OP_NAME>
+ patterns.insert<
- ADD_ALL_LINALG_PATTERNS(linalg::ConvOp),
+#define ADD_ALL_LINALG_PATTERNS(OP_NAME) \
+ MapLinalgOpToGlobalInvocationId<OP_NAME>, \
+ MapLinalgOpToLocalInvocationId<OP_NAME>
+
ADD_ALL_LINALG_PATTERNS(linalg::CopyOp),
ADD_ALL_LINALG_PATTERNS(linalg::FillOp),
ADD_ALL_LINALG_PATTERNS(linalg::GenericOp),
ADD_ALL_LINALG_PATTERNS(linalg::IndexedGenericOp),
- ADD_ALL_LINALG_PATTERNS(linalg::MatmulOp),
- ADD_ALL_LINALG_PATTERNS(linalg::PoolingMaxOp),
- ADD_ALL_LINALG_PATTERNS(linalg::PoolingMinOp),
- ADD_ALL_LINALG_PATTERNS(linalg::PoolingSumOp),
#undef ADD_ALL_LINALG_PATTERNS
+#define ADD_ALL_CONV_POOL_PATTERNS(OP_NAME) \
+ MapConvPoolToLocalInvocationId<OP_NAME>, \
+ MapLinalgOpToGlobalInvocationId<OP_NAME>
+
+ ADD_ALL_CONV_POOL_PATTERNS(linalg::PoolingMaxOp),
+ ADD_ALL_CONV_POOL_PATTERNS(linalg::PoolingMinOp),
+ ADD_ALL_CONV_POOL_PATTERNS(linalg::PoolingSumOp),
+
+#undef ADD_ALL_CONV_POOL_PATTERNS
+
+ MapLinalgOpToLocalInvocationId<linalg::MatmulOp>,
PartitionPLoopToWorkgroups, RemoveLinalgRange>(context);
// clang-format on
+ patterns.insert<MapLinalgOpToGlobalInvocationId<linalg::ConvOp>>(context);
+ if (useLegacyConvLowering)
+ patterns.insert<MapConvPoolToLocalInvocationId<linalg::ConvOp>>(context);
+ else
+ patterns.insert<MapLinalgOpToLocalInvocationId<linalg::ConvOp>>(context);
+
if (failed(applyFullConversion(funcOp, target, patterns)))
return signalPassFailure();
}
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp
index aeb0996..399b3f5 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp
@@ -21,7 +21,7 @@
//
//===----------------------------------------------------------------------===//
-#include "iree/compiler/Conversion/CodegenUtils/MarkerUtils.h"
+#include "iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
#include "llvm/ADT/STLExtras.h"
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
index 48e1d57..d50cb40 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
@@ -17,9 +17,10 @@
// Implements a pass to tile and fuse linalg operations on buffers.
//
//===----------------------------------------------------------------------===//
-#include "iree/compiler/Conversion/CodegenUtils/MarkerUtils.h"
+#include "iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/MemorySpace.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h"
+#include "iree/compiler/Conversion/LinalgToSPIRV/Utils.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
@@ -42,8 +43,6 @@
// Utility functions
//===----------------------------------------------------------------------===//
-static constexpr unsigned kMaxWorkgroupRank = 3;
-
static ArrayRef<int64_t> dropTrailingOnes(ArrayRef<int64_t> vector) {
if (vector.empty()) return vector;
auto numTrailingOnes = 0;
@@ -56,60 +55,14 @@
return vector.drop_back(numTrailingOnes);
}
-/// Returns the number of "outer" parallel loops specified in the `linalgOp`.
-static unsigned getNumOuterParallelLoops(linalg::LinalgOp linalgOp) {
- if (auto convOp = dyn_cast<linalg::ConvOp>(linalgOp.getOperation())) {
- Optional<DenseIntElementsAttr> padding = convOp.padding();
- if (padding) return convOp.getNumBatchDimensions();
- }
- return linalgOp.iterator_types()
- .getValue()
- .take_while([](Attribute attr) {
- return attr.cast<StringAttr>().getValue() ==
- getParallelIteratorTypeName();
- })
- .size();
-}
-
-/// Updates the workgroup size used for the dispatch region.
-static LogicalResult updateWorkGroupSize(FuncOp funcOp,
- ArrayRef<int64_t> workGroupSize) {
- // Need to update both the surrounding FuncOp that has the spv.entry_point_abi
- // attribute, and the hal.executable.
- Region &body = funcOp.getBody();
- if (!llvm::hasSingleElement(body))
- return funcOp.emitError("unhandled dispatch function with multiple blocks");
-
- SmallVector<int32_t, 3> workGroupSizeVec = llvm::to_vector<3>(llvm::map_range(
- workGroupSize, [](int64_t v) { return static_cast<int32_t>(v); }));
-
- // TODO(ravishankarm, antiagainst): We should have at most one scf.parallel
- // op, but that is not the case till the splitting of kernels lands.
- unsigned numParallelLoops = 0;
- auto updateNumParallelLoops = [&numParallelLoops](unsigned nPar) {
- numParallelLoops =
- (!numParallelLoops ? nPar : std::min(numParallelLoops, nPar));
- };
- for (auto parallelLoop : body.front().getOps<scf::ParallelOp>()) {
- updateNumParallelLoops(parallelLoop.getNumLoops());
- }
- // If there are no parallel loops, there might be linalg ops that arent
- // tiled. Use that to get the number of parallel loops.
- for (auto linalgOp : body.front().getOps<linalg::LinalgOp>()) {
- updateNumParallelLoops(getNumOuterParallelLoops(linalgOp));
- }
- workGroupSizeVec.resize(numParallelLoops);
- LLVM_DEBUG({
- llvm::dbgs() << "--- IREE Linalg tile and fuse configuration ---\n";
- llvm::dbgs() << "# workgroup sizes at end: [";
- interleaveComma(workGroupSizeVec, llvm::dbgs());
- llvm::dbgs() << "]\n";
- });
- MLIRContext *context = funcOp.getContext();
- workGroupSizeVec.resize(3, 1);
- funcOp.setAttr(spirv::getEntryPointABIAttrName(),
- spirv::getEntryPointABIAttr(workGroupSizeVec, context));
- return success();
+/// Returns true if the linalg op has padding attribute, and that it has
+/// non-zero entries.
+template <typename OpTy>
+static bool hasPadding(OpTy op) {
+ Optional<DenseIntElementsAttr> padding = op.padding();
+ if (!padding) return false;
+ return llvm::any_of(padding.getValue(),
+ [](APInt v) -> bool { return !v.isNullValue(); });
}
namespace {
@@ -119,7 +72,13 @@
class TileSizeCalculator {
public:
TileSizeCalculator(FuncOp funcOp)
- : resourceLimits(spirv::lookupTargetEnv(funcOp).getResourceLimits()) {}
+ : resourceLimits(spirv::lookupTargetEnv(funcOp).getResourceLimits()) {
+ if (DenseIntElementsAttr attr = spirv::lookupLocalWorkGroupSize(funcOp)) {
+ for (auto val : attr.getValues<APInt>())
+ workgroupSize.push_back(val.getSExtValue());
+ }
+ workgroupSize.resize(3, 1);
+ }
/// Compute the tile sizes based on workgroup size specified.
LogicalResult setTileSizesBasedOnWorkgroupSize(
@@ -139,21 +98,10 @@
/// Get the current tile size computed.
ArrayRef<int64_t> getTileSizes() const { return tileSizes; }
- /// Linalg convention is to use 0 for no tiling. If any of the tile dimensions
- /// is set to 1 make it 0.
- SmallVector<int64_t, 3> getTileSizesForLinalg() const {
- return llvm::to_vector<3>(llvm::map_range(
- tileSizes, [](int64_t v) -> int64_t { return v == 1 ? 0 : v; }));
- }
-
/// Returns the workgroup size to use based on the tile sizes.
ArrayRef<int64_t> getWorkGroupSize() const { return workgroupSize; }
private:
- /// Get the default tile sizes based on just number of dimensions, i.e., "x",
- /// "y", and "z".
- void setTileSizesBasedOnDimensions(unsigned numDims);
-
/// Current tile size configuration.
SmallVector<int64_t, 4> tileSizes;
@@ -165,67 +113,72 @@
};
} // namespace
-void TileSizeCalculator::setTileSizesBasedOnDimensions(unsigned numDims) {
- tileSizes.clear();
- workgroupSize.clear();
- tileSizes.reserve(3);
- if (numDims == 0) {
- // Scalar case.
- workgroupSize = {1, 1, 1};
- return;
- }
- unsigned maxWorkGroupSize =
- resourceLimits.max_compute_workgroup_invocations().getInt();
-
- // Make the tile size 32 along the x-dimension, and then split the remaining
- // maxWorkGroupSize threads amongst the y-dimension or z-dimension.
- unsigned tileSizeX = llvm::PowerOf2Floor(std::min(maxWorkGroupSize, 32u));
- maxWorkGroupSize /= tileSizeX;
- if (numDims == 1) {
- tileSizes = {tileSizeX};
- workgroupSize = {tileSizeX, 1, 1};
- return;
- }
- if (numDims == 2) {
- unsigned tileSizeY = llvm::PowerOf2Floor(maxWorkGroupSize);
- tileSizes = {tileSizeY, tileSizeX};
- workgroupSize = {tileSizeX, tileSizeY, 1};
- return;
- }
- unsigned tileSizeYZ =
- llvm::PowerOf2Floor(static_cast<unsigned>(std::sqrt(maxWorkGroupSize)));
- tileSizes = {tileSizeYZ, tileSizeYZ, tileSizeX};
- workgroupSize = {tileSizeX, tileSizeYZ, tileSizeYZ};
-}
-
LogicalResult TileSizeCalculator::setTileSizesBasedOnOps(
ArrayRef<linalg::LinalgOp> linalgOps) {
tileSizes.clear();
+ if (linalgOps.empty()) {
+ tileSizes = {1, 1, 1};
+ workgroupSize = {1, 1, 1};
+ return success();
+ }
// The tile size will be driven by operations like matmul, conv, etc. within
// the list. So see what operation exists in the list to decide the tile size.
// If there are two such operations in the list, return error.
- bool hasMatmul = false;
- unsigned numParallelLoops = kMaxWorkgroupRank;
- for (linalg::LinalgOp op : linalgOps) {
- // If there is no marker on this op (i.e. a marker to prevent tile), add an
- // explicit marker to indicate that the op is to be tiled. Makes subsequent
- // lowering simpler.
- if (isa<linalg::MatmulOp>(op.getOperation())) {
- if (hasMatmul)
- return op.emitError(
- "unhandled multiple matmuls within dispatch region");
- hasMatmul = true;
- }
- numParallelLoops = std::min(numParallelLoops, getNumOuterParallelLoops(op));
+ enum OpInfo : uint32_t {
+ None = 0x0,
+ Convolution = 0x1,
+ Matmul = 0x2,
+ Pooling = 0x4,
+ };
+ uint32_t opInfo = OpInfo::None;
+ for (linalg::LinalgOp linalgOp : linalgOps) {
+ Operation *op = linalgOp.getOperation();
+ if (isa<linalg::ConvOp>(op)) opInfo |= OpInfo::Convolution;
+ if (isa<linalg::MatmulOp>(op)) opInfo |= OpInfo::Matmul;
+ if (isa<linalg::PoolingMaxOp>(op)) opInfo |= OpInfo::Pooling;
+ if (isa<linalg::PoolingMinOp>(op)) opInfo |= OpInfo::Pooling;
+ if (isa<linalg::PoolingSumOp>(op)) opInfo |= OpInfo::Pooling;
}
- if (hasMatmul) {
+ // If there are no tilable ops, there is nothing to do here.
+ if (!opInfo) return success();
+
+ Operation *linalgOp = *(linalgOps.begin());
+ if (llvm::countPopulation(opInfo) != 1)
+ return linalgOp->getParentOfType<FuncOp>().emitError(
+ "unhandled fusion of ops in dispatch function");
+
+ // TODO(ravishanarm, antiagainst): Only the maximum workgroup size is used
+ // here for computing tile sizes. In reality we also need the maximum
+ // workgroup memory size available (per workgroup) to compute the tile sizes
+ // effectively.
+ unsigned maxWorkgroupSize =
+ resourceLimits.max_compute_workgroup_invocations().getInt();
+ if (opInfo & OpInfo::Convolution) {
+ // TODO(ravishankarm): This tiling is meant to enable promotion to workgroup
+ // memory, but doesnt actually get us to a state where we can do this. The
+ // promotion is possible only when the subviews created are constant
+ // size. For now this doesnt really matter. Revisit this later.
+ int64_t tileSizeX = 32;
+ int64_t tileSizeY = maxWorkgroupSize / 32;
+ tileSizes = {1, tileSizeY, tileSizeX};
+ workgroupSize = {tileSizeX, tileSizeY, 1};
+ return success();
+ }
+ if (opInfo & OpInfo::Matmul) {
// TODO: For now just hard wire this, but we can do better.
tileSizes = {8, 8, 4};
workgroupSize = {8, 8, 1};
return success();
}
- setTileSizesBasedOnDimensions(numParallelLoops);
- return success();
+ if (opInfo & OpInfo::Pooling) {
+ int64_t tileSizeX = 32;
+ int64_t tileSizeY = maxWorkgroupSize / 32;
+ tileSizes = {tileSizeY, tileSizeX};
+ workgroupSize = {tileSizeX, tileSizeY, 1};
+ return success();
+ }
+ return linalgOp->getParentOfType<FuncOp>().emitError(
+ "unable to find tile size for ops in this dispatch function");
}
//===----------------------------------------------------------------------===//
@@ -294,22 +247,41 @@
SmallVector<int64_t, 3> workGroupSize;
};
-/// Pattern to tile linalg operations if they have the workgroup marker.
-template <typename LinalgOp>
-struct TileLinalgOpPattern : public linalg::LinalgTilingPattern<LinalgOp> {
- using linalg::LinalgTilingPattern<LinalgOp>::LinalgTilingPattern;
+/// Pattern for tiling operations. Updates the workgroup size in the surrounding
+/// function operation if tiling succeeds.
+template <typename OpTy>
+struct TilingPattern : public linalg::LinalgTilingPattern<OpTy> {
+ using Base = linalg::LinalgTilingPattern<OpTy>;
+ TilingPattern(MLIRContext *context, linalg::LinalgTilingOptions options,
+ ArrayRef<int64_t> workgroupSize,
+ linalg::LinalgMarker marker = linalg::LinalgMarker(),
+ PatternBenefit benefit = 1)
+ : Base(context, options, marker, benefit),
+ workgroupSize(workgroupSize.begin(), workgroupSize.end()) {}
+
+ virtual LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const {
+ // Find the parent FuncOp before tiling. If tiling succeeds, the op will be
+ // erased.
+ FuncOp funcOp = op->getParentOfType<FuncOp>();
+ return failure(!funcOp || failed(Base::matchAndRewrite(op, rewriter)) ||
+ failed(updateWorkGroupSize(funcOp, workgroupSize)));
+ }
+
+ SmallVector<int64_t, 3> workgroupSize;
+};
+
+/// Pattern for tiling convolution and pooling operations. Currently is just a
+/// way to not tile when the operation has padding.
+template <typename OpTy>
+struct TileConvPoolPattern : public TilingPattern<OpTy> {
+ using Base = TilingPattern<OpTy>;
+ using Base::TilingPattern;
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
- if (!hasWorkGroupMarker(op)) return failure();
- if (succeeded(linalg::LinalgTilingPattern<LinalgOp>::matchAndRewrite(
- op, rewriter)))
- return success();
- // Update the marker to map to global invocation ID.
- rewriter.startRootUpdate(op);
- setNoTileMarker(op);
- rewriter.finalizeRootUpdate(op);
- return success();
+ if (hasPadding(cast<OpTy>(op))) return failure();
+ return Base::matchAndRewrite(op, rewriter);
}
};
@@ -348,14 +320,7 @@
auto linalgOps = block.getOps<linalg::LinalgOp>();
if (linalgOps.empty()) return;
- // Go through all the Linalg ops and set the marker to trigger tiling./
- // TODO(ravishankarm): Move this to HLOToLinalgOnBuffers so that it is added
- // on op-creation.
- for (auto op : linalgOps)
- if (!hasMarker(op)) setWorkGroupMarker(op);
-
TileSizeCalculator tileSizeCalculator(funcOp);
-
if (workGroupSize.empty()) {
// Get the tile sizes to use for the lowering.
SmallVector<int64_t, 3> tileSizes;
@@ -376,20 +341,17 @@
});
OwningRewritePatternList tilingPatterns;
- tilingPatterns.insert<TileLinalgOpPattern<linalg::ConvOp>,
- TileLinalgOpPattern<linalg::CopyOp>,
- TileLinalgOpPattern<linalg::FillOp>,
- TileLinalgOpPattern<linalg::GenericOp>,
- TileLinalgOpPattern<linalg::IndexedGenericOp>,
- TileLinalgOpPattern<linalg::MatmulOp>,
- TileLinalgOpPattern<linalg::PoolingMaxOp>,
- TileLinalgOpPattern<linalg::PoolingMinOp>,
- TileLinalgOpPattern<linalg::PoolingSumOp>>(
+ tilingPatterns.insert<TileConvPoolPattern<linalg::ConvOp>,
+ TilingPattern<linalg::MatmulOp>,
+ TileConvPoolPattern<linalg::PoolingMaxOp>,
+ TileConvPoolPattern<linalg::PoolingMinOp>,
+ TileConvPoolPattern<linalg::PoolingSumOp>>(
context,
linalg::LinalgTilingOptions()
- .setTileSizes(tileSizeCalculator.getTileSizesForLinalg())
+ .setTileSizes(tileSizeCalculator.getTileSizes())
.setLoopType(linalg::LinalgTilingLoopType::ParallelLoops),
- linalg::LinalgMarker(Identifier::get(getWorkGroupMarker(), context),
+ tileSizeCalculator.getWorkGroupSize(),
+ linalg::LinalgMarker(ArrayRef<Identifier>(),
Identifier::get(getWorkItemMarker(), context)));
applyPatternsAndFoldGreedily(getOperation(), tilingPatterns);
@@ -423,15 +385,6 @@
insertBarrierAfter(builder, linalgOp.getLoc(), linalgOp);
}
});
-
- // Update the workgroup size to be consistent with the tile sizes used. Note
- // the tile sizes are ordered from outer most to inner most loops. The
- // heuristic is to map the inner loops to x, the next outer (if it exists) to
- // y, and the next outer (if it exists) to z. So tile sizes are reversed to
- // get the workgroup size.
- if (failed(
- updateWorkGroupSize(funcOp, tileSizeCalculator.getWorkGroupSize())))
- return signalPassFailure();
}
//===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Conversion/CodegenUtils/MarkerUtils.cpp b/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.cpp
similarity index 87%
rename from iree/compiler/Conversion/CodegenUtils/MarkerUtils.cpp
rename to iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.cpp
index cf641e1..a285236 100644
--- a/iree/compiler/Conversion/CodegenUtils/MarkerUtils.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.cpp
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "iree/compiler/Conversion/CodegenUtils/MarkerUtils.h"
+#include "iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/IR/Attributes.h"
@@ -31,11 +31,9 @@
static bool checkMarkerValue(Operation *op, StringRef marker = "") {
StringAttr attr = op->getAttrOfType<StringAttr>(
linalg::LinalgTransforms::kLinalgTransformMarker);
- return attr && (marker == "" || attr.getValue() == marker);
+ return attr && (marker.empty() || attr.getValue() == marker);
}
-StringRef getNoTileMarker() { return "no-tile"; }
-
StringRef getWorkGroupMarker() { return "workgroup"; }
StringRef getWorkItemMarker() { return "workitem"; }
@@ -46,10 +44,6 @@
return checkMarkerValue(op, marker);
}
-bool hasNoTileMarker(Operation *op) {
- return checkMarkerValue(op, getNoTileMarker());
-}
-
bool hasWorkGroupMarker(Operation *op) {
return checkMarkerValue(op, getWorkGroupMarker());
}
@@ -69,8 +63,6 @@
StringAttr::get(marker, op->getContext()));
}
-void setNoTileMarker(Operation *op) { setMarker(op, getNoTileMarker()); }
-
void setCooperativeMatrixMarker(Operation *op) {
op->setAttr(VectorTransforms::kVectorTransformMarker,
StringAttr::get(getCooperativeMatrixMarker(), op->getContext()));
diff --git a/iree/compiler/Conversion/CodegenUtils/MarkerUtils.h b/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h
similarity index 70%
rename from iree/compiler/Conversion/CodegenUtils/MarkerUtils.h
rename to iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h
index fa14263..633bca0 100644
--- a/iree/compiler/Conversion/CodegenUtils/MarkerUtils.h
+++ b/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h
@@ -19,8 +19,8 @@
//
//===----------------------------------------------------------------------===//
-#ifndef IREE_COMPILER_CONVERSION_CODEGENUTILS_MARKERUTILS_H_
-#define IREE_COMPILER_CONVERSION_CODEGENUTILS_MARKERUTILS_H_
+#ifndef IREE_COMPILER_CONVERSION_LINALGTOSPIRV_MARKERUTILS_H_
+#define IREE_COMPILER_CONVERSION_LINALGTOSPIRV_MARKERUTILS_H_
#include "llvm/ADT/StringRef.h"
#include "mlir/Support/LLVM.h"
@@ -30,55 +30,34 @@
class Operation;
namespace iree_compiler {
-/// Marker to denote that do not tile the linalg operation.
-StringRef getNoTileMarker();
-
-/// Marker to denote that a linalg operation is to be partitioned to workgroups.
-StringRef getWorkGroupMarker();
-
/// Marker to denote that a linalg operation is to be partitioned to workitems.
StringRef getWorkItemMarker();
-/// Returns true if an operation has the specified `marker`. When `marker` is
-/// empty, returns true if the operation has any marker.
-bool hasMarker(Operation *, StringRef marker = "");
-
-/// Returns true if an operation has marker to denote that it is not to be
-/// tiled.
-bool hasNoTileMarker(Operation *);
-
-/// Returns true if an operation has marker to denote that it is to be
-/// partitioned to workgroups.
-bool hasWorkGroupMarker(Operation *);
-
-/// Returns true if an operation has marker to denote that it is to be
-/// partitioned to workitems.
-bool hasWorkItemMarker(Operation *);
-
/// Returns true if an operation has a marker to denote that it will be mapped
/// to cooperative matrix operations. Markers need to be consistent as
/// cooperative matrices have their own type and load/store operations.
bool hasCooperativeMatrixMarker(Operation *);
-/// Sets a given marker on an operation.
-void setMarker(Operation *, StringRef);
+/// Returns true if an operation has the specified `marker`. When `marker` is
+/// empty, returns true if the operation has any marker.
+bool hasMarker(Operation *, StringRef marker = "");
-/// Sets marker to prevent tiling of a linalg operation.
-void setNoTileMarker(Operation *);
-
-/// Sets marker to denote that a linalg operation is to be partitioned to
-/// workgroups.
-void setWorkGroupMarker(Operation *);
-
-/// Sets marker to denote that a linalg operation is to be partitioned to
-/// workitems.
-void setWorkItemMarker(Operation *);
+/// Returns true if an operation has marker to denote that it is to be
+/// partitioned to workitems.
+bool hasWorkItemMarker(Operation *);
/// Sets marker to denote that a vector operation is to be execute on a
/// cooperative matrix.
void setCooperativeMatrixMarker(Operation *);
+/// Sets a given marker on an operation.
+void setMarker(Operation *, StringRef);
+
+/// Sets marker to denote that a linalg operation is to be partitioned to
+/// workitems.
+void setWorkItemMarker(Operation *);
+
} // namespace iree_compiler
} // namespace mlir
-#endif // IREE_COMPILER_CONVERSION_CODEGENUTILS_MARKERUTILS_H_
+#endif // IREE_COMPILER_CONVERSION_LINALGTOSPIRV_MARKERUTILS_H_
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp
index 4ca5f2f..b1bba5e 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp
@@ -15,11 +15,11 @@
//===- SplitDispathFunctionPass.cpp ---------------------------------------===//
//
// This file implements a pass to split computation workload to multiple
-// sequential dispatch functions. This pass operates on Linalg ops and prepares
-// for lowering to GPU, where we need to tile the workload to workgroups and
-// workitems. If the workload involves computation A and B, where B is
-// dependent on A and A needs all workgroups to complete, then we need
-// to split A and B into different kernels because there is no mechanism
+// sequential dispatch functions. This pass operates on Linalg ops and
+// scf.parallel op and prepares for lowering to GPU, where we need to tile the
+// workload to workgroups and workitems. If the workload involves computation A
+// and B, where B is dependent on A and A needs all workgroups to complete, then
+// we need to split A and B into different kernels because there is no mechanism
// to perform cross-workgroup synchronization within a single kernel.
//
//===----------------------------------------------------------------------===//
@@ -35,6 +35,7 @@
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
@@ -51,24 +52,20 @@
namespace {
-/// Returns true if the given `block` contains 0 or 1 Linalg structured ops.
-bool hasZeroOrOneLinalgOp(Block &block) {
- auto ops = block.getOps<linalg::LinalgOp>();
- return std::distance(ops.begin(), ops.end()) <= 1;
-}
-
/// Returns true if the Linalg ops can be separated to multiple kernels.
-bool canSeparateLinalgOps(MutableArrayRef<linalg::LinalgOp> linalgOps) {
- if (llvm::any_of(linalgOps, [](linalg::LinalgOp op) {
- return !op.hasBufferSemantics();
+bool canSeparateOps(ArrayRef<Operation *> ops) {
+ if (llvm::any_of(ops, [](Operation *op) {
+ if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op))
+ return !linalgOp.hasBufferSemantics();
+ return false;
}))
return false;
// Require no other ops interleave with Linalg structured ops for now. This is
// the common case and it simplifies further analysis.
- for (int i = 0, e = linalgOps.size() - 1; i < e; ++i) {
- if (linalgOps[i].getOperation()->getNextNode() != linalgOps[i + 1])
- return false;
+ for (auto currOp = ops.begin(), nextOp = std::next(ops.begin());
+ nextOp != ops.end(); ++currOp, ++nextOp) {
+ if ((*currOp)->getNextNode() != *nextOp) return false;
}
return true;
@@ -144,15 +141,20 @@
return oldFn.emitError("expected only one block");
}
- // The dispatch function should have more than one Linalg structured ops.
- // Otherwise there is nothing to do.
- if (hasZeroOrOneLinalgOp(oldFn.getBlocks().front())) return success();
+ // The dispatch function should have more than one separable ops. Otherwise
+ // there is nothing to do.
+ Block &fnBody = oldFn.getBlocks().front();
- // Collect all Linalg ops for distributing.
- SmallVector<linalg::LinalgOp, 4> linalgOps =
- llvm::to_vector<4>(oldFn.getBlocks().front().getOps<linalg::LinalgOp>());
- if (!canSeparateLinalgOps(linalgOps)) {
- return oldFn.emitError("cannot separate Linalg ops into multiple kernels");
+ // Collect all Linalg and scf.parallel ops for distributing.
+ SmallVector<Operation *, 4> separableOps;
+ for (Operation &op : fnBody)
+ if (isa<linalg::LinalgOp>(op) || isa<scf::ParallelOp>(op))
+ separableOps.push_back(&op);
+
+ if (separableOps.size() <= 1) return success();
+ if (!canSeparateOps(separableOps)) {
+ return oldFn.emitError(
+ "cannot separate Linalg/Parallel ops into multiple kernels");
}
ModuleOp moduleOp = cast<ModuleOp>(oldFn.getParentOp());
@@ -160,13 +162,13 @@
Location loc = oldFn.getLoc();
SmallVector<std::string, 4> splitKernels;
- splitKernels.reserve(linalgOps.size());
+ splitKernels.reserve(separableOps.size());
llvm::SmallPtrSet<Operation *, 16> closure;
- for (const auto &linalgOp : llvm::enumerate(linalgOps)) {
- // Create a new function for hosting this Linalg op.
- splitKernels.emplace_back(
- llvm::formatv("{0}_dispatch_{1}", oldFn.getName(), linalgOp.index()));
+ for (const auto &separableOp : llvm::enumerate(separableOps)) {
+ // Create a new function for hosting this op.
+ splitKernels.emplace_back(llvm::formatv("{0}_dispatch_{1}", oldFn.getName(),
+ separableOp.index()));
StringRef newFnName = splitKernels.back();
builder.setInsertionPointToStart(moduleOp.getBody());
auto newFn = builder.create<FuncOp>(loc, newFnName, oldFn.getType(),
@@ -181,7 +183,7 @@
// Collect the closure for the current Linalg op.
closure.clear();
- collectAllReferencedOps(linalgOp.value(), closure);
+ collectAllReferencedOps(separableOp.value(), closure);
// Clone all ops in the closure to the new function.
Block *newFnBlock = newFn.addEntryBlock();
@@ -190,14 +192,14 @@
for (Operation &op : oldFnBlock) {
if (closure.count(&op) == 0) continue;
builder.insert(op.clone(remapper));
- if (&op == linalgOp.value()) break;
+ if (&op == separableOp.value()) break;
}
builder.insert(oldFnBlock.getTerminator()->clone(remapper));
}
// Add the entry point schedule to the module op.
SmallVector<Attribute, 4> entryPoints;
- entryPoints.reserve(linalgOps.size());
+ entryPoints.reserve(separableOps.size());
for (const std::string &kernel : splitKernels) {
entryPoints.emplace_back(builder.getStringAttr(kernel));
}
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Utils.cpp b/iree/compiler/Conversion/LinalgToSPIRV/Utils.cpp
new file mode 100644
index 0000000..6fed42e
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToSPIRV/Utils.cpp
@@ -0,0 +1,51 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//===- Utils.cpp - Utility functions used in Linalg to SPIR-V lowering ----===//
+//
+// Implementaiton of utility functions used while lowering from Linalg to SPIRV.
+//
+//===----------------------------------------------------------------------===//
+
+#include "iree/compiler/Conversion/LinalgToSPIRV/Utils.h"
+
+#include "mlir/Dialect/SPIRV/TargetAndABI.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Region.h"
+#include "mlir/Support/LogicalResult.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+LogicalResult updateWorkGroupSize(FuncOp funcOp,
+ ArrayRef<int64_t> workGroupSize) {
+ // Need to update both the surrounding FuncOp that has the spv.entry_point_abi
+ // attribute, and the hal.executable.
+ Region &body = funcOp.getBody();
+ if (!llvm::hasSingleElement(body))
+ return funcOp.emitError("unhandled dispatch function with multiple blocks");
+
+ if (workGroupSize.size() != 3)
+ return funcOp.emitError("expected workgroup size to have three entries");
+ SmallVector<int32_t, 3> workGroupSizeVec = llvm::to_vector<3>(llvm::map_range(
+ workGroupSize, [](int64_t v) { return static_cast<int32_t>(v); }));
+
+ funcOp.setAttr(
+ spirv::getEntryPointABIAttrName(),
+ spirv::getEntryPointABIAttr(workGroupSizeVec, funcOp.getContext()));
+ return success();
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Utils.h b/iree/compiler/Conversion/LinalgToSPIRV/Utils.h
new file mode 100644
index 0000000..bdea68e
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToSPIRV/Utils.h
@@ -0,0 +1,38 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//===- Utils.h - Utility functions used in Linalg to SPIR-V lowering ------===//
+//
+// Utility functions used while lowering from Linalg to SPIRV.
+//
+//===----------------------------------------------------------------------===//
+#ifndef IREE_COMPILER_CONVERSION_LINALGTOSPIRV_UTILS_H_
+#define IREE_COMPILER_CONVERSION_LINALGTOSPIRV_UTILS_H_
+
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+class FuncOp;
+struct LogicalResult;
+
+namespace iree_compiler {
+
+/// Updates the workgroup size used for the dispatch region.
+LogicalResult updateWorkGroupSize(FuncOp funcOp,
+ ArrayRef<int64_t> workGroupSize);
+
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_CONVERSION_LINALGTOSPIRV_UTILS_H_
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir
index 9d81a75..55399f3 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir
@@ -1,205 +1,410 @@
-// RUN: iree-opt -iree-codegen-convert-to-gpu -canonicalize -split-input-file %s | IreeFileCheck %s
+// RUN: iree-opt -iree-codegen-convert-to-gpu -canonicalize -cse -split-input-file %s | IreeFileCheck %s
-#map0 = affine_map<(d0, d1, d2) -> (d0, d1 - d2)>
-#map1 = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
-#map2 = affine_map<(d0, d1) -> (d0, d1)>
+#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+module attributes {
+ spv.target_env =
+ #spv.target_env<#spv.vce<v1.3,
+ [Shader], [SPV_KHR_storage_buffer_storage_class]>,
+ {max_compute_workgroup_invocations = 128 : i32,
+ max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
+ func @parallel_4D(%arg0: memref<?x?x?x?xf32>,
+ %arg1 : memref<?x?x?x?xf32>,
+ %arg2 : memref<?x?x?x?xf32>)
+ attributes {iree.dispatch_fn_name = "parallel_4D"} {
+ linalg.generic
+ {args_in = 2 : i64, args_out = 1 : i64,
+ indexing_maps = [#map0, #map0, #map0],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+ %arg0, %arg1, %arg2 {
+ ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
+ %0 = addf %arg3, %arg4 : f32
+ linalg.yield %0 : f32
+ } : memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>
+ return
+ }
+}
+// CHECK-LABEL: func @parallel_4D
+// CHECK-SAME: local_size = dense<[32, 1, 1]>
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = constant 2 : index
+// CHECK-DAG: %[[C3:.+]] = constant 3 : index
+// CHECK-DAG: %[[UB0:.+]] = dim %{{.+}}, %[[C0]]
+// CHECK-DAG: %[[UB1:.+]] = dim %{{.+}}, %[[C1]]
+// CHECK-DAG: %[[UB2:.+]] = dim %{{.+}}, %[[C2]]
+// CHECK-DAG: %[[UB3:.+]] = dim %{{.+}}, %[[C3]]
+// CHECK: %[[T4:.+]] = muli %[[UB3]], %[[UB2]]
+// CHECK: %[[T5:.+]] = muli %[[T4]], %[[UB1]]
+// CHECK: %[[UB:.+]] = muli %[[T5]], %[[UB0]]
+// CHECK-DAG: %[[BID:.+]] = "gpu.block_id"() {dimension = "x"}
+// CHECK-DAG: %[[BDIM:.+]] = "gpu.block_dim"() {dimension = "x"}
+// CHECK-DAG: %[[TID:.+]] = "gpu.thread_id"() {dimension = "x"}
+// CHECK: %[[BOFFSET:.+]] = muli %[[BID]], %[[BDIM]]
+// CHECK: %[[IV:.+]] = addi %[[BOFFSET]], %[[TID]]
+// CHECK: %[[COND:.+]] = cmpi "slt", %[[IV]], %[[UB]]
+// CHECK: scf.if %[[COND]]
+// CHECK: %[[IV0:.+]] = divi_signed %[[IV]], %[[T5]]
+// CHECK: %[[T14:.+]] = remi_signed %[[IV]], %[[T5]]
+// CHECK: %[[IV1:.+]] = divi_signed %[[T14]], %[[T4]]
+// CHECK: %[[T16:.+]] = remi_signed %[[T14]], %[[T4]]
+// CHECK: %[[IV2:.+]] = divi_signed %[[T16]], %[[UB3]]
+// CHECK: %[[IV3:.+]] = remi_signed %[[T16]], %[[UB3]]
+// CHECK: load %{{.+}}[%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
+// CHECK: load %{{.+}}[%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
+// CHECK: store %{{.+}}[%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
+
+
+// -----
+
+#map0 = affine_map<() -> ()>
+#accesses = [#map0, #map0, #map0]
+#trait = {
+ args_in = 2 : i64,
+ args_out = 1 : i64,
+ indexing_maps = #accesses,
+ iterator_types = []
+}
+
+module attributes {
+ spv.target_env =
+ #spv.target_env<#spv.vce<v1.3,
+ [Shader], [SPV_KHR_storage_buffer_storage_class]>,
+ {max_compute_workgroup_invocations = 128 : i32,
+ max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
+ func @scalar_add(%arg0 : memref<f32>, %arg1 : memref<f32>,
+ %arg2 : memref<f32>)
+ {
+ linalg.generic #trait %arg0, %arg1, %arg2 {
+ ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
+ %0 = addf %arg3, %arg4 : f32
+ linalg.yield %0 : f32
+ } : memref<f32>, memref<f32>, memref<f32>
+ return
+ }
+}
+// CHECK-LABEL: func @scalar_add
+// CHECK-SAME: local_size = dense<1> : vector<3xi32>
+// CHECK-NEXT: load
+// CHECK-NEXT: load
+// CHECK-NEXT: addf
+// CHECK-NEXT: store
+// CHECK-NEXT: return
+
+// -----
module {
- func @pw_add(%arg0: memref<4x8xi32>, %arg1: memref<4x8xi32>,
- %arg2: memref<4x8xi32>)
- attributes {iree.dispatch_fn_name = "pw_add"} {
- %c32 = constant 32 : index
+ func @reduce_sum(%arg0: memref<?x?x?xf32>, %arg1: memref<f32>, %arg2: memref<?xf32>)
+ attributes {iree.dispatch_fn_name = "reduce_sum"} {
+ linalg.indexed_generic
+ {args_in = 2 : i64, args_out = 1 : i64,
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> ()>,
+ affine_map<(d0, d1, d2) -> (d0)>],
+ iterator_types = ["parallel", "parallel", "reduction"]} %arg0, %arg1, %arg2 {
+ ^bb0(%arg3: index, %arg4: index, %arg5: index,
+ %arg6: f32, %arg7: f32, %arg8: f32): // no predecessors
+ %c0 = constant 0 : index
+ %cst = constant true
+ %0 = cmpi "eq", %arg5, %c0 : index
+ %1 = and %cst, %0 : i1
+ %2 = select %1, %arg7, %arg8 : f32
+ %3 = addf %arg6, %2 : f32
+ linalg.yield %3 : f32
+ }: memref<?x?x?xf32>, memref<f32>, memref<?xf32>
+ return
+ }
+}
+
+// CHECK-LABEL: func @reduce_sum
+// CHECK-SAME: local_size = dense<[32, 1, 1]> : vector<3xi32>
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = constant 2 : index
+// CHECK: %[[UB0:.+]] = dim %{{.+}}, %[[C0]]
+// CHECK: %[[UB1:.+]] = dim %{{.+}}, %[[C1]]
+// CHECK: %[[UB2:.+]] = dim %{{.+}}, %[[C2]]
+// CHECK: %[[UB:.+]] = muli %[[UB1]], %[[UB0]]
+// CHECK: %[[COND:.+]] = cmpi "slt", %{{.+}}, %[[UB]]
+// CHECK: scf.if %[[COND]]
+// CHECK: %[[IV0:.+]] = divi_signed %{{.+}}, %[[UB1]]
+// CHECK: %[[IV1:.+]] = remi_signed %{{.+}}, %[[UB1]]
+// CHECK: scf.for %[[IV:.+]] = %{{.+}} to %[[UB2]]
+// CHECK: %[[ISZERO:.+]] = cmpi "eq", %[[IV]], %[[C0]]
+
+// -----
+
+#map0 = affine_map<(d0)[s0] -> (8, -d0 + s0)>
+#map1 = affine_map<(d0)[s0] -> (4, -d0 + s0)>
+#map2 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
+
+module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
+ func @matmul(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) attributes {spv.entry_point_abi = {local_size = dense<[8, 8, 1]> : vector<3xi32>}} {
%c0 = constant 0 : index
+ %c1 = constant 1 : index
%c4 = constant 4 : index
%c8 = constant 8 : index
- %c1 = constant 1 : index
- scf.parallel (%arg3, %arg4) = (%c0, %c0) to (%c4, %c8) step (%c4, %c32) {
- %0 = affine.min #map0(%c4, %c4, %arg3)
- %1 = affine.min #map0(%c32, %c8, %arg4)
- %2 = subview %arg0[%arg3, %arg4] [%0, %1] [%c1, %c1]
- : memref<4x8xi32> to memref<?x?xi32, #map1>
- %3 = affine.min #map0(%c4, %c4, %arg3)
- %4 = affine.min #map0(%c32, %c8, %arg4)
- %5 = subview %arg1[%arg3, %arg4] [%3, %4] [%c1, %c1]
- : memref<4x8xi32> to memref<?x?xi32, #map1>
- %6 = affine.min #map0(%c4, %c4, %arg3)
- %7 = affine.min #map0(%c32, %c8, %arg4)
- %8 = subview %arg2[%arg3, %arg4] [%6, %7] [%c1, %c1]
- : memref<4x8xi32> to memref<?x?xi32, #map1>
- linalg.generic
- {args_in = 2 : i64, args_out = 1 : i64,
- indexing_maps = [#map2, #map2, #map2],
- iterator_types = ["parallel", "parallel"]}
- {__internal_linalg_transform__ = "workitem"} %2, %5, %8 {
- ^bb0(%arg5: i32, %arg6: i32, %arg7: i32): // no predecessors
- %9 = addi %arg5, %arg6 : i32
- linalg.yield %9 : i32
- } : memref<?x?xi32, #map1>, memref<?x?xi32, #map1>, memref<?x?xi32, #map1>
+ %0 = dim %arg0, %c0 : memref<?x?xf32>
+ %1 = dim %arg0, %c1 : memref<?x?xf32>
+ %2 = dim %arg1, %c1 : memref<?x?xf32>
+ scf.parallel (%arg3, %arg4) = (%c0, %c0) to (%0, %2) step (%c8, %c8) {
+ scf.for %arg5 = %c0 to %1 step %c4 {
+ %3 = affine.min #map0(%arg3)[%0]
+ %4 = affine.min #map1(%arg5)[%1]
+ %5 = subview %arg0[%arg3, %arg5] [%3, %4] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map2>
+ %6 = dim %arg1, %c0 : memref<?x?xf32>
+ %7 = affine.min #map1(%arg5)[%6]
+ %8 = affine.min #map0(%arg4)[%2]
+ %9 = subview %arg1[%arg5, %arg4] [%7, %8] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map2>
+ %10 = dim %arg2, %c0 : memref<?x?xf32>
+ %11 = affine.min #map0(%arg3)[%10]
+ %12 = dim %arg2, %c1 : memref<?x?xf32>
+ %13 = affine.min #map0(%arg4)[%12]
+ %14 = subview %arg2[%arg3, %arg4] [%11, %13] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map2>
+ linalg.matmul %5, %9, %14 {__internal_linalg_transform__ = "workitem"} : (memref<?x?xf32, #map2>, memref<?x?xf32, #map2>, memref<?x?xf32, #map2>)
+ }
scf.yield
}
return
}
}
-// CHECK-DAG: %[[STEPY:.+]] = constant 4 : index
-// CHECK-DAG: %[[STEPX:.+]] = constant 32 : index
+
+// CHECK-LABEL: func @matmul
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9$._-]+]]: memref<?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9$._-]+]]: memref<?x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9$._-]+]]: memref<?x?xf32>
+// CHECK-DAG: %[[C4:.+]] = constant 4 : index
+// CHECK-DAG: %[[C8:.+]] = constant 8 : index
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[UB0:.+]] = dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[UB1:.+]] = dim %[[ARG1]], %[[C1]]
+// CHECK-DAG: %[[UB2:.+]] = dim %[[ARG0]], %[[C1]]
+// CHECK-DAG: %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
+// CHECK-DAG: %[[GDIMX:.+]] = "gpu.grid_dim"() {dimension = "x"}
+// CHECK-DAG: %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
+// CHECK-DAG: %[[GDIMY:.+]] = "gpu.grid_dim"() {dimension = "y"}
+// CHECK: %[[BOFFSETY:.+]] = muli %[[BIDY]], %[[C8]]
+// CHECK: %[[BSTEPY:.+]] = muli %[[GDIMY]], %[[C8]]
+// CHECK: %[[BOFFSETX:.+]] = muli %[[BIDX]], %[[C8]]
+// CHECK: %[[BSTEPX:.+]] = muli %[[GDIMX]], %[[C8]]
+// CHECK: scf.for %[[BIV0:.+]] = %[[BOFFSETY]] to %[[UB0]] step %[[BSTEPY]]
+// CHECK: scf.for %[[BIV1:.+]] = %[[BOFFSETX]] to %[[UB1]] step %[[BSTEPX]]
+// CHECK: scf.for %[[BIV2:.+]] = %[[C0]] to %[[UB2]] step %[[C4]]
+// CHECK-DAG: %[[VIEWUB0:.+]] = affine.min #{{.+}}(%[[BIV0]])[%[[UB0]]]
+// CHECK-DAG: %[[VIEWUB1:.+]] = affine.min #{{.+}}(%[[BIV1]])[%[[UB1]]]
+// CHECK-DAG: %[[VIEWUB2:.+]] = affine.min #{{.+}}(%[[BIV2]])[%[[UB2]]]
+// CHECK-DAG: %[[TIDX:.+]] = "gpu.thread_id"() {dimension = "x"}
+// CHECK-DAG: %[[TIDY:.+]] = "gpu.thread_id"() {dimension = "y"}
+// CHECK: %[[INBOUNDY:.+]] = cmpi "slt", %[[TIDY]], %[[VIEWUB0]]
+// CHECK: %[[INBOUNDX:.+]] = cmpi "slt", %[[TIDX]], %[[VIEWUB1]]
+// CHECK: %[[COND:.+]] = and %[[INBOUNDY]], %[[INBOUNDX]]
+// CHECK: scf.if %[[COND]]
+// CHECK: scf.for %{{.+}} = %[[C0]] to %[[VIEWUB2]] step %[[C1]]
+
+// -----
+
+
+#map0 = affine_map<(d0)[s0] -> (1, -d0 + s0)>
+#map1 = affine_map<(d0)[s0, s1] -> (s0 + 4, -d0 + s1)>
+#map2 = affine_map<(d0)[s0, s1] -> (s0 + 32, -d0 + s1)>
+#map3 = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3)>
+#map4 = affine_map<(d0)[s0] -> (4, -d0 + s0)>
+#map5 = affine_map<(d0)[s0] -> (32, -d0 + s0)>
+
+module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
+ func @conv_no_padding(%arg0: memref<?x?x?x?xf32>, %arg1: memref<?x?x?x?xf32>, %arg2: memref<?x?x?x?xf32>) attributes {spv.entry_point_abi = {local_size = dense<[32, 4, 1]> : vector<3xi32>}} {
+ %c4 = constant 4 : index
+ %c32 = constant 32 : index
+ %c2 = constant 2 : index
+ %c0 = constant 0 : index
+ %c3 = constant 3 : index
+ %c1 = constant 1 : index
+ %0 = dim %arg1, %c0 : memref<?x?x?x?xf32>
+ %1 = dim %arg1, %c1 : memref<?x?x?x?xf32>
+ %2 = dim %arg1, %c2 : memref<?x?x?x?xf32>
+ %3 = dim %arg2, %c1 : memref<?x?x?x?xf32>
+ %4 = dim %arg2, %c2 : memref<?x?x?x?xf32>
+ scf.parallel (%arg3, %arg4, %arg5) = (%c0, %c0, %c0) to (%0, %3, %4) step (%c1, %c4, %c32) {
+ %5 = affine.min #map0(%arg3)[%0]
+ %6 = affine.min #map1(%arg4)[%1, %1]
+ %7 = affine.min #map2(%arg5)[%2, %2]
+ %8 = dim %arg1, %c3 : memref<?x?x?x?xf32>
+ %9 = subview %arg1[%arg3, %arg4, %arg5, 0] [%5, %6, %7, %8] [1, 1, 1, 1] : memref<?x?x?x?xf32> to memref<?x?x?x?xf32, #map3>
+ %10 = dim %arg2, %c0 : memref<?x?x?x?xf32>
+ %11 = affine.min #map0(%arg3)[%10]
+ %12 = affine.min #map4(%arg4)[%3]
+ %13 = affine.min #map5(%arg5)[%4]
+ %14 = dim %arg2, %c3 : memref<?x?x?x?xf32>
+ %15 = subview %arg2[%arg3, %arg4, %arg5, 0] [%11, %12, %13, %14] [1, 1, 1, 1] : memref<?x?x?x?xf32> to memref<?x?x?x?xf32, #map3>
+ linalg.conv(%arg0, %9, %15) {__internal_linalg_transform__ = "workitem", dilations = [1, 1], strides = [1, 1]} : memref<?x?x?x?xf32>, memref<?x?x?x?xf32, #map3>, memref<?x?x?x?xf32, #map3>
+ scf.yield
+ }
+ return
+ }
+}
+
+// CHECK-LABEL: func @conv_no_padding
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
+// CHECK-DAG: %[[C4:.+]] = constant 4 : index
+// CHECK-DAG: %[[C32:.+]] = constant 32 : index
+// CHECK-DAG: %[[C2:.+]] = constant 2 : index
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[UB0:.+]] = dim %[[ARG1]], %[[C0]]
+// CHECK-DAG: %[[UB1:.+]] = dim %[[ARG1]], %[[C1]]
+// CHECK-DAG: %[[UB2:.+]] = dim %[[ARG1]], %[[C2]]
+// CHECK-DAG: %[[UB3:.+]] = dim %[[ARG2]], %[[C1]]
+// CHECK-DAG: %[[UB4:.+]] = dim %[[ARG2]], %[[C2]]
// CHECK-DAG: %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
// CHECK-DAG: %[[NBLOCKSX:.+]] = "gpu.grid_dim"() {dimension = "x"}
// CHECK-DAG: %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
// CHECK-DAG: %[[NBLOCKSY:.+]] = "gpu.grid_dim"() {dimension = "y"}
-// CHECK: %[[NEWLBY:.+]] = muli %[[BIDY]], %[[STEPY]]
-// CHECK: %[[NEWSTEPY:.+]] = muli %[[NBLOCKSY]], %[[STEPY]]
-// CHECK: %[[NEWLBX:.+]] = muli %[[BIDX]], %[[STEPX]]
-// CHECK: %[[NEWSTEPX:.+]] = muli %[[NBLOCKSX]], %[[STEPX]]
-// CHECK: scf.for %{{.+}} = %[[NEWLBY]] to %{{.+}} step %[[NEWSTEPY]]
-// CHECK: scf.for %{{.+}} = %[[NEWLBX]] to %{{.+}} step %[[NEWSTEPX]]
-// CHECK-DAG: %[[TIDX:.+]] = "gpu.thread_id"() {dimension = "x"}
-// CHECK-DAG: %[[NTHREADSX:.+]] = "gpu.block_dim"() {dimension = "x"}
-// CHECK-DAG: %[[TIDY:.+]] = "gpu.thread_id"() {dimension = "y"}
-// CHECK-DAG: %[[NTHREADSY:.+]] = "gpu.block_dim"() {dimension = "y"}
-// CHECK: scf.for %{{.+}} = %[[TIDY]] to %{{.+}} step %[[NTHREADSY]]
-// CHECK: scf.for %{{.+}} = %[[TIDX]] to %{{.+}} step %[[NTHREADSX]]
+// CHECK-DAG: %[[BIDZ:.+]] = "gpu.block_id"() {dimension = "z"}
+// CHECK-DAG: %[[NBLOCKSZ:.+]] = "gpu.grid_dim"() {dimension = "z"}
+// CHECK-DAG: %[[BOFFSETY:.+]] = muli %[[BIDY]], %[[C4]]
+// CHECK-DAG: %[[BSTEPY:.+]] = muli %[[NBLOCKSY]], %[[C4]]
+// CHECK-DAG: %[[BOFFSETX:.+]] = muli %[[BIDX]], %[[C32]]
+// CHECK-DAG: %[[BSTEPX:.+]] = muli %[[NBLOCKSX]], %[[C32]]
+// CHECK: scf.for %[[IV3:.+]] = %[[BIDZ]] to %[[UB0]] step %[[NBLOCKSZ]]
+// CHECK: scf.for %[[IV4:.+]] = %[[BOFFSETY]] to %[[UB3]] step %[[BSTEPY]]
+// CHECK: scf.for %[[IV5:.+]] = %[[BOFFSETX]] to %[[UB4]] step %[[BSTEPX]]
+// CHECK: %[[BOUNDSZ:.+]] = affine.min #{{.+}}(%[[IV3]])
+// CHECK: %[[SV1:.+]] = subview %[[ARG1]][%[[IV3]], %[[IV4]], %[[IV5]], 0]
+// CHECK: %[[BOUNDSY:.+]] = affine.min #{{.+}}(%[[IV4]])
+// CHECK: %[[BOUNDSX:.+]] = affine.min #{{.+}}(%[[IV5]])
+// CHECK: %[[SV2:.+]] = subview %[[ARG2]][%[[IV3]], %[[IV4]], %[[IV5]], 0]
+// CHECK-DAG: %[[TIDX:.+]] = "gpu.thread_id"() {dimension = "x"}
+// CHECK-DAG: %[[NTHREADSX:.+]] = "gpu.block_dim"() {dimension = "x"}
+// CHECK-DAG: %[[TIDY:.+]] = "gpu.thread_id"() {dimension = "y"}
+// CHECK-DAG: %[[NTHREADSY:.+]] = "gpu.block_dim"() {dimension = "y"}
+// CHECK-DAG: %[[TIDZ:.+]] = "gpu.thread_id"() {dimension = "z"}
+// CHECK-DAG: %[[NTHREADSZ:.+]] = "gpu.block_dim"() {dimension = "z"}
+// CHECK: scf.for %[[IV3:.+]] = %[[TIDZ]] to %[[BOUNDSZ]] step %[[NTHREADSZ]]
+// CHECK: scf.for %[[IV4:.+]] = %[[TIDY]] to %[[BOUNDSY]] step %[[NTHREADSY]]
+// CHECK: scf.for %[[IV5:.+]] = %[[TIDX]] to %[[BOUNDSX]] step %[[NTHREADSX]]
+// CHECK: scf.for
+// CHECK: scf.for
+// CHECK: scf.for
+// CHECK: scf.for
+// CHECK-NOT: linalg.conv
// -----
-module {
- func @reduce_sum(%arg0: memref<4xf32>, %arg1: memref<f32>, %arg2: memref<f32>)
- attributes {iree.dispatch_fn_name = "reduce_sum"} {
- linalg.indexed_generic
- {args_in = 2 : i64, args_out = 1 : i64,
- indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>,
- affine_map<(d0) -> ()>],
- iterator_types = ["reduction"]} %arg0, %arg1, %arg2 {
- ^bb0(%arg3: index, %arg4: f32, %arg5: f32, %arg6: f32): // no predecessors
- %c0 = constant 0 : index
- %cst = constant true
- %0 = cmpi "eq", %arg3, %c0 : index
- %1 = and %cst, %0 : i1
- %2 = select %1, %arg5, %arg6 : f32
- %3 = addf %arg4, %2 : f32
- linalg.yield %3 : f32
- }: memref<4xf32>, memref<f32>, memref<f32>
+#map0 = affine_map<(d0, d1, d2) -> (32, d1 - d2)>
+#map1 = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4)>
+
+
+module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
+ func @conv_padding(%arg0: memref<?x?x?x?xf32>, %arg1: memref<?x?x?x?xf32>, %arg2: memref<?x?x?x?xf32>) attributes {spv.entry_point_abi = {local_size = dense<[32, 1, 1]> : vector<3xi32>}} {
+ linalg.conv(%arg0, %arg1, %arg2) {dilations = [1, 1], padding = dense<[[1, 1], [0, 1]]> : tensor<2x2xi64>, strides = [1, 1]} : memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>
return
}
}
-// CHECK-DAG: %[[C0:.+]] = constant 0 : index
-// CHECK-DAG: %[[C4:.+]] = constant 4 : index
-// CHECK-DAG: %[[C1:.+]] = constant 1 : index
-// CHECK: scf.for %{{.+}} = %[[C0]] to %[[C4]] step %[[C1]]
-// CHECK-NOT: scf
+
+// CHECK-LABEL: func @conv_padding
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
+// CHECK-DAG: %[[C2:.+]] = constant 2 : index
+// CHECK-DAG: %[[C3:.+]] = constant 3 : index
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[UB0:.+]] = dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[UB1:.+]] = dim %[[ARG0]], %[[C1]]
+// CHECK-DAG: %[[UB2:.+]] = dim %[[ARG0]], %[[C2]]
+// CHECK-DAG: %[[UB3:.+]] = dim %[[ARG0]], %[[C3]]
+// CHECK-DAG: %[[UB4:.+]] = dim %[[ARG1]], %[[C0]]
+// CHECK-DAG: %[[UB5:.+]] = dim %[[ARG2]], %[[C1]]
+// CHECK-DAG: %[[UB6:.+]] = dim %[[ARG2]], %[[C2]]
+// CHECK: %[[T7:.+]] = muli %[[UB3]], %[[UB6]]
+// CHECK: %[[T8:.+]] = muli %[[T7]], %[[UB5]]
+// CHECK: %[[UB:.+]] = muli %[[T8]], %[[UB4]]
+// CHECK-DAG: %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
+// CHECK-DAG: %[[NTHREADSX:.+]] = "gpu.block_dim"() {dimension = "x"}
+// CHECK-DAG: %[[TIDX:.+]] = "gpu.thread_id"() {dimension = "x"}
+// CHECK: %[[T13:.+]] = muli %[[BIDX]], %[[NTHREADSX]]
+// CHECK: %[[PROCID:.+]] = addi %[[T13]], %[[TIDX]]
+// CHECK: %[[COND:.+]] = cmpi "slt", %[[PROCID]], %[[UB]]
+// CHECK: scf.if %[[COND]]
+// CHECK: %[[IV0:.+]] = divi_signed %[[PROCID]], %[[T8]]
+// CHECK: %[[T17:.+]] = remi_signed %[[PROCID]], %[[T8]]
+// CHECK: %[[IV1:.+]] = divi_signed %[[T17]], %[[T7]]
+// CHECK: %[[T19:.+]] = remi_signed %[[T17]], %[[T7]]
+// CHECK: %[[IV2:.+]] = divi_signed %[[T19]], %[[UB3]]
+// CHECK: %[[T21:.+]] = remi_signed %[[T19]], %[[UB3]]
+// CHECK: scf.for %[[IV3:.+]] = %[[C0]] to %[[UB2]] step %[[C1]]
+// CHECK: scf.for %[[IV4:.+]] = %[[C0]] to %[[UB0]] step %[[C1]]
+// CHECK: scf.for %[[IV5:.+]]= %[[C0]] to %[[UB1]] step %[[C1]]
+// CHECK-NOT: linalg.conv
+
// -----
-#map0 = affine_map<(d0)[s0] -> (2, -d0 + s0)>
-#map1 = affine_map<(d0)[s0] -> (32, -d0 + s0)>
-#map2 = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4)>
-#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+#map0 = affine_map<(d0)[s0, s1] -> (s0 + 4, -d0 + s1)>
+#map1 = affine_map<(d0)[s0, s1] -> (s0 + 32, -d0 + s1)>
+#map2 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
+#map3 = affine_map<(d0)[s0] -> (4, -d0 + s0)>
+#map4 = affine_map<(d0)[s0] -> (32, -d0 + s0)>
-module {
- func @parallel_4D(%arg0: memref<?x?x?x?xf32>, %arg1: memref<?x?x?x?xf32>, %arg2: memref<?x?x?x?xf32>) attributes {iree.dispatch_fn_name = "parallel_4D", spv.entry_point_abi = {local_size = dense<[32, 2, 2]> : vector<3xi32>}} {
- %c2 = constant 2 : index
- %c32 = constant 32 : index
+module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
+ func @pooling_no_padding(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) attributes {spv.entry_point_abi = {local_size = dense<[32, 4, 1]> : vector<3xi32>}} {
+ %c4 = constant 4 : index
%c0 = constant 0 : index
+ %c32 = constant 32 : index
%c1 = constant 1 : index
- %c3 = constant 3 : index
- %0 = dim %arg0, %c0 : memref<?x?x?x?xf32>
- %1 = dim %arg0, %c1 : memref<?x?x?x?xf32>
- %2 = dim %arg0, %c2 : memref<?x?x?x?xf32>
- %3 = dim %arg0, %c3 : memref<?x?x?x?xf32>
- scf.parallel (%arg3, %arg4, %arg5, %arg6) = (%c0, %c0, %c0, %c0) to (%0, %1, %2, %3) step (%c2, %c2, %c2, %c32) {
- %12 = affine.min #map0(%arg3)[%0]
- %13 = affine.min #map0(%arg4)[%1]
- %14 = affine.min #map0(%arg5)[%2]
- %15 = affine.min #map1(%arg6)[%3]
- %16 = subview %arg0[%arg3, %arg4, %arg5, %c0] [%12, %13, %14, %15] [%c1, %c1, %c1, %c1] : memref<?x?x?x?xf32> to memref<?x?x?x?xf32, #map2>
- %17 = subview %arg1[%arg3, %arg4, %arg5, %c0] [%12, %13, %14, %15] [%c1, %c1, %c1, %c1] : memref<?x?x?x?xf32> to memref<?x?x?x?xf32, #map2>
- %18 = subview %arg2[%arg3, %arg4, %arg5, %c0] [%12, %13, %14, %15] [%c1, %c1, %c1, %c1] : memref<?x?x?x?xf32> to memref<?x?x?x?xf32, #map2>
- linalg.generic {args_in = 2 : i64, args_out = 1 : i64,
- indexing_maps = [#map3, #map3, #map3],
- iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
- {__internal_linalg_transform__ = "workitem"}
- %16, %17, %18
- {
- ^bb0(%arg7: f32, %arg8: f32, %arg9: f32): // no predecessors
- %19 = addf %arg7, %arg8 : f32
- linalg.yield %19 : f32
- } : memref<?x?x?x?xf32, #map2>, memref<?x?x?x?xf32, #map2>, memref<?x?x?x?xf32, #map2>
+ %0 = dim %arg1, %c0 : memref<?x?xf32>
+ %1 = dim %arg1, %c1 : memref<?x?xf32>
+ %2 = dim %arg2, %c0 : memref<?x?xf32>
+ %3 = dim %arg2, %c1 : memref<?x?xf32>
+ scf.parallel (%arg3, %arg4) = (%c0, %c0) to (%2, %3) step (%c4, %c32) {
+ %4 = dim %arg0, %c0 : memref<?x?xf32>
+ %5 = affine.min #map0(%arg3)[%0, %4]
+ %6 = dim %arg0, %c1 : memref<?x?xf32>
+ %7 = affine.min #map1(%arg4)[%1, %6]
+ %8 = subview %arg0[%arg3, %arg4] [%5, %7] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map2>
+ %9 = affine.min #map3(%arg3)[%2]
+ %10 = affine.min #map4(%arg4)[%3]
+ %11 = subview %arg2[%arg3, %arg4] [%9, %10] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map2>
+ linalg.pooling_max(%8, %arg1, %11) {__internal_linalg_transform__ = "workitem", dilations = [1, 1], strides = [1, 1]} : memref<?x?xf32, #map2>, memref<?x?xf32>, memref<?x?xf32, #map2>
scf.yield
}
return
}
}
-// CHECK-DAG: %[[C2:.+]] = constant 2 : index
-// CHECK-DAG: %[[C32:.+]] = constant 32 : index
-// CHECK-DAG: %[[C0:.+]] = constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = constant 1 : index
-// CHECK-DAG: %[[C3:.+]] = constant 3 : index
-// CHECK-DAG: %[[SERIALDIMOUTER:.+]] = dim %{{.+}}, %[[C3]]
-// CHECK-DAG: %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"} : () -> index
-// CHECK-DAG: %[[NBLOCKSX:.+]] = "gpu.grid_dim"() {dimension = "x"} : () -> index
-// CHECK-DAG: %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"} : () -> index
-// CHECK-DAG: %[[NBLOCKSY:.+]] = "gpu.grid_dim"() {dimension = "y"} : () -> index
-// CHECK-DAG: %[[BIDZ:.+]] = "gpu.block_id"() {dimension = "z"} : () -> index
-// CHECK-DAG: %[[NBLOCKSZ:.+]] = "gpu.grid_dim"() {dimension = "z"} : () -> index
-// CHECK-DAG: %[[LB0:.+]] = muli %[[BIDZ]], %[[C2]]
-// CHECK-DAG: %[[STEP0:.+]] = muli %[[NBLOCKSZ]], %[[C2]]
-// CHECK-DAG: %[[LB1:.+]] = muli %[[BIDY]], %[[C2]]
-// CHECK-DAG: %[[STEP1:.+]] = muli %[[NBLOCKSY]], %[[C2]]
-// CHECK-DAG: %[[LB2:.+]] = muli %[[BIDX]], %[[C2]]
-// CHECK-DAG: %[[STEP2:.+]] = muli %[[NBLOCKSX]], %[[C2]]
-// CHECK: scf.for %{{.+}} = %[[LB0]] to %{{.+}} step %[[STEP0]]
-// CHECK: scf.for %{{.+}} = %[[LB1]] to %{{.+}} step %[[STEP1]]
-// CHECK: scf.for %{{.+}} = %[[LB2]] to %{{.+}} step %[[STEP2]]
-// CHECK: scf.for %{{.+}} = %[[C0]] to %[[SERIALDIMOUTER]] step %[[C32]]
-// CHECK-DAG: %[[TIDX:.+]] = "gpu.thread_id"() {dimension = "x"} : () -> index
-// CHECK-DAG: %[[NTHREADSX:.+]] = "gpu.block_dim"() {dimension = "x"} : () -> index
-// CHECK-DAG: %[[TIDY:.+]] = "gpu.thread_id"() {dimension = "y"} : () -> index
-// CHECK-DAG: %[[NTHREADSY:.+]] = "gpu.block_dim"() {dimension = "y"} : () -> index
-// CHECK-DAG: %[[TIDZ:.+]] = "gpu.thread_id"() {dimension = "z"} : () -> index
-// CHECK-DAG: %[[NTHREADSZ:.+]] = "gpu.block_dim"() {dimension = "z"} : () -> index
-// CHECK: scf.for %{{.+}} = %[[TIDZ]] to %{{.+}} step %[[NTHREADSZ]]
-// CHECK: scf.for %{{.+}} = %[[TIDY]] to %{{.+}} step %[[NTHREADSY]]
-// CHECK: scf.for %{{.+}} = %[[TIDX]] to %{{.+}} step %[[NTHREADSX]]
-// CHECK: scf.for %{{.+}} = %[[C0]] to %{{.+}} step %[[C1]]
-
-// -----
-
-module {
- func @no_tile(%arg0: memref<?x?xf32>, %arg1 : memref<?x?xf32>,
- %arg2 : memref<?x?xf32>)
- attributes {iree.dispatch_fn_name = "reduce_sum"} {
- linalg.generic
- {args_in = 2 : i64, args_out = 1 : i64,
- indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
- affine_map<(d0, d1) -> (d0, d1)>,
- affine_map<(d0, d1) -> (d0, d1)>],
- iterator_types = ["parallel", "parallel"]}
- {__internal_linalg_tranform__ = "no-tile"} %arg0, %arg1, %arg2 {
- ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
- %0 = addf %arg3, %arg4 : f32
- linalg.yield %0 : f32
- }: memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
- return
- }
-}
-
-// CHECK-DAG: %[[C0:.*]] = constant 0 : index
-// CHECK-DAG: %[[C1:.*]] = constant 1 : index
-// CHECK-DAG: %[[UBY:.+]] = dim %{{.*}}, %[[C0]]
-// CHECK-DAG: %[[UBX:.+]] = dim %{{.*}}, %[[C1]]
-// CHECK-DAG: %[[NBLOCKSX:.+]] = "gpu.grid_dim"() {dimension = "x"}
-// CHECK-DAG: %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
-// CHECK-DAG: %[[BLOCKSIZEX:.+]] = "gpu.block_dim"() {dimension = "x"}
-// CHECK-DAG: %[[TIDX:.+]] = "gpu.thread_id"() {dimension = "x"}
-// CHECK: %[[T6:.+]] = muli %[[BIDX]], %[[BLOCKSIZEX]]
-// CHECK: %[[GIDX:.+]] = addi %[[T6]], %[[TIDX]]
-// CHECK: %[[NPROCSX:.+]] = muli %[[BLOCKSIZEX]], %[[NBLOCKSX]]
-// CHECK-DAG: %[[NBLOCKSY:.+]] = "gpu.grid_dim"() {dimension = "y"}
-// CHECK-DAG: %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
-// CHECK-DAG: %[[BLOCKSIZEY:.+]] = "gpu.block_dim"() {dimension = "y"}
-// CHECK-DAG: %[[TIDY:.+]] = "gpu.thread_id"() {dimension = "y"}
-// CHECK: %[[T6:.+]] = muli %[[BIDY]], %[[BLOCKSIZEY]]
-// CHECK: %[[GIDY:.+]] = addi %[[T6]], %[[TIDY]]
-// CHECK: %[[NPROCSY:.+]] = muli %[[BLOCKSIZEY]], %[[NBLOCKSY]]
-// CHECK: scf.for %{{.+}} = %[[GIDY]] to %[[UBY]] step %[[NPROCSY]]
-// CHECK: scf.for %{{.+}} = %[[GIDX]] to %[[UBX]] step %[[NPROCSX]]
+// CHECK-LABEL: func @pooling_no_padding
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9$._-]+]]: memref<?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9$._-]+]]: memref<?x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9$._-]+]]: memref<?x?xf32>
+// CHECK-DAG: %[[C4:.+]] = constant 4 : index
+// CHECK-DAG: %[[C32:.+]] = constant 32 : index
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[UB0:.+]] = dim %[[ARG1]], %[[C0]]
+// CHECK-DAG: %[[UB1:.+]] = dim %[[ARG1]], %[[C1]]
+// CHECK-DAG: %[[UB2:.+]] = dim %[[ARG2]], %[[C0]]
+// CHECK-DAG: %[[UB3:.+]] = dim %[[ARG2]], %[[C1]]
+// CHECK-DAG: %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
+// CHECK-DAG: %[[NBLOCKSX:.+]] = "gpu.grid_dim"() {dimension = "x"}
+// CHECK-DAG: %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
+// CHECK-DAG: %[[NBLOCKSY:.+]] = "gpu.grid_dim"() {dimension = "y"}
+// CHECK-DAG: %[[BOFFSETY:.+]] = muli %[[BIDY]], %[[C4]]
+// CHECK-DAG: %[[BSTEPY:.+]] = muli %[[NBLOCKSY]], %[[C4]]
+// CHECK-DAG: %[[BOFFSETX:.+]] = muli %[[BIDX]], %[[C32]]
+// CHECK-DAG: %[[BSTEPX:.+]] = muli %[[NBLOCKSX]], %[[C32]]
+// CHECK: scf.for %[[IV3:.+]] = %[[BOFFSETY]] to %[[UB2]] step %[[BSTEPY]]
+// CHECK: scf.for %[[IV4:.+]] = %[[BOFFSETX]] to %[[UB3]] step %[[BSTEPX]]
+// CHECK: %[[SV1:.+]] = subview %[[ARG0]][%[[IV3]], %[[IV4]]]
+// CHECK: %[[BOUNDSY:.+]] = affine.min #{{.+}}(%[[IV3]])
+// CHECK: %[[BOUNDSX:.+]] = affine.min #{{.+}}(%[[IV4]])
+// CHECK: %[[SV2:.+]] = subview %[[ARG2]][%[[IV3]], %[[IV4]]]
+// CHECK-DAG: %[[TIDX:.+]] = "gpu.thread_id"() {dimension = "x"}
+// CHECK-DAG: %[[NTHREADSX:.+]] = "gpu.block_dim"() {dimension = "x"}
+// CHECK-DAG: %[[TIDY:.+]] = "gpu.thread_id"() {dimension = "y"}
+// CHECK-DAG: %[[NTHREADSY:.+]] = "gpu.block_dim"() {dimension = "y"}
+// CHECK: scf.for %[[IV5:.+]] = %[[TIDY]] to %[[BOUNDSY]] step %[[NTHREADSY]]
+// CHECK: scf.for %[[IV6:.+]] = %[[TIDX]] to %[[BOUNDSX]] step %[[NTHREADSX]]
+// CHECK: scf.for
+// CHECK: scf.for
+// CHECK-NOT: linalg.pooling_max
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu_option.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu_option.mlir
new file mode 100644
index 0000000..be33748
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu_option.mlir
@@ -0,0 +1,87 @@
+// RUN: iree-opt -iree-codegen-convert-to-gpu -iree-codegen-use-legacy-conv-lowering=false -canonicalize -cse -split-input-file %s | IreeFileCheck %s
+
+#map0 = affine_map<(d0)[s0] -> (1, -d0 + s0)>
+#map1 = affine_map<(d0)[s0, s1] -> (s0 + 4, -d0 + s1)>
+#map2 = affine_map<(d0)[s0, s1] -> (s0 + 32, -d0 + s1)>
+#map3 = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3)>
+#map4 = affine_map<(d0)[s0] -> (4, -d0 + s0)>
+#map5 = affine_map<(d0)[s0] -> (32, -d0 + s0)>
+
+module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
+ func @conv_no_padding(%arg0: memref<?x?x?x?xf32>, %arg1: memref<?x?x?x?xf32>, %arg2: memref<?x?x?x?xf32>) attributes {spv.entry_point_abi = {local_size = dense<[32, 4, 1]> : vector<3xi32>}} {
+ %c4 = constant 4 : index
+ %c32 = constant 32 : index
+ %c2 = constant 2 : index
+ %c0 = constant 0 : index
+ %c3 = constant 3 : index
+ %c1 = constant 1 : index
+ %0 = dim %arg1, %c0 : memref<?x?x?x?xf32>
+ %1 = dim %arg1, %c1 : memref<?x?x?x?xf32>
+ %2 = dim %arg1, %c2 : memref<?x?x?x?xf32>
+ %3 = dim %arg2, %c1 : memref<?x?x?x?xf32>
+ %4 = dim %arg2, %c2 : memref<?x?x?x?xf32>
+ scf.parallel (%arg3, %arg4, %arg5) = (%c0, %c0, %c0) to (%0, %3, %4) step (%c1, %c4, %c32) {
+ %5 = affine.min #map0(%arg3)[%0]
+ %6 = affine.min #map1(%arg4)[%1, %1]
+ %7 = affine.min #map2(%arg5)[%2, %2]
+ %8 = dim %arg1, %c3 : memref<?x?x?x?xf32>
+ %9 = subview %arg1[%arg3, %arg4, %arg5, 0] [%5, %6, %7, %8] [1, 1, 1, 1] : memref<?x?x?x?xf32> to memref<?x?x?x?xf32, #map3>
+ %10 = dim %arg2, %c0 : memref<?x?x?x?xf32>
+ %11 = affine.min #map0(%arg3)[%10]
+ %12 = affine.min #map4(%arg4)[%3]
+ %13 = affine.min #map5(%arg5)[%4]
+ %14 = dim %arg2, %c3 : memref<?x?x?x?xf32>
+ %15 = subview %arg2[%arg3, %arg4, %arg5, 0] [%11, %12, %13, %14] [1, 1, 1, 1] : memref<?x?x?x?xf32> to memref<?x?x?x?xf32, #map3>
+ linalg.conv(%arg0, %9, %15) {__internal_linalg_transform__ = "workitem", dilations = [1, 1], strides = [1, 1]} : memref<?x?x?x?xf32>, memref<?x?x?x?xf32, #map3>, memref<?x?x?x?xf32, #map3>
+ scf.yield
+ }
+ return
+ }
+}
+
+// CHECK-LABEL: func @conv_no_padding
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
+// CHECK-DAG: %[[C4:.+]] = constant 4 : index
+// CHECK-DAG: %[[C32:.+]] = constant 32 : index
+// CHECK-DAG: %[[C2:.+]] = constant 2 : index
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[UB0:.+]] = dim %[[ARG1]], %[[C0]]
+// CHECK-DAG: %[[UB1:.+]] = dim %[[ARG1]], %[[C1]]
+// CHECK-DAG: %[[UB2:.+]] = dim %[[ARG1]], %[[C2]]
+// CHECK-DAG: %[[UB3:.+]] = dim %[[ARG2]], %[[C1]]
+// CHECK-DAG: %[[UB4:.+]] = dim %[[ARG2]], %[[C2]]
+// CHECK-DAG: %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
+// CHECK-DAG: %[[NBLOCKSX:.+]] = "gpu.grid_dim"() {dimension = "x"}
+// CHECK-DAG: %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
+// CHECK-DAG: %[[NBLOCKSY:.+]] = "gpu.grid_dim"() {dimension = "y"}
+// CHECK-DAG: %[[BIDZ:.+]] = "gpu.block_id"() {dimension = "z"}
+// CHECK-DAG: %[[NBLOCKSZ:.+]] = "gpu.grid_dim"() {dimension = "z"}
+// CHECK-DAG: %[[BOFFSETY:.+]] = muli %[[BIDY]], %[[C4]]
+// CHECK-DAG: %[[BSTEPY:.+]] = muli %[[NBLOCKSY]], %[[C4]]
+// CHECK-DAG: %[[BOFFSETX:.+]] = muli %[[BIDX]], %[[C32]]
+// CHECK-DAG: %[[BSTEPX:.+]] = muli %[[NBLOCKSX]], %[[C32]]
+// CHECK: scf.for %[[IV3:.+]] = %[[BIDZ]] to %[[UB0]] step %[[NBLOCKSZ]]
+// CHECK: scf.for %[[IV4:.+]] = %[[BOFFSETY]] to %[[UB3]] step %[[BSTEPY]]
+// CHECK: scf.for %[[IV5:.+]] = %[[BOFFSETX]] to %[[UB4]] step %[[BSTEPX]]
+// CHECK: %[[BOUNDSZ:.+]] = affine.min #{{.+}}(%[[IV3]])
+// CHECK: %[[SV1:.+]] = subview %[[ARG1]][%[[IV3]], %[[IV4]], %[[IV5]], 0]
+// CHECK: %[[BOUNDSY:.+]] = affine.min #{{.+}}(%[[IV4]])
+// CHECK: %[[BOUNDSX:.+]] = affine.min #{{.+}}(%[[IV5]])
+// CHECK: %[[SV2:.+]] = subview %[[ARG2]][%[[IV3]], %[[IV4]], %[[IV5]], 0]
+// CHECK-DAG: %[[TIDX:.+]] = "gpu.thread_id"() {dimension = "x"}
+// CHECK-DAG: %[[TIDY:.+]] = "gpu.thread_id"() {dimension = "y"}
+// CHECK-DAG: %[[TIDZ:.+]] = "gpu.thread_id"() {dimension = "z"}
+// CHECK: %[[INBOUNDSZ:.+]] = cmpi "slt", %[[TIDZ]], %[[BOUNDSZ]]
+// CHECK: %[[INBOUNDSY:.+]] = cmpi "slt", %[[TIDY]], %[[BOUNDSY]]
+// CHECK: %[[T35:.+]] = and %[[INBOUNDSZ]], %[[INBOUNDSY]]
+// CHECK: %[[INBOUNDSX:.+]] = cmpi "slt", %[[TIDX]], %[[BOUNDSX]]
+// CHECK: %[[INBOUNDS:.+]] = and %[[T35]], %[[INBOUNDSX]]
+// CHECK: scf.if %[[INBOUNDS]]
+// CHECK: scf.for
+// CHECK: scf.for
+// CHECK: scf.for
+// CHECK: scf.for
+// CHECK-NOT: linalg.conv
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir
index 70f3d17..1b5ddda 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir
@@ -1,65 +1,14 @@
// RUN: iree-opt -split-input-file -iree-codegen-linalg-tile-and-fuse %s | IreeFileCheck %s
+// Test to check that convolution with padding is not tiled.
module attributes {
spv.target_env =
#spv.target_env<#spv.vce<v1.3,
[Shader], [SPV_KHR_storage_buffer_storage_class]>,
{max_compute_workgroup_invocations = 128 : i32,
max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
- // CHECK-LABEL: func @tile_only
- // CHECK-SAME: %[[ARG0:[a-zA-Z0-9$._-]+]]: memref<4x8xi32>
- // CHECK-SAME: %[[ARG1:[a-zA-Z0-9$._-]+]]: memref<4x8xi32>
- // CHECK-SAME: %[[ARG2:[a-zA-Z0-9$._-]+]]: memref<4x8xi32>
- // CHECK-SAME: local_size = dense<[32, 4, 1]>
- // CHECK: scf.parallel
- // CHECK: %[[VIEW0:.+]] = subview %[[ARG0]]
- // CHECK: %[[VIEW1:.+]] = subview %[[ARG1]]
- // CHECK: %[[VIEW2:.+]] = subview %[[ARG2]]
- // CHECK: linalg.generic
- // CHECK-SAME: "workitem"
- // CHECK-SAME: %[[VIEW0]]
- // CHECK-SAME: %[[VIEW1]]
- // CHECK-SAME: %[[VIEW2]]
- func @tile_only(%arg0: memref<4x8xi32>, %arg1: memref<4x8xi32>,
- %arg2: memref<4x8xi32>) {
- linalg.generic
- {args_in = 2 : i64, args_out = 1 : i64,
- indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
- affine_map<(d0, d1) -> (d0, d1)>,
- affine_map<(d0, d1) -> (d0, d1)>],
- iterator_types = ["parallel", "parallel"]} %arg0, %arg1, %arg2 {
- ^bb0(%arg3: i32, %arg4: i32, %arg5: i32):
- %0 = addi %arg3, %arg4 : i32
- linalg.yield %0 : i32
- }: memref<4x8xi32>, memref<4x8xi32>, memref<4x8xi32>
- return
- }
-}
-
-// -----
-
-module attributes {
- spv.target_env =
- #spv.target_env<#spv.vce<v1.3,
- [Shader], [SPV_KHR_storage_buffer_storage_class]>,
- {max_compute_workgroup_invocations = 128 : i32,
- max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
- // CHECK-LABEL: func @conv_padding
- // CHECK-SAME: %[[ARG0:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
- // CHECK-SAME: %[[ARG1:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
- // CHECK-SAME: %[[ARG2:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
- // CHECK-SAME: local_size = dense<[32, 1, 1]>
- // CHECK: scf.parallel (%{{.+}})
- // CHECK: %[[VIEW1:.+]] = subview %[[ARG1]]
- // CHECK: %[[VIEW2:.+]] = subview %[[ARG2]]
- // CHECK: linalg.conv
- // CHECK-SAME: %[[VIEW1]]
- // CHECK-SAME: %[[VIEW2]]
- // CHECK-SAME: "workitem"
func @conv_padding(%arg0 : memref<?x?x?x?xf32>, %arg1 : memref<?x?x?x?xf32>,
- %arg2 : memref<?x?x?x?xf32>)
- attributes
- {iree.dispatch_fn_name = "conv_padding"} {
+ %arg2 : memref<?x?x?x?xf32>) {
linalg.conv(%arg0, %arg1, %arg2)
{dilations = [1, 1],
padding = dense<[[1, 1], [0, 1]]> : tensor<2x2xi64>, strides = [1, 1]} :
@@ -67,6 +16,14 @@
return
}
}
+// CHECK-LABEL: func @conv_padding
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
+// CHECK: linalg.conv
+// CHECK-SAME: %[[ARG0]]
+// CHECK-SAME: %[[ARG1]]
+// CHECK-SAME: %[[ARG2]]
// -----
@@ -76,55 +33,24 @@
[Shader], [SPV_KHR_storage_buffer_storage_class]>,
{max_compute_workgroup_invocations = 128 : i32,
max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
- // CHECK-LABEL: func @conv_no_padding
- // CHECK-SAME: %[[ARG0:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
- // CHECK-SAME: %[[ARG1:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
- // CHECK-SAME: %[[ARG2:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
- // CHECK-SAME: local_size = dense<[32, 2, 2]>
- // CHECK: scf.parallel (%{{.+}}, %{{.+}}, %{{.+}})
- // CHECK: %[[VIEW1:.+]] = subview %[[ARG1]]
- // CHECK: %[[VIEW2:.+]] = subview %[[ARG2]]
- // CHECK: linalg.conv
- // CHECK-SAME: %[[VIEW1]]
- // CHECK-SAME: %[[VIEW2]]
- // CHECK-SAME: "workitem"
func @conv_no_padding(%arg0 : memref<?x?x?x?xf32>, %arg1 : memref<?x?x?x?xf32>,
- %arg2 : memref<?x?x?x?xf32>)
- attributes
- {iree.dispatch_fn_name = "conv_no_padding"} {
+ %arg2 : memref<?x?x?x?xf32>) {
linalg.conv(%arg0, %arg1, %arg2) {dilations = [1, 1], strides = [1, 1]} :
memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>
return
}
}
-
-// -----
-
-#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-module attributes {
- spv.target_env =
- #spv.target_env<#spv.vce<v1.3,
- [Shader], [SPV_KHR_storage_buffer_storage_class]>,
- {max_compute_workgroup_invocations = 128 : i32,
- max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
- // CHECK-LABEL: func @parallel_4D
- // CHECK: scf.parallel (%{{.+}}, %{{.+}}, %{{.+}})
- func @parallel_4D(%arg0: memref<?x?x?x?xf32>,
- %arg1 : memref<?x?x?x?xf32>,
- %arg2 : memref<?x?x?x?xf32>)
- attributes {iree.dispatch_fn_name = "parallel_4D"} {
- linalg.generic
- {args_in = 2 : i64, args_out = 1 : i64,
- indexing_maps = [#map0, #map0, #map0],
- iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
- %arg0, %arg1, %arg2 {
- ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
- %0 = addf %arg3, %arg4 : f32
- linalg.yield %0 : f32
- } : memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>
- return
- }
-}
+// CHECK-LABEL: func @conv_no_padding
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
+// CHECK-SAME: local_size = dense<[32, 4, 1]>
+// CHECK: scf.parallel (%{{.+}}, %{{.+}}, %{{.+}})
+// CHECK: %[[VIEW1:.+]] = subview %[[ARG1]]
+// CHECK: %[[VIEW2:.+]] = subview %[[ARG2]]
+// CHECK: linalg.conv
+// CHECK-SAME: %[[ARG0]], %[[VIEW1]], %[[VIEW2]]
+// CHECK-SAME: "workitem"
// -----
@@ -134,54 +60,52 @@
[Shader], [SPV_KHR_storage_buffer_storage_class]>,
{max_compute_workgroup_invocations = 128 : i32,
max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
- func @no_tile(%arg0: memref<?x?xf32>,
+ func @matmul(%arg0: memref<?x?xf32>,
%arg1: memref<?x?xf32>,
%ret0: memref<?x?xf32>) {
- linalg.matmul %arg0, %arg1, %ret0 {__internal_linalg_transform__ = "no-tile"} :
+ linalg.matmul %arg0, %arg1, %ret0 :
(memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
return
}
}
-// CHECK-LABEL: func @no_tile
+
+// CHECK-LABEL: func @matmul
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9$._-]+]]: memref<?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9$._-]+]]: memref<?x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9$._-]+]]: memref<?x?xf32>
// CHECK-SAME: local_size = dense<[8, 8, 1]>
-// CHECK-NOT: scf
-// CHECK: linalg.matmul
-// CHECK-NOT: scf
-// CHECK: return
+// CHECK: scf.parallel (%{{.+}}, %{{.+}}, %{{.+}})
+// CHECK: %[[VIEW0:.+]] = subview %[[ARG0]]
+// CHECK: %[[VIEW1:.+]] = subview %[[ARG1]]
+// CHECK: %[[VIEW2:.+]] = subview %[[ARG2]]
+// CHECK: linalg.matmul
+// CHECK-SAME: "workitem"
+// CHECK-SAME: %[[VIEW0]], %[[VIEW1]], %[[VIEW2]]
// -----
-#map0 = affine_map<() -> ()>
-#accesses = [#map0, #map0]
-#trait = {
- args_in = 2 : i64,
- args_out = 1 : i64,
- indexing_maps = #accesses,
- iterator_types = []
-}
-
module attributes {
spv.target_env =
#spv.target_env<#spv.vce<v1.3,
[Shader], [SPV_KHR_storage_buffer_storage_class]>,
{max_compute_workgroup_invocations = 128 : i32,
max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
- func @scalar_add(%arg0 : memref<f32>, %arg1 : memref<f32>,
- %arg2 : memref<f32>)
- {
- linalg.generic #trait %arg0, %arg1, %arg2 {
- ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
- %0 = addf %arg3, %arg4 : f32
- linalg.yield %0 : f32
- } : memref<f32>, memref<f32>, memref<f32>
- return
+ func @pooling_sum_no_padding(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>,
+ %arg2 : memref<?x?xf32>) {
+ linalg.pooling_max(%arg0, %arg1, %arg2) {dilations = [1, 1], strides = [1, 1]} :
+ memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
+ return
}
}
-// CHECK-LABEL: func @scalar_add
-// CHECK-NOT: scf.parallel
-// CHECK-NOT: scf.for
-// CHECK: linalg.generic
-// CHECK-SAME: "no-tile"
-// CHECK-NOT: scf.parallel
-// CHECK-NOT: scf.for
-// CHECK: return
+
+// CHECK-LABEL: func @pooling_sum_no_padding
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9$._-]+]]: memref<?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9$._-]+]]: memref<?x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9$._-]+]]: memref<?x?xf32>
+// CHECK-SAME: local_size = dense<[32, 4, 1]>
+// CHECK: scf.parallel (%{{.+}}, %{{.+}})
+// CHECK: %[[VIEW0:.+]] = subview %[[ARG0]]
+// CHECK: %[[VIEW2:.+]] = subview %[[ARG2]]
+// CHECK: linalg.pooling_max
+// CHECK-SAME: %[[VIEW0]], %[[ARG1]], %[[VIEW2]]
+// CHECK-SAME: "workitem"
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir
index 81db628..637fd7b 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir
@@ -46,6 +46,65 @@
// -----
+// CHECK: module attributes {vkspv.entry_point_schedule = ["kernel_dispatch_0", "kernel_dispatch_1", "kernel_dispatch_2"]}
+module {
+ // CHECK: func @kernel_dispatch_2()
+ // CHECK: %[[DIM:.+]] = hal.interface.load.constant
+ // CHECK: %[[SHAPE1:.+]] = shapex.make_ranked_shape %[[DIM]]
+ // CHECK: %[[SHAPE2:.+]] = shapex.make_ranked_shape %[[DIM]]
+ // CHECK: %[[IN1:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x2x2x512xf32>
+ // CHECK: %[[TS1:.+]] = shapex.tie_shape %[[IN1]], %[[SHAPE1]]
+ // CHECK: %[[IN2:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<3x3x512x1xf32>
+ // CHECK: %[[OUT:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<?x1x1x512xf32>
+ // CHECK: %[[TS2:.+]] = shapex.tie_shape %[[OUT]], %[[SHAPE2]]
+ // CHECK: linalg.conv(%[[IN2]], %[[TS1]], %[[TS2]])
+ // CHECK: return
+
+ // CHECK: func @kernel_dispatch_1() {
+ // CHECK: %[[C0:.+]] = constant 0 : index
+ // CHECK: %[[C1:.+]] = constant 1 : index
+ // CHECK: scf.parallel (%{{.*}}) = (%[[C0]]) to (%[[C1]]) step (%[[C1]])
+ // CHECK: scf.yield
+ // CHECK: return
+
+ // CHECK: func @kernel_dispatch_0()
+ // CHECK: %[[ZERO:.+]] = constant
+ // CHECK: %[[DIM:.+]] = hal.interface.load.constant
+ // CHECK: %[[SHAPE:.+]] = shapex.make_ranked_shape %[[DIM]]
+ // CHECK: %[[OUT:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<?x1x1x512xf32>
+ // CHECK: %[[TS:.+]] = shapex.tie_shape %[[OUT]], %[[SHAPE]]
+ // CHECK: linalg.fill(%[[TS]], %[[ZERO]])
+ // CHECK: return
+
+ func @kernel() {
+ %cst = constant 0.000000e+00 : f32
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %dim = hal.interface.load.constant offset = 0 : index
+ %shape1 = shapex.make_ranked_shape %dim : (index) -> !shapex.ranked_shape<[?,2,2,512]>
+ %shape2 = shapex.make_ranked_shape %dim : (index) -> !shapex.ranked_shape<[?,1,1,512]>
+ %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x2x2x512xf32>
+ %ts1 = shapex.tie_shape %0, %shape1 : memref<?x2x2x512xf32>, !shapex.ranked_shape<[?,2,2,512]>
+ %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<3x3x512x1xf32>
+ %2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<?x1x1x512xf32>
+ %ts2 = shapex.tie_shape %2, %shape2 : memref<?x1x1x512xf32>, !shapex.ranked_shape<[?,1,1,512]>
+ linalg.fill(%ts2, %cst) : memref<?x1x1x512xf32>, f32
+ scf.parallel (%iv) = (%c0) to (%c1) step (%c1) {
+ scf.yield
+ }
+ linalg.conv(%1, %ts1, %ts2) {dilations = [1, 1], padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, strides = [2, 2]} : memref<3x3x512x1xf32>, memref<?x2x2x512xf32>, memref<?x1x1x512xf32>
+ return
+ }
+ hal.interface @legacy_io attributes {push_constants = 1 : i32, sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+ }
+}
+
+
+// -----
+
// Nothing to do if there is just one Linalg op.
// CHECK-NOT: vkspv.entry_point_schedule
@@ -71,7 +130,7 @@
// Do not split when Linalg and non-Linalg ops are interleaving each other.
module {
- // expected-error @+1 {{cannot separate Linalg ops into multiple kernels}}
+ // expected-error @+1 {{cannot separate Linalg/Parallel ops into multiple kernels}}
func @kernel() {
%cst = constant 0.000000e+00 : f32
%0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<1x2x2x512xf32>
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir
index 060dc5a..76cfcb8 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir
@@ -5,7 +5,7 @@
%arg0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<96x96xf32>
%arg1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<96x96xf32>
%arg2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<96x96xf32>
- linalg.matmul %arg0, %arg1, %arg2 {__internal_linalg_transform__ = "workgroup"} :
+ linalg.matmul %arg0, %arg1, %arg2 :
(memref<96x96xf32>, memref<96x96xf32>, memref<96x96xf32>)
return
}
diff --git a/iree/compiler/Dialect/Flow/Transforms/HLOToHLOPreprocessing.cpp b/iree/compiler/Dialect/Flow/Transforms/HLOToHLOPreprocessing.cpp
index ac18c5f..c2367a7 100644
--- a/iree/compiler/Dialect/Flow/Transforms/HLOToHLOPreprocessing.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/HLOToHLOPreprocessing.cpp
@@ -32,13 +32,23 @@
static llvm::cl::opt<bool> extractPadFromConv(
"iree-extract-pad-from-conv",
llvm::cl::desc("Extract padding attributes from conv op"),
- llvm::cl::init(false));
+ llvm::cl::init(true));
static bool isAllZero(DenseIntElementsAttr attr) {
if (!attr.isSplat()) return false;
return attr.getSplatValue<IntegerAttr>().getInt() == 0;
}
+/// Returns true if the linalg op has padding attribute, and that it has
+/// non-zero entries.
+template <typename OpTy>
+static bool hasPadding(OpTy op) {
+ Optional<DenseIntElementsAttr> padding = op.padding();
+ if (!padding) return false;
+ return llvm::any_of(padding.getValue(),
+ [](APInt v) -> bool { return !v.isNullValue(); });
+}
+
class ExtractConvOpPaddingAttributes
: public OpRewritePattern<xla_hlo::ConvOp> {
public:
@@ -46,7 +56,7 @@
LogicalResult matchAndRewrite(xla_hlo::ConvOp op,
PatternRewriter &rewriter) const override {
- if (!op.padding()) return failure();
+ if (!hasPadding(op)) return failure();
auto inputType = op.lhs().getType().cast<ShapedType>();
int rank = inputType.getRank();
SmallVector<int64_t, 4> paddingLow, paddingHigh, interiorPadding, shape;
diff --git a/iree/hal/vmla/op_kernels_ruy.h b/iree/hal/vmla/op_kernels_ruy.h
index 981ee72..003fb80 100644
--- a/iree/hal/vmla/op_kernels_ruy.h
+++ b/iree/hal/vmla/op_kernels_ruy.h
@@ -15,10 +15,13 @@
#ifndef IREE_HAL_VMLA_OP_KERNELS_RUY_H_
#define IREE_HAL_VMLA_OP_KERNELS_RUY_H_
+#include <type_traits>
+
#include "absl/base/thread_annotations.h"
#include "absl/memory/memory.h"
#include "iree/base/status.h"
#include "ruy/context.h"
+#include "ruy/mul_params.h"
#include "ruy/ruy.h"
namespace iree {
@@ -37,6 +40,56 @@
return absl::make_unique<RuntimeState>();
}
+// Floating-point case.
+template <typename ACC, typename T>
+struct MakeRuyMulParamsImpl {
+ static_assert(std::is_floating_point<ACC>::value, "");
+ static_assert(std::is_floating_point<T>::value, "");
+ static void Run(const MatMul::Buffers<T, ACC>& buffers,
+ ruy::MulParams<ACC, T>* mul_params) {
+ mul_params->set_bias(buffers.bias_buffer.data());
+ }
+};
+
+// Integer quantized case with downquantization to a destination T narrower than
+// int32.
+template <typename T>
+struct MakeRuyMulParamsImpl<std::int32_t, T> {
+ static_assert(std::is_integral<T>::value, "");
+ static_assert(sizeof(T) < sizeof(std::int32_t), "");
+ static void Run(const MatMul::Buffers<T, std::int32_t>& buffers,
+ ruy::MulParams<std::int32_t, T>* mul_params) {
+ mul_params->set_bias(buffers.bias_buffer.data());
+ if (buffers.multiplier_mantissa_buffer.size() == 1) {
+ mul_params->set_multiplier_fixedpoint(
+ buffers.multiplier_mantissa_buffer[0]);
+ mul_params->set_multiplier_exponent(
+ buffers.multiplier_exponent_buffer[0]);
+ } else {
+ mul_params->set_multiplier_fixedpoint_perchannel(
+ buffers.multiplier_mantissa_buffer.data());
+ mul_params->set_multiplier_exponent_perchannel(
+ buffers.multiplier_exponent_buffer.data());
+ }
+ }
+};
+
+// Raw integer case with int32 destination. This case does not support any
+// output operation besides bias-addition.
+template <>
+struct MakeRuyMulParamsImpl<std::int32_t, std::int32_t> {
+ static void Run(const MatMul::Buffers<std::int32_t, std::int32_t>& buffers,
+ ruy::MulParams<std::int32_t, std::int32_t>* mul_params) {
+ mul_params->set_bias(buffers.bias_buffer.data());
+ }
+};
+
+template <typename ACC, typename T>
+void MakeRuyMulParams(const MatMul::Buffers<T, ACC>& buffers,
+ ruy::MulParams<ACC, T>* mul_params) {
+ MakeRuyMulParamsImpl<ACC, T>::Run(buffers, mul_params);
+}
+
template <typename T, typename ACC>
Status MatMul::Execute(RuntimeState* runtime_state,
const Buffers<T, ACC>& buffers) {
@@ -56,17 +109,7 @@
ruy::Order::kColMajor, dst.mutable_layout());
ruy::MulParams<ACC, T> mul_params;
- mul_params.set_bias(buffers.bias_buffer.data());
-
- if (buffers.multiplier_mantissa_buffer.size() == 1) {
- mul_params.set_multiplier_fixedpoint(buffers.multiplier_mantissa_buffer[0]);
- mul_params.set_multiplier_exponent(buffers.multiplier_exponent_buffer[0]);
- } else {
- mul_params.set_multiplier_fixedpoint_perchannel(
- buffers.multiplier_mantissa_buffer.data());
- mul_params.set_multiplier_exponent_perchannel(
- buffers.multiplier_exponent_buffer.data());
- }
+ MakeRuyMulParams(buffers, &mul_params);
ruy::Mul(lhs, rhs, mul_params, &runtime_state->context, &dst);
diff --git a/iree/test/e2e/vulkan_specific/BUILD b/iree/test/e2e/vulkan_specific/BUILD
index d937456..a3abc81 100644
--- a/iree/test/e2e/vulkan_specific/BUILD
+++ b/iree/test/e2e/vulkan_specific/BUILD
@@ -29,14 +29,37 @@
target_backend = "vulkan-spirv",
)
-# TODO(hanchung): Merge two tests into one single file.
+# TODO(#2345): Merge two tests into one single file.
iree_check_single_backend_test_suite(
- name = "check_vulkan-spirv-pad-conv_vulkan",
+ name = "check_vulkan-spirv-split-pad-conv_vulkan",
srcs = [
"convolution1.mlir",
"convolution2.mlir",
],
- compiler_flags = ["-iree-extract-pad-from-conv"],
+ driver = "vulkan",
+ target_backend = "vulkan-spirv",
+)
+
+# TODO(#2345): Merge two tests into one single file.
+iree_check_single_backend_test_suite(
+ name = "check_vulkan-spirv-nosplit-pad-conv_vulkan",
+ srcs = [
+ "convolution1.mlir",
+ "convolution2.mlir",
+ ],
+ compiler_flags = ["-iree-extract-pad-from-conv=false"],
+ driver = "vulkan",
+ target_backend = "vulkan-spirv",
+)
+
+# TODO(#2345): Merge two tests into one single file.
+iree_check_single_backend_test_suite(
+ name = "check_vulkan-spirv-conv-nocontrol_vulkan",
+ srcs = [
+ "convolution1.mlir",
+ "convolution2.mlir",
+ ],
+ compiler_flags = ["-iree-codegen-use-legacy-conv-lowering=false"],
driver = "vulkan",
target_backend = "vulkan-spirv",
)
diff --git a/iree/test/e2e/vulkan_specific/CMakeLists.txt b/iree/test/e2e/vulkan_specific/CMakeLists.txt
index 27e65fc..cca6c58 100644
--- a/iree/test/e2e/vulkan_specific/CMakeLists.txt
+++ b/iree/test/e2e/vulkan_specific/CMakeLists.txt
@@ -28,7 +28,19 @@
iree_check_single_backend_test_suite(
NAME
- check_vulkan-spirv-pad-conv_vulkan
+ check_vulkan-spirv-split-pad-conv_vulkan
+ SRCS
+ "convolution1.mlir"
+ "convolution2.mlir"
+ TARGET_BACKEND
+ vulkan-spirv
+ DRIVER
+ vulkan
+)
+
+iree_check_single_backend_test_suite(
+ NAME
+ check_vulkan-spirv-nosplit-pad-conv_vulkan
SRCS
"convolution1.mlir"
"convolution2.mlir"
@@ -37,5 +49,19 @@
DRIVER
vulkan
COMPILER_FLAGS
- "-iree-extract-pad-from-conv"
+ "-iree-extract-pad-from-conv=false"
+)
+
+iree_check_single_backend_test_suite(
+ NAME
+ check_vulkan-spirv-conv-nocontrol_vulkan
+ SRCS
+ "convolution1.mlir"
+ "convolution2.mlir"
+ TARGET_BACKEND
+ vulkan-spirv
+ DRIVER
+ vulkan
+ COMPILER_FLAGS
+ "-iree-codegen-use-legacy-conv-lowering=false"
)
diff --git a/iree/test/e2e/xla_ops/BUILD b/iree/test/e2e/xla_ops/BUILD
index 6f95110..31967b0 100644
--- a/iree/test/e2e/xla_ops/BUILD
+++ b/iree/test/e2e/xla_ops/BUILD
@@ -71,8 +71,7 @@
"reduce_window.mlir",
"remainder.mlir",
"reshape.mlir",
- # TODO(#1699): Enable after xla_hlo.reverse can be lowered to linalg.
- # "reverse.mlir",
+ "reverse.mlir",
"rsqrt.mlir",
"select.mlir",
"sine.mlir",
diff --git a/iree/test/e2e/xla_ops/CMakeLists.txt b/iree/test/e2e/xla_ops/CMakeLists.txt
index b2eec93..7f65b06 100644
--- a/iree/test/e2e/xla_ops/CMakeLists.txt
+++ b/iree/test/e2e/xla_ops/CMakeLists.txt
@@ -56,6 +56,7 @@
"reduce_window.mlir"
"remainder.mlir"
"reshape.mlir"
+ "reverse.mlir"
"rsqrt.mlir"
"select.mlir"
"sine.mlir"
diff --git a/iree/test/e2e/xla_ops/convolution.mlir b/iree/test/e2e/xla_ops/convolution.mlir
index 183fc10..f1880fd 100644
--- a/iree/test/e2e/xla_ops/convolution.mlir
+++ b/iree/test/e2e/xla_ops/convolution.mlir
@@ -65,47 +65,51 @@
return
}
-func @conv2d_2451x2311_same() attributes { iree.module.export } {
- %inputs = iree.unfoldable_constant dense<[
- [[[ 1.0], [ 2.0], [ 3.0], [ 4.0], [ 5.0]],
- [[ 6.0], [ 7.0], [ 8.0], [ 9.0], [10.0]],
- [[11.0], [12.0], [13.0], [14.0], [15.0]],
- [[16.0], [17.0], [18.0], [19.0], [20.0]]],
- [[[21.0], [22.0], [23.0], [24.0], [25.0]],
- [[26.0], [27.0], [28.0], [29.0], [30.0]],
- [[31.0], [32.0], [33.0], [34.0], [35.0]],
- [[36.0], [37.0], [38.0], [39.0], [40.0]]]]> : tensor <2x4x5x1xf32>
- %weights = iree.unfoldable_constant dense<[
- [[[1.0]], [[2.0]], [[3.0]]],
- [[[4.0]], [[5.0]], [[6.0]]]]> : tensor <2x3x1x1xf32>
- %res = "xla_hlo.convolution"(%inputs, %weights) {
- batch_group_count = 1 : i64,
- dimension_numbers = {
- input_batch_dimension = 0 : i64,
- input_feature_dimension = 3 : i64,
- input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
- kernel_input_feature_dimension = 2 : i64,
- kernel_output_feature_dimension = 3 : i64,
- kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
- output_batch_dimension = 0 : i64,
- output_feature_dimension = 3 : i64,
- output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>},
- feature_group_count = 1 : i64,
- padding = dense<[[0, 1], [1, 1]]> : tensor<2x2xi64>,
- rhs_dilation = dense<1> : tensor<2xi64>,
- window_strides = dense<1> : tensor<2xi64>} :
- (tensor<2x4x5x1xf32>, tensor<2x3x1x1xf32>) -> tensor<2x4x5x1xf32>
- check.expect_almost_eq_const(%res, dense<[
- [[[ 80.0], [121.0], [142.0], [163.0], [100.0]],
- [[160.0], [226.0], [247.0], [268.0], [160.0]],
- [[240.0], [331.0], [352.0], [373.0], [220.0]],
- [[ 83.0], [104.0], [110.0], [116.0], [ 59.0]]],
- [[[400.0], [541.0], [562.0], [583.0], [340.0]],
- [[480.0], [646.0], [667.0], [688.0], [400.0]],
- [[560.0], [751.0], [772.0], [793.0], [460.0]],
- [[183.0], [224.0], [230.0], [236.0], [119.0]]]]> : tensor<2x4x5x1xf32>) : tensor<2x4x5x1xf32>
- return
-}
+// TODO(#2345): This test seems to fail when executed with another
+// test from this file, but passes as a standalone test. Needs further
+// investigation
+
+// func @conv2d_2451x2311_same() attributes { iree.module.export } {
+// %inputs = iree.unfoldable_constant dense<[
+// [[[ 1.0], [ 2.0], [ 3.0], [ 4.0], [ 5.0]],
+// [[ 6.0], [ 7.0], [ 8.0], [ 9.0], [10.0]],
+// [[11.0], [12.0], [13.0], [14.0], [15.0]],
+// [[16.0], [17.0], [18.0], [19.0], [20.0]]],
+// [[[21.0], [22.0], [23.0], [24.0], [25.0]],
+// [[26.0], [27.0], [28.0], [29.0], [30.0]],
+// [[31.0], [32.0], [33.0], [34.0], [35.0]],
+// [[36.0], [37.0], [38.0], [39.0], [40.0]]]]> : tensor <2x4x5x1xf32>
+// %weights = iree.unfoldable_constant dense<[
+// [[[1.0]], [[2.0]], [[3.0]]],
+// [[[4.0]], [[5.0]], [[6.0]]]]> : tensor <2x3x1x1xf32>
+// %res = "xla_hlo.convolution"(%inputs, %weights) {
+// batch_group_count = 1 : i64,
+// dimension_numbers = {
+// input_batch_dimension = 0 : i64,
+// input_feature_dimension = 3 : i64,
+// input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
+// kernel_input_feature_dimension = 2 : i64,
+// kernel_output_feature_dimension = 3 : i64,
+// kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
+// output_batch_dimension = 0 : i64,
+// output_feature_dimension = 3 : i64,
+// output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>},
+// feature_group_count = 1 : i64,
+// padding = dense<[[0, 1], [1, 1]]> : tensor<2x2xi64>,
+// rhs_dilation = dense<1> : tensor<2xi64>,
+// window_strides = dense<1> : tensor<2xi64>} :
+// (tensor<2x4x5x1xf32>, tensor<2x3x1x1xf32>) -> tensor<2x4x5x1xf32>
+// check.expect_almost_eq_const(%res, dense<[
+// [[[ 80.0], [121.0], [142.0], [163.0], [100.0]],
+// [[160.0], [226.0], [247.0], [268.0], [160.0]],
+// [[240.0], [331.0], [352.0], [373.0], [220.0]],
+// [[ 83.0], [104.0], [110.0], [116.0], [ 59.0]]],
+// [[[400.0], [541.0], [562.0], [583.0], [340.0]],
+// [[480.0], [646.0], [667.0], [688.0], [400.0]],
+// [[560.0], [751.0], [772.0], [793.0], [460.0]],
+// [[183.0], [224.0], [230.0], [236.0], [119.0]]]]> : tensor<2x4x5x1xf32>) : tensor<2x4x5x1xf32>
+// return
+// }
func @conv2d_no_padding() attributes { iree.module.export } {
%inputs = iree.unfoldable_constant dense<[
diff --git a/third_party/llvm-project b/third_party/llvm-project
index 1becd29..e34523c 160000
--- a/third_party/llvm-project
+++ b/third_party/llvm-project
@@ -1 +1 @@
-Subproject commit 1becd298b82ed2f1a8ba5e61c5ad2ce7fe32d812
+Subproject commit e34523c87c3f1cfabcf741568dede026bbb12d3a