Merge pull request #4870 from KoolJBlack:main-to-google
PiperOrigin-RevId: 358026286
diff --git a/CMakeLists.txt b/CMakeLists.txt
index d586279..243e3e6 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -114,6 +114,7 @@
# List of all HAL drivers to be built by default:
set(IREE_ALL_HAL_DRIVERS
+ Cuda
DyLib
VMLA
Vulkan
@@ -126,6 +127,10 @@
if(APPLE)
list(REMOVE_ITEM IREE_HAL_DRIVERS_TO_BUILD Vulkan)
endif()
+ # Remove Cuda from Android and Apple platforms.
+ if(ANDROID OR APPLE)
+ list(REMOVE_ITEM IREE_HAL_DRIVERS_TO_BUILD Cuda)
+ endif()
endif()
message(STATUS "Building HAL drivers: ${IREE_HAL_DRIVERS_TO_BUILD}")
@@ -382,6 +387,7 @@
include(flatbuffer_c_library)
add_subdirectory(third_party/benchmark EXCLUDE_FROM_ALL)
+add_subdirectory(build_tools/third_party/cuda_headers EXCLUDE_FROM_ALL)
add_subdirectory(build_tools/third_party/flatcc EXCLUDE_FROM_ALL)
add_subdirectory(build_tools/third_party/half EXCLUDE_FROM_ALL)
add_subdirectory(build_tools/third_party/pffft EXCLUDE_FROM_ALL)
diff --git a/build_tools/bazel/workspace.bzl b/build_tools/bazel/workspace.bzl
index ed68102..7e1afb4 100644
--- a/build_tools/bazel/workspace.bzl
+++ b/build_tools/bazel/workspace.bzl
@@ -123,3 +123,10 @@
build_file = iree_repo_alias + "//:build_tools/third_party/spirv_cross/BUILD.overlay",
path = paths.join(iree_path, "third_party/spirv_cross"),
)
+
+ maybe(
+ native.new_local_repository,
+ name = "cuda_headers",
+ build_file = iree_repo_alias + "//:build_tools/third_party/cuda_headers/BUILD.overlay",
+ path = paths.join(iree_path, "third_party/cuda_headers"),
+ )
diff --git a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
index a7a6ff1..d4d7c11 100644
--- a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
+++ b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
@@ -52,6 +52,8 @@
"@llvm-project//mlir:TensorDialect": ["MLIRTensor"],
# Vulkan
"@iree_vulkan_headers//:vulkan_headers": ["Vulkan::Headers"],
+ # Cuda
+ "@cuda_headers//:cuda_headers": ["cuda_headers"],
# The Bazel target maps to the IMPORTED target defined by FindVulkan().
"@vulkan_sdk//:sdk": ["Vulkan::Vulkan"],
# Misc single targets
diff --git a/build_tools/buildkite/cmake/android/arm64-v8a/benchmark.yml b/build_tools/buildkite/cmake/android/arm64-v8a/benchmark.yml
index d36bc46..33e7833 100644
--- a/build_tools/buildkite/cmake/android/arm64-v8a/benchmark.yml
+++ b/build_tools/buildkite/cmake/android/arm64-v8a/benchmark.yml
@@ -29,6 +29,8 @@
- wait
- label: "benchmark on snapdragon-855 (adreno-640) (Pixel 4)"
+ # TODO(#4861): Re-enable when phone is fixed
+ skip: "Phone is borked. See https://github.com/google/iree/issues/4861"
commands:
- "buildkite-agent artifact download --step build model-artifacts.tgz ./"
- "tar xzvf model-artifacts.tgz"
diff --git a/build_tools/third_party/cuda_headers/BUILD.overlay b/build_tools/third_party/cuda_headers/BUILD.overlay
new file mode 100644
index 0000000..ba40179
--- /dev/null
+++ b/build_tools/third_party/cuda_headers/BUILD.overlay
@@ -0,0 +1,21 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+package(default_visibility = ["//visibility:public"])
+
+cc_library(
+ name = "cuda_headers",
+ hdrs = ["cuda.h"],
+)
+
diff --git a/build_tools/third_party/cuda_headers/CMakeLists.txt b/build_tools/third_party/cuda_headers/CMakeLists.txt
new file mode 100644
index 0000000..8c3992e
--- /dev/null
+++ b/build_tools/third_party/cuda_headers/CMakeLists.txt
@@ -0,0 +1,29 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set(CUDA_HEADERS_API_ROOT "${IREE_ROOT_DIR}/third_party/cuda_headers/")
+
+external_cc_library(
+ PACKAGE
+ cuda_headers
+ NAME
+ cuda_headers
+ ROOT
+ ${CUDA_HEADERS_API_ROOT}
+ HDRS
+ "cuda.h"
+ INCLUDES
+ ${CUDA_HEADERS_API_ROOT}
+)
+
diff --git a/integrations/tensorflow/iree_tf_compiler/iree-import-tflite-main.cpp b/integrations/tensorflow/iree_tf_compiler/iree-import-tflite-main.cpp
index 63398d0..c978a7f 100644
--- a/integrations/tensorflow/iree_tf_compiler/iree-import-tflite-main.cpp
+++ b/integrations/tensorflow/iree_tf_compiler/iree-import-tflite-main.cpp
@@ -61,6 +61,7 @@
// Register any command line options.
registerAsmPrinterCLOptions();
registerMLIRContextCLOptions();
+ registerPassManagerCLOptions();
cl::ParseCommandLineOptions(argc, argv);
// Initialize dialects.
diff --git a/iree/base/CMakeLists.txt b/iree/base/CMakeLists.txt
index a01db40..048bf22 100644
--- a/iree/base/CMakeLists.txt
+++ b/iree/base/CMakeLists.txt
@@ -266,8 +266,7 @@
DEPS
::core_headers
DEFINES
- # TODO(#2114): Change the mode to 2.
- "IREE_TRACING_MODE=1"
+ "IREE_TRACING_MODE=2"
PUBLIC
)
else()
diff --git a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
index fa55fda..c707431 100644
--- a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
+++ b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
@@ -57,25 +57,6 @@
// Utility functions.
// -----------------------------------------------------------------------------
-/// Returns the constant value associated with the init value if the defining
-/// operation is a constant.
-static Attribute getInitValueAsConst(Value init) {
- Attribute attr;
- if (!matchPattern(init, m_Constant(&attr))) return {};
- if (attr.getType().isa<IntegerType, FloatType>()) return attr;
-
- auto splatAttr = attr.dyn_cast<SplatElementsAttr>();
- if (!splatAttr) return {};
- auto type = splatAttr.getType().dyn_cast<ShapedType>();
- if (!type) return {};
- if (auto intType = type.getElementType().dyn_cast<IntegerType>()) {
- return IntegerAttr::get(intType, splatAttr.getSplatValue<APInt>());
- } else if (auto floatType = type.getElementType().dyn_cast<FloatType>()) {
- return FloatAttr::get(floatType, splatAttr.getSplatValue<APFloat>());
- }
- return {};
-}
-
/// Returns an ArrayAttr that contains `nLoops` attributes. All the attributes
/// are "parallel" except the last `nReduction` elements, where are "reduction"
/// attributes.
@@ -554,51 +535,36 @@
}
//===----------------------------------------------------------------------===//
-// mhlo.pad conversion patterns and utility functions.
+// linalg.pad_tensor conversion patterns and utility functions.
//===----------------------------------------------------------------------===//
namespace {
-/// Converts mhlo.pad operation to linalg.indexed_generic op.
-// TODO(#1604): Lower the pad op to a Linalg named op.
-struct PadOpConversion
- : public ConvertToLinalgBufferOp<PadOpConversion, mhlo::PadOp> {
- using ConvertToLinalgBufferOp<PadOpConversion,
- mhlo::PadOp>::ConvertToLinalgBufferOp;
+/// Converts linalg.pad_tensor operation to fill + subview + copy ops.
+struct PadTensorOpConversion
+ : public ConvertToLinalgBufferOp<PadTensorOpConversion,
+ linalg::PadTensorOp> {
+ using ConvertToLinalgBufferOp<PadTensorOpConversion,
+ linalg::PadTensorOp>::ConvertToLinalgBufferOp;
- LogicalResult apply(mhlo::PadOp op, ArrayRef<Value> inputBuffers,
+ LogicalResult apply(linalg::PadTensorOp op, ArrayRef<Value> inputBuffers,
ArrayRef<Value> resultBuffers,
ConversionPatternRewriter &rewriter) const;
};
} // namespace
-LogicalResult PadOpConversion::apply(
- mhlo::PadOp op, ArrayRef<Value> inputBuffers, ArrayRef<Value> resultBuffers,
- ConversionPatternRewriter &rewriter) const {
- mhlo::PadOp::Adaptor adaptor(inputBuffers);
+LogicalResult PadTensorOpConversion::apply(
+ linalg::PadTensorOp op, ArrayRef<Value> inputBuffers,
+ ArrayRef<Value> resultBuffers, ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
-
- Attribute paddingConstVal = getInitValueAsConst(adaptor.padding_value());
- Value paddingVal =
- paddingConstVal
- ? rewriter.create<ConstantOp>(loc, paddingConstVal).getResult()
- : rewriter.create<LoadOp>(loc, adaptor.padding_value());
-
- const auto &edgePaddingLow = op.edge_padding_low();
- const auto &interiorPadding = op.interior_padding();
- SmallVector<Value, 3> offsets, sizes, strides;
- for (auto it : llvm::enumerate(llvm::zip(edgePaddingLow, interiorPadding))) {
- Value startIndex = rewriter.create<ConstantIndexOp>(
- loc, std::get<0>(it.value()).getZExtValue());
- offsets.push_back(startIndex);
- Value size = rewriter.create<DimOp>(loc, inputBuffers[0], it.index());
- sizes.push_back(size);
- Value stride = rewriter.create<ConstantIndexOp>(
- loc, std::get<1>(it.value()).getZExtValue() + 1);
- strides.push_back(stride);
+ auto yieldOp = cast<linalg::YieldOp>(op.region().begin()->getTerminator());
+ rewriter.create<linalg::FillOp>(loc, resultBuffers[0], yieldOp.values()[0]);
+ SmallVector<Value, 4> sizes, strides;
+ int rank = op.getSourceType().getRank();
+ for (int i = 0; i < rank; ++i) {
+ sizes.push_back(rewriter.create<DimOp>(loc, inputBuffers[0], i));
+ strides.push_back(rewriter.create<ConstantIndexOp>(loc, 1));
}
-
- rewriter.create<linalg::FillOp>(loc, resultBuffers[0], paddingVal);
- auto subViewOp = rewriter.create<SubViewOp>(loc, resultBuffers[0], offsets,
+ auto subViewOp = rewriter.create<SubViewOp>(loc, resultBuffers[0], op.low(),
sizes, strides);
if (auto cstOp = dyn_cast<ConstantOp>(inputBuffers[0].getDefiningOp())) {
auto inputConstAttr =
@@ -1229,7 +1195,7 @@
LinalgOpOnTensorConversion<linalg::IndexedGenericOp>,
MatmulOnTensorConversion<linalg::MatmulOp>,
MatmulOnTensorConversion<linalg::BatchMatmulOp>,
- PadOpConversion, ReduceWindowOpConversion,
+ PadTensorOpConversion, ReduceWindowOpConversion,
SubTensorOpConversion, TensorReshapeOpConversion>(
context, resultTensorToBufferMap);
}
@@ -1266,8 +1232,8 @@
[](Shape::TieShapeOp op) -> bool {
return op.operand().getType().isa<MemRefType>();
});
- // Also convert away linalg.tensor_reshape.
- target.addIllegalOp<linalg::TensorReshapeOp>();
+ // Also convert away linalg.tensor_reshape and linalg.pad_tensor.
+ target.addIllegalOp<linalg::TensorReshapeOp, linalg::PadTensorOp>();
target.addDynamicallyLegalDialect<linalg::LinalgDialect>(
Optional<ConversionTarget::DynamicLegalityCallbackFn>([](Operation *op) {
// The generated structured Linalg ops should have buffer
diff --git a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensors.cpp b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensors.cpp
index 2cb4a48..585570f 100644
--- a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensors.cpp
+++ b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensors.cpp
@@ -36,6 +36,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Matchers.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
@@ -150,6 +151,69 @@
} // namespace
//===----------------------------------------------------------------------===//
+// mhlo.pad conversion patterns.
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Returns the constant value associated with the init value if the defining
+/// operation is a constant.
+static Attribute getInitValueAsConst(Value init) {
+ Attribute attr;
+ if (!matchPattern(init, m_Constant(&attr))) return {};
+ if (attr.getType().isa<IntegerType, FloatType>()) return attr;
+
+ auto splatAttr = attr.dyn_cast<SplatElementsAttr>();
+ if (!splatAttr) return {};
+ auto type = splatAttr.getType().dyn_cast<ShapedType>();
+ if (!type) return {};
+ if (auto intType = type.getElementType().dyn_cast<IntegerType>()) {
+ return IntegerAttr::get(intType, splatAttr.getSplatValue<APInt>());
+ } else if (auto floatType = type.getElementType().dyn_cast<FloatType>()) {
+ return FloatAttr::get(floatType, splatAttr.getSplatValue<APFloat>());
+ }
+ return {};
+}
+
+/// Converts mhlo.pad operation to linalg.pad_tensor op.
+struct PadOpConversion : public OpConversionPattern<mhlo::PadOp> {
+ using OpConversionPattern<mhlo::PadOp>::OpConversionPattern;
+
+ LogicalResult matchAndRewrite(
+ mhlo::PadOp op, ArrayRef<Value> args,
+ ConversionPatternRewriter &rewriter) const override {
+ mhlo::PadOp::Adaptor adaptor(args);
+ if (llvm::any_of(op.interior_padding().getValues<APInt>(),
+ [](APInt intVal) { return intVal.getZExtValue() != 0; })) {
+ return rewriter.notifyMatchFailure(op, "expected no interior padding");
+ }
+ auto loc = op.getLoc();
+
+ Attribute paddingConstVal = getInitValueAsConst(adaptor.padding_value());
+ Value paddingVal =
+ paddingConstVal
+ ? rewriter.create<ConstantOp>(loc, paddingConstVal).getResult()
+ : rewriter.create<tensor::ExtractOp>(loc, adaptor.padding_value());
+
+ const auto &edgePaddingLow = op.edge_padding_low();
+ const auto &edgePaddingHigh = op.edge_padding_high();
+ SmallVector<OpFoldResult, 4> low, high;
+ for (auto it :
+ llvm::enumerate(llvm::zip(edgePaddingLow, edgePaddingHigh))) {
+ low.push_back(rewriter.createOrFold<ConstantIndexOp>(
+ loc, std::get<0>(it.value()).getZExtValue()));
+ high.push_back(rewriter.createOrFold<ConstantIndexOp>(
+ loc, std::get<1>(it.value()).getZExtValue()));
+ }
+ Type resultType = op.getResult().getType();
+ auto padTensorOp = linalg::PadTensorOp::createPadScalarOp(
+ resultType, adaptor.operand(), paddingVal, low, high, loc, rewriter);
+ rewriter.replaceOp(op, padTensorOp.getResult());
+ return success();
+ }
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
// mhlo.slice conversion patterns.
//===----------------------------------------------------------------------===//
@@ -233,7 +297,7 @@
MLIRContext *context, OwningRewritePatternList &patterns) {
mhlo::populateHLOToLinalgConversionPattern(context, &patterns);
patterns.insert<TorchIndexSelectOpConversion, SliceOpConversion,
- ConstOpConversion>(context);
+ ConstOpConversion, PadOpConversion>(context);
}
std::unique_ptr<OperationPass<FuncOp>> createHLOToLinalgOnTensorsPass() {
diff --git a/iree/compiler/Conversion/HLOToLinalg/test/pad.mlir b/iree/compiler/Conversion/HLOToLinalg/test/pad.mlir
index 29c700b..2f44eeb 100644
--- a/iree/compiler/Conversion/HLOToLinalg/test/pad.mlir
+++ b/iree/compiler/Conversion/HLOToLinalg/test/pad.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt -split-input-file -iree-codegen-hlo-to-linalg-on-buffers -canonicalize %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -iree-codegen-hlo-to-linalg-on-tensors -canonicalize %s | IreeFileCheck %s
module {
func @pad_cst() {
@@ -18,13 +18,16 @@
hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write"
}
}
-// CHECK_LABEL: @pad_cst
-// CHECK-DAG: %[[CST:.+]] = constant 0.000000e+00 : f32
-// CHECK-DAG: %[[OUT:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<18x12xf32>
-// CHECK-DAG: %[[IN:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<12x4xf32>
-// CHECK: linalg.fill(%[[OUT]], %[[CST]])
-// CHECK: %[[SUBVIEW:.+]] = subview %[[OUT]][4, 5] [12, 4] [1, 1]
-// CHECK: linalg.copy(%[[IN]], %[[SUBVIEW]])
+// CHECK-LABEL: func @pad_cst
+// CHECK-DAG: %[[PAD:.+]] = constant 0.000000e+00 : f32
+// CHECK-DAG: %[[C4:.+]] = constant 4 : index
+// CHECK-DAG: %[[C2:.+]] = constant 2 : index
+// CHECK-DAG: %[[C5:.+]] = constant 5 : index
+// CHECK-DAG: %[[C3:.+]] = constant 3 : index
+// CHECK-DAG: %[[IN:.+]] = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 : tensor<12x4xf32>
+// CHECK: linalg.pad_tensor %[[IN]] low[%[[C4]], %[[C5]]] high[%[[C2]], %[[C3]]]
+// CHECK: linalg.yield %[[PAD]] : f32
+// CHECK: } : tensor<12x4xf32> to tensor<18x12xf32>
// -----
@@ -47,14 +50,17 @@
hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
}
}
-// CHECK_LABEL: @pad_memref
-// CHECK-DAG: %[[OUT:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<18x12xf32>
-// CHECK-DAG: %[[IN:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<12x4xf32>
-// CHECK-DAG: %[[PAD_BUF:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<f32>
-// CHECK: %[[PAD_VAL:.+]] = load %[[PAD_BUF]][] : memref<f32>
-// CHECK: linalg.fill(%[[OUT]], %[[PAD_VAL]])
-// CHECK: %[[SUBVIEW:.+]] = subview %[[OUT]][4, 5] [12, 4] [1, 1]
-// CHECK: linalg.copy(%[[IN]], %[[SUBVIEW]])
+// CHECK-LABEL: func @pad_memref
+// CHECK-DAG: %[[C4:.+]] = constant 4 : index
+// CHECK-DAG: %[[C2:.+]] = constant 2 : index
+// CHECK-DAG: %[[C5:.+]] = constant 5 : index
+// CHECK-DAG: %[[C3:.+]] = constant 3 : index
+// CHECK-DAG: %[[IN1:.+]] = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 : tensor<12x4xf32>
+// CHECK-DAG: %[[IN2:.+]] = hal.interface.load.tensor @legacy_io::@arg1, offset = %c0 : tensor<f32>
+// CHECK-DAG: %[[PAD:.+]] = tensor.extract %[[IN2]][] : tensor<f32>
+// CHECK: linalg.pad_tensor %[[IN1]] low[%[[C4]], %[[C5]]] high[%[[C2]], %[[C3]]]
+// CHECK: linalg.yield %[[PAD]] : f32
+// CHECK: } : tensor<12x4xf32> to tensor<18x12xf32>
// -----
@@ -76,34 +82,28 @@
hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write"
}
}
-// CHECK_LABEL: @pad_no_op
-// CHECK-DAG: %[[OUT:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<12x4xf32>
-// CHECK-DAG: %[[IN:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<12x4xf32>
-// CHECK: linalg.copy(%[[IN]], %[[OUT]])
+// CHECK-LABEL: func @pad_no_op
+// CHECK-NOT: linalg.pad_tensor
// -----
module {
func @cst_pad_cst() {
%c0 = constant 0 : index
- %0 = constant dense<1.0> : tensor<12x4xf32>
+ %0 = constant dense<1.0> : tensor<1xf32>
%1 = constant dense<0.0> : tensor<f32>
%2 = "mhlo.pad"(%0, %1) {
- edge_padding_high = dense<[2, 3]> : tensor<2xi64>,
- edge_padding_low = dense<[4, 5]> : tensor<2xi64>,
- interior_padding = dense<0> : tensor<2xi64>
- } : (tensor<12x4xf32>, tensor<f32>) -> tensor<18x12xf32>
- hal.interface.store.tensor %2, @legacy_io::@ret0, offset = %c0 : tensor<18x12xf32>
+ edge_padding_high = dense<[1]> : tensor<1xi64>,
+ edge_padding_low = dense<[2]> : tensor<1xi64>,
+ interior_padding = dense<0> : tensor<1xi64>
+ } : (tensor<1xf32>, tensor<f32>) -> tensor<4xf32>
+ hal.interface.store.tensor %2, @legacy_io::@ret0, offset = %c0 : tensor<4xf32>
return
}
hal.interface @legacy_io attributes {sym_visibility = "private"} {
hal.interface.binding @ret0, set=0, binding=0, type="StorageBuffer", access="Write"
}
}
-// CHECK_LABEL: @cst_pad_cst
-// CHECK-DAG: %[[ZERO:.+]] = constant 0.000000e+00 : f32
-// CHECK-DAG: %[[ONE:.+]] = constant 1.000000e+00 : f32
-// CHECK-DAG: %[[OUT:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<18x12xf32>
-// CHECK: linalg.fill(%[[OUT]], %[[ZERO]])
-// CHECK: %[[SUBVIEW:.+]] = subview %[[OUT]][4, 5] [12, 4] [1, 1]
-// CHECK: linalg.fill(%[[SUBVIEW]], %[[ONE]])
+// CHECK-LABEL: func @cst_pad_cst
+// CHECK: constant dense<[0.000000e+00, 0.000000e+00, 1.000000e+00, 0.000000e+00]> : tensor<4xf32>
+// CHECK-NOT: linalg.pad_tensor
diff --git a/iree/compiler/Conversion/HLOToLinalg/test/pad_tensor.mlir b/iree/compiler/Conversion/HLOToLinalg/test/pad_tensor.mlir
new file mode 100644
index 0000000..8d77fc1
--- /dev/null
+++ b/iree/compiler/Conversion/HLOToLinalg/test/pad_tensor.mlir
@@ -0,0 +1,64 @@
+// RUN: iree-opt -split-input-file -iree-codegen-hlo-to-linalg-on-buffers -canonicalize %s | IreeFileCheck %s
+
+module {
+ func @pad_cst() {
+ %c0 = constant 0 : index
+ %cst = constant 0.000000e+00 : f32
+ %c4 = constant 4 : index
+ %c2 = constant 2 : index
+ %c5 = constant 5 : index
+ %c3 = constant 3 : index
+ %0 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 : tensor<12x4xf32>
+ %1 = linalg.pad_tensor %0 low[%c4, %c5] high[%c2, %c3] {
+ ^bb0(%arg0: index, %arg1: index): // no predecessors
+ linalg.yield %cst : f32
+ } : tensor<12x4xf32> to tensor<18x12xf32>
+ hal.interface.store.tensor %1, @legacy_io::@ret0, offset = %c0 : tensor<18x12xf32>
+ return
+ }
+ hal.interface @legacy_io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write"
+ }
+}
+// CHECK-LABEL: @pad_cst
+// CHECK-DAG: %[[CST:.+]] = constant 0.000000e+00 : f32
+// CHECK-DAG: %[[OUT:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<18x12xf32>
+// CHECK-DAG: %[[IN:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<12x4xf32>
+// CHECK: linalg.fill(%[[OUT]], %[[CST]])
+// CHECK: %[[SUBVIEW:.+]] = subview %[[OUT]][4, 5] [12, 4] [1, 1]
+// CHECK: linalg.copy(%[[IN]], %[[SUBVIEW]])
+
+// -----
+
+module {
+ func @pad_memref() {
+ %c0 = constant 0 : index
+ %c4 = constant 4 : index
+ %c2 = constant 2 : index
+ %c5 = constant 5 : index
+ %c3 = constant 3 : index
+ %0 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 : tensor<12x4xf32>
+ %1 = hal.interface.load.tensor @legacy_io::@arg1, offset = %c0 : tensor<f32>
+ %2 = tensor.extract %1[] : tensor<f32>
+ %3 = linalg.pad_tensor %0 low[%c4, %c5] high[%c2, %c3] {
+ ^bb0(%arg0: index, %arg1: index): // no predecessors
+ linalg.yield %2 : f32
+ } : tensor<12x4xf32> to tensor<18x12xf32>
+ hal.interface.store.tensor %3, @legacy_io::@ret0, offset = %c0 : tensor<18x12xf32>
+ return
+ }
+ hal.interface @legacy_io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
+ }
+}
+// CHECK-LABEL: @pad_memref
+// CHECK-DAG: %[[OUT:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<18x12xf32>
+// CHECK-DAG: %[[IN:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<12x4xf32>
+// CHECK-DAG: %[[PAD_BUF:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<f32>
+// CHECK: %[[PAD_VAL:.+]] = load %[[PAD_BUF]][] : memref<f32>
+// CHECK: linalg.fill(%[[OUT]], %[[PAD_VAL]])
+// CHECK: %[[SUBVIEW:.+]] = subview %[[OUT]][4, 5] [12, 4] [1, 1]
+// CHECK: linalg.copy(%[[IN]], %[[SUBVIEW]])
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
index a1144ff..1c4bcd5 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
@@ -344,6 +344,12 @@
return success();
}
+struct TileWorkgroupSizePair {
+ // How many scalar elements each workgroup should handle along each dimension.
+ std::array<int64_t, 3> tileSize;
+ std::array<int64_t, 3> workgroupSize;
+};
+
static LogicalResult getMaliSpecificConfig(linalg::ConvOp op,
TileSizesListType &tileSizes,
LaunchConfigInfo &config) {
@@ -352,42 +358,52 @@
if (!inputType.hasStaticShape() || !outputType.hasStaticShape())
return failure();
- const int tileWidth = 8;
- const int tileChannel = 32;
-
- auto outputShape = outputType.getShape();
bool isInputTilable = inputType.getDimSize(3) % 4 == 0;
- bool isOutputTilable = outputShape[0] == 1 &&
- outputShape[2] % tileWidth == 0 &&
- outputShape[3] % tileChannel == 0;
- if (!isInputTilable || !isOutputTilable) return failure();
+ if (!isInputTilable) return failure();
- config.workgroupSize = {8, 2, 1};
+ static const TileWorkgroupSizePair tileWorkgroupSizePairs[] = {
+ {{1, 8, 32}, {8, 2, 1}},
+ };
- SmallVector<int64_t, 4> workgroupLevel = {/*batch=*/0, /*output_height=*/1,
- /*output_width=*/tileWidth,
- /*output_channel=*/tileChannel};
- tileSizes.emplace_back(std::move(workgroupLevel));
+ for (const auto &pair : tileWorkgroupSizePairs) {
+ const std::array<int64_t, 3> &tileSize = pair.tileSize;
+ const std::array<int64_t, 3> &workgroupSize = pair.workgroupSize;
- // No tiling at the subgroup level given that we don't use subgroup
- // level syncrhonization or shared memory.
- tileSizes.emplace_back();
+ auto outputShape = outputType.getShape();
+ bool isOutputTilable = (outputShape[0] == 1) &&
+ (outputShape[1] % tileSize[0] == 0) &&
+ (outputShape[2] % tileSize[1] == 0) &&
+ (outputShape[3] % tileSize[2] == 0);
+ if (!isOutputTilable) continue;
- SmallVector<int64_t, 4> invocationLevel = {
- /*batch=*/0, /*output_height=*/1,
- /*output_width=*/tileWidth / config.workgroupSize[1],
- /*output_channel=*/tileChannel / config.workgroupSize[0]};
- tileSizes.emplace_back(invocationLevel);
+ SmallVector<int64_t, 4> workgroupLevel = {
+ /*batch=*/0, /*output_height=*/tileSize[0],
+ /*output_width=*/tileSize[1], /*output_channel=*/tileSize[2]};
+ tileSizes.emplace_back(std::move(workgroupLevel));
- // Finally, for each invocation, we use tiling to generate loops to loop over
- // the filter's height (step 1), width (step 1), and input channel (step 4)
- // dimensions.
- SmallVector<int64_t, 4> fourthLevel = {0, 0, 0, 0, 4, 1, 1};
- tileSizes.emplace_back(fourthLevel);
+ // No tiling at the subgroup level given that we don't use subgroup
+ // level syncrhonization or shared memory.
+ tileSizes.emplace_back();
- config.vectorize = true;
+ SmallVector<int64_t, 4> invocationLevel = {
+ /*batch=*/0, /*output_height=*/tileSize[0] / workgroupSize[2],
+ /*output_width=*/tileSize[1] / workgroupSize[1],
+ /*output_channel=*/tileSize[2] / workgroupSize[0]};
+ tileSizes.emplace_back(invocationLevel);
- return success();
+ // Finally, for each invocation, we use tiling to generate loops to loop
+ // over the filter's height (step 1), width (step 1), and input channel
+ // (step 4) dimensions.
+ SmallVector<int64_t, 4> fourthLevel = {0, 0, 0, 0, 4, 1, 1};
+ tileSizes.emplace_back(fourthLevel);
+
+ config.workgroupSize = workgroupSize;
+ config.vectorize = true;
+
+ return success();
+ }
+
+ return failure();
}
template <>
@@ -412,6 +428,80 @@
return success();
}
+static LogicalResult getMaliSpecificConfig(
+ linalg::DepthwiseConvInputNHWCFilterHWCOp op, TileSizesListType &tileSizes,
+ LaunchConfigInfo &config) {
+ auto inputType = op.getInput(0).getType().cast<MemRefType>();
+ auto outputType = op.getOutputBufferTypes()[0].cast<MemRefType>();
+ if (!inputType.hasStaticShape() || !outputType.hasStaticShape())
+ return failure();
+
+ static const TileWorkgroupSizePair tileWorkgroupSizePairs[] = {
+ {{2, 2, 32}, {8, 2, 2}},
+ };
+
+ for (const auto &pair : tileWorkgroupSizePairs) {
+ const std::array<int64_t, 3> &tileSize = pair.tileSize;
+ const std::array<int64_t, 3> &workgroupSize = pair.workgroupSize;
+
+ auto outputShape = outputType.getShape();
+ bool isOutputTilable = outputShape[0] == 1 &&
+ (outputShape[1] % tileSize[0] == 0) &&
+ (outputShape[2] % tileSize[1] == 0) &&
+ (outputShape[3] % tileSize[2] == 0);
+ if (!isOutputTilable) continue;
+
+ SmallVector<int64_t, 4> workgroupLevel = {/*batch=*/0,
+ /*output_height=*/tileSize[0],
+ /*output_width=*/tileSize[1],
+ /*output_channel=*/tileSize[2]};
+ tileSizes.emplace_back(std::move(workgroupLevel));
+
+ // No tiling at the subgroup level given that we don't use subgroup
+ // level syncrhonization or shared memory.
+ tileSizes.emplace_back();
+
+ SmallVector<int64_t, 4> invocationLevel = {
+ /*batch=*/0, /*output_height=*/tileSize[0] / workgroupSize[2],
+ /*output_width=*/tileSize[1] / workgroupSize[1],
+ /*output_channel=*/tileSize[2] / workgroupSize[0]};
+ tileSizes.emplace_back(invocationLevel);
+
+ // Finally, for each invocation, we use tiling to generate loops to loop
+ // over the filter's height (step 1) and width (step 1) dimensions.
+ SmallVector<int64_t, 4> fourthLevel = {0, 0, 0, 0, 1, 1};
+ tileSizes.emplace_back(fourthLevel);
+
+ config.workgroupSize = workgroupSize;
+ config.vectorize = true;
+
+ return success();
+ }
+ return failure();
+}
+
+template <>
+LogicalResult getOpLaunchConfig(linalg::DepthwiseConvInputNHWCFilterHWCOp op,
+ const spirv::TargetEnv &targetEnv,
+ const SPIRVCodegenOptions &options,
+ TileSizesListType &tileSizes,
+ LaunchConfigInfo &config) {
+ if (targetEnv.getVendorID() == spirv::Vendor::ARM &&
+ succeeded(getMaliSpecificConfig(op, tileSizes, config))) {
+ return success();
+ }
+
+ unsigned maxWorkgroupSize = targetEnv.getResourceLimits()
+ .max_compute_workgroup_invocations()
+ .getInt();
+ const int64_t tileSizeX = 32;
+ int64_t tileSizeY = maxWorkgroupSize / tileSizeX;
+ SmallVector<int64_t, 4> ts = {1, tileSizeY, tileSizeX};
+ tileSizes.emplace_back(std::move(ts));
+ config.workgroupSize = {tileSizeX, tileSizeY, 1};
+ return success();
+}
+
template <typename PoolingOpTy>
static LogicalResult getPoolingOpLaunchConfig(
PoolingOpTy op, const spirv::TargetEnv &targetEnv,
@@ -492,6 +582,7 @@
DISPATCH(linalg::BatchMatmulOp)
DISPATCH(linalg::ConvOp)
+ DISPATCH(linalg::DepthwiseConvInputNHWCFilterHWCOp)
DISPATCH(linalg::MatmulOp)
DISPATCH(linalg::PoolingMaxOp)
DISPATCH(linalg::PoolingMinOp)
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
index 886acc5..69e32b9 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
@@ -318,7 +318,9 @@
{getWorkgroupMemoryMarker(), getWorkgroupMarker()},
getVectorizeMarker(), context));
- patterns.insert<linalg::LinalgTilingPattern<linalg::ConvOp>>(
+ patterns.insert<
+ linalg::LinalgTilingPattern<linalg::ConvOp>,
+ linalg::LinalgTilingPattern<linalg::DepthwiseConvInputNHWCFilterHWCOp>>(
context, tilingOptions,
getLinalgMatchAndReplaceMarker(
{getWorkgroupMemoryMarker(), getWorkgroupMarker()},
@@ -411,6 +413,15 @@
return tileSizes;
};
+ auto depthWiseConvTilingOptions =
+ linalg::LinalgTilingOptions()
+ .setLoopType(linalg::LinalgTilingLoopType::Loops)
+ .setTileSizeComputationFunction(getTileSizeFn);
+
+ patterns.insert<
+ linalg::LinalgTilingPattern<linalg::DepthwiseConvInputNHWCFilterHWCOp>>(
+ context, depthWiseConvTilingOptions, marker);
+
// TODO(antiagainst): move this to launch configuration.
SmallVector<unsigned, 8> loopOrder = {
/*batch=*/0,
@@ -422,13 +433,13 @@
/*input_channel=*/4,
};
- auto tilingOptions = linalg::LinalgTilingOptions()
- .setLoopType(linalg::LinalgTilingLoopType::Loops)
- .setInterchange(loopOrder)
- .setTileSizeComputationFunction(getTileSizeFn);
+ auto convTilingOptions = linalg::LinalgTilingOptions()
+ .setLoopType(linalg::LinalgTilingLoopType::Loops)
+ .setInterchange(loopOrder)
+ .setTileSizeComputationFunction(getTileSizeFn);
patterns.insert<linalg::LinalgTilingPattern<linalg::ConvOp>>(
- context, tilingOptions, marker);
+ context, convTilingOptions, marker);
}
//====---------------------------------------------------------------------===//
@@ -597,7 +608,7 @@
applyCanonicalizationPatternsForTiling(context, funcOp);
LLVM_DEBUG({
- llvm::dbgs() << "--- After tiling linalg.conv ---\n";
+ llvm::dbgs() << "--- After tiling convolution filter ---\n";
funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
llvm::dbgs() << "\n\n";
});
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir
index 49e2661..af2f432 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir
@@ -636,14 +636,76 @@
}
// CHECK-LABEL: func @conv_tiled_and_vectorized()
+// For linalg.fill
+// CHECK-COUNT-4: vector.transfer_write
+
+// For linalg.conv
// CHECK-COUNT-4: vector.transfer_read
// check tiling loop along filter height/width and input channel
-// CHECK: scf.for %{{.*}} = %c0 to %c3 step %c1
-// CHECK: scf.for %{{.*}} = %c0 to %c3 step %c1
-// CHECK: scf.for %{{.*}} = %c0 to %c16 step %c4
+// CHECK: scf.for %{{.*}} = %c0 to %c3 step %c1
+// CHECK-SAME: -> (vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>)
+// CHECK: scf.for %{{.*}} = %c0 to %c3 step %c1
+// CHECK-SAME: -> (vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>)
+// CHECK: scf.for %{{.*}} = %c0 to %c16 step %c4
+// CHECK-SAME: -> (vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>)
// CHECK-COUNT-16: vector.contract
// CHECK-COUNT-3: scf.yield
+
+// For linalg.conv
// CHECK-COUNT-4: vector.transfer_write
+
+// -----
+
+hal.executable @conv_tiled_and_vectorized attributes {sym_visibility = "private"} {
+ hal.interface @legacy_io {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+ }
+ hal.executable.target @vulkan, filter="dylib*" {
+ hal.executable.entry_point @conv_tiled_and_vectorized attributes {
+ interface = @legacy_io, ordinal = 0 : i32,
+ signature = (!flow.dispatch.input<?x?xf32>, !flow.dispatch.input<?x?xf32>,
+ !flow.dispatch.output<?x?xf32>) -> ()}
+ module attributes {
+ spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader, Float16, Int16, Int8, StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform, VariablePointers, VariablePointersStorageBuffer], [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>, ARM:IntegratedGPU, {max_compute_shared_memory_size = 32768 : i32, max_compute_workgroup_invocations = 512 : i32, max_compute_workgroup_size = dense<512> : vector<3xi32>, subgroup_size = 16 : i32}>
+ } {
+ func @depthwise_conv_tiled_and_vectorized() attributes {hal.num_workgroups_fn = @get_num_workgroups} {
+ %cst = constant 0.000000e+00 : f32
+ %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<1x56x56x96xf32>
+ %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<1x113x113x96xf32>
+ %2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<3x3x96xf32>
+ linalg.fill(%0, %cst) : memref<1x56x56x96xf32>, f32
+ linalg.depthwise_conv_2d_input_nhwc_filter_hwc {strides = dense<2> : tensor<2xi64>} ins(%1, %2 : memref<1x113x113x96xf32>, memref<3x3x96xf32>) outs(%0 : memref<1x56x56x96xf32>)
+ return
+ }
+
+ func private @get_num_workgroups(!shapex.ranked_shape<[1,113,113,96]>, !shapex.ranked_shape<[3,3,96]>, !shapex.ranked_shape<[1,56,56,96]>) -> (index, index, index)
+ }
+ }
+}
+
+// CHECK-LABEL: func @depthwise_conv_tiled_and_vectorized()
+
+// For linalg.fill
+// CHECK: vector.transfer_write
+
+// For linalg.depthwise_conv_2d_input_nhwc_filter_hwc
+// CHECK: vector.transfer_read
+
+// check tiling loop along filter height/width and input channel
+// CHECK: scf.for %{{.+}} = %c0 to %c3 step %c1
+// CHECK-SAME: -> (vector<4xf32>)
+// CHECK: scf.for %{{.+}} = %c0 to %c3 step %c1
+// CHECK-SAME: -> (vector<4xf32>)
+
+
+// CHECK: vector.fma
+
+// CHECK-COUNT-2: scf.yield
+
+// For linalg.depthwise_conv_2d_input_nhwc_filter_hwc
+// CHECK: vector.transfer_write
diff --git a/iree/compiler/Conversion/LinalgToVector/VectorizeConv.cpp b/iree/compiler/Conversion/LinalgToVector/VectorizeConv.cpp
index 8e5ee34..4335354 100644
--- a/iree/compiler/Conversion/LinalgToVector/VectorizeConv.cpp
+++ b/iree/compiler/Conversion/LinalgToVector/VectorizeConv.cpp
@@ -47,7 +47,6 @@
/// - For filter:
/// - Hf must be 1.
/// - Hf must be 1.
-/// - Ci must be 4.
/// - No dilation.
/// - No padding.
///
@@ -209,6 +208,132 @@
}
};
+/// Vectorizes linalg.depthwise_conv_2d_input_nhwc_filter_hwc for a single GPU
+/// invocation. Therefore, the linalg.depthwise_conv_2d_input_nhwc_filter_hwc op
+/// should have a very specific form; other patterns are expected to tile and
+/// distribute larger convolutions into this form for a single GPU invocation.
+///
+/// The linalg.depthwise_conv_2d_input_nhwc_filter_hwc op should follow:
+/// - Filter: HfWfC format
+/// - Input : NHiWiC format
+/// - Output: NHoWoC format
+/// - For output:
+/// - N must be 1.
+/// - C must be a multiple of 4.
+/// - For filter:
+/// - Hf must be 1.
+/// - Hf must be 1.
+/// - No dilation.
+/// - No padding.
+///
+/// Channel is requried to be a multiple of 4 so that we can process them with
+/// load4/store4, which is native to GPUs.
+struct VectorizeLinalgDepthwiseConv
+ : OpRewritePattern<linalg::DepthwiseConvInputNHWCFilterHWCOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(
+ linalg::DepthwiseConvInputNHWCFilterHWCOp convOp,
+ PatternRewriter &rewriter) const override {
+ LLVM_DEBUG(llvm::dbgs() << "inspecting " << convOp << "\n");
+
+ auto inputViewOp = convOp.getInput(0).getDefiningOp<SubViewOp>();
+ auto filterViewOp = convOp.getInput(1).getDefiningOp<SubViewOp>();
+ auto outputViewOp = convOp.getOutput(0).getDefiningOp<SubViewOp>();
+ if (!filterViewOp || !inputViewOp || !outputViewOp) return failure();
+
+ // The filter/input/output view should have static sizes to vectorize.
+ if (!llvm::empty(filterViewOp.getDynamicSizes()) ||
+ !llvm::empty(inputViewOp.getDynamicSizes()) ||
+ !llvm::empty(outputViewOp.getDynamicSizes())) {
+ return failure();
+ }
+
+ // The output batch dimension should be 1.
+ if (outputViewOp.getStaticSize(0) != 1) return failure();
+
+ // We addtionally expect the filter height/width dimensions are both 1 to
+ // simplify vectorization. Other patterns can generate loops to create 1x1
+ // filter subivews.
+ if (filterViewOp.getStaticSize(0) != 1 ||
+ filterViewOp.getStaticSize(1) != 1) {
+ return failure();
+ }
+
+ int64_t numChannels = outputViewOp.getStaticSize(3);
+ if (numChannels % 4 != 0) return failure();
+
+ int64_t numOutputHeights = outputViewOp.getStaticSize(1);
+ int64_t numOutputWidths = outputViewOp.getStaticSize(2);
+ int64_t heightStride = convOp.strides().getValue<int64_t>({0});
+ int64_t widthStride = convOp.strides().getValue<int64_t>({1});
+
+ // This invocation handles a batch of (numOutputHeights * numOutputWidths *
+ // numChannels).
+ LLVM_DEBUG({
+ llvm::dbgs() << "# output height: " << numOutputHeights << "\n";
+ llvm::dbgs() << "# output width: " << numOutputWidths << "\n";
+ llvm::dbgs() << "# channels: " << numChannels << "\n";
+ llvm::dbgs() << "height stride: " << heightStride << "\n";
+ llvm::dbgs() << "width stride: " << widthStride << "\n";
+ });
+
+ Location loc = convOp.getLoc();
+
+ Type elementType = filterViewOp.getType().getElementType();
+ auto vector4Type = VectorType::get(4, elementType);
+ auto filterVectorType = VectorType::get({numChannels}, elementType);
+ Value zero = rewriter.createOrFold<ConstantIndexOp>(loc, 0);
+
+ // Load the entire filter subview.
+ SmallVector<Value, 4> filterIndices(3, zero);
+ Value wholeFilter = rewriter.create<vector::TransferReadOp>(
+ loc, filterVectorType, filterViewOp, filterIndices);
+
+ // Compute the (numOutputHeights * numOutputWidths * numChannels) output
+ // batch. We only contribute numChannels accumulation along the reduction
+ // dimension.
+ for (int oc = 0; oc < numChannels / 4; ++oc) {
+ Value filterVector = rewriter.create<vector::ExtractStridedSliceOp>(
+ loc, wholeFilter, /*offsets=*/oc * 4, /*sizes=*/4, /*strides=*/1);
+
+ for (int oh = 0; oh < numOutputHeights; ++oh) {
+ for (int ow = 0; ow < numOutputWidths; ++ow) {
+ // Read in the initial value for this output vector.
+ SmallVector<Value, 4> outputIndices(4, zero);
+ outputIndices[1] = rewriter.createOrFold<ConstantIndexOp>(loc, oh);
+ outputIndices[2] = rewriter.createOrFold<ConstantIndexOp>(loc, ow);
+ outputIndices[3] =
+ rewriter.createOrFold<ConstantIndexOp>(loc, oc * 4);
+ Value outputVector = rewriter.create<vector::TransferReadOp>(
+ loc, vector4Type, outputViewOp, outputIndices);
+
+ // Read in the input vector for these 4 input channels a a batch.
+ SmallVector<Value, 4> inputIndices(4, zero);
+ inputIndices[1] =
+ rewriter.createOrFold<ConstantIndexOp>(loc, oh * heightStride);
+ inputIndices[2] =
+ rewriter.createOrFold<ConstantIndexOp>(loc, ow * widthStride);
+ inputIndices[3] = rewriter.createOrFold<ConstantIndexOp>(loc, oc * 4);
+ Value inputVector = rewriter.create<vector::TransferReadOp>(
+ loc, vector4Type, inputViewOp, inputIndices);
+
+ // Peform element-wise product and accumulation.
+ outputVector = rewriter.create<vector::FMAOp>(
+ loc, inputVector, filterVector, outputVector);
+
+ // Write out the output vector.
+ rewriter.create<vector::TransferWriteOp>(loc, outputVector,
+ outputViewOp, outputIndices);
+ }
+ }
+ }
+
+ rewriter.eraseOp(convOp);
+ return success();
+ }
+};
+
struct VectorizeLinalgConvPass
: public PassWrapper<VectorizeLinalgConvPass, OperationPass<FuncOp>> {
void getDependentDialects(DialectRegistry ®istry) const override {
@@ -218,7 +343,7 @@
void runOnOperation() override {
MLIRContext *context = &getContext();
OwningRewritePatternList patterns;
- patterns.insert<VectorizeLinalgConv>(context);
+ patterns.insert<VectorizeLinalgConv, VectorizeLinalgDepthwiseConv>(context);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
@@ -227,7 +352,7 @@
void populateVectorizeLinalgConvPatterns(MLIRContext *context,
OwningRewritePatternList &patterns) {
- patterns.insert<VectorizeLinalgConv>(context);
+ patterns.insert<VectorizeLinalgConv, VectorizeLinalgDepthwiseConv>(context);
}
std::unique_ptr<Pass> createVectorizeLinalgConvPass() {
diff --git a/iree/compiler/Conversion/LinalgToVector/test/vectorize_linalg_conv.mlir b/iree/compiler/Conversion/LinalgToVector/test/vectorize_linalg_conv.mlir
index 76d835c..2ca5b2f 100644
--- a/iree/compiler/Conversion/LinalgToVector/test/vectorize_linalg_conv.mlir
+++ b/iree/compiler/Conversion/LinalgToVector/test/vectorize_linalg_conv.mlir
@@ -119,3 +119,95 @@
linalg.conv(%0, %1, %2) {dilations = [2, 1], strides = [2, 2]} : memref<1x1x4x4xf32>, memref<1x1x7x4xf32>, memref<1x1x4x4xf32>
return
}
+
+// -----
+
+func @vectorize_depthwise_conv(%input: memref<1x3x3x8xf32>, %filter: memref<1x1x8xf32>, %output: memref<1x2x2x8xf32>) {
+ %0 = subview %input[0, 0, 0, 0] [1, 3, 3, 8] [1, 1, 1, 1] : memref<1x3x3x8xf32> to memref<1x3x3x8xf32>
+ %1 = subview %filter[0, 0, 0] [1, 1, 8] [1, 1, 1] : memref<1x1x8xf32> to memref<1x1x8xf32>
+ %2 = subview %output[0, 0, 0, 0] [1, 2, 2, 8] [1, 1, 1, 1] : memref<1x2x2x8xf32> to memref<1x2x2x8xf32>
+ linalg.depthwise_conv_2d_input_nhwc_filter_hwc {strides = dense<2> : tensor<2xi64>} ins(%0, %1 : memref<1x3x3x8xf32>, memref<1x1x8xf32>) outs(%2 : memref<1x2x2x8xf32>)
+ return
+}
+
+// CHECK-LABEL: func @vectorize_depthwise_conv
+// CHECK-SAME: %[[INPUT_ARG:.+]]: memref<1x3x3x8xf32>,
+// CHECK-SAME: %[[FILTER_ARG:.+]]: memref<1x1x8xf32>,
+// CHECK-SAME: %[[OUTPUT_ARG:.+]]: memref<1x2x2x8xf32>
+
+// CHECK: %[[FLOAT_ZERO:.+]] = constant 0.000000e+00 : f32
+// CHECK: %[[INPUT:.+]] = subview %[[INPUT_ARG]]
+// CHECK: %[[FILTER:.+]] = subview %[[FILTER_ARG]]
+// CHECK: %[[OUTPUT:.+]] = subview %[[OUTPUT_ARG]]
+
+// CHECK: %[[FILTER_VECTOR:.+]] = vector.transfer_read %[[FILTER]][%c0, %c0, %c0], %cst {masked = [false]} : memref<1x1x8xf32>, vector<8xf32>
+
+// Common filter #0
+// CHECK: %[[FILTER_0:.+]] = vector.extract_strided_slice %[[FILTER_VECTOR]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xf32> to vector<4xf32>
+
+// CHECK: %[[OUTPUT_0_0:.+]] = vector.transfer_read %[[OUTPUT]][%c0, %c0, %c0, %c0], %cst {masked = [false]} : memref<1x2x2x8xf32>, vector<4xf32>
+// CHECK: %[[INPUT_0_0:.+]] = vector.transfer_read %[[INPUT]][%c0, %c0, %c0, %c0], %cst {masked = [false]} : memref<1x3x3x8xf32>, vector<4xf32>
+// CHECK: %[[FMA_0_0:.+]] = vector.fma %[[INPUT_0_0]], %[[FILTER_0]], %[[OUTPUT_0_0]] : vector<4xf32>
+// CHECK: vector.transfer_write %[[FMA_0_0]], %[[OUTPUT]][%c0, %c0, %c0, %c0] {masked = [false]} : vector<4xf32>, memref<1x2x2x8xf32>
+
+// CHECK: %[[OUTPUT_0_1:.+]] = vector.transfer_read %[[OUTPUT]][%c0, %c0, %c1, %c0], %cst {masked = [false]} : memref<1x2x2x8xf32>, vector<4xf32>
+// CHECK: %[[INPUT_0_1:.+]] = vector.transfer_read %[[INPUT]][%c0, %c0, %c2, %c0], %cst {masked = [false]} : memref<1x3x3x8xf32>, vector<4xf32>
+// CHECK: %[[FMA_0_1:.+]] = vector.fma %[[INPUT_0_1]], %[[FILTER_0]], %[[OUTPUT_0_1]] : vector<4xf32>
+// CHECK: vector.transfer_write %[[FMA_0_1]], %[[OUTPUT]][%c0, %c0, %c1, %c0] {masked = [false]} : vector<4xf32>, memref<1x2x2x8xf32>
+
+// CHECK: %[[OUTPUT_1_0:.+]] = vector.transfer_read %[[OUTPUT]][%c0, %c1, %c0, %c0], %cst {masked = [false]} : memref<1x2x2x8xf32>, vector<4xf32>
+// CHECK: %[[INPUT_1_0:.+]] = vector.transfer_read %[[INPUT]][%c0, %c2, %c0, %c0], %cst {masked = [false]} : memref<1x3x3x8xf32>, vector<4xf32>
+// CHECK: %[[FMA_1_0:.+]] = vector.fma %[[INPUT_1_0]], %[[FILTER_0]], %[[OUTPUT_1_0]] : vector<4xf32>
+// CHECK: vector.transfer_write %[[FMA_1_0]], %[[OUTPUT]][%c0, %c1, %c0, %c0] {masked = [false]} : vector<4xf32>, memref<1x2x2x8xf32>
+
+// CHECK: %[[OUTPUT_1_1:.+]] = vector.transfer_read %[[OUTPUT]][%c0, %c1, %c1, %c0], %cst {masked = [false]} : memref<1x2x2x8xf32>, vector<4xf32>
+// CHECK: %[[INPUT_1_1:.+]] = vector.transfer_read %[[INPUT]][%c0, %c2, %c2, %c0], %cst {masked = [false]} : memref<1x3x3x8xf32>, vector<4xf32>
+// CHECK: %[[FMA_1_1:.+]] = vector.fma %[[INPUT_1_1]], %[[FILTER_0]], %[[OUTPUT_1_1]] : vector<4xf32>
+// CHECK: vector.transfer_write %[[FMA_1_1]], %[[OUTPUT]][%c0, %c1, %c1, %c0] {masked = [false]} : vector<4xf32>, memref<1x2x2x8xf32>
+
+// Common filter #1
+// CHECK: %[[FILTER_1:.+]] = vector.extract_strided_slice %[[FILTER_VECTOR]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xf32> to vector<4xf32>
+
+// CHECK: %[[OUTPUT_0_0:.+]] = vector.transfer_read %[[OUTPUT]][%c0, %c0, %c0, %c4], %cst {masked = [false]} : memref<1x2x2x8xf32>, vector<4xf32>
+// CHECK: %[[INPUT_0_0:.+]] = vector.transfer_read %[[INPUT]][%c0, %c0, %c0, %c4], %cst {masked = [false]} : memref<1x3x3x8xf32>, vector<4xf32>
+// CHECK: %[[FMA_0_0:.+]] = vector.fma %[[INPUT_0_0]], %[[FILTER_1]], %[[OUTPUT_0_0]] : vector<4xf32>
+// CHECK: vector.transfer_write %[[FMA_0_0]], %[[OUTPUT]][%c0, %c0, %c0, %c4] {masked = [false]} : vector<4xf32>, memref<1x2x2x8xf32>
+
+// CHECK: %[[OUTPUT_0_1:.+]] = vector.transfer_read %[[OUTPUT]][%c0, %c0, %c1, %c4], %cst {masked = [false]} : memref<1x2x2x8xf32>, vector<4xf32>
+// CHECK: %[[INPUT_0_1:.+]] = vector.transfer_read %[[INPUT]][%c0, %c0, %c2, %c4], %cst {masked = [false]} : memref<1x3x3x8xf32>, vector<4xf32>
+// CHECK: %[[FMA_0_1:.+]] = vector.fma %[[INPUT_0_1]], %[[FILTER_1]], %[[OUTPUT_0_1]] : vector<4xf32>
+// CHECK: vector.transfer_write %[[FMA_0_1]], %[[OUTPUT]][%c0, %c0, %c1, %c4] {masked = [false]} : vector<4xf32>, memref<1x2x2x8xf32>
+
+// CHECK: %[[OUTPUT_1_0:.+]] = vector.transfer_read %[[OUTPUT]][%c0, %c1, %c0, %c4], %cst {masked = [false]} : memref<1x2x2x8xf32>, vector<4xf32>
+// CHECK: %[[INPUT_1_0:.+]] = vector.transfer_read %[[INPUT]][%c0, %c2, %c0, %c4], %cst {masked = [false]} : memref<1x3x3x8xf32>, vector<4xf32>
+// CHECK: %[[FMA_1_0:.+]] = vector.fma %[[INPUT_1_0]], %[[FILTER_1]], %[[OUTPUT_1_0]] : vector<4xf32>
+// CHECK: vector.transfer_write %[[FMA_1_0]], %[[OUTPUT]][%c0, %c1, %c0, %c4] {masked = [false]} : vector<4xf32>, memref<1x2x2x8xf32>
+
+// CHECK: %[[OUTPUT_1_1:.+]] = vector.transfer_read %[[OUTPUT]][%c0, %c1, %c1, %c4], %cst {masked = [false]} : memref<1x2x2x8xf32>, vector<4xf32>
+// CHECK: %[[INPUT_1_1:.+]] = vector.transfer_read %[[INPUT]][%c0, %c2, %c2, %c4], %cst {masked = [false]} : memref<1x3x3x8xf32>, vector<4xf32>
+// CHECK: %[[FMA_1_1:.+]] = vector.fma %[[INPUT_1_1]], %[[FILTER_1]], %[[OUTPUT_1_1]] : vector<4xf32>
+// CHECK: vector.transfer_write %[[FMA_1_1]], %[[OUTPUT]][%c0, %c1, %c1, %c4] {masked = [false]} : vector<4xf32>, memref<1x2x2x8xf32>
+
+// -----
+
+// CHECK-LABEL: func @do_not_vectorize_depthwise_conv_with_non_1_filter_height
+func @do_not_vectorize_depthwise_conv_with_non_1_filter_height(%input: memref<1x2x3x4xf32>, %filter: memref<2x1x4xf32>, %output: memref<1x1x2x4xf32>) {
+ %0 = subview %input[0, 0, 0, 0] [1, 2, 3, 4] [1, 1, 1, 1] : memref<1x2x3x4xf32> to memref<1x2x3x4xf32>
+ %1 = subview %filter[0, 0, 0] [2, 1, 4] [1, 1, 1] : memref<2x1x4xf32> to memref<2x1x4xf32>
+ %2 = subview %output[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : memref<1x1x2x4xf32> to memref<1x1x2x4xf32>
+ // CHECK: linalg.depthwise_conv_2d_input_nhwc_filter_hwc
+ linalg.depthwise_conv_2d_input_nhwc_filter_hwc {strides = dense<2> : tensor<2xi64>} ins(%0, %1 : memref<1x2x3x4xf32>, memref<2x1x4xf32>) outs(%2 : memref<1x1x2x4xf32>)
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @do_not_vectorize_depthwise_conv_with_non_1_filter_width
+func @do_not_vectorize_depthwise_conv_with_non_1_filter_width(%input: memref<1x1x4x4xf32>, %filter: memref<1x2x4xf32>, %output: memref<1x1x2x4xf32>) {
+ %0 = subview %input[0, 0, 0, 0] [1, 1, 4, 4] [1, 1, 1, 1] : memref<1x1x4x4xf32> to memref<1x1x4x4xf32>
+ %1 = subview %filter[0, 0, 0] [1, 2, 4] [1, 1, 1] : memref<1x2x4xf32> to memref<1x2x4xf32>
+ %2 = subview %output[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : memref<1x1x2x4xf32> to memref<1x1x2x4xf32>
+ // CHECK: linalg.depthwise_conv_2d_input_nhwc_filter_hwc
+ linalg.depthwise_conv_2d_input_nhwc_filter_hwc {strides = dense<2> : tensor<2xi64>} ins(%0, %1 : memref<1x1x4x4xf32>, memref<1x2x4xf32>) outs(%2 : memref<1x1x2x4xf32>)
+ return
+}
diff --git a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp
index c40b199..8b00c68 100644
--- a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp
+++ b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp
@@ -72,6 +72,10 @@
void populateVMToCPatterns(MLIRContext *context,
OwningRewritePatternList &patterns) {
+ // Constants
+ patterns.insert<CallOpConversion<IREE::VM::ConstI32Op>>(context,
+ "vm_const_i32");
+
// Conditional assignment ops
patterns.insert<CallOpConversion<IREE::VM::SelectI32Op>>(context,
"vm_select_i32");
@@ -132,9 +136,36 @@
patterns.insert<CallOpConversion<IREE::VM::CheckEQOp>>(context,
"VM_CHECK_EQ");
- // Const
- patterns.insert<CallOpConversion<IREE::VM::ConstI32Op>>(context,
- "vm_const_i32");
+ // ExtI64: Constants
+ patterns.insert<CallOpConversion<IREE::VM::ConstI64Op>>(context,
+ "vm_const_i64");
+
+ // ExtI64: Conditional assignment ops
+ patterns.insert<CallOpConversion<IREE::VM::SelectI64Op>>(context,
+ "vm_select_i64");
+ // ExtI64: Native integer arithmetic ops
+ patterns.insert<CallOpConversion<IREE::VM::AddI64Op>>(context, "vm_add_i64");
+ patterns.insert<CallOpConversion<IREE::VM::SubI64Op>>(context, "vm_sub_i64");
+ patterns.insert<CallOpConversion<IREE::VM::MulI64Op>>(context, "vm_mul_i64");
+ patterns.insert<CallOpConversion<IREE::VM::DivI64SOp>>(context,
+ "vm_div_i64s");
+ patterns.insert<CallOpConversion<IREE::VM::DivI64UOp>>(context,
+ "vm_div_i64u");
+ patterns.insert<CallOpConversion<IREE::VM::RemI64SOp>>(context,
+ "vm_rem_i64s");
+ patterns.insert<CallOpConversion<IREE::VM::RemI64UOp>>(context,
+ "vm_rem_i64u");
+ patterns.insert<CallOpConversion<IREE::VM::NotI64Op>>(context, "vm_not_i64");
+ patterns.insert<CallOpConversion<IREE::VM::AndI64Op>>(context, "vm_and_i64");
+ patterns.insert<CallOpConversion<IREE::VM::OrI64Op>>(context, "vm_or_i64");
+ patterns.insert<CallOpConversion<IREE::VM::XorI64Op>>(context, "vm_xor_i64");
+
+ // ExtI64: Native bitwise shift and rotate ops
+ patterns.insert<CallOpConversion<IREE::VM::ShlI64Op>>(context, "vm_shl_i64");
+ patterns.insert<CallOpConversion<IREE::VM::ShrI64SOp>>(context,
+ "vm_shr_i64s");
+ patterns.insert<CallOpConversion<IREE::VM::ShrI64UOp>>(context,
+ "vm_shr_i64u");
}
namespace IREE {
@@ -160,6 +191,9 @@
target.addLegalDialect<iree_compiler::IREEDialect>();
target.addLegalDialect<IREE::VM::VMDialect>();
+ // Constants
+ target.addIllegalOp<IREE::VM::ConstI32Op>();
+
// Conditional assignment ops
target.addIllegalOp<IREE::VM::SelectI32Op>();
@@ -201,8 +235,29 @@
// support for control flow ops has landed in the c module target
target.addIllegalOp<IREE::VM::CheckEQOp>();
- // Const ops
- target.addIllegalOp<IREE::VM::ConstI32Op>();
+ // ExtI64: Constants
+ target.addIllegalOp<IREE::VM::ConstI64Op>();
+
+ // ExtI64: Conditional assignment ops
+ target.addIllegalOp<IREE::VM::SelectI64Op>();
+
+ // ExtI64: Native integer arithmetic ops
+ target.addIllegalOp<IREE::VM::AddI64Op>();
+ target.addIllegalOp<IREE::VM::SubI64Op>();
+ target.addIllegalOp<IREE::VM::MulI64Op>();
+ target.addIllegalOp<IREE::VM::DivI64SOp>();
+ target.addIllegalOp<IREE::VM::DivI64UOp>();
+ target.addIllegalOp<IREE::VM::RemI64SOp>();
+ target.addIllegalOp<IREE::VM::RemI64UOp>();
+ target.addIllegalOp<IREE::VM::NotI64Op>();
+ target.addIllegalOp<IREE::VM::AndI64Op>();
+ target.addIllegalOp<IREE::VM::OrI64Op>();
+ target.addIllegalOp<IREE::VM::XorI64Op>();
+
+ // ExtI64: Native bitwise shift and rotate ops
+ target.addIllegalOp<IREE::VM::ShlI64Op>();
+ target.addIllegalOp<IREE::VM::ShrI64SOp>();
+ target.addIllegalOp<IREE::VM::ShrI64UOp>();
if (failed(
applyFullConversion(getOperation(), target, std::move(patterns)))) {
diff --git a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/arithmetic_ops_i64.mlir b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/arithmetic_ops_i64.mlir
new file mode 100644
index 0000000..9f5a582
--- /dev/null
+++ b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/arithmetic_ops_i64.mlir
@@ -0,0 +1,120 @@
+// RUN: iree-opt -split-input-file -pass-pipeline='vm.module(iree-convert-vm-to-emitc)' %s | IreeFileCheck %s
+
+// CHECK-LABEL: @add_i64
+vm.module @my_module {
+ vm.func @add_i64(%arg0: i64, %arg1: i64) {
+ // CHECK-NEXT: %0 = emitc.call "vm_add_i64"(%arg0, %arg1) {args = [0 : index, 1 : index]} : (i64, i64) -> i64
+ %0 = vm.add.i64 %arg0, %arg1 : i64
+ vm.return %0 : i64
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @sub_i64
+vm.module @my_module {
+ vm.func @sub_i64(%arg0: i64, %arg1: i64) {
+ // CHECK: %0 = emitc.call "vm_sub_i64"(%arg0, %arg1) {args = [0 : index, 1 : index]} : (i64, i64) -> i64
+ %0 = vm.sub.i64 %arg0, %arg1 : i64
+ vm.return %0 : i64
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @mul_i64
+vm.module @my_module {
+ vm.func @mul_i64(%arg0: i64, %arg1: i64) {
+ // CHECK: %0 = emitc.call "vm_mul_i64"(%arg0, %arg1) {args = [0 : index, 1 : index]} : (i64, i64) -> i64
+ %0 = vm.mul.i64 %arg0, %arg1 : i64
+ vm.return %0 : i64
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @div_i64_s
+vm.module @my_module {
+ vm.func @div_i64_s(%arg0: i64, %arg1: i64) {
+ // CHECK: %0 = emitc.call "vm_div_i64s"(%arg0, %arg1) {args = [0 : index, 1 : index]} : (i64, i64) -> i64
+ %0 = vm.div.i64.s %arg0, %arg1 : i64
+ vm.return %0 : i64
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @div_i64_u
+vm.module @my_module {
+ vm.func @div_i64_u(%arg0: i64, %arg1: i64) {
+ // CHECK: %0 = emitc.call "vm_div_i64u"(%arg0, %arg1) {args = [0 : index, 1 : index]} : (i64, i64) -> i64
+ %0 = vm.div.i64.u %arg0, %arg1 : i64
+ vm.return %0 : i64
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @rem_i64_s
+vm.module @my_module {
+ vm.func @rem_i64_s(%arg0: i64, %arg1: i64) {
+ // CHECK: %0 = emitc.call "vm_rem_i64s"(%arg0, %arg1) {args = [0 : index, 1 : index]} : (i64, i64) -> i64
+ %0 = vm.rem.i64.s %arg0, %arg1 : i64
+ vm.return %0 : i64
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @rem_i64_u
+vm.module @my_module {
+ vm.func @rem_i64_u(%arg0: i64, %arg1: i64) {
+ // CHECK: %0 = emitc.call "vm_rem_i64u"(%arg0, %arg1) {args = [0 : index, 1 : index]} : (i64, i64) -> i64
+ %0 = vm.rem.i64.u %arg0, %arg1 : i64
+ vm.return %0 : i64
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @not_i64
+vm.module @my_module {
+ vm.func @not_i64(%arg0 : i64) -> i64 {
+ // CHECK: %0 = emitc.call "vm_not_i64"(%arg0) {args = [0 : index]} : (i64) -> i64
+ %0 = vm.not.i64 %arg0 : i64
+ vm.return %0 : i64
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @and_i64
+vm.module @my_module {
+ vm.func @and_i64(%arg0 : i64, %arg1 : i64) -> i64 {
+ // CHECK: %0 = emitc.call "vm_and_i64"(%arg0, %arg1) {args = [0 : index, 1 : index]} : (i64, i64) -> i64
+ %0 = vm.and.i64 %arg0, %arg1 : i64
+ vm.return %0 : i64
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @or_i64
+vm.module @my_module {
+ vm.func @or_i64(%arg0 : i64, %arg1 : i64) -> i64 {
+ // CHECK: %0 = emitc.call "vm_or_i64"(%arg0, %arg1) {args = [0 : index, 1 : index]} : (i64, i64) -> i64
+ %0 = vm.or.i64 %arg0, %arg1 : i64
+ vm.return %0 : i64
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @xor_i64
+vm.module @my_module {
+ vm.func @xor_i64(%arg0 : i64, %arg1 : i64) -> i64 {
+ // CHECK: %0 = emitc.call "vm_xor_i64"(%arg0, %arg1) {args = [0 : index, 1 : index]} : (i64, i64) -> i64
+ %0 = vm.xor.i64 %arg0, %arg1 : i64
+ vm.return %0 : i64
+ }
+}
diff --git a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/assignment_ops.mlir b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/assignment_ops.mlir
new file mode 100644
index 0000000..26940ca
--- /dev/null
+++ b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/assignment_ops.mlir
@@ -0,0 +1,10 @@
+// RUN: iree-opt -split-input-file -pass-pipeline='vm.module(iree-convert-vm-to-emitc)' %s | IreeFileCheck %s
+
+// CHECK-LABEL: vm.func @select_i32
+vm.module @my_module {
+ vm.func @select_i32(%arg0 : i32, %arg1 : i32, %arg2 : i32) -> i32 {
+ // CHECK: %0 = emitc.call "vm_select_i32"(%arg0, %arg1, %arg2) {args = [0 : index, 1 : index, 2 : index]} : (i32, i32, i32) -> i32
+ %0 = vm.select.i32 %arg0, %arg1, %arg2 : i32
+ vm.return %0 : i32
+ }
+}
diff --git a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/assignment_ops_i64.mlir b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/assignment_ops_i64.mlir
new file mode 100644
index 0000000..67927c4
--- /dev/null
+++ b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/assignment_ops_i64.mlir
@@ -0,0 +1,10 @@
+// RUN: iree-opt -split-input-file -pass-pipeline='vm.module(iree-convert-vm-to-emitc)' %s | IreeFileCheck %s
+
+// CHECK-LABEL: vm.func @select_i64
+vm.module @my_module {
+ vm.func @select_i64(%arg0 : i32, %arg1 : i64, %arg2 : i64) -> i64 {
+ // CHECK: %0 = emitc.call "vm_select_i64"(%arg0, %arg1, %arg2) {args = [0 : index, 1 : index, 2 : index]} : (i32, i64, i64) -> i64
+ %0 = vm.select.i64 %arg0, %arg1, %arg2 : i64
+ vm.return %0 : i64
+ }
+}
diff --git a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/const.mlir b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/const_ops.mlir
similarity index 100%
rename from iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/const.mlir
rename to iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/const_ops.mlir
diff --git a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/const_ops_i64.mlir b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/const_ops_i64.mlir
new file mode 100644
index 0000000..5b412fd
--- /dev/null
+++ b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/const_ops_i64.mlir
@@ -0,0 +1,16 @@
+// RUN: iree-opt -split-input-file -pass-pipeline='vm.module(iree-convert-vm-to-emitc)' %s | IreeFileCheck %s
+
+// CHECK: vm.module @module {
+vm.module @module {
+ // CHECK-LABEL: vm.func @const_i64
+ vm.func @const_i64() {
+ // CHECK-NEXT: %0 = emitc.call "vm_const_i64"() {args = [0]} : () -> i64
+ %0 = vm.const.i64 0 : i64
+ // CHECK-NEXT: %1 = emitc.call "vm_const_i64"() {args = [2]} : () -> i64
+ %1 = vm.const.i64 2 : i64
+ // CHECK-NEXT: %2 = emitc.call "vm_const_i64"() {args = [-2]} : () -> i64
+ %2 = vm.const.i64 -2 : i64
+ // CHECK-NEXT: vm.return
+ vm.return
+ }
+}
diff --git a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/shift_ops_i64.mlir b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/shift_ops_i64.mlir
new file mode 100644
index 0000000..a98a494
--- /dev/null
+++ b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/shift_ops_i64.mlir
@@ -0,0 +1,32 @@
+// RUN: iree-opt -split-input-file -pass-pipeline='vm.module(iree-convert-vm-to-emitc)' %s | IreeFileCheck %s
+
+// CHECK-LABEL: @shl_i64
+vm.module @my_module {
+ vm.func @shl_i64(%arg0 : i64) -> i64 {
+ // CHECK: %0 = emitc.call "vm_shl_i64"(%arg0) {args = [0 : index, 2 : i8]} : (i64) -> i64
+ %0 = vm.shl.i64 %arg0, 2 : i64
+ vm.return %0 : i64
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @shr_i64_s
+vm.module @my_module {
+ vm.func @shr_i64_s(%arg0 : i64) -> i64 {
+ // CHECK: %0 = emitc.call "vm_shr_i64s"(%arg0) {args = [0 : index, 2 : i8]} : (i64) -> i64
+ %0 = vm.shr.i64.s %arg0, 2 : i64
+ vm.return %0 : i64
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @shr_i64_u
+vm.module @my_module {
+ vm.func @shr_i64_u(%arg0 : i64) -> i64 {
+ // CHECK: %0 = emitc.call "vm_shr_i64u"(%arg0) {args = [0 : index, 2 : i8]} : (i64) -> i64
+ %0 = vm.shr.i64.u %arg0, 2 : i64
+ vm.return %0 : i64
+ }
+}
diff --git a/iree/compiler/Dialect/VM/Target/C/CModuleTarget.cpp b/iree/compiler/Dialect/VM/Target/C/CModuleTarget.cpp
index 0be1bbf..3d38624 100644
--- a/iree/compiler/Dialect/VM/Target/C/CModuleTarget.cpp
+++ b/iree/compiler/Dialect/VM/Target/C/CModuleTarget.cpp
@@ -164,7 +164,7 @@
return failure();
}
- if (funcOp.getNumResults() > 0) {
+ if (funcOp.getNumResults() > 0 && funcOp.getNumArguments() > 0) {
output << ", ";
}
diff --git a/iree/compiler/Dialect/VM/Target/C/test/calling_convention.mlir b/iree/compiler/Dialect/VM/Target/C/test/calling_convention.mlir
new file mode 100644
index 0000000..8869bf1
--- /dev/null
+++ b/iree/compiler/Dialect/VM/Target/C/test/calling_convention.mlir
@@ -0,0 +1,34 @@
+// RUN: iree-translate -iree-vm-ir-to-c-module %s | IreeFileCheck %s
+
+// CHECK: #include "iree/vm/ops.h"
+vm.module @calling_convention_test {
+ // CHECK: iree_status_t calling_convention_test_no_in_no_return_impl() {
+ vm.func @no_in_no_return() -> () {
+ // CHECK-NEXT: return iree_ok_status();
+ vm.return
+ }
+
+ // CHECK: iree_status_t calling_convention_test_i32_in_no_return_impl(int32_t v1) {
+ vm.func @i32_in_no_return(%arg0 : i32) -> () {
+ // CHECK-NEXT: return iree_ok_status();
+ vm.return
+ }
+
+ // CHECK: iree_status_t calling_convention_test_no_in_i32_return_impl(int32_t *out0) {
+ vm.func @no_in_i32_return() -> (i32) {
+ // CHECK-NEXT: int32_t v1 = vm_const_i32(32);
+ %0 = vm.const.i32 32 : i32
+ // CHECK-NEXT: *out0 = v1;
+ // CHECK-NEXT: return iree_ok_status();
+ vm.return %0 : i32
+ }
+
+ // CHECK: iree_status_t calling_convention_test_i32_in_i32_return_impl(int32_t v1, int32_t *out0) {
+ vm.func @i32_in_i32_return(%arg0 : i32) -> (i32) {
+ // CHECK-NEXT: int32_t v2 = vm_const_i32(32);
+ %0 = vm.const.i32 32 : i32
+ // CHECK-NEXT: *out0 = v2;
+ // CHECK-NEXT: return iree_ok_status();
+ vm.return %0 : i32
+ }
+}
diff --git a/iree/hal/cuda/BUILD.bazel b/iree/hal/cuda/BUILD.bazel
new file mode 100644
index 0000000..18fc69b
--- /dev/null
+++ b/iree/hal/cuda/BUILD.bazel
@@ -0,0 +1,60 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("//iree:build_defs.oss.bzl", "iree_cmake_extra_content")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_cmake_extra_content(
+ content = """
+if(NOT ${IREE_HAL_DRIVER_CUDA})
+ return()
+endif()
+""",
+)
+
+cc_library(
+ name = "dynamic_symbols",
+ srcs = [
+ "cuda_headers.h",
+ "dynamic_symbols.cc",
+ "dynamic_symbols_tables.h",
+ ],
+ hdrs = [
+ "dynamic_symbols.h",
+ ],
+ deps = [
+ "//iree/base:core_headers",
+ "//iree/base:dynamic_library",
+ "//iree/base:status",
+ "//iree/base:tracing",
+ "@com_google_absl//absl/types:span",
+ "@cuda_headers",
+ ],
+)
+
+cc_test(
+ name = "dynamic_symbols_test",
+ srcs = ["dynamic_symbols_test.cc"],
+ tags = ["driver=cuda"],
+ deps = [
+ ":dynamic_symbols",
+ "//iree/testing:gtest",
+ "//iree/testing:gtest_main",
+ ],
+)
diff --git a/iree/hal/cuda/CMakeLists.txt b/iree/hal/cuda/CMakeLists.txt
new file mode 100644
index 0000000..cc7667f
--- /dev/null
+++ b/iree/hal/cuda/CMakeLists.txt
@@ -0,0 +1,51 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+if(NOT ${IREE_HAL_DRIVER_CUDA})
+ return()
+endif()
+
+iree_add_all_subdirs()
+
+iree_cc_library(
+ NAME
+ dynamic_symbols
+ HDRS
+ "dynamic_symbols.h"
+ SRCS
+ "cuda_headers.h"
+ "dynamic_symbols.cc"
+ "dynamic_symbols_tables.h"
+ DEPS
+ absl::span
+ cuda_headers
+ iree::base::core_headers
+ iree::base::dynamic_library
+ iree::base::status
+ iree::base::tracing
+ PUBLIC
+)
+
+iree_cc_test(
+ NAME
+ dynamic_symbols_test
+ SRCS
+ "dynamic_symbols_test.cc"
+ DEPS
+ ::dynamic_symbols
+ iree::testing::gtest
+ iree::testing::gtest_main
+ LABELS
+ "driver=cuda"
+)
diff --git a/iree/hal/cuda/cuda_headers.h b/iree/hal/cuda/cuda_headers.h
new file mode 100644
index 0000000..f5fd736
--- /dev/null
+++ b/iree/hal/cuda/cuda_headers.h
@@ -0,0 +1,20 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_CUDA_CUDA_HEADERS_H_
+#define IREE_HAL_CUDA_CUDA_HEADERS_H_
+
+#include "cuda.h"
+
+#endif // IREE_HAL_CUDA_CUDA_HEADERS_H_
diff --git a/iree/hal/cuda/dynamic_symbols.cc b/iree/hal/cuda/dynamic_symbols.cc
new file mode 100644
index 0000000..0927116
--- /dev/null
+++ b/iree/hal/cuda/dynamic_symbols.cc
@@ -0,0 +1,60 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/hal/cuda/dynamic_symbols.h"
+
+#include <cstddef>
+
+#include "absl/types/span.h"
+#include "iree/base/status.h"
+#include "iree/base/target_platform.h"
+#include "iree/base/tracing.h"
+
+namespace iree {
+namespace hal {
+namespace cuda {
+
+static const char* kCudaLoaderSearchNames[] = {
+#if defined(IREE_PLATFORM_WINDOWS)
+ "nvcuda.dll",
+#else
+ "libcuda.so",
+#endif
+};
+
+Status DynamicSymbols::LoadSymbols() {
+ IREE_TRACE_SCOPE();
+
+ IREE_RETURN_IF_ERROR(DynamicLibrary::Load(
+ absl::MakeSpan(kCudaLoaderSearchNames), &loader_library_));
+
+#define CU_PFN_DECL(cudaSymbolName) \
+ { \
+ using FuncPtrT = std::add_pointer<decltype(::cudaSymbolName)>::type; \
+ static const char* kName = #cudaSymbolName; \
+ cudaSymbolName = loader_library_->GetSymbol<FuncPtrT>(kName); \
+ if (!cudaSymbolName) { \
+ return iree_make_status(IREE_STATUS_UNAVAILABLE, "symbol not found"); \
+ } \
+ }
+
+#include "dynamic_symbols_tables.h"
+#undef CU_PFN_DECL
+
+ return OkStatus();
+}
+
+} // namespace cuda
+} // namespace hal
+} // namespace iree
diff --git a/iree/hal/cuda/dynamic_symbols.h b/iree/hal/cuda/dynamic_symbols.h
new file mode 100644
index 0000000..9d2c40e
--- /dev/null
+++ b/iree/hal/cuda/dynamic_symbols.h
@@ -0,0 +1,52 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_CUDA_DYNAMIC_SYMBOLS_H_
+#define IREE_HAL_CUDA_DYNAMIC_SYMBOLS_H_
+
+#include <cstdint>
+#include <functional>
+#include <memory>
+
+#include "iree/base/dynamic_library.h"
+#include "iree/base/status.h"
+#include "iree/hal/cuda/cuda_headers.h"
+
+namespace iree {
+namespace hal {
+namespace cuda {
+
+/// DyanmicSymbols allow loading dynamically a subset of CUDA driver API. It
+/// loads all the function declared in `dynamic_symbol_tables.def` and fail if
+/// any of the symbol is not available. The functions signatures are matching
+/// the declarations in `cuda.h`.
+struct DynamicSymbols {
+ Status LoadSymbols();
+
+#define CU_PFN_DECL(cudaSymbolName) \
+ std::add_pointer<decltype(::cudaSymbolName)>::type cudaSymbolName;
+
+#include "dynamic_symbols_tables.h"
+#undef CU_PFN_DECL
+
+ private:
+ // Cuda Loader dynamic library.
+ std::unique_ptr<DynamicLibrary> loader_library_;
+};
+
+} // namespace cuda
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_CUDA_DYNAMIC_SYMBOLS_H_
diff --git a/iree/hal/cuda/dynamic_symbols_tables.h b/iree/hal/cuda/dynamic_symbols_tables.h
new file mode 100644
index 0000000..5adece6
--- /dev/null
+++ b/iree/hal/cuda/dynamic_symbols_tables.h
@@ -0,0 +1,90 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+CU_PFN_DECL(cuCtxCreate)
+CU_PFN_DECL(cuCtxDestroy)
+CU_PFN_DECL(cuCtxEnablePeerAccess)
+CU_PFN_DECL(cuCtxGetCurrent)
+CU_PFN_DECL(cuCtxGetDevice)
+CU_PFN_DECL(cuCtxGetSharedMemConfig)
+CU_PFN_DECL(cuCtxSetCurrent)
+CU_PFN_DECL(cuCtxSetSharedMemConfig)
+CU_PFN_DECL(cuCtxSynchronize)
+CU_PFN_DECL(cuDeviceCanAccessPeer)
+CU_PFN_DECL(cuDeviceGet)
+CU_PFN_DECL(cuDeviceGetAttribute)
+CU_PFN_DECL(cuDeviceGetCount)
+CU_PFN_DECL(cuDeviceGetName)
+CU_PFN_DECL(cuDeviceGetPCIBusId)
+CU_PFN_DECL(cuDevicePrimaryCtxGetState)
+CU_PFN_DECL(cuDevicePrimaryCtxRelease)
+CU_PFN_DECL(cuDevicePrimaryCtxRetain)
+CU_PFN_DECL(cuDevicePrimaryCtxSetFlags)
+CU_PFN_DECL(cuDeviceTotalMem)
+CU_PFN_DECL(cuDriverGetVersion)
+CU_PFN_DECL(cuEventCreate)
+CU_PFN_DECL(cuEventDestroy)
+CU_PFN_DECL(cuEventElapsedTime)
+CU_PFN_DECL(cuEventQuery)
+CU_PFN_DECL(cuEventRecord)
+CU_PFN_DECL(cuEventSynchronize)
+CU_PFN_DECL(cuFuncGetAttribute)
+CU_PFN_DECL(cuFuncSetCacheConfig)
+CU_PFN_DECL(cuGetErrorName)
+CU_PFN_DECL(cuGetErrorString)
+CU_PFN_DECL(cuGraphAddMemcpyNode)
+CU_PFN_DECL(cuGraphAddMemsetNode)
+CU_PFN_DECL(cuGraphAddKernelNode)
+CU_PFN_DECL(cuGraphCreate)
+CU_PFN_DECL(cuGraphDestroy)
+CU_PFN_DECL(cuGraphExecDestroy)
+CU_PFN_DECL(cuGraphGetNodes)
+CU_PFN_DECL(cuGraphInstantiate)
+CU_PFN_DECL(cuGraphLaunch)
+CU_PFN_DECL(cuInit)
+CU_PFN_DECL(cuLaunchKernel)
+CU_PFN_DECL(cuMemAlloc)
+CU_PFN_DECL(cuMemAllocManaged)
+CU_PFN_DECL(cuMemFree)
+CU_PFN_DECL(cuMemFreeHost)
+CU_PFN_DECL(cuMemGetAddressRange)
+CU_PFN_DECL(cuMemGetInfo)
+CU_PFN_DECL(cuMemHostAlloc)
+CU_PFN_DECL(cuMemHostGetDevicePointer)
+CU_PFN_DECL(cuMemHostRegister)
+CU_PFN_DECL(cuMemHostUnregister)
+CU_PFN_DECL(cuMemcpyDtoD)
+CU_PFN_DECL(cuMemcpyDtoDAsync)
+CU_PFN_DECL(cuMemcpyDtoH)
+CU_PFN_DECL(cuMemcpyDtoHAsync)
+CU_PFN_DECL(cuMemcpyHtoD)
+CU_PFN_DECL(cuMemcpyHtoDAsync)
+CU_PFN_DECL(cuMemsetD32)
+CU_PFN_DECL(cuMemsetD32Async)
+CU_PFN_DECL(cuMemsetD8)
+CU_PFN_DECL(cuMemsetD8Async)
+CU_PFN_DECL(cuModuleGetFunction)
+CU_PFN_DECL(cuModuleGetGlobal)
+CU_PFN_DECL(cuModuleLoadDataEx)
+CU_PFN_DECL(cuModuleLoadFatBinary)
+CU_PFN_DECL(cuModuleUnload)
+CU_PFN_DECL(cuOccupancyMaxActiveBlocksPerMultiprocessor)
+CU_PFN_DECL(cuOccupancyMaxPotentialBlockSize)
+CU_PFN_DECL(cuPointerGetAttribute)
+CU_PFN_DECL(cuStreamAddCallback)
+CU_PFN_DECL(cuStreamCreate)
+CU_PFN_DECL(cuStreamDestroy)
+CU_PFN_DECL(cuStreamQuery)
+CU_PFN_DECL(cuStreamSynchronize)
+CU_PFN_DECL(cuStreamWaitEvent)
diff --git a/iree/hal/cuda/dynamic_symbols_test.cc b/iree/hal/cuda/dynamic_symbols_test.cc
new file mode 100644
index 0000000..6a7967c
--- /dev/null
+++ b/iree/hal/cuda/dynamic_symbols_test.cc
@@ -0,0 +1,51 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/hal/cuda/dynamic_symbols.h"
+
+#include "iree/testing/gtest.h"
+#include "iree/testing/status_matchers.h"
+
+namespace iree {
+namespace hal {
+namespace cuda {
+namespace {
+
+#define CUDE_CHECK_ERRORS(expr) \
+ { \
+ CUresult status = expr; \
+ ASSERT_EQ(CUDA_SUCCESS, status); \
+ }
+
+TEST(DynamicSymbolsTest, CreateFromSystemLoader) {
+ DynamicSymbols symbols;
+ Status status = symbols.LoadSymbols();
+ if (!status.ok()) {
+ IREE_LOG(WARNING) << "Symbols cannot be loaded, skipping test.";
+ GTEST_SKIP();
+ }
+
+ int device_count = 0;
+ CUDE_CHECK_ERRORS(symbols.cuInit(0));
+ CUDE_CHECK_ERRORS(symbols.cuDeviceGetCount(&device_count));
+ if (device_count > 0) {
+ CUdevice device;
+ CUDE_CHECK_ERRORS(symbols.cuDeviceGet(&device, /*ordinal=*/0));
+ }
+}
+
+} // namespace
+} // namespace cuda
+} // namespace hal
+} // namespace iree
diff --git a/iree/hal/local/arena.c b/iree/hal/local/arena.c
index 7b9529a..dabc746 100644
--- a/iree/hal/local/arena.c
+++ b/iree/hal/local/arena.c
@@ -52,6 +52,8 @@
iree_status_t iree_arena_block_pool_acquire(iree_arena_block_pool_t* block_pool,
iree_arena_block_t** out_block) {
+ IREE_TRACE_ZONE_BEGIN(z0);
+
iree_arena_block_t* block =
iree_atomic_arena_block_slist_pop(&block_pool->available_slist);
@@ -63,23 +65,28 @@
// that's fine - it's just one block and the contention means there's likely
// to be a need for more anyway.
uint8_t* block_base = NULL;
- IREE_RETURN_IF_ERROR(iree_allocator_malloc(block_pool->block_allocator,
- block_pool->total_block_size,
- (void**)&block_base));
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_allocator_malloc(block_pool->block_allocator,
+ block_pool->total_block_size,
+ (void**)&block_base));
block = (iree_arena_block_t*)(block_base + (block_pool->total_block_size -
sizeof(iree_arena_block_t)));
}
block->next = NULL;
*out_block = block;
+
+ IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}
void iree_arena_block_pool_release(iree_arena_block_pool_t* block_pool,
iree_arena_block_t* block_head,
iree_arena_block_t* block_tail) {
+ IREE_TRACE_ZONE_BEGIN(z0);
iree_atomic_arena_block_slist_concat(&block_pool->available_slist, block_head,
block_tail);
+ IREE_TRACE_ZONE_END(z0);
}
//===----------------------------------------------------------------------===//
diff --git a/iree/modules/hal/hal_module.c b/iree/modules/hal/hal_module.c
index 06c0baf..02f31fe 100644
--- a/iree/modules/hal/hal_module.c
+++ b/iree/modules/hal/hal_module.c
@@ -226,8 +226,8 @@
// Block and wait for the semaphore to be signaled (or fail).
status = iree_hal_semaphore_wait_with_deadline(semaphore, 1ull,
IREE_TIME_INFINITE_FUTURE);
+ iree_hal_semaphore_release(semaphore);
if (!iree_status_is_ok(status)) {
- iree_hal_semaphore_release(semaphore);
return status;
}
diff --git a/iree/test/e2e/xla_ops/pad.mlir b/iree/test/e2e/xla_ops/pad.mlir
index 537e684..7f6df37 100644
--- a/iree/test/e2e/xla_ops/pad.mlir
+++ b/iree/test/e2e/xla_ops/pad.mlir
@@ -20,19 +20,3 @@
check.expect_eq(%res, %input) : tensor<2x3xi32>
return
}
-
-func @pad_with_interior_padding() attributes { iree.module.export } {
- %input = iree.unfoldable_constant dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
- %c0 = iree.unfoldable_constant dense<0> : tensor<i32>
- %res = "mhlo.pad"(%input, %c0) {
- edge_padding_low = dense<[0, 1]> : tensor<2xi64>,
- edge_padding_high = dense<[1, 5]> : tensor<2xi64>,
- interior_padding = dense<[1, 2]> : tensor<2xi64>
- } : (tensor<2x3xi32>, tensor<i32>) -> tensor<4x13xi32>
- check.expect_eq_const(%res, dense<[
- [0, 1, 0, 0, 2, 0, 0, 3, 0, 0, 0, 0, 0],
- [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
- [0, 4, 0, 0, 5, 0, 0, 6, 0, 0, 0, 0, 0],
- [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]> : tensor<4x13xi32>) : tensor<4x13xi32>
- return
-}
diff --git a/iree/tools/iree-benchmark-module-main.cc b/iree/tools/iree-benchmark-module-main.cc
index 0bc1825..b567c1d 100644
--- a/iree/tools/iree-benchmark-module-main.cc
+++ b/iree/tools/iree-benchmark-module-main.cc
@@ -125,21 +125,17 @@
// benchmarking.
class IREEBenchmark {
public:
- IREEBenchmark()
- : instance_(nullptr),
- device_(nullptr),
- hal_module_(nullptr),
- context_(nullptr),
- input_module_(nullptr){};
+ IREEBenchmark() = default;
+
~IREEBenchmark() {
IREE_TRACE_SCOPE0("IREEBenchmark::dtor");
// Order matters.
inputs_.reset();
+ iree_vm_context_release(context_);
iree_vm_module_release(hal_module_);
iree_vm_module_release(input_module_);
iree_hal_device_release(device_);
- iree_vm_context_release(context_);
iree_vm_instance_release(instance_);
};
@@ -247,11 +243,11 @@
}
std::string module_data_;
- iree_vm_instance_t* instance_;
- iree_hal_device_t* device_;
- iree_vm_module_t* hal_module_;
- iree_vm_context_t* context_;
- iree_vm_module_t* input_module_;
+ iree_vm_instance_t* instance_ = nullptr;
+ iree_hal_device_t* device_ = nullptr;
+ iree_vm_module_t* hal_module_ = nullptr;
+ iree_vm_context_t* context_ = nullptr;
+ iree_vm_module_t* input_module_ = nullptr;
iree::vm::ref<iree_vm_list_t> inputs_;
};
} // namespace
diff --git a/iree/vm/bytecode_dispatch.c b/iree/vm/bytecode_dispatch.c
index 86c6509..17253a3 100644
--- a/iree/vm/bytecode_dispatch.c
+++ b/iree/vm/bytecode_dispatch.c
@@ -402,8 +402,7 @@
} break;
case IREE_VM_CCONV_TYPE_REF: {
uint16_t src_reg = src_reg_list->registers[reg_i++];
- iree_vm_ref_retain_or_move(
- src_reg & IREE_REF_REGISTER_MOVE_BIT,
+ iree_vm_ref_assign(
&caller_registers.ref[src_reg & caller_registers.ref_mask],
(iree_vm_ref_t*)p);
p += sizeof(iree_vm_ref_t);
@@ -447,8 +446,7 @@
} break;
case IREE_VM_CCONV_TYPE_REF: {
uint16_t src_reg = src_reg_list->registers[reg_i++];
- iree_vm_ref_retain_or_move(
- src_reg & IREE_REF_REGISTER_MOVE_BIT,
+ iree_vm_ref_assign(
&caller_registers.ref[src_reg & caller_registers.ref_mask],
(iree_vm_ref_t*)p);
p += sizeof(iree_vm_ref_t);
@@ -1362,7 +1360,7 @@
int64_t true_value = VM_DecOperandRegI64("true_value");
int64_t false_value = VM_DecOperandRegI64("false_value");
int64_t* result = VM_DecResultRegI64("result");
- *result = condition ? true_value : false_value;
+ *result = vm_select_i64(condition, true_value, false_value);
});
DISPATCH_OP(EXT_I64, SwitchI64, {
@@ -1383,63 +1381,53 @@
// ExtI64: Native integer arithmetic
//===----------------------------------------------------------------===//
-#define DISPATCH_OP_EXT_I64_UNARY_ALU_I64(op_name, type, op) \
- DISPATCH_OP(EXT_I64, op_name, { \
- int64_t operand = VM_DecOperandRegI64("operand"); \
- int64_t* result = VM_DecResultRegI64("result"); \
- *result = (int64_t)(op((type)operand)); \
- });
-
-#define DISPATCH_OP_EXT_I64_BINARY_ALU_I64(op_name, type, op) \
- DISPATCH_OP(EXT_I64, op_name, { \
- int64_t lhs = VM_DecOperandRegI64("lhs"); \
- int64_t rhs = VM_DecOperandRegI64("rhs"); \
- int64_t* result = VM_DecResultRegI64("result"); \
- *result = (int64_t)(((type)lhs)op((type)rhs)); \
- });
-
- DISPATCH_OP_EXT_I64_BINARY_ALU_I64(AddI64, int64_t, +);
- DISPATCH_OP_EXT_I64_BINARY_ALU_I64(SubI64, int64_t, -);
- DISPATCH_OP_EXT_I64_BINARY_ALU_I64(MulI64, int64_t, *);
- DISPATCH_OP_EXT_I64_BINARY_ALU_I64(DivI64S, int64_t, /);
- DISPATCH_OP_EXT_I64_BINARY_ALU_I64(DivI64U, uint64_t, /);
- DISPATCH_OP_EXT_I64_BINARY_ALU_I64(RemI64S, int64_t, %);
- DISPATCH_OP_EXT_I64_BINARY_ALU_I64(RemI64U, uint64_t, %);
- DISPATCH_OP_EXT_I64_UNARY_ALU_I64(NotI64, uint64_t, ~);
- DISPATCH_OP_EXT_I64_BINARY_ALU_I64(AndI64, uint64_t, &);
- DISPATCH_OP_EXT_I64_BINARY_ALU_I64(OrI64, uint64_t, |);
- DISPATCH_OP_EXT_I64_BINARY_ALU_I64(XorI64, uint64_t, ^);
+ DISPATCH_OP_EXT_I64_BINARY_I64(AddI64, vm_add_i64);
+ DISPATCH_OP_EXT_I64_BINARY_I64(SubI64, vm_sub_i64);
+ DISPATCH_OP_EXT_I64_BINARY_I64(MulI64, vm_mul_i64);
+ DISPATCH_OP_EXT_I64_BINARY_I64(DivI64S, vm_div_i64s);
+ DISPATCH_OP_EXT_I64_BINARY_I64(DivI64U, vm_div_i64u);
+ DISPATCH_OP_EXT_I64_BINARY_I64(RemI64S, vm_rem_i64s);
+ DISPATCH_OP_EXT_I64_BINARY_I64(RemI64U, vm_rem_i64u);
+ DISPATCH_OP_EXT_I64_UNARY_I64(NotI64, vm_not_i64);
+ DISPATCH_OP_EXT_I64_BINARY_I64(AndI64, vm_and_i64);
+ DISPATCH_OP_EXT_I64_BINARY_I64(OrI64, vm_or_i64);
+ DISPATCH_OP_EXT_I64_BINARY_I64(XorI64, vm_xor_i64);
//===----------------------------------------------------------------===//
// ExtI64: Casting and type conversion/emulation
//===----------------------------------------------------------------===//
-#define DISPATCH_OP_EXT_I64_CAST_I64(op_name, src_type, dst_type) \
- DISPATCH_OP(EXT_I64, op_name, { \
- int64_t operand = VM_DecOperandRegI64("operand"); \
- int64_t* result = VM_DecResultRegI64("result"); \
- *result = (dst_type)((src_type)operand); \
- });
-
- DISPATCH_OP_EXT_I64_CAST_I64(TruncI64I32, uint64_t, uint32_t);
- DISPATCH_OP_EXT_I64_CAST_I64(ExtI32I64S, int32_t, int64_t);
- DISPATCH_OP_EXT_I64_CAST_I64(ExtI32I64U, uint32_t, uint64_t);
+ DISPATCH_OP(EXT_I64, TruncI64I32, {
+ int64_t operand = VM_DecOperandRegI64("operand");
+ int32_t* result = VM_DecResultRegI32("result");
+ *result = (uint32_t)((uint64_t)operand);
+ });
+ DISPATCH_OP(EXT_I64, ExtI32I64S, {
+ int32_t operand = VM_DecOperandRegI32("operand");
+ int64_t* result = VM_DecResultRegI64("result");
+ *result = (int64_t)((int32_t)operand);
+ });
+ DISPATCH_OP(EXT_I64, ExtI32I64U, {
+ int32_t operand = VM_DecOperandRegI32("operand");
+ int64_t* result = VM_DecResultRegI64("result");
+ *result = (uint64_t)((uint32_t)operand);
+ });
//===----------------------------------------------------------------===//
// ExtI64: Native bitwise shifts and rotates
//===----------------------------------------------------------------===//
-#define DISPATCH_OP_EXT_I64_SHIFT_I64(op_name, type, op) \
- DISPATCH_OP(EXT_I64, op_name, { \
- int64_t operand = VM_DecOperandRegI64("operand"); \
- int8_t amount = VM_DecConstI8("amount"); \
- int64_t* result = VM_DecResultRegI64("result"); \
- *result = (int64_t)(((type)operand)op amount); \
+#define DISPATCH_OP_EXT_I64_SHIFT_I64(op_name, op_func) \
+ DISPATCH_OP(EXT_I64, op_name, { \
+ int64_t operand = VM_DecOperandRegI64("operand"); \
+ int8_t amount = VM_DecConstI8("amount"); \
+ int64_t* result = VM_DecResultRegI64("result"); \
+ *result = op_func(operand, amount); \
});
- DISPATCH_OP_EXT_I64_SHIFT_I64(ShlI64, int64_t, <<);
- DISPATCH_OP_EXT_I64_SHIFT_I64(ShrI64S, int64_t, >>);
- DISPATCH_OP_EXT_I64_SHIFT_I64(ShrI64U, uint64_t, >>);
+ DISPATCH_OP_EXT_I64_SHIFT_I64(ShlI64, vm_shl_i64);
+ DISPATCH_OP_EXT_I64_SHIFT_I64(ShrI64S, vm_shr_i64s);
+ DISPATCH_OP_EXT_I64_SHIFT_I64(ShrI64U, vm_shr_i64u);
//===----------------------------------------------------------------===//
// ExtI64: Comparison ops
diff --git a/iree/vm/bytecode_dispatch_util.h b/iree/vm/bytecode_dispatch_util.h
index a826be4..c0be859 100644
--- a/iree/vm/bytecode_dispatch_util.h
+++ b/iree/vm/bytecode_dispatch_util.h
@@ -377,4 +377,19 @@
*result = op_func(lhs, rhs); \
});
+#define DISPATCH_OP_EXT_I64_UNARY_I64(op_name, op_func) \
+ DISPATCH_OP(EXT_I64, op_name, { \
+ int64_t operand = VM_DecOperandRegI64("operand"); \
+ int64_t* result = VM_DecResultRegI64("result"); \
+ *result = op_func(operand); \
+ });
+
+#define DISPATCH_OP_EXT_I64_BINARY_I64(op_name, op_func) \
+ DISPATCH_OP(EXT_I64, op_name, { \
+ int64_t lhs = VM_DecOperandRegI64("lhs"); \
+ int64_t rhs = VM_DecOperandRegI64("rhs"); \
+ int64_t* result = VM_DecResultRegI64("result"); \
+ *result = op_func(lhs, rhs); \
+ });
+
#endif // IREE_VM_BYTECODE_DISPATCH_UTIL_H_
diff --git a/iree/vm/list.c b/iree/vm/list.c
index 64fbdc8..24e3158 100644
--- a/iree/vm/list.c
+++ b/iree/vm/list.c
@@ -70,21 +70,27 @@
iree_host_size_t offset,
iree_host_size_t length) {
switch (list->storage_mode) {
- case IREE_VM_LIST_STORAGE_MODE_VALUE:
- // Nothing special, freeing the storage is all we need.
+ case IREE_VM_LIST_STORAGE_MODE_VALUE: {
+ void* base_ptr =
+ (void*)((uintptr_t)list->storage + offset * list->element_size);
+ memset(base_ptr, 0, length * list->element_size);
break;
+ }
case IREE_VM_LIST_STORAGE_MODE_REF: {
iree_vm_ref_t* ref_storage = (iree_vm_ref_t*)list->storage;
- for (iree_host_size_t i = offset; i < length; ++i) {
+ for (iree_host_size_t i = offset; i < offset + length; ++i) {
iree_vm_ref_release(&ref_storage[i]);
}
break;
}
case IREE_VM_LIST_STORAGE_MODE_VARIANT: {
iree_vm_variant_t* variant_storage = (iree_vm_variant_t*)list->storage;
- for (iree_host_size_t i = offset; i < length; ++i) {
+ for (iree_host_size_t i = offset; i < offset + length; ++i) {
if (iree_vm_type_def_is_ref(&variant_storage[i].type)) {
iree_vm_ref_release(&variant_storage[i].ref);
+ memset(&variant_storage[i].type, 0, sizeof(variant_storage[i].type));
+ } else {
+ memset(&variant_storage[i], 0, sizeof(variant_storage[i]));
}
}
break;
@@ -243,7 +249,7 @@
return iree_ok_status();
} else if (new_size < list->count) {
// Truncating.
- iree_vm_list_reset_range(list, new_size + 1, list->count - new_size);
+ iree_vm_list_reset_range(list, new_size, list->count - new_size);
list->count = new_size;
} else if (new_size > list->capacity) {
// Extending beyond capacity.
diff --git a/iree/vm/list_test.cc b/iree/vm/list_test.cc
index 4423cbc..3bad3ac 100644
--- a/iree/vm/list_test.cc
+++ b/iree/vm/list_test.cc
@@ -229,11 +229,161 @@
iree_vm_list_release(list);
}
-// TODO(benvanik): test resize value.
+// Tests the behavior of resize for truncation and extension on primitives.
+TEST_F(VMListTest, ResizeI32) {
+ iree_vm_type_def_t element_type =
+ iree_vm_type_def_make_value_type(IREE_VM_VALUE_TYPE_I32);
+ iree_host_size_t initial_capacity = 4;
+ iree_vm_list_t* list = nullptr;
+ IREE_ASSERT_OK(iree_vm_list_create(&element_type, initial_capacity,
+ iree_allocator_system(), &list));
+ EXPECT_LE(initial_capacity, iree_vm_list_capacity(list));
+ EXPECT_EQ(0, iree_vm_list_size(list));
-// TODO(benvanik): test resize ref.
+ // Extend and zero-initialize.
+ IREE_ASSERT_OK(iree_vm_list_resize(list, 5));
+ for (iree_host_size_t i = 0; i < 5; ++i) {
+ iree_vm_value_t value;
+ IREE_ASSERT_OK(
+ iree_vm_list_get_value_as(list, i, IREE_VM_VALUE_TYPE_I32, &value));
+ EXPECT_EQ(0, value.i32);
+ }
-// TODO(benvanik): test resize variant.
+ // Overwrite with [0, 5).
+ for (iree_host_size_t i = 0; i < 5; ++i) {
+ iree_vm_value_t value = iree_vm_value_make_i32((int32_t)i);
+ IREE_ASSERT_OK(iree_vm_list_set_value(list, i, &value));
+ }
+
+ // Truncate to [0, 2) and then extend again.
+ // This ensures that we test the primitive clearing path during cleanup:
+ // [int, int, int, int, int]
+ // |___________| <- truncation region
+ IREE_ASSERT_OK(iree_vm_list_resize(list, 2));
+ IREE_ASSERT_OK(iree_vm_list_resize(list, 5));
+
+ // Ensure that elements 2+ are zeroed after having been reset while 0 and 1
+ // are still valid as before.
+ for (iree_host_size_t i = 0; i < 2; ++i) {
+ iree_vm_value_t value;
+ IREE_ASSERT_OK(
+ iree_vm_list_get_value_as(list, i, IREE_VM_VALUE_TYPE_I32, &value));
+ EXPECT_EQ(i, value.i32);
+ }
+ for (iree_host_size_t i = 2; i < 5; ++i) {
+ iree_vm_value_t value;
+ IREE_ASSERT_OK(
+ iree_vm_list_get_value_as(list, i, IREE_VM_VALUE_TYPE_I32, &value));
+ EXPECT_EQ(0, value.i32);
+ }
+
+ iree_vm_list_release(list);
+}
+
+// Tests the behavior of resize for truncation and extension on refs.
+TEST_F(VMListTest, ResizeRef) {
+ iree_vm_type_def_t element_type =
+ iree_vm_type_def_make_ref_type(test_a_type_id());
+ iree_host_size_t initial_capacity = 4;
+ iree_vm_list_t* list = nullptr;
+ IREE_ASSERT_OK(iree_vm_list_create(&element_type, initial_capacity,
+ iree_allocator_system(), &list));
+ EXPECT_LE(initial_capacity, iree_vm_list_capacity(list));
+ EXPECT_EQ(0, iree_vm_list_size(list));
+
+ // Extend and zero-initialize.
+ IREE_ASSERT_OK(iree_vm_list_resize(list, 5));
+ for (iree_host_size_t i = 0; i < 5; ++i) {
+ iree_vm_ref_t ref_a{0};
+ IREE_ASSERT_OK(iree_vm_list_get_ref_assign(list, i, &ref_a));
+ EXPECT_TRUE(iree_vm_ref_is_null(&ref_a));
+ }
+
+ // Overwrite with [0, 5).
+ for (iree_host_size_t i = 0; i < 5; ++i) {
+ iree_vm_ref_t ref_a = MakeRef<A>((float)i);
+ IREE_ASSERT_OK(iree_vm_list_set_ref_move(list, i, &ref_a));
+ }
+
+ // Truncate to [0, 2) and then extend again.
+ // This ensures that we test the ref path during cleanup:
+ // [ref, ref, ref, ref, ref]
+ // |___________| <- truncation region
+ IREE_ASSERT_OK(iree_vm_list_resize(list, 2));
+ IREE_ASSERT_OK(iree_vm_list_resize(list, 5));
+
+ // Ensure that elements 2+ are reset after having been reset while 0 and 1
+ // are still valid as before.
+ for (iree_host_size_t i = 0; i < 2; ++i) {
+ iree_vm_ref_t ref_a{0};
+ IREE_ASSERT_OK(iree_vm_list_get_ref_retain(list, i, &ref_a));
+ EXPECT_TRUE(test_a_isa(ref_a));
+ auto* a = test_a_deref(ref_a);
+ EXPECT_EQ(i, a->data());
+ iree_vm_ref_release(&ref_a);
+ }
+ for (iree_host_size_t i = 2; i < 5; ++i) {
+ iree_vm_ref_t ref_a{0};
+ IREE_ASSERT_OK(iree_vm_list_get_ref_assign(list, i, &ref_a));
+ EXPECT_TRUE(iree_vm_ref_is_null(&ref_a));
+ }
+
+ iree_vm_list_release(list);
+}
+
+// Tests the behavior of resize for truncation and extension on variants.
+TEST_F(VMListTest, ResizeVariant) {
+ iree_vm_type_def_t element_type = iree_vm_type_def_make_variant_type();
+ iree_host_size_t initial_capacity = 4;
+ iree_vm_list_t* list = nullptr;
+ IREE_ASSERT_OK(iree_vm_list_create(&element_type, initial_capacity,
+ iree_allocator_system(), &list));
+ EXPECT_LE(initial_capacity, iree_vm_list_capacity(list));
+ EXPECT_EQ(0, iree_vm_list_size(list));
+
+ // Extend and zero-initialize.
+ IREE_ASSERT_OK(iree_vm_list_resize(list, 5));
+ for (iree_host_size_t i = 0; i < 5; ++i) {
+ iree_vm_variant_t value = iree_vm_variant_empty();
+ IREE_ASSERT_OK(iree_vm_list_get_variant(list, i, &value));
+ EXPECT_TRUE(iree_vm_variant_is_empty(value));
+ }
+
+ // Overwrite with [0, 5) in mixed types.
+ for (iree_host_size_t i = 0; i < 4; ++i) {
+ iree_vm_ref_t ref_a = MakeRef<A>((float)i);
+ IREE_ASSERT_OK(iree_vm_list_set_ref_move(list, i, &ref_a));
+ }
+ for (iree_host_size_t i = 4; i < 5; ++i) {
+ iree_vm_value_t value = iree_vm_value_make_i32((int32_t)i);
+ IREE_ASSERT_OK(iree_vm_list_set_value(list, i, &value));
+ }
+
+ // Truncate to [0, 2) and then extend again.
+ // This ensures that we test the variant path during cleanup:
+ // [ref, ref, ref, ref, int]
+ // |___________| <- truncation region
+ IREE_ASSERT_OK(iree_vm_list_resize(list, 2));
+ IREE_ASSERT_OK(iree_vm_list_resize(list, 5));
+
+ // Ensure that elements 2+ are reset after having been reset while 0 and 1
+ // are still valid as before.
+ for (iree_host_size_t i = 0; i < 2; ++i) {
+ iree_vm_ref_t ref_a{0};
+ IREE_ASSERT_OK(iree_vm_list_get_ref_retain(list, i, &ref_a));
+ EXPECT_TRUE(test_a_isa(ref_a));
+ auto* a = test_a_deref(ref_a);
+ EXPECT_EQ(i, a->data());
+ iree_vm_ref_release(&ref_a);
+ }
+ for (iree_host_size_t i = 2; i < 5; ++i) {
+ iree_vm_variant_t value = iree_vm_variant_empty();
+ IREE_ASSERT_OK(iree_vm_list_get_variant(list, i, &value));
+ EXPECT_TRUE(iree_vm_variant_is_empty(value));
+ }
+
+ iree_vm_list_release(list);
+}
// TODO(benvanik): test value get/set.
diff --git a/iree/vm/module.h b/iree/vm/module.h
index d8635e0..8eabc0a 100644
--- a/iree/vm/module.h
+++ b/iree/vm/module.h
@@ -192,9 +192,8 @@
// Argument buffer in the format described above.
// This is only read on beginning the function and need not live beyond that.
//
- // Refs contained will be moved into the target function or released if not
- // needed. Callers must ensure they move or retain arguments when populating
- // the arguments buffer.
+ // Refs contained are retained by the caller and callees must retain them if
+ // they need them to live beyond the call.
iree_byte_span_t arguments;
// Storage for the result buffer; assumed undefined and then populated with
diff --git a/iree/vm/module_abi_packing.h b/iree/vm/module_abi_packing.h
index 0b00b30..d679964 100644
--- a/iree/vm/module_abi_packing.h
+++ b/iree/vm/module_abi_packing.h
@@ -362,7 +362,7 @@
struct ParamUnpack<opaque_ref> {
using storage_type = opaque_ref;
static void Load(Status& status, params_ptr_t& ptr, storage_type& out_param) {
- iree_vm_ref_move(reinterpret_cast<iree_vm_ref_t*>(ptr), &out_param);
+ iree_vm_ref_retain(reinterpret_cast<iree_vm_ref_t*>(ptr), &out_param);
ptr += sizeof(iree_vm_ref_t);
}
};
@@ -376,7 +376,7 @@
auto* reg_ptr = reinterpret_cast<iree_vm_ref_t*>(ptr);
ptr += sizeof(iree_vm_ref_t);
if (reg_ptr->type == ref_type_descriptor<T>::get()->type) {
- out_param = vm::assign_ref(reinterpret_cast<T*>(reg_ptr->ptr));
+ out_param = vm::retain_ref(reinterpret_cast<T*>(reg_ptr->ptr));
memset(reg_ptr, 0, sizeof(*reg_ptr));
} else if (IREE_UNLIKELY(reg_ptr->type != IREE_VM_REF_TYPE_NULL)) {
status =
@@ -401,7 +401,7 @@
auto* reg_ptr = reinterpret_cast<iree_vm_ref_t*>(ptr);
ptr += sizeof(iree_vm_ref_t);
if (reg_ptr->type == ref_type_descriptor<T>::get()->type) {
- out_param = vm::assign_ref(reinterpret_cast<T*>(reg_ptr->ptr));
+ out_param = vm::retain_ref(reinterpret_cast<T*>(reg_ptr->ptr));
memset(reg_ptr, 0, sizeof(*reg_ptr));
} else if (IREE_UNLIKELY(reg_ptr->type != IREE_VM_REF_TYPE_NULL)) {
status =
diff --git a/iree/vm/ops.h b/iree/vm/ops.h
index b205445..0e6e181 100644
--- a/iree/vm/ops.h
+++ b/iree/vm/ops.h
@@ -18,6 +18,12 @@
#include <stdint.h>
//===------------------------------------------------------------------===//
+// Constants
+//===------------------------------------------------------------------===//
+
+static inline int32_t vm_const_i32(int32_t a) { return a; }
+
+//===------------------------------------------------------------------===//
// Conditional assignment
//===------------------------------------------------------------------===//
@@ -58,22 +64,22 @@
static inline int32_t vm_trunc_i32i8(int32_t operand) {
return (uint8_t)((uint32_t)operand);
-};
+}
static inline int32_t vm_trunc_i32i16(int32_t operand) {
return (uint16_t)((uint32_t)operand);
-};
+}
static inline int32_t vm_ext_i8i32s(int32_t operand) {
return (int32_t)((int8_t)operand);
-};
+}
static inline int32_t vm_ext_i8i32u(int32_t operand) {
return (uint32_t)((uint8_t)operand);
-};
+}
static inline int32_t vm_ext_i16i32s(int32_t operand) {
return (int32_t)((int16_t)operand);
-};
+}
static inline int32_t vm_ext_i16i32u(int32_t operand) {
return (uint32_t)((uint16_t)operand);
-};
+}
//===------------------------------------------------------------------===//
// Native bitwise shifts and rotates
@@ -81,13 +87,13 @@
static inline int32_t vm_shl_i32(int32_t operand, int8_t amount) {
return (int32_t)(operand << amount);
-};
+}
static inline int32_t vm_shr_i32s(int32_t operand, int8_t amount) {
return (int32_t)(operand >> amount);
-};
+}
static inline int32_t vm_shr_i32u(int32_t operand, int8_t amount) {
return (int32_t)(((uint32_t)operand) >> amount);
-};
+}
//===------------------------------------------------------------------===//
// Comparison ops
@@ -109,6 +115,9 @@
return (operand != 0) ? 1 : 0;
}
+//===------------------------------------------------------------------===//
+// Additional ops
+//===------------------------------------------------------------------===//
// Check ops
// TODO(simon-camp): These macros should be removed once control flow ops are
// supported in the c module target
@@ -118,7 +127,59 @@
iree_make_cstring_view("message")); \
}
-// Const ops
-inline int32_t vm_const_i32(int32_t a) { return a; }
+//===------------------------------------------------------------------===//
+// ExtI64: Constants
+//===------------------------------------------------------------------===//
+
+static inline int64_t vm_const_i64(int64_t a) { return a; }
+
+//===------------------------------------------------------------------===//
+// ExtI64: Conditional assignment
+//===------------------------------------------------------------------===//
+
+static inline int64_t vm_select_i64(int32_t condition, int64_t true_value,
+ int64_t false_value) {
+ return condition ? true_value : false_value;
+}
+
+//===------------------------------------------------------------------===//
+// ExtI64: Native integer arithmetic ops
+//===------------------------------------------------------------------===//
+
+static inline int64_t vm_add_i64(int64_t lhs, int64_t rhs) { return lhs + rhs; }
+static inline int64_t vm_sub_i64(int64_t lhs, int64_t rhs) { return lhs - rhs; }
+static inline int64_t vm_mul_i64(int64_t lhs, int64_t rhs) { return lhs * rhs; }
+static inline int64_t vm_div_i64s(int64_t lhs, int64_t rhs) {
+ return lhs / rhs;
+}
+static inline int64_t vm_div_i64u(int64_t lhs, int64_t rhs) {
+ return (int64_t)(((uint64_t)lhs) / ((uint64_t)rhs));
+}
+static inline int64_t vm_rem_i64s(int64_t lhs, int64_t rhs) {
+ return lhs % rhs;
+}
+static inline int64_t vm_rem_i64u(int64_t lhs, int64_t rhs) {
+ return (int64_t)(((uint64_t)lhs) % ((uint64_t)rhs));
+}
+static inline int64_t vm_not_i64(int64_t operand) {
+ return (int64_t)(~((uint64_t)operand));
+}
+static inline int64_t vm_and_i64(int64_t lhs, int64_t rhs) { return lhs & rhs; }
+static inline int64_t vm_or_i64(int64_t lhs, int64_t rhs) { return lhs | rhs; }
+static inline int64_t vm_xor_i64(int64_t lhs, int64_t rhs) { return lhs ^ rhs; }
+
+//===------------------------------------------------------------------===//
+// ExtI64: Native bitwise shifts and rotates
+//===------------------------------------------------------------------===//
+
+static inline int64_t vm_shl_i64(int64_t operand, int8_t amount) {
+ return (int64_t)(operand << amount);
+}
+static inline int64_t vm_shr_i64s(int64_t operand, int8_t amount) {
+ return (int64_t)(operand >> amount);
+}
+static inline int64_t vm_shr_i64u(int64_t operand, int8_t amount) {
+ return (int64_t)(((uint64_t)operand) >> amount);
+}
#endif // IREE_VM_OPS_H_
diff --git a/iree/vm/test/BUILD b/iree/vm/test/BUILD
index a826400..ac52126 100644
--- a/iree/vm/test/BUILD
+++ b/iree/vm/test/BUILD
@@ -37,11 +37,14 @@
":arithmetic_ops.vmfb",
":arithmetic_ops_i64.vmfb",
":assignment_ops.vmfb",
+ ":assignment_ops_i64.vmfb",
":comparison_ops.vmfb",
":control_flow_ops.vmfb",
":conversion_ops.vmfb",
+ ":conversion_ops_i64.vmfb",
":list_ops.vmfb",
":shift_ops.vmfb",
+ ":shift_ops_i64.vmfb",
],
cc_file_output = "all_bytecode_modules.cc",
cpp_namespace = "iree::vm::test",
@@ -68,6 +71,12 @@
)
iree_bytecode_module(
+ name = "assignment_ops_i64",
+ src = "assignment_ops_i64.mlir",
+ flags = ["-iree-vm-ir-to-bytecode-module"],
+)
+
+iree_bytecode_module(
name = "comparison_ops",
src = "comparison_ops.mlir",
flags = ["-iree-vm-ir-to-bytecode-module"],
@@ -86,6 +95,12 @@
)
iree_bytecode_module(
+ name = "conversion_ops_i64",
+ src = "conversion_ops_i64.mlir",
+ flags = ["-iree-vm-ir-to-bytecode-module"],
+)
+
+iree_bytecode_module(
name = "list_ops",
src = "list_ops.mlir",
cc_namespace = "iree::vm::test",
@@ -98,3 +113,10 @@
cc_namespace = "iree::vm::test",
flags = ["-iree-vm-ir-to-bytecode-module"],
)
+
+iree_bytecode_module(
+ name = "shift_ops_i64",
+ src = "shift_ops_i64.mlir",
+ cc_namespace = "iree::vm::test",
+ flags = ["-iree-vm-ir-to-bytecode-module"],
+)
diff --git a/iree/vm/test/CMakeLists.txt b/iree/vm/test/CMakeLists.txt
index 3351bd3..a6546c6 100644
--- a/iree/vm/test/CMakeLists.txt
+++ b/iree/vm/test/CMakeLists.txt
@@ -25,11 +25,14 @@
"arithmetic_ops.vmfb"
"arithmetic_ops_i64.vmfb"
"assignment_ops.vmfb"
+ "assignment_ops_i64.vmfb"
"comparison_ops.vmfb"
"control_flow_ops.vmfb"
"conversion_ops.vmfb"
+ "conversion_ops_i64.vmfb"
"list_ops.vmfb"
"shift_ops.vmfb"
+ "shift_ops_i64.vmfb"
CC_FILE_OUTPUT
"all_bytecode_modules.cc"
H_FILE_OUTPUT
@@ -72,6 +75,16 @@
iree_bytecode_module(
NAME
+ assignment_ops_i64
+ SRC
+ "assignment_ops_i64.mlir"
+ FLAGS
+ "-iree-vm-ir-to-bytecode-module"
+ PUBLIC
+)
+
+iree_bytecode_module(
+ NAME
comparison_ops
SRC
"comparison_ops.mlir"
@@ -102,6 +115,16 @@
iree_bytecode_module(
NAME
+ conversion_ops_i64
+ SRC
+ "conversion_ops_i64.mlir"
+ FLAGS
+ "-iree-vm-ir-to-bytecode-module"
+ PUBLIC
+)
+
+iree_bytecode_module(
+ NAME
list_ops
SRC
"list_ops.mlir"
@@ -123,3 +146,15 @@
"-iree-vm-ir-to-bytecode-module"
PUBLIC
)
+
+iree_bytecode_module(
+ NAME
+ shift_ops_i64
+ SRC
+ "shift_ops_i64.mlir"
+ CC_NAMESPACE
+ "iree::vm::test"
+ FLAGS
+ "-iree-vm-ir-to-bytecode-module"
+ PUBLIC
+)
diff --git a/iree/vm/test/assignment_ops_i64.mlir b/iree/vm/test/assignment_ops_i64.mlir
new file mode 100644
index 0000000..895efb0
--- /dev/null
+++ b/iree/vm/test/assignment_ops_i64.mlir
@@ -0,0 +1,25 @@
+vm.module @assignment_ops_i64 {
+
+ //===--------------------------------------------------------------------===//
+ // ExtI64: Conditional assignment
+ //===--------------------------------------------------------------------===//
+
+ // TODO: The CModuleTarget enforces exports to be ordered.
+ vm.export @test_select_i64
+
+ vm.func @test_select_i64() {
+ %c0 = vm.const.i32 0 : i32
+ %c0dno = iree.do_not_optimize(%c0) : i32
+ %c1 = vm.const.i32 1 : i32
+ %c1dno = iree.do_not_optimize(%c1) : i32
+ %c2 = vm.const.i64 0 : i64
+ %c2dno = iree.do_not_optimize(%c2) : i64
+ %c3 = vm.const.i64 1 : i64
+ %c3dno = iree.do_not_optimize(%c3) : i64
+ %v1 = vm.select.i64 %c0dno, %c2dno, %c3dno : i64
+ vm.check.eq %v1, %c3, "0 ? 0 : 1 = 1" : i64
+ %v2 = vm.select.i64 %c1dno, %c2dno, %c3dno : i64
+ vm.check.eq %v2, %c2, "1 ? 0 : 1 = 0" : i64
+ vm.return
+ }
+}
diff --git a/iree/vm/test/conversion_ops_i64.mlir b/iree/vm/test/conversion_ops_i64.mlir
new file mode 100644
index 0000000..6d420c9
--- /dev/null
+++ b/iree/vm/test/conversion_ops_i64.mlir
@@ -0,0 +1,17 @@
+vm.module @conversion_ops_i64 {
+
+ //===----------------------------------------------------------------------===//
+ // ExtI64: Casting and type conversion/emulation
+ //===----------------------------------------------------------------------===//
+
+ vm.export @test_trunc_i64_i32
+
+ vm.func @test_trunc_i64_i32() {
+ %c1 = vm.const.i64 9223372036854775807 : i64
+ %c1dno = iree.do_not_optimize(%c1) : i64
+ %v = vm.trunc.i64.i32 %c1dno : i64 -> i32
+ %c2 = vm.const.i32 4294967295 : i32
+ vm.check.eq %v, %c2, "truncate unsigned i64 to unsigned i32" : i32
+ vm.return
+ }
+}
diff --git a/iree/vm/test/emitc/CMakeLists.txt b/iree/vm/test/emitc/CMakeLists.txt
index c64226b..2bd1e7e 100644
--- a/iree/vm/test/emitc/CMakeLists.txt
+++ b/iree/vm/test/emitc/CMakeLists.txt
@@ -31,9 +31,12 @@
iree::vm::ops
iree::vm::shims
::arithmetic_ops_cc
+ ::arithmetic_ops_i64_cc
::assignment_ops_cc
+ ::assignment_ops_i64_cc
::conversion_ops_cc
::shift_ops_cc
+ ::shift_ops_i64_cc
)
iree_bytecode_module(
@@ -50,6 +53,18 @@
iree_bytecode_module(
NAME
+ arithmetic_ops_i64
+ SRC
+ "../arithmetic_ops_i64.mlir"
+ CC_NAMESPACE
+ "iree::vm::test::emitc"
+ FLAGS
+ "-iree-vm-ir-to-c-module"
+ PUBLIC
+)
+
+iree_bytecode_module(
+ NAME
assignment_ops
SRC
"../assignment_ops.mlir"
@@ -62,6 +77,18 @@
iree_bytecode_module(
NAME
+ assignment_ops_i64
+ SRC
+ "../assignment_ops_i64.mlir"
+ CC_NAMESPACE
+ "iree::vm::test::emitc"
+ FLAGS
+ "-iree-vm-ir-to-c-module"
+ PUBLIC
+)
+
+iree_bytecode_module(
+ NAME
conversion_ops
SRC
"../conversion_ops.mlir"
@@ -84,4 +111,16 @@
PUBLIC
)
+iree_bytecode_module(
+ NAME
+ shift_ops_i64
+ SRC
+ "../shift_ops_i64.mlir"
+ CC_NAMESPACE
+ "iree::vm::test::emitc"
+ FLAGS
+ "-iree-vm-ir-to-c-module"
+ PUBLIC
+)
+
endif()
diff --git a/iree/vm/test/emitc/module_test.cc b/iree/vm/test/emitc/module_test.cc
index f0bb3a3..51140bd 100644
--- a/iree/vm/test/emitc/module_test.cc
+++ b/iree/vm/test/emitc/module_test.cc
@@ -19,9 +19,12 @@
#include "iree/testing/gtest.h"
#include "iree/vm/api.h"
#include "iree/vm/test/emitc/arithmetic_ops.vmfb"
+#include "iree/vm/test/emitc/arithmetic_ops_i64.vmfb"
#include "iree/vm/test/emitc/assignment_ops.vmfb"
+#include "iree/vm/test/emitc/assignment_ops_i64.vmfb"
#include "iree/vm/test/emitc/conversion_ops.vmfb"
#include "iree/vm/test/emitc/shift_ops.vmfb"
+#include "iree/vm/test/emitc/shift_ops_i64.vmfb"
namespace {
@@ -49,9 +52,12 @@
// TODO(simon-camp): get these automatically
std::vector<ModuleDescription> modules = {
{arithmetic_ops_descriptor_, arithmetic_ops_create},
+ {arithmetic_ops_i64_descriptor_, arithmetic_ops_i64_create},
{assignment_ops_descriptor_, assignment_ops_create},
+ {assignment_ops_i64_descriptor_, assignment_ops_i64_create},
{conversion_ops_descriptor_, conversion_ops_create},
- {shift_ops_descriptor_, shift_ops_create}};
+ {shift_ops_descriptor_, shift_ops_create},
+ {shift_ops_i64_descriptor_, shift_ops_i64_create}};
for (size_t i = 0; i < modules.size(); i++) {
iree_vm_native_module_descriptor_t descriptor = modules[i].descriptor;
diff --git a/iree/vm/test/shift_ops_i64.mlir b/iree/vm/test/shift_ops_i64.mlir
new file mode 100644
index 0000000..04da8af
--- /dev/null
+++ b/iree/vm/test/shift_ops_i64.mlir
@@ -0,0 +1,36 @@
+vm.module @shift_ops_i64 {
+
+ //===--------------------------------------------------------------------===//
+ // ExtI64: Native bitwise shifts and rotates
+ //===--------------------------------------------------------------------===//
+
+ vm.export @test_shl_i64
+ vm.func @test_shl_i64() {
+ %c1 = vm.const.i64 1 : i64
+ %c1dno = iree.do_not_optimize(%c1) : i64
+ %v = vm.shl.i64 %c1dno, 2 : i64
+ %c2 = vm.const.i64 4 : i64
+ vm.check.eq %v, %c2, "1<<2=4" : i64
+ vm.return
+ }
+
+ vm.export @test_shr_i64s
+ vm.func @test_shr_i64s() {
+ %c1 = vm.const.i64 -1 : i64
+ %c1dno = iree.do_not_optimize(%c1) : i64
+ %v = vm.shr.i64.s %c1dno, 2 : i64
+ %c2 = vm.const.i64 -1 : i64
+ vm.check.eq %v, %c2, "-1>>-1=-1" : i64
+ vm.return
+ }
+
+ vm.export @test_shr_i64u
+ vm.func @test_shr_i64u() {
+ %c1 = vm.const.i64 4 : i64
+ %c1dno = iree.do_not_optimize(%c1) : i64
+ %v = vm.shr.i64.u %c1dno, 2 : i64
+ %c2 = vm.const.i64 1 : i64
+ vm.check.eq %v, %c2, "4>>2=1" : i64
+ vm.return
+ }
+}
diff --git a/iree/vm/type_def.h b/iree/vm/type_def.h
index 5592494..c3294d6 100644
--- a/iree/vm/type_def.h
+++ b/iree/vm/type_def.h
@@ -87,6 +87,7 @@
{ {IREE_VM_VALUE_TYPE_NONE, IREE_VM_REF_TYPE_NULL}, {0}, }
#define iree_vm_variant_is_value(v) iree_vm_type_def_is_value(&(v).type)
#define iree_vm_variant_is_ref(v) iree_vm_type_def_is_ref(&(v).type)
+#define iree_vm_variant_is_empty(v) iree_vm_type_def_is_variant(&(v).type)
#ifdef __cplusplus
} // extern "C"