Merge google -> main (#7536)
* 912958a14 Update bazel_to_cmake_targets.py llvm-project/lld dep
* 188d61bbb Merge pull request #7530 from not-jenni:main-to-google
* ac20abe4e Integrate LLVM at llvm/llvm-project@7277d2e1c86b
* d10bb8e0c Integrate LLVM at llvm/llvm-project@d36dd1f842c1
* 86d7af1b4 Synchronize submodules with LLVM at llvm/llvm-project@3d32218d1af2
* 97a428b49 Run bazel_to_cmake on iree/compiler/Codegen/Common/CMakeLists.txt
* fe4a6efee Integrate LLVM at llvm/llvm-project@3d32218d1af2
* f9316dae6 [compiler] reenable test broken by MLIR upstream changes
* 5c20099f8 Disable broken comprehensive bufferize test at head test
* 4caceffd8 Integrate LLVM at llvm/llvm-project@bcad20bc6591
diff --git a/colab/test_notebooks.py b/colab/test_notebooks.py
index 683d15d..2de3de6 100644
--- a/colab/test_notebooks.py
+++ b/colab/test_notebooks.py
@@ -14,11 +14,7 @@
NOTEBOOKS_TO_SKIP = []
-NOTEBOOKS_EXPECTED_TO_FAIL = [
- # Text classification notebook fails to extract the vocab file on Docker
- # (needs visibility into the tempdir?)
- "tflite_text_classification.ipynb",
-]
+NOTEBOOKS_EXPECTED_TO_FAIL = []
class ColabNotebookTests(absltest.TestCase):
diff --git a/docs/developers/iree_community.md b/docs/developers/iree_community.md
deleted file mode 100644
index c55a23e..0000000
--- a/docs/developers/iree_community.md
+++ /dev/null
@@ -1,8 +0,0 @@
-# IREE Community
-
-## Community Projects
-
-* The [IREE C++ Template ](https://github.com/iml130/iree-template-cpp)
- demonstrates how to integrate IREE into a third-party project with CMake.
- The project demonstrates the usage of runtime support and how to use a
- custom dialect alongside with the runtime.
diff --git a/docs/website/docs/building-from-source/riscv.md b/docs/website/docs/building-from-source/riscv.md
index 7f9e23c..a27a5a3 100644
--- a/docs/website/docs/building-from-source/riscv.md
+++ b/docs/website/docs/building-from-source/riscv.md
@@ -130,12 +130,12 @@
* RISC-V toolchain is built from
[https://github.com/llvm/llvm-project](https://github.com/llvm/llvm-project) (main branch).
- * Currently, the LLVM compiler is built on GNU toolchain, including libgcc,
- GNU linker, and C libraries. You need to build GNU toolchain first.
- * Clone GNU toolchain from:
- [https://github.com/riscv/riscv-gnu-toolchain](https://github.com/riscv/riscv-gnu-toolchain)
- (master branch). Switch the "riscv-binutils" submodule to `rvv-1.0.x-zfh`
- branch manually.
+ * Currently, the LLVM compiler is built on GNU toolchain, including libgcc,
+ GNU linker, and C libraries. You need to build GNU toolchain first.
+ * Clone GNU toolchain from:
+ [https://github.com/riscv/riscv-gnu-toolchain](https://github.com/riscv/riscv-gnu-toolchain)
+ (master branch). Switch the "riscv-binutils" submodule to `rvv-1.0.x-zfh`
+ branch manually.
* RISC-V QEMU is built from
[https://github.com/sifive/qemu/tree/v5.2.0-rvv-rvb-zfh](https://github.com/sifive/qemu/tree/v5.2.0-rvv-rvb-zfh).
diff --git a/docs/website/docs/community/projects.md b/docs/website/docs/community/index.md
similarity index 100%
rename from docs/website/docs/community/projects.md
rename to docs/website/docs/community/index.md
diff --git a/docs/website/mkdocs.yml b/docs/website/mkdocs.yml
index af678ef..8b33b48 100644
--- a/docs/website/mkdocs.yml
+++ b/docs/website/mkdocs.yml
@@ -118,7 +118,7 @@
- 'Extensions':
- 'extensions/index.md'
- 'Community':
- - Projects: 'community/projects.md'
+ - 'community/index.md'
- 'Blog':
- 'blog/index.md'
- CUDA backend: 'blog/2021-10-15-cuda-backend.md'
diff --git a/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp b/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
index c009f37..ef12029 100644
--- a/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
+++ b/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
@@ -769,15 +769,6 @@
results.insert<FoldSplatReshapeIntoSplat>(context);
}
-OpFoldResult TensorSplatOp::fold(ArrayRef<Attribute> operands) {
- if (operands.size() == 1 && operands.front()) {
- // Splat value is constant and we can fold the operation.
- return SplatElementsAttr::get(result().getType().cast<ShapedType>(),
- operands[0]);
- }
- return {};
-}
-
OpFoldResult TensorCloneOp::fold(ArrayRef<Attribute> operands) {
if (operands[0]) {
// Constants always fold.
diff --git a/iree/compiler/Dialect/Flow/IR/FlowOps.td b/iree/compiler/Dialect/Flow/IR/FlowOps.td
index be3caed..102d861 100644
--- a/iree/compiler/Dialect/Flow/IR/FlowOps.td
+++ b/iree/compiler/Dialect/Flow/IR/FlowOps.td
@@ -839,7 +839,6 @@
}];
let hasCanonicalizer = 1;
- let hasFolder = 1;
}
def FLOW_TensorCloneOp : FLOW_PureOp<"tensor.clone", [
diff --git a/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir b/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir
index 792b88d..bd843fb 100644
--- a/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir
+++ b/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir
@@ -154,28 +154,6 @@
// -----
-// CHECK-LABEL: @splatConst
-func @splatConst() -> tensor<4xi32> {
- %0 = arith.constant 4 : i32
- // CHECK-NEXT: %[[C:.+]] = arith.constant dense<4> : tensor<4xi32>
- %1 = flow.tensor.splat %0 : tensor<4xi32>
- // CHECK-NEXT: return %[[C]]
- return %1 : tensor<4xi32>
-}
-
-// -----
-
-// CHECK-LABEL: @splatConstScalar
-func @splatConstScalar() -> tensor<i32> {
- %0 = arith.constant 4 : i32
- // CHECK-NEXT: %[[C:.+]] = arith.constant dense<4> : tensor<i32>
- %1 = flow.tensor.splat %0 : tensor<i32>
- // CHECK-NEXT: return %[[C]]
- return %1 : tensor<i32>
-}
-
-// -----
-
// CHECK-LABEL: @splatDynamicShape
// CHECK-SAME: (%[[DIM0:.+]]: index, %[[DIM1:.+]]: index)
func @splatDynamicShape(%dim0: index, %dim1: index) -> tensor<?x?xi32> {
diff --git a/iree/compiler/Dialect/Flow/Transforms/ConvertLinalgTensorOps.cpp b/iree/compiler/Dialect/Flow/Transforms/ConvertLinalgTensorOps.cpp
index 869143a..992878b 100644
--- a/iree/compiler/Dialect/Flow/Transforms/ConvertLinalgTensorOps.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/ConvertLinalgTensorOps.cpp
@@ -81,7 +81,6 @@
// Don't convert linalg.fill ops that were fused together with other ops.
return failure();
}
-
SmallVector<Value, 4> dynamicDims =
getDynamicDimValues(rewriter, fillOp.getLoc(), fillOp.output());
rewriter.replaceOpWithNewOp<TensorSplatOp>(
@@ -90,6 +89,26 @@
}
};
+struct ConvertSplatConstantOp : public OpRewritePattern<mlir::ConstantOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(mlir::ConstantOp op,
+ PatternRewriter &rewriter) const override {
+ if (op->getParentOfType<IREE::Flow::DispatchWorkgroupsOp>()) {
+ return rewriter.notifyMatchFailure(op, "ignoring dispatch ops");
+ }
+ auto splatAttr = op.getValue().dyn_cast<SplatElementsAttr>();
+ if (!splatAttr) {
+ return rewriter.notifyMatchFailure(op, "only looking for splats");
+ }
+ auto tensorType = op.getType().cast<TensorType>();
+ auto elementValue = rewriter.createOrFold<mlir::ConstantOp>(
+ op.getLoc(), tensorType.getElementType(), splatAttr.getSplatValue());
+ rewriter.replaceOpWithNewOp<IREE::Flow::TensorSplatOp>(
+ op, tensorType, elementValue, ValueRange{});
+ return success();
+ }
+};
+
/// Converts linalg operations that can map to flow.tensor.* operations.
struct ConvertLinalgTensorOpsPass
: public ConvertLinalgTensorOpsBase<ConvertLinalgTensorOpsPass> {
@@ -116,7 +135,8 @@
LinalgTensorReshapeToFlowTensorReshape<linalg::TensorExpandShapeOp>>(
context);
} else {
- patterns.insert<LinalgFillToFlowTensorSplat>(context);
+ patterns.insert<LinalgFillToFlowTensorSplat, ConvertSplatConstantOp>(
+ context);
}
IREE::Flow::TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
diff --git a/iree/compiler/Dialect/Flow/Transforms/OutlineLargeConstants.cpp b/iree/compiler/Dialect/Flow/Transforms/OutlineLargeConstants.cpp
index 27c3e80..24de678 100644
--- a/iree/compiler/Dialect/Flow/Transforms/OutlineLargeConstants.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/OutlineLargeConstants.cpp
@@ -29,11 +29,15 @@
// more efficient and fewer bindings.
static bool isConstantLarge(arith::ConstantOp constantOp,
size_t minLargeConstantSize) {
+ if (constantOp.value().isa<SplatElementsAttr>()) {
+ // Never outline splats; we want those transient within streams.
+ return false;
+ }
auto type = constantOp.getType();
if (auto shapedType = type.dyn_cast<RankedTensorType>()) {
size_t unpackedByteLength =
(shapedType.getNumElements() * shapedType.getElementTypeBitWidth()) / 8;
- if (unpackedByteLength >= minLargeConstantSize) {
+ if (unpackedByteLength > minLargeConstantSize) {
return true;
}
}
@@ -63,8 +67,6 @@
: public OutlineLargeConstantsBase<OutlineLargeConstantsPass> {
public:
OutlineLargeConstantsPass() = default;
- OutlineLargeConstantsPass(size_t minLargeConstantSize)
- : minLargeConstantSize(minLargeConstantSize){};
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<IREE::Flow::FlowDialect, IREE::Util::UtilDialect>();
@@ -84,7 +86,7 @@
std::vector<std::pair<arith::ConstantOp, IREE::Util::GlobalOp>>
replacements;
for (auto &largeConstantOp :
- findLargeConstantsInModule(moduleOp, minLargeConstantSize)) {
+ findLargeConstantsInModule(moduleOp, minStorageSize.getValue())) {
std::string name;
do {
name = baseName + std::to_string(uniqueId++);
@@ -114,14 +116,11 @@
constantOp.erase();
}
}
-
- private:
- size_t minLargeConstantSize;
};
-std::unique_ptr<OperationPass<mlir::ModuleOp>> createOutlineLargeConstantsPass(
- size_t minLargeConstantSize) {
- return std::make_unique<OutlineLargeConstantsPass>(minLargeConstantSize);
+std::unique_ptr<OperationPass<mlir::ModuleOp>>
+createOutlineLargeConstantsPass() {
+ return std::make_unique<OutlineLargeConstantsPass>();
}
} // namespace Flow
diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.h b/iree/compiler/Dialect/Flow/Transforms/Passes.h
index 123cfb0..dafb8d6 100644
--- a/iree/compiler/Dialect/Flow/Transforms/Passes.h
+++ b/iree/compiler/Dialect/Flow/Transforms/Passes.h
@@ -131,12 +131,8 @@
//===----------------------------------------------------------------------===//
// Outlines large tensor constants into util.globals at the module level.
-//
-// TODO(#5493): implement the support for inlining constants into the command
-// buffer and raise this value to one that is measured to be good.
-static constexpr size_t kMinLargeConstantSize = 1;
-std::unique_ptr<OperationPass<mlir::ModuleOp>> createOutlineLargeConstantsPass(
- size_t minLargeConstantSize = kMinLargeConstantSize);
+std::unique_ptr<OperationPass<mlir::ModuleOp>>
+createOutlineLargeConstantsPass();
// Deduplicates equivalent executables.
std::unique_ptr<OperationPass<mlir::ModuleOp>>
diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.td b/iree/compiler/Dialect/Flow/Transforms/Passes.td
index 64941b2..4fee8c0 100644
--- a/iree/compiler/Dialect/Flow/Transforms/Passes.td
+++ b/iree/compiler/Dialect/Flow/Transforms/Passes.td
@@ -102,8 +102,12 @@
def OutlineLargeConstants :
Pass<"iree-flow-outline-large-constants", "mlir::ModuleOp"> {
let summary = "Outlines large tensor constants into util.globals at the module level.";
- // TODO(#5493): add a flag for this.
- let constructor = "mlir::iree_compiler::IREE::Flow::createOutlineLargeConstantsPass(25)";
+ let constructor = "mlir::iree_compiler::IREE::Flow::createOutlineLargeConstantsPass()";
+ let options = [
+ Option<"minStorageSize", "min-storage-size",
+ "int64_t", /*default=*/"64",
+ "Outlines constants with storage sizes > than this byte size.">
+ ];
}
def PadLinalgOps :
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/outline_large_constants.mlir b/iree/compiler/Dialect/Flow/Transforms/test/outline_large_constants.mlir
index cdc299f..7394119 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/outline_large_constants.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/outline_large_constants.mlir
@@ -1,10 +1,12 @@
-// RUN: iree-opt -split-input-file -iree-flow-outline-large-constants %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -iree-flow-outline-large-constants='min-storage-size=9' %s | IreeFileCheck %s
-// CHECK: util.global private @[[LARGE_VARIABLE:.+]] {noinline} = dense<1.200000e+00> : tensor<512x128xf32>
-func @fn1() -> (tensor<2xf32>, tensor<512x128xf32>) {
+// CHECK: util.global private @[[LARGE_VARIABLE:.+]] {noinline} = dense<{{.+}}> : tensor<8xf32>
+func @fn1() -> (tensor<2xf32>, tensor<512x128xf32>, tensor<8xf32>) {
// CHECK-DAG: %[[SMALL_VALUE:.+]] = arith.constant dense<{{.+}}> : tensor<2xf32>
%cst_0 = arith.constant dense<[0.0287729427, 0.0297581609]> : tensor<2xf32>
- // CHECK-DAG: %[[LARGE_VALUE:.+]] = util.global.load @[[LARGE_VARIABLE]] : tensor<512x128xf32>
+ // CHECK-DAG: %[[SPLATG_VALUE:.+]] = arith.constant dense<{{.+}}> : tensor<512x128xf32>
%cst_1 = arith.constant dense<1.2> : tensor<512x128xf32>
- return %cst_0, %cst_1 : tensor<2xf32>, tensor<512x128xf32>
+ // CHECK-DAG: %[[LARGE_VALUE:.+]] = util.global.load @[[LARGE_VARIABLE]] : tensor<8xf32>
+ %cst_2 = arith.constant dense<[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]> : tensor<8xf32>
+ return %cst_0, %cst_1, %cst_2 : tensor<2xf32>, tensor<512x128xf32>, tensor<8xf32>
}
diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertStreamOps.cpp b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertStreamOps.cpp
index f0da141..417ecf2 100644
--- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertStreamOps.cpp
+++ b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertStreamOps.cpp
@@ -937,65 +937,16 @@
return success();
}
-// Splats a pattern value of 1, 2, or 4 bytes out to a 4 byte integer value.
-// The bit representation of |baseValue| will be repeated as many times as
-// needed in the returned value to use 4 bytes of storage. For example,
-// a 16-bit value (int or float) will have its native bit representation
-// repeated twice.
-static Value splatFillPattern(Location loc, Value baseValue,
- OpBuilder &builder) {
- // Bitcast to an integer, then use integer math for the rest of the pattern.
- auto baseBitWidth = baseValue.getType().getIntOrFloatBitWidth();
- baseValue = builder.createOrFold<arith::BitcastOp>(
- loc, builder.getIntegerType(baseBitWidth), baseValue);
-
- switch (baseBitWidth) {
- case 8: {
- // (v << 24) | (v << 16) | (v << 8) | v
- auto b0 = builder.createOrFold<arith::ExtUIOp>(
- loc, baseValue, builder.getIntegerType(32));
- auto c8 = builder.create<arith::ConstantIntOp>(loc, 8, 32);
- auto b1 = builder.createOrFold<arith::ShLIOp>(loc, b0, c8);
- auto c16 = builder.create<arith::ConstantIntOp>(loc, 16, 32);
- auto b2 = builder.createOrFold<arith::ShLIOp>(loc, b0, c16);
- auto c24 = builder.create<arith::ConstantIntOp>(loc, 24, 32);
- auto b3 = builder.createOrFold<arith::ShLIOp>(loc, b0, c24);
- return builder.createOrFold<arith::OrIOp>(
- loc, b0,
- builder.createOrFold<arith::OrIOp>(
- loc, b1, builder.createOrFold<arith::OrIOp>(loc, b2, b3)));
- }
- case 16: {
- // (v << 16) | v
- auto c16 = builder.create<arith::ConstantIntOp>(loc, 16, 32);
- auto b0 = builder.createOrFold<arith::ExtUIOp>(
- loc, baseValue, builder.getIntegerType(32));
- auto b1 = builder.createOrFold<arith::ShLIOp>(loc, b0, c16);
- return builder.createOrFold<arith::OrIOp>(loc, b0, b1);
- }
- case 32:
- return baseValue;
- default:
- return {}; // Unsupported (so far)
- }
-}
-
static LogicalResult recordTensorSplat(Value device, Value commandBuffer,
IREE::Flow::TensorSplatOp &splatOp,
StreamSchedulingState &schedulingState,
ConversionPatternRewriter &rewriter) {
auto resultBuffer = schedulingState.lookupTensorBufferRange(splatOp.result());
-
- auto pattern = splatFillPattern(splatOp.getLoc(), splatOp.value(), rewriter);
- if (!pattern) {
- return splatOp.emitError() << ">4 byte/non-byte-aligned fills are not yet "
- "implemented (require special emulation)";
- }
-
auto zeroOffset = schedulingState.lookupOrCreateIndex(0, rewriter);
+
rewriter.create<IREE::HAL::CommandBufferFillBufferOp>(
splatOp.getLoc(), commandBuffer, resultBuffer.buffer, zeroOffset,
- resultBuffer.length, pattern);
+ resultBuffer.length, splatOp.value());
// Full barriers for now as we aren't scheduling things.
recordFullExecutionBarrier(commandBuffer, splatOp.getLoc(), rewriter);
diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/stream_ops.mlir b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/stream_ops.mlir
index 46456b5..79a0a72 100644
--- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/stream_ops.mlir
+++ b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/stream_ops.mlir
@@ -609,14 +609,7 @@
// CHECK: %[[BUFFER:.+]] = hal.allocator.allocate<%allocator : !hal.allocator> type("HostVisible|DeviceVisible|DeviceLocal") usage("Transfer|Mapping|Dispatch") : !hal.buffer{%[[SIZE]]}
%0 = flow.ex.stream.fragment(%value, %dim) : (i8, index) -> tensor<?x128xi8>{%dim} =
(%arg0: i8, %arg1: index) -> tensor<?x128xi8> {
- // CHECK-DAG: %[[B0:.+]] = arith.extui %[[VALUE]] : i8 to i32
- // CHECK-DAG: %[[B1:.+]] = arith.shli %[[B0]], %c8
- // CHECK-DAG: %[[B2:.+]] = arith.shli %[[B0]], %c16
- // CHECK-DAG: %[[B3:.+]] = arith.shli %[[B0]], %c24
- // CHECK-DAG: %[[ORA:.+]] = arith.ori %[[B2]], %[[B3]]
- // CHECK-DAG: %[[ORB:.+]] = arith.ori %[[B1]], %[[ORA]]
- // CHECK-DAG: %[[PATTERN:.+]] = arith.ori %[[B0]], %[[ORB]]
- // CHECK-NEXT: hal.command_buffer.fill_buffer<%cmd : !hal.command_buffer> target(%[[BUFFER]] : !hal.buffer)[%c0, %[[SIZE]]] pattern(%[[PATTERN]] : i32)
+ // CHECK: hal.command_buffer.fill_buffer<%cmd : !hal.command_buffer> target(%[[BUFFER]] : !hal.buffer)[%c0, %[[SIZE]]] pattern(%[[VALUE]] : i8)
%1 = flow.tensor.splat %arg0 : tensor<?x128xi8>{%arg1}
flow.return %1 : tensor<?x128xi8>
}
@@ -635,8 +628,7 @@
// CHECK: %[[BUFFER:.+]] = hal.allocator.allocate<%allocator : !hal.allocator> type("HostVisible|DeviceVisible|DeviceLocal") usage("Transfer|Mapping|Dispatch") : !hal.buffer{%c1024}
%0 = flow.ex.stream.fragment(%value) : (f32) -> tensor<2x128xf32> =
(%arg0: f32) -> tensor<2x128xf32> {
- // CHECK-DAG: %[[PATTERN:.+]] = arith.bitcast %[[VALUE]] : f32 to i32
- // CHECK: hal.command_buffer.fill_buffer<%cmd : !hal.command_buffer> target(%[[BUFFER]] : !hal.buffer)[%c0, %c1024] pattern(%[[PATTERN]] : i32)
+ // CHECK: hal.command_buffer.fill_buffer<%cmd : !hal.command_buffer> target(%[[BUFFER]] : !hal.buffer)[%c0, %c1024] pattern(%[[VALUE]] : f32)
%1 = flow.tensor.splat %arg0 : tensor<2x128xf32>
flow.return %1 : tensor<2x128xf32>
}
@@ -655,11 +647,7 @@
// CHECK: %[[BUFFER:.+]] = hal.allocator.allocate<%allocator : !hal.allocator> type("HostVisible|DeviceVisible|DeviceLocal") usage("Transfer|Mapping|Dispatch") : !hal.buffer{%c512}
%0 = flow.ex.stream.fragment(%value) : (f16) -> tensor<2x128xf16> =
(%arg0: f16) -> tensor<2x128xf16> {
- // CHECK-DAG: %[[BITCAST:.+]] = arith.bitcast %[[VALUE]] : f16 to i16
- // CHECK-DAG: %[[B0:.+]] = arith.extui %[[BITCAST]] : i16 to i32
- // CHECK-DAG: %[[B1:.+]] = arith.shli %[[B0]], %c16
- // CHECK-DAG: %[[PATTERN:.+]] = arith.ori %[[B0]], %[[B1]]
- // CHECK: hal.command_buffer.fill_buffer<%cmd : !hal.command_buffer> target(%[[BUFFER]] : !hal.buffer)[%c0, %c512] pattern(%[[PATTERN]] : i32)
+ // CHECK: hal.command_buffer.fill_buffer<%cmd : !hal.command_buffer> target(%[[BUFFER]] : !hal.buffer)[%c0, %c512] pattern(%[[VALUE]] : f16)
%1 = flow.tensor.splat %arg0 : tensor<2x128xf16>
flow.return %1 : tensor<2x128xf16>
}
diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp b/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp
index 1543705..9026fa0 100644
--- a/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp
+++ b/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp
@@ -13,6 +13,60 @@
namespace iree_compiler {
namespace {
+class CommandBufferFillBufferOpConversion
+ : public OpConversionPattern<IREE::HAL::CommandBufferFillBufferOp> {
+ public:
+ CommandBufferFillBufferOpConversion(MLIRContext *context,
+ SymbolTable &importSymbols,
+ TypeConverter &typeConverter,
+ StringRef importName)
+ : OpConversionPattern(typeConverter, context) {
+ importOp = importSymbols.lookup<IREE::VM::ImportOp>(importName);
+ assert(importOp);
+ }
+
+ LogicalResult matchAndRewrite(
+ IREE::HAL::CommandBufferFillBufferOp op, llvm::ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto importType = importOp.getType();
+ IREE::HAL::CommandBufferFillBufferOp::Adaptor newOperands(operands);
+
+ SmallVector<Value, 8> callOperands = {
+ newOperands.command_buffer(),
+ newOperands.target_buffer(),
+ newOperands.target_offset(),
+ newOperands.length(),
+ };
+
+ // Record the original pattern length then extend it to a 32 bit integer.
+ auto originalPatternType = op.pattern().getType();
+ auto patternBitWidth = originalPatternType.getIntOrFloatBitWidth();
+ auto patternLength = rewriter.createOrFold<mlir::arith::ConstantIntOp>(
+ op.getLoc(), patternBitWidth / 8, 32);
+ Value pattern = op.pattern();
+ if (originalPatternType.isF16() || originalPatternType.isF32()) {
+ pattern = rewriter.createOrFold<arith::BitcastOp>(
+ op.getLoc(), rewriter.getIntegerType(patternBitWidth), pattern);
+ }
+ if (patternBitWidth < 32) {
+ pattern = rewriter.createOrFold<arith::ExtUIOp>(
+ op.getLoc(), pattern, rewriter.getIntegerType(32));
+ }
+ callOperands.push_back(pattern);
+ callOperands.push_back(patternLength);
+
+ auto callOp = rewriter.replaceOpWithNewOp<IREE::VM::CallOp>(
+ op, SymbolRefAttr::get(importOp), importType.getResults(),
+ callOperands);
+
+ copyImportAttrs(importOp, callOp);
+ return success();
+ }
+
+ private:
+ mutable IREE::VM::ImportOp importOp;
+};
+
class CommandBufferPushDescriptorSetOpConversion
: public OpConversionPattern<IREE::HAL::CommandBufferPushDescriptorSetOp> {
public:
@@ -84,7 +138,7 @@
.insert<VMImportOpConversion<IREE::HAL::CommandBufferExecutionBarrierOp>>(
context, importSymbols, typeConverter,
"hal.command_buffer.execution_barrier");
- patterns.insert<VMImportOpConversion<IREE::HAL::CommandBufferFillBufferOp>>(
+ patterns.insert<CommandBufferFillBufferOpConversion>(
context, importSymbols, typeConverter, "hal.command_buffer.fill_buffer");
patterns.insert<VMImportOpConversion<IREE::HAL::CommandBufferCopyBufferOp>>(
context, importSymbols, typeConverter, "hal.command_buffer.copy_buffer");
diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir b/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir
index 4699c05..bfb386b 100644
--- a/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir
+++ b/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir
@@ -43,7 +43,7 @@
%c100 = arith.constant 100 : index
%c200 = arith.constant 200 : index
%c300 = arith.constant 300 : i32
- // CHECK: vm.call @hal.command_buffer.fill_buffer(%arg0, %arg1, %c100, %c200, %c300) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.buffer>, i32, i32, i32) -> ()
+ // CHECK: vm.call @hal.command_buffer.fill_buffer(%arg0, %arg1, %c100, %c200, %c300, %c4) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.buffer>, i32, i32, i32, i32) -> ()
hal.command_buffer.fill_buffer<%arg0 : !hal.command_buffer>
target(%arg1 : !hal.buffer)[%c100, %c200]
pattern(%c300 : i32)
@@ -52,6 +52,23 @@
// -----
+// CHECK-LABEL: @command_buffer_fill_buffer_i16
+func @command_buffer_fill_buffer_i16(
+ %arg0: !hal.command_buffer,
+ %arg1: !hal.buffer
+) {
+ %c100 = arith.constant 100 : index
+ %c200 = arith.constant 200 : index
+ %cst = arith.constant 1234 : i16
+ // CHECK: vm.call @hal.command_buffer.fill_buffer(%arg0, %arg1, %c100, %c200, %c1234_0, %c2) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.buffer>, i32, i32, i32, i32) -> ()
+ hal.command_buffer.fill_buffer<%arg0 : !hal.command_buffer>
+ target(%arg1 : !hal.buffer)[%c100, %c200]
+ pattern(%cst : i16)
+ return
+}
+
+// -----
+
// CHECK-LABEL: @command_buffer_copy_buffer
func @command_buffer_copy_buffer(
%arg0: !hal.command_buffer,
diff --git a/iree/compiler/Dialect/HAL/IR/HALOps.td b/iree/compiler/Dialect/HAL/IR/HALOps.td
index 01d9cf7..f9de603 100644
--- a/iree/compiler/Dialect/HAL/IR/HALOps.td
+++ b/iree/compiler/Dialect/HAL/IR/HALOps.td
@@ -1025,7 +1025,7 @@
HAL_BufferType:$target_buffer,
HAL_DeviceSize:$target_offset,
HAL_DeviceSize:$length,
- I32:$pattern
+ AnyTypeOf<[I8, I16, I32, F16, F32]>:$pattern
);
let assemblyFormat = [{
diff --git a/iree/compiler/Dialect/HAL/hal.imports.mlir b/iree/compiler/Dialect/HAL/hal.imports.mlir
index 40189c4..f5fe69c 100644
--- a/iree/compiler/Dialect/HAL/hal.imports.mlir
+++ b/iree/compiler/Dialect/HAL/hal.imports.mlir
@@ -197,7 +197,8 @@
%target_buffer : !vm.ref<!hal.buffer>,
%target_offset : i32,
%length : i32,
- %pattern : i32
+ %pattern : i32,
+ %pattern_length: i32
)
// Copies a range of one buffer to another.
diff --git a/iree/modules/hal/exports.inl b/iree/modules/hal/exports.inl
index 4d76e83..1a05208 100644
--- a/iree/modules/hal/exports.inl
+++ b/iree/modules/hal/exports.inl
@@ -53,7 +53,7 @@
EXPORT_FN("command_buffer.end", iree_hal_module_command_buffer_end, r, v)
EXPORT_FN("command_buffer.end_debug_group", iree_hal_module_command_buffer_end_debug_group, r, v)
EXPORT_FN("command_buffer.execution_barrier", iree_hal_module_command_buffer_execution_barrier, riii, v)
-EXPORT_FN("command_buffer.fill_buffer", iree_hal_module_command_buffer_fill_buffer, rriii, v)
+EXPORT_FN("command_buffer.fill_buffer", iree_hal_module_command_buffer_fill_buffer, rriiii, v)
EXPORT_FN("command_buffer.push_constants", iree_hal_module_command_buffer_push_constants, rriCiD, v)
EXPORT_FN("command_buffer.push_descriptor_set", iree_hal_module_command_buffer_push_descriptor_set, rriCiriiD, v)
diff --git a/iree/modules/hal/module.c b/iree/modules/hal/module.c
index f8dcfeb..2828372 100644
--- a/iree/modules/hal/module.c
+++ b/iree/modules/hal/module.c
@@ -717,7 +717,7 @@
IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_fill_buffer, //
iree_hal_module_state_t, //
- rriii, v) {
+ rriiii, v) {
iree_hal_command_buffer_t* command_buffer = NULL;
IREE_RETURN_IF_ERROR(
iree_hal_command_buffer_check_deref(args->r0, &command_buffer));
@@ -726,12 +726,13 @@
iree_vm_size_t target_offset = (iree_vm_size_t)args->i2;
iree_vm_size_t length = (iree_vm_size_t)args->i3;
uint32_t pattern = (uint32_t)args->i4;
+ uint32_t pattern_length = (uint32_t)args->i5;
iree_hal_module_ex_defer_release(state, args->r1);
return iree_hal_command_buffer_fill_buffer(command_buffer, target_buffer,
target_offset, length, &pattern,
- sizeof(pattern));
+ pattern_length);
}
IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_copy_buffer, //
diff --git a/iree/test/e2e/models/collatz.mlir b/iree/test/e2e/models/collatz.mlir
index feffff7..eca3d70 100644
--- a/iree/test/e2e/models/collatz.mlir
+++ b/iree/test/e2e/models/collatz.mlir
@@ -1,4 +1,5 @@
// RUN: iree-run-mlir --iree-input-type=mhlo -iree-hal-target-backends=vmvx %s | IreeFileCheck %s
+// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir --iree-input-type=mhlo %s -iree-hal-target-backends=vulkan-spirv | IreeFileCheck %s)
// CHECK-LABEL: EXEC @collatz
func @collatz() -> tensor<f32> {
diff --git a/iree/vm/shims.c b/iree/vm/shims.c
index 1b8a680..38af0c7 100644
--- a/iree/vm/shims.c
+++ b/iree/vm/shims.c
@@ -38,7 +38,6 @@
IREE_VM_ABI_DEFINE_SHIM(rrCiriiD, r);
IREE_VM_ABI_DEFINE_SHIM(rriCiD, v);
IREE_VM_ABI_DEFINE_SHIM(rriCiriiD, v);
-IREE_VM_ABI_DEFINE_SHIM(rriii, v);
IREE_VM_ABI_DEFINE_SHIM(rriiii, v);
IREE_VM_ABI_DEFINE_SHIM(rrirCiD, v);
IREE_VM_ABI_DEFINE_SHIM(rriri, v);
diff --git a/iree/vm/shims.h b/iree/vm/shims.h
index 5136861..f5d3415 100644
--- a/iree/vm/shims.h
+++ b/iree/vm/shims.h
@@ -253,14 +253,6 @@
int32_t i6;
});
-IREE_VM_ABI_FIXED_STRUCT(rriii, {
- iree_vm_ref_t r0;
- iree_vm_ref_t r1;
- int32_t i2;
- int32_t i3;
- int32_t i4;
-});
-
IREE_VM_ABI_FIXED_STRUCT(rriiii, {
iree_vm_ref_t r0;
iree_vm_ref_t r1;
@@ -427,7 +419,6 @@
IREE_VM_ABI_DECLARE_SHIM(rrCiriiD, r);
IREE_VM_ABI_DECLARE_SHIM(rriCiD, v);
IREE_VM_ABI_DECLARE_SHIM(rriCiriiD, v);
-IREE_VM_ABI_DECLARE_SHIM(rriii, v);
IREE_VM_ABI_DECLARE_SHIM(rriiii, v);
IREE_VM_ABI_DECLARE_SHIM(rrirCiD, v);
IREE_VM_ABI_DECLARE_SHIM(rriri, v);