Merge pull request #3901 from GMNGeoffrey:main-to-google
PiperOrigin-RevId: 343112174
diff --git a/docs/get_started/getting_started_android_cmake.md b/docs/get_started/getting_started_android_cmake.md
index 05cbd71..874b183 100644
--- a/docs/get_started/getting_started_android_cmake.md
+++ b/docs/get_started/getting_started_android_cmake.md
@@ -276,8 +276,8 @@
```shell
$ ../iree-build-android/host/bin/iree-translate \
-iree-mlir-to-vm-bytecode-module \
- -iree-llvm-target-triple=aarch64-linux-android \
-iree-hal-target-backends=dylib-llvm-aot \
+ -iree-llvm-target-triple=aarch64-linux-android \
$PWD/iree/tools/test/simple.mlir \
-o /tmp/simple-llvm_aot.vmfb
```
diff --git a/docs/get_started/getting_started_windows_cmake.md b/docs/get_started/getting_started_windows_cmake.md
index a6025b2..04a48e0 100644
--- a/docs/get_started/getting_started_windows_cmake.md
+++ b/docs/get_started/getting_started_windows_cmake.md
@@ -89,6 +89,51 @@
> cmake --build ..\iree-build\
```
+## Target Configuration
+
+### LLVM AOT Backend
+
+`-iree-hal-target-backends=dylib-llvm-aot` can be used to generate modules with
+ahead-of-time compiled kernels stored in DLLs. Run the iree-opt/iree-translate
+tools from a command prompt with `lld-link.exe` or `link.exe` tools on the
+`PATH` and the MSVC/Windows SDK environment variables; the easiest way to get
+this configured is to use the `vsvarsall.bat` or `vcvars64.bat` files to set
+your environment. See
+[the Microsoft documentation](https://docs.microsoft.com/en-us/cpp/build/building-on-the-command-line?view=vs-2019)
+for details on configuring the toolchain.
+
+If you want to manually specify the linker used, set the
+`IREE_LLVMAOT_LINKER_PATH` environment variable to the path of the linker:
+
+```powershell
+> set IREE_LLVMAOT_LINKER_PATH="C:\Tools\LLVM\bin\lld-link.exe"
+```
+
+Translate a source MLIR file into an IREE module:
+
+```powershell
+> ..\iree-build\iree\tools\iree-translate.exe \
+ -iree-mlir-to-vm-bytecode-module \
+ -iree-hal-target-backends=dylib-llvm-aot \
+ iree/tools/test/simple.mlir \
+ -o %TMP%/simple-llvm_aot.vmfb
+```
+
+Note that this will use the host machine as the target by default, and the
+exact target triple and architecture can be specified with flags when
+cross-compiling:
+
+```powershell
+> ..\iree-build\iree\tools\iree-translate.exe \
+ -iree-mlir-to-vm-bytecode-module \
+ -iree-hal-target-backends=dylib-llvm-aot \
+ -iree-llvm-target-triple=x86_64-pc-windows-msvc \
+ -iree-llvm-target-cpu=host \
+ -iree-llvm-target-cpu-features=host \
+ iree/tools/test/simple.mlir \
+ -o %TMP%/simple-llvm_aot.vmfb
+```
+
## What's next?
### Take a Look Around
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/BUILD b/integrations/tensorflow/bindings/python/pyiree/tf/support/BUILD
index fc67c96..9b1c869 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/BUILD
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/BUILD
@@ -60,7 +60,6 @@
],
python_version = "PY3",
tags = [
- "driver=llvm",
"driver=vmla",
],
deps = INTREE_TENSORFLOW_PY_DEPS + [
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
index 187b765..591df43 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
@@ -992,11 +992,6 @@
"driver": "vmla",
"compiler_targets": ["vmla"]
},
- "iree_llvmjit": {
- "compiled_module_class": IreeCompiledModule,
- "driver": "llvm",
- "compiler_targets": ["llvm-ir"]
- },
"iree_vulkan": {
"compiled_module_class": IreeCompiledModule,
"driver": "vulkan",
@@ -1009,13 +1004,13 @@
Args:
backend_name: a str specifying which backend to use. Should be one of
- 'tf', 'iree_vmla', 'iree_llvmjit', 'iree_vulkan'.
+ 'tf', 'iree_vmla', 'iree_vulkan'.
backend_id: an optional str specifying what name to use when saving
compiled artifacts. Must satisfy `backend_id.startswith(backend_name)`.
Raises:
- KeyError: if backend_name is not one of ['tf', 'iree_vmla',
- 'iree_llvmjit', 'iree_vulkan'].
+ KeyError: if backend_name is not one of
+ ['tf', 'iree_vmla', 'iree_vulkan'].
ValueError: if backend_id doesn't start with backend_name.
"""
if backend_name not in self._name_to_info:
diff --git a/integrations/tensorflow/e2e/README.md b/integrations/tensorflow/e2e/README.md
index 0ac1605..4b457b5 100644
--- a/integrations/tensorflow/e2e/README.md
+++ b/integrations/tensorflow/e2e/README.md
@@ -18,7 +18,7 @@
If you do not have your environment setup to use IREE with Vulkan (see
[this doc](https://google.github.io/iree/get-started/generic-vulkan-env-setup)),
then you can run the manual test targets with
-`--target_backends=tf,iree_vmla,iree_llvmjit` (that is, by omitting
+`--target_backends=tf,iree_vmla` (that is, by omitting
`iree_vulkan` from the list of backends to run the tests on).
The test suites can be run excluding Vulkan by specifying
@@ -234,8 +234,8 @@
to check numerical correctness against TensorFlow. Tests targets that pass are
placed into the `e2e_tests` test suite. Tests that fail on particular backends
are recorded in lists in the `BUILD` files. For example, if
-`experimental_new_test.py` fails on the `iree_llvmjit` and `iree_vulkan`
-backends then the following lines should be added to the `BUILD` file:
+`experimental_new_test.py` fails on the `iree_vulkan` backend then the following
+lines should be added to the `BUILD` file:
```build
LLVM_FAILING = [
diff --git a/integrations/tensorflow/e2e/iree_e2e_cartesian_product_test_suite.bzl b/integrations/tensorflow/e2e/iree_e2e_cartesian_product_test_suite.bzl
index ee6d4f1..7f49f83 100644
--- a/integrations/tensorflow/e2e/iree_e2e_cartesian_product_test_suite.bzl
+++ b/integrations/tensorflow/e2e/iree_e2e_cartesian_product_test_suite.bzl
@@ -164,6 +164,13 @@
tests = []
for flags in all_flag_configurations:
+ if len(flags["target_backends"].split(",")) > 1:
+ fail("Multiple target backends cannot be specified at once, but " +
+ "got `{}`".format(flags["target_backends"]))
+ driver = get_driver(flags["target_backends"])
+ if not driver:
+ continue
+
# Check if this is a failing configuration.
failing = flags in failing_flag_configurations
@@ -180,12 +187,6 @@
tests.append(test_name)
args = ["--{}={}".format(k, v) for k, v in flags.items()]
-
- if len(flags["target_backends"].split(",")) > 1:
- fail("Multiple target backends cannot be specified at once, but " +
- "got `{}`".format(flags["target_backends"]))
-
- driver = get_driver(flags["target_backends"])
py_test_tags = ["driver={}".format(driver)]
if tags != None: # `is` is not supported.
py_test_tags += tags
diff --git a/integrations/tensorflow/e2e/iree_e2e_test_suite.bzl b/integrations/tensorflow/e2e/iree_e2e_test_suite.bzl
index 8fb35a4..58be0f6 100644
--- a/integrations/tensorflow/e2e/iree_e2e_test_suite.bzl
+++ b/integrations/tensorflow/e2e/iree_e2e_test_suite.bzl
@@ -19,8 +19,10 @@
def get_driver(backend):
# TODO(#2175): Simplify this after backend names are standardized.
driver = backend.replace("iree_", "") # "iree_<driver>" --> "<driver>"
+
+ # TODO(#2673): enable LLVM AOT for these tests. JIT is deprecated.
if driver == "llvmjit":
- driver = "llvm"
+ driver = ""
return driver
def iree_e2e_test_suite(
@@ -70,6 +72,8 @@
]
driver = get_driver(backend)
+ if not driver:
+ continue
py_test_tags = ["driver={}".format(driver)]
if tags != None: # `is` is not supported.
py_test_tags += tags
diff --git a/integrations/tensorflow/e2e/keras/BUILD b/integrations/tensorflow/e2e/keras/BUILD
index d6e61ad..468eb79 100644
--- a/integrations/tensorflow/e2e/keras/BUILD
+++ b/integrations/tensorflow/e2e/keras/BUILD
@@ -44,13 +44,13 @@
vision_model_test_manual is for manual testing of all keras vision models.
Test will run only manually with all parameters specified manually, for example:
bazel run -c opt integrations/tensorflow/e2e/keras:vision_model_test_manual -- \
---target_backends=tf,iree_vmla,iree_llvmjit \
+--target_backends=tf,iree_vmla \
--data=imagenet \
--url=https://storage.googleapis.com/iree_models/ \
--model=ResNet50
Command arguments description:
---target_backends: can be combination of these: tf,iree_vmla,iree_llvmjit
+--target_backends: can be combination of these: tf,iree_vmla
--data: can be 'imagenet' or 'cifar10'.
imagenet - input image size (1, 224, 224, 3)
cifar10 - input image size (1, 32, 32, 3) - it is used for quick tests
diff --git a/integrations/tensorflow/e2e/slim_vision_models/BUILD b/integrations/tensorflow/e2e/slim_vision_models/BUILD
index 88f60b0..881aff3 100644
--- a/integrations/tensorflow/e2e/slim_vision_models/BUILD
+++ b/integrations/tensorflow/e2e/slim_vision_models/BUILD
@@ -76,7 +76,6 @@
"resnet_v2_152",
],
"target_backends": [
- "iree_llvmjit",
"iree_vulkan",
],
},
@@ -159,7 +158,6 @@
"tf",
"tflite",
"iree_vmla",
- "iree_llvmjit",
"iree_vulkan",
],
},
diff --git a/iree/base/wait_handle.h b/iree/base/wait_handle.h
index adff1c9..b95bf7e 100644
--- a/iree/base/wait_handle.h
+++ b/iree/base/wait_handle.h
@@ -32,8 +32,7 @@
#endif // IREE_PLATFORM_*
// TODO(benvanik): see if we can get sync file on linux too:
-// TODO(scotttodd): fix include on Android (missing linkopts?)
-#if 0 && defined(IREE_PLATFORM_ANDROID)
+#if defined(IREE_PLATFORM_ANDROID)
#define IREE_HAVE_WAIT_TYPE_SYNC_FILE 1
#endif // IREE_PLATFORM_ANDROID
diff --git a/iree/base/wait_handle_posix.c b/iree/base/wait_handle_posix.c
index 8f180d1..088f6cc 100644
--- a/iree/base/wait_handle_posix.c
+++ b/iree/base/wait_handle_posix.c
@@ -27,7 +27,7 @@
#include <sys/eventfd.h>
#endif // IREE_HAVE_WAIT_TYPE_EVENTFD
#if defined(IREE_HAVE_WAIT_TYPE_SYNC_FILE)
-#include <sync.h>
+#include <android/sync.h>
#endif // IREE_HAVE_WAIT_TYPE_SYNC_FILE
//===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.h b/iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.h
index e8cd9db..01e4b3d 100644
--- a/iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.h
+++ b/iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.h
@@ -28,7 +28,7 @@
namespace mlir {
class Location;
class FuncOp;
-class LogicalResult;
+struct LogicalResult;
class PatternRewriter;
class ConversionPatternRewriter;
class Value;
diff --git a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
index 9f33bd1..7381818 100644
--- a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
+++ b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
@@ -665,44 +665,55 @@
namespace {
/// Converts mhlo.slice operation to linalg.subview + linalg.copy
-struct SliceOpConversion
- : public ConvertToLinalgBufferOp<SliceOpConversion, mhlo::SliceOp> {
- using ConvertToLinalgBufferOp<SliceOpConversion,
- mhlo::SliceOp>::ConvertToLinalgBufferOp;
+struct SliceOpConversion : public OpConversionPattern<mhlo::SliceOp> {
+ SliceOpConversion(MLIRContext *context,
+ TensorToBufferMap const &resultTensorToBufferMap,
+ PatternBenefit benefit = 1)
+ : OpConversionPattern<mhlo::SliceOp>(context, benefit),
+ resultTensorToBufferMap(resultTensorToBufferMap) {}
- LogicalResult apply(mhlo::SliceOp op, ArrayRef<Value> inputBuffers,
- ArrayRef<Value> resultBuffers,
- ConversionPatternRewriter &rewriter) const;
+ LogicalResult matchAndRewrite(
+ mhlo::SliceOp op, ArrayRef<Value> inputBuffers,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ auto argType = inputBuffers[0].getType().template dyn_cast<ShapedType>();
+ if (!argType || !argType.hasStaticShape()) {
+ return op.emitError("expected static shape");
+ }
+
+ auto resultShape = op.getResult().getType().cast<ShapedType>().getShape();
+ SmallVector<Value, 3> offsets, sizes, strides;
+ for (int i = 0, e = argType.getRank(); i < e; ++i) {
+ Value startIndex = rewriter.create<ConstantIndexOp>(
+ loc, op.start_indices().getValue<int64_t>(i));
+ offsets.push_back(startIndex);
+ Value size = rewriter.create<ConstantIndexOp>(loc, resultShape[i]);
+ sizes.push_back(size);
+ Value stride = rewriter.create<ConstantIndexOp>(
+ loc, op.strides().getValue<int64_t>(i));
+ strides.push_back(stride);
+ }
+ auto subViewOp = rewriter.create<SubViewOp>(loc, inputBuffers[0], offsets,
+ sizes, strides);
+
+ // If the result of the subview is already mapped to a buffer, a copy is
+ // required from the buffer above into the mapped buffer.
+ if (Value bufferForResult =
+ resultTensorToBufferMap.lookup(op.getResult())) {
+ rewriter.create<linalg::CopyOp>(loc, subViewOp, bufferForResult);
+ rewriter.replaceOp(op, bufferForResult);
+ } else {
+ rewriter.replaceOp(op, subViewOp.getResult());
+ }
+
+ return success();
+ }
+
+ private:
+ TensorToBufferMap const &resultTensorToBufferMap;
};
} // namespace
-LogicalResult SliceOpConversion::apply(
- mhlo::SliceOp op, ArrayRef<Value> inputBuffers,
- ArrayRef<Value> resultBuffers, ConversionPatternRewriter &rewriter) const {
- auto loc = op.getLoc();
- auto argType = inputBuffers[0].getType().template dyn_cast<ShapedType>();
- if (!argType || !argType.hasRank()) {
- return op.emitError("expected known-rank args");
- }
-
- SmallVector<Value, 3> offsets, sizes, strides;
- for (int i = 0, e = argType.getRank(); i < e; ++i) {
- Value startIndex = rewriter.create<ConstantIndexOp>(
- loc, op.start_indices().getValue<int64_t>(i));
- offsets.push_back(startIndex);
- Value size = rewriter.create<DimOp>(loc, resultBuffers[0], i);
- sizes.push_back(size);
- Value stride = rewriter.create<ConstantIndexOp>(
- loc, op.strides().getValue<int64_t>(i));
- strides.push_back(stride);
- }
- auto subViewOp =
- rewriter.create<SubViewOp>(loc, inputBuffers[0], offsets, sizes, strides);
- rewriter.create<linalg::CopyOp>(loc, subViewOp, resultBuffers[0]);
-
- return success();
-}
-
//===----------------------------------------------------------------------===//
// mhlo.reduce_window conversion patterns and utility functions.
//===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Conversion/HLOToLinalg/test/slice.mlir b/iree/compiler/Conversion/HLOToLinalg/test/slice.mlir
index 406a62b..0de43b3 100644
--- a/iree/compiler/Conversion/HLOToLinalg/test/slice.mlir
+++ b/iree/compiler/Conversion/HLOToLinalg/test/slice.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt -split-input-file -iree-codegen-hlo-to-linalg-on-buffers -cse %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -iree-codegen-hlo-to-linalg-pipeline -canonicalize %s | IreeFileCheck %s
module {
// CHECK_LABEL: @slice_whole_buffer
@@ -25,19 +25,10 @@
// -----
module {
- // CHECK: #[[MAP:.+]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
+ // CHECK: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0 * 4 + d1 + 4)>
// CHECK: @slice_whole_stride
- // CHECK-DAG: %[[OUT:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<1x4xi32>
// CHECK-DAG: %[[IN:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<3x4xi32>
- // CHECK-DAG: %[[ZERO:.+]] = constant 0 : index
- // CHECK-DAG: %[[ONE:.+]] = constant 1 : index
- // CHECK-DAG: %[[DIM0:.+]] = dim %[[OUT]], %[[ZERO]] : memref<1x4xi32>
- // CHECK-DAG: %[[DIM1:.+]] = dim %[[OUT]], %[[ONE]] : memref<1x4xi32>
- // CHECK: subview %[[IN]]
- // CHECK-SAME: [%[[ONE]], %[[ZERO]]]
- // CHECK-SAME: [%[[DIM0]], %[[DIM1]]]
- // CHECK-SAME: [%[[ONE]], %[[ONE]]]
- // CHECK-SAME: : memref<3x4xi32> to memref<?x?xi32, #[[MAP]]>
+ // CHECK: subview %[[IN]][1, 0] [1, 4] [1, 1] : memref<3x4xi32> to memref<1x4xi32, #[[MAP]]>
// CHECK: linalg.copy
func @slice_whole_stride()
attributes {signature = (tensor<3x4xi32>) -> (tensor<1x4xi32>)} {
@@ -60,19 +51,10 @@
// -----
module {
- // CHECK: #[[MAP:.+]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
+ // CHECK: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0 * 4 + d1 + 5)>
// CHECK: @slice_stride_part
- // CHECK-DAG: %[[OUT:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<1x2xi32>
// CHECK-DAG: %[[IN:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<3x4xi32>
- // CHECK-DAG: %[[ZERO:.+]] = constant 0 : index
- // CHECK-DAG: %[[ONE:.+]] = constant 1 : index
- // CHECK-DAG: %[[DIM0:.+]] = dim %[[OUT]], %[[ZERO]] : memref<1x2xi32>
- // CHECK-DAG: %[[DIM1:.+]] = dim %[[OUT]], %[[ONE]] : memref<1x2xi32>
- // CHECK: subview %[[IN]]
- // CHECK-SAME: [%[[ONE]], %[[ONE]]]
- // CHECK-SAME: [%[[DIM0]], %[[DIM1]]]
- // CHECK-SAME: [%[[ONE]], %[[ONE]]]
- // CHECK-SAME: : memref<3x4xi32> to memref<?x?xi32, #map>
+ // CHECK: subview %[[IN]][1, 1] [1, 2] [1, 1] : memref<3x4xi32> to memref<1x2xi32, #[[MAP]]>
// CHECK: linalg.copy
func @slice_stride_part()
attributes {signature = (tensor<3x4xi32>) -> (tensor<1x2xi32>)} {
@@ -91,3 +73,32 @@
hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write"
}
}
+
+// -----
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0 * 4 + d1 + 5)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK: func @slice_stride_part
+// CHECK: %[[OUT:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0, operand_result_index = 2 : i32} : memref<1x2xi32>
+// CHECK: %[[IN0:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0, operand_result_index = 0 : i32} : memref<3x4xi32>
+// CHECK: %[[SUBVIEW:.+]] = subview %[[IN0]][1, 1] [1, 2] [1, 1] : memref<3x4xi32> to memref<1x2xi32, #[[MAP0]]>
+// CHECK: %[[IN1:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1, operand_result_index = 1 : i32} : memref<1x2xi32>
+// CHECK: linalg.generic
+// CHECK-SAME: ins(%[[SUBVIEW]], %[[IN1]]
+// CHECK-SAME: outs(%[[OUT]]
+module {
+ func @slice_stride_part() {
+ %c0 = constant 0 : index
+ %0 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 {operand_result_index = 0 : i32} : tensor<3x4xi32>
+ %1 = hal.interface.load.tensor @legacy_io::@arg1, offset = %c0 {operand_result_index = 1 : i32} : tensor<1x2xi32>
+ %2 = "mhlo.slice"(%0) {limit_indices = dense<[2, 3]> : tensor<2xi64>, start_indices = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x2xi32>
+ %3 = mhlo.add %2, %1 : tensor<1x2xi32>
+ hal.interface.store.tensor %3, @legacy_io::@ret0, offset = %c0 {operand_result_index = 2 : i32} : tensor<1x2xi32>
+ return
+ }
+ hal.interface @legacy_io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+ }
+}
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h
index 04be0ff..17f3bc5 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h
+++ b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h
@@ -32,7 +32,7 @@
namespace mlir {
class FuncOp;
-class LogicalResult;
+struct LogicalResult;
class Operation;
class PatternRewriter;
class ShapedType;
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Utils.h b/iree/compiler/Conversion/LinalgToSPIRV/Utils.h
index 972e116..ddaecf3 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/Utils.h
+++ b/iree/compiler/Conversion/LinalgToSPIRV/Utils.h
@@ -29,7 +29,7 @@
class SubViewOp;
class OperationFolder;
class OpBuilder;
-class LogicalResult;
+struct LogicalResult;
namespace iree_compiler {
diff --git a/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.cpp b/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.cpp
index 176f09d..9964eae 100644
--- a/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.cpp
@@ -31,12 +31,6 @@
namespace {
// TODO(laurenzo): Every one of these should have better support and removed
// from this exclusion list eventually.
-bool isUnsupportedFusionOp(Operation *op) {
- return isa<mhlo::ConcatenateOp, mhlo::ConvOp, mhlo::DotGeneralOp, mhlo::DotOp,
- mhlo::PadOp, mhlo::ReduceOp, mhlo::ReduceWindowOp, mhlo::SliceOp,
- mhlo::TorchIndexSelectOp>(op);
-}
-
// Allowlist of ops that materialize to a an index-permuted copy of some kind
// if they exist standalone. Generally we try to avoid anchoring on these,
// letting them fuse into more meaningful ops as possible.
@@ -182,6 +176,18 @@
return FusionType::DISABLED;
}
+// TODO(b/144530470): replace with tablegen attributes/interfaces.
+bool OpDispatchPolicy::isUnsupportedFusionOp(Operation *op) {
+ return isa<mhlo::ConcatenateOp, mhlo::ConvOp, mhlo::DotGeneralOp, mhlo::DotOp,
+ mhlo::PadOp, mhlo::ReduceOp, mhlo::ReduceWindowOp,
+ mhlo::TorchIndexSelectOp>(op) ||
+ isRootOnlyOp(op);
+}
+
+bool OpDispatchPolicy::isRootOnlyOp(Operation *op) {
+ return isa<mhlo::SliceOp>(op);
+}
+
} // namespace Flow
} // namespace IREE
} // namespace iree_compiler
diff --git a/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.h b/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.h
index 22666b3..ee9299f 100644
--- a/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.h
+++ b/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.h
@@ -41,6 +41,12 @@
OpDispatchPolicy(Dispatchability &dispatchability)
: dispatchability(dispatchability) {}
+ // Returns true if |op| is not able to fuse with either producer or consumer.
+ static bool isUnsupportedFusionOp(Operation *op);
+
+ // Returns true if |op| can only be a root op.
+ static bool isRootOnlyOp(Operation *op);
+
// Returns true if the given |op| can be dispatched in all cases.
// Other passes may handle special cases of these ops but this initial
// identification is conservative.
diff --git a/iree/compiler/Dialect/Flow/Transforms/FoldCompatibleDispatchRegions.cpp b/iree/compiler/Dialect/Flow/Transforms/FoldCompatibleDispatchRegions.cpp
index 984b8c7..9565747 100644
--- a/iree/compiler/Dialect/Flow/Transforms/FoldCompatibleDispatchRegions.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/FoldCompatibleDispatchRegions.cpp
@@ -13,6 +13,7 @@
// limitations under the License.
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "iree/compiler/Dialect/Flow/Transforms/DispatchConfig.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
@@ -200,10 +201,9 @@
// that substituting library calls is easier.
for (auto &block : regionOp.body().getBlocks()) {
for (auto &op : block) {
- // TODO(b/144530470): replace with tablegen attributes/interfaces.
- if (isa<mhlo::ConcatenateOp, mhlo::ConvOp, mhlo::DotGeneralOp,
- mhlo::DotOp, mhlo::PadOp, mhlo::ReduceOp, mhlo::ReduceWindowOp,
- mhlo::SliceOp, mhlo::TorchIndexSelectOp>(op)) {
+ // A root only op is mergable.
+ if (OpDispatchPolicy::isUnsupportedFusionOp(&op) &&
+ !OpDispatchPolicy::isRootOnlyOp(&op)) {
return false;
}
}
@@ -211,6 +211,24 @@
return regionOp.body().getBlocks().size() == 1;
}
+// Returns true if rhs has ops that can only be root op and will lose the
+// characteristic if merge two dispatch regions.
+bool rhsHasRootOnlyOp(DispatchRegionOp &lhs, DispatchRegionOp &rhs) {
+ auto &rhsBlock = rhs.body().front();
+ auto lhsArgs = llvm::to_vector<8>(lhs.args());
+ auto rhsArgs = llvm::to_vector<8>(rhs.args());
+ for (int rhsOpIdx = 0; rhsOpIdx < rhsArgs.size(); ++rhsOpIdx) {
+ for (int lhsResultIdx = 0; lhsResultIdx < lhs.getNumResults();
+ ++lhsResultIdx) {
+ if (rhsArgs[rhsOpIdx] != lhs.getResult(lhsResultIdx)) continue;
+ for (auto *user : rhsBlock.getArgument(rhsOpIdx).getUsers()) {
+ if (OpDispatchPolicy::isRootOnlyOp(user)) return true;
+ }
+ }
+ }
+ return false;
+}
+
// Merges |rhs| into |lhs| and returns the new |lhs| op.
// Precondition: !areDispatchRegionsTransitivelyDependent
DispatchRegionOp mergeDispatchRegions(DispatchRegionOp &lhs,
@@ -345,6 +363,10 @@
LLVM_DEBUG(llvm::dbgs()
<< " -REGION CONTAINS NON-TRIVIAL CONTROL FLOW-\n");
}
+ if (rhsHasRootOnlyOp(lhs, rhs)) {
+ LLVM_DEBUG(llvm::dbgs() << " -RHS REGION HAS ROOT OP-\n");
+ continue;
+ }
mergableRegions[i] = mergeDispatchRegions(lhs, rhs);
if (!mergableRegions[i]) {
return failure();
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/fold_compatible_dispatch_regions.mlir b/iree/compiler/Dialect/Flow/Transforms/test/fold_compatible_dispatch_regions.mlir
index 938c92a..c324ef7 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/fold_compatible_dispatch_regions.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/fold_compatible_dispatch_regions.mlir
@@ -132,3 +132,33 @@
// CHECK-LABEL: func @dominate
// CHECK: flow.dispatch.region
// CHECK-NOT: flow.dispatch.region
+
+// -----
+
+// Test if the op that only can be a root op fuse with consumer but not
+// producer. This test use a dummy workload to test on root only op
+// functionality.
+module {
+ func @rootOnlyOp(%arg0: tensor<3x4xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
+ %c0 = constant 0 : index
+ %0 = flow.dispatch.region[%c0 : index](%arg2 = %arg0 : tensor<3x4xi32>) -> tensor<3x4xi32> {
+ %3 = mhlo.add %arg2, %arg2 : tensor<3x4xi32>
+ flow.return %3 : tensor<3x4xi32>
+ }
+ %1 = flow.dispatch.region[%c0 : index](%arg2 = %0 : tensor<3x4xi32>) -> tensor<1x2xi32> {
+ %3 = "mhlo.slice"(%arg2) {limit_indices = dense<[2, 3]> : tensor<2xi64>, start_indices = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x2xi32>
+ flow.return %3 : tensor<1x2xi32>
+ }
+ %2 = flow.dispatch.region[%c0 : index](%arg2 = %1 : tensor<1x2xi32>, %arg3 = %arg1 : tensor<1x2xi32>) -> tensor<1x2xi32> {
+ %3 = mhlo.multiply %arg2, %arg3 : tensor<1x2xi32>
+ flow.return %3 : tensor<1x2xi32>
+ }
+ return %2 : tensor<1x2xi32>
+ }
+}
+// CHECK-LABEL: func @rootOnlyOp
+// CHECK: flow.dispatch.region
+// CHECK-NEXT: mhlo.add
+// CHECK: flow.dispatch.region
+// CHECK-NEXT: mhlo.slice
+// CHECK-NEXT: mhlo.multiply
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/BUILD b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/BUILD
index 594a0be..0c81851 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/BUILD
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/BUILD
@@ -38,7 +38,7 @@
"LLVMAOTTarget.h",
],
deps = [
- ":LLVMAOTTargetLinker",
+ ":LinkerTool",
"//iree/compiler/Dialect/HAL/Target",
"//iree/compiler/Dialect/HAL/Target/LLVM:LLVMBaseTarget",
"//iree/compiler/Dialect/HAL/Target/LLVM:LLVMIRPasses",
@@ -57,16 +57,19 @@
)
cc_library(
- name = "LLVMAOTTargetLinker",
- hdrs = ["LLVMAOTTargetLinker.h"],
- deps = platform_trampoline_deps("LLVMAOTTargetLinker", "compiler/Dialect/HAL/Target/LLVM/AOT"),
+ name = "LinkerTool",
+ srcs = ["LinkerTool.cpp"],
+ hdrs = ["LinkerTool.h"],
+ deps = platform_trampoline_deps("LinkerTools", "compiler/Dialect/HAL/Target/LLVM/AOT"),
)
cc_library(
- name = "LLVMAOTTargetLinker_hdrs",
- hdrs = ["LLVMAOTTargetLinker.h"],
+ name = "LinkerTool_hdrs",
+ hdrs = ["LinkerTool.h"],
deps = [
- "//iree/base:status",
"//iree/compiler/Dialect/HAL/Target/LLVM:LLVMTargetOptions",
+ "@llvm-project//llvm:Core",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:Support",
],
)
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/CMakeLists.txt
index ff46a62..7f69d4d 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/CMakeLists.txt
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/CMakeLists.txt
@@ -26,7 +26,7 @@
SRCS
"LLVMAOTTarget.cpp"
DEPS
- ::LLVMAOTTargetLinker
+ ::LinkerTool
LLVMAArch64AsmParser
LLVMAArch64CodeGen
LLVMARMAsmParser
@@ -46,21 +46,25 @@
iree_cc_library(
NAME
- LLVMAOTTargetLinker
+ LinkerTool
HDRS
- "LLVMAOTTargetLinker.h"
+ "LinkerTool.h"
+ SRCS
+ "LinkerTool.cpp"
DEPS
- iree::compiler::Dialect::HAL::Target::LLVM::AOT::internal::LLVMAOTTargetLinker_internal
+ iree::compiler::Dialect::HAL::Target::LLVM::AOT::internal::LinkerTools_internal
PUBLIC
)
iree_cc_library(
NAME
- LLVMAOTTargetLinker_hdrs
+ LinkerTool_hdrs
HDRS
- "LLVMAOTTargetLinker.h"
+ "LinkerTool.h"
DEPS
- iree::base::status
+ LLVMCore
+ LLVMSupport
+ MLIRSupport
iree::compiler::Dialect::HAL::Target::LLVM::LLVMTargetOptions
PUBLIC
)
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/LLVMAOTTarget.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/LLVMAOTTarget.cpp
index 6060600..d86c12c 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/LLVMAOTTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/LLVMAOTTarget.cpp
@@ -16,7 +16,7 @@
#include <cstdlib>
-#include "iree/compiler/Dialect/HAL/Target/LLVM/AOT/LLVMAOTTargetLinker.h"
+#include "iree/compiler/Dialect/HAL/Target/LLVM/AOT/LinkerTool.h"
#include "iree/compiler/Dialect/HAL/Target/LLVM/LLVMBaseTarget.h"
#include "iree/compiler/Dialect/HAL/Target/LLVM/LLVMIRPasses.h"
#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
@@ -46,16 +46,20 @@
// multi-threading issues.
llvm::LLVMContext context;
- // Remove all private functions, e.g tile size calcuations.
- SmallVector<FuncOp, 4> nonPublicFn;
- for (auto func : targetOp.getInnerModule().getOps<FuncOp>()) {
- if (SymbolTable::getSymbolVisibility(func) !=
- SymbolTable::Visibility::Public) {
- nonPublicFn.push_back(func);
- }
- }
- for (auto func : nonPublicFn) {
- func.erase();
+ // We name our files after the executable name so that they are easy to
+ // track both during compilation (logs/artifacts/etc), as outputs (final
+ // intermediate code/binary files), and at runtime (loaded
+ // libraries/symbols/etc).
+ auto libraryName =
+ targetOp.getParentOfType<IREE::HAL::ExecutableOp>().getName().str();
+
+ // TODO(#3737): don't add functions we don't want to serialize to the
+ // module. Right now workgroup count calculation functions end up in here
+ // as std.func ops and not just the llvm.func ops we expect.
+ auto illegalFuncOps =
+ llvm::to_vector<4>(targetOp.getInnerModule().getOps<FuncOp>());
+ for (auto funcOp : illegalFuncOps) {
+ funcOp.erase();
}
// At this moment we are leaving MLIR LLVM dialect land translating module
@@ -63,63 +67,107 @@
auto llvmModule =
mlir::translateModuleToLLVMIR(targetOp.getInnerModule(), context);
if (!llvmModule) {
- return failure();
+ return targetOp.emitError() << "failed to translate the MLIR LLVM "
+ "dialect to the native llvm::Module";
}
+ // Export all entry points such that they are accessible on the dynamic
+ // libraries we generate.
iree::DyLibExecutableDefT dyLibExecutableDef;
- // Create invocation function an populate entry_points.
- auto entryPointOps = targetOp.getBlock().getOps<ExecutableEntryPointOp>();
-
- for (auto entryPointOp : entryPointOps) {
+ SmallVector<StringRef, 8> entryPointNames;
+ for (auto entryPointOp :
+ targetOp.getBlock().getOps<ExecutableEntryPointOp>()) {
dyLibExecutableDef.entry_points.push_back(
std::string(entryPointOp.sym_name()));
+ entryPointNames.push_back(entryPointOp.sym_name());
}
- // LLVMIR opt passes.
+ // Try to grab a linker tool based on the options (and target environment).
+ llvm::Triple targetTriple(options_.targetTriple);
+ auto linkerTool = LinkerTool::getForTarget(targetTriple, options_);
+ if (!linkerTool) {
+ return mlir::emitError(targetOp.getLoc())
+ << "failed to find a target linker for the given target triple '"
+ << options_.targetTriple << "'";
+ }
+
+ // Configure the module with any code generation options required later by
+ // linking (such as initializer functions).
+ if (failed(
+ linkerTool->configureModule(llvmModule.get(), entryPointNames))) {
+ return targetOp.emitError()
+ << "failed to configure LLVM module for target linker";
+ }
+
+ // LLVM opt passes that perform code generation optimizations/transformation
+ // similar to what a frontend would do before passing to linking.
auto targetMachine = createTargetMachine(options_);
if (!targetMachine) {
- targetOp.emitError("Can't create target machine for target triple: " +
- options_.targetTriple);
- return failure();
+ return mlir::emitError(targetOp.getLoc())
+ << "failed to create target machine for target triple '"
+ << options_.targetTriple << "'";
}
-
llvmModule->setDataLayout(targetMachine->createDataLayout());
llvmModule->setTargetTriple(targetMachine->getTargetTriple().str());
-
if (failed(
runLLVMIRPasses(options_, targetMachine.get(), llvmModule.get()))) {
- return targetOp.emitError(
- "Can't build LLVMIR opt passes for ExecutableOp module");
+ return targetOp.emitError()
+ << "failed to run LLVM-IR opt passes for IREE::HAL::ExecutableOp "
+ "targeting '"
+ << options_.targetTriple << "'";
}
- std::string objData;
- if (failed(runEmitObjFilePasses(targetMachine.get(), llvmModule.get(),
- &objData))) {
- return targetOp.emitError("Can't compile LLVMIR module to an obj");
+ // Emit object files.
+ SmallVector<Artifact, 4> objectFiles;
+ {
+ // NOTE: today we just use a single object file, however if we wanted to
+ // scale code generation and linking we'd want to generate one per
+ // function (or something like that).
+ std::string objectData;
+ if (failed(runEmitObjFilePasses(targetMachine.get(), llvmModule.get(),
+ &objectData))) {
+ return targetOp.emitError()
+ << "failed to compile LLVM-IR module to an object file";
+ }
+ auto objectFile = Artifact::createTemporary(libraryName, "obj");
+ auto &os = objectFile.outputFile->os();
+ os << objectData;
+ os.flush();
+ os.close();
+ objectFiles.push_back(std::move(objectFile));
}
- std::string sharedLibData;
- const char *linkerToolPath = std::getenv("IREE_LLVMAOT_LINKER_PATH");
- if (linkerToolPath != nullptr) {
- auto sharedLibDataStatus = linkLLVMAOTObjects(linkerToolPath, objData);
- if (!sharedLibDataStatus.ok()) {
- return targetOp.emitError(
- "Can't link executable and generate target dylib, using linker "
- "toolchain: '" +
- std::string(linkerToolPath) + "'");
- }
- sharedLibData = sharedLibDataStatus.value();
- } else {
- auto sharedLibDataStatus = linkLLVMAOTObjectsWithLLDElf(objData);
- if (!sharedLibDataStatus.ok()) {
- return targetOp.emitError(
- "Can't link executable and generate target dylib using "
- "lld::elf::link");
- }
- sharedLibData = sharedLibDataStatus.value();
+ // Link the generated object files into a dylib.
+ auto linkArtifactsOr =
+ linkerTool->linkDynamicLibrary(libraryName, objectFiles);
+ if (!linkArtifactsOr.hasValue()) {
+ return mlir::emitError(targetOp.getLoc())
+ << "failed to link executable and generate target dylib using "
+ "linker toolchain "
+ << linkerTool->getToolPath();
}
- dyLibExecutableDef.library_embedded = {sharedLibData.begin(),
- sharedLibData.end()};
+ auto &linkArtifacts = linkArtifactsOr.getValue();
+ dyLibExecutableDef.library_embedded =
+ linkArtifacts.libraryFile.read().getValueOr(std::vector<int8_t>());
+ if (dyLibExecutableDef.library_embedded.empty()) {
+ return targetOp.emitError() << "failed to read back dylib temp file at "
+ << linkArtifacts.libraryFile.path;
+ }
+
+ if (options_.debugSymbols && linkArtifacts.debugFile.outputFile) {
+ dyLibExecutableDef.debug_database_embedded =
+ linkArtifacts.debugFile.read().getValue();
+ assert(!dyLibExecutableDef.debug_database_embedded.empty());
+ dyLibExecutableDef.debug_database_filename =
+ llvm::sys::path::filename(linkArtifacts.debugFile.path).str();
+ }
+
+ if (options_.keepLinkerArtifacts) {
+ return mlir::emitRemark(targetOp.getLoc())
+ << "Linker artifacts for " << targetOp.getName() << " preserved:\n"
+ << " " << linkArtifacts.libraryFile.path;
+ linkArtifacts.keepAllFiles();
+ }
::flatbuffers::FlatBufferBuilder fbb;
auto executableOffset =
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/LLVMAOTTargetLinker.h b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/LLVMAOTTargetLinker.h
deleted file mode 100644
index d6d5220..0000000
--- a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/LLVMAOTTargetLinker.h
+++ /dev/null
@@ -1,41 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-//
-
-#ifndef IREE_COMPILER_DIALECT_HAL_TARGET_LLVM_AOT_LLVMAOTTARGETLINKER_H_
-#define IREE_COMPILER_DIALECT_HAL_TARGET_LLVM_AOT_LLVMAOTTARGETLINKER_H_
-
-#include <string>
-
-#include "iree/base/status.h"
-#include "iree/compiler/Dialect/HAL/Target/LLVM/LLVMTargetOptions.h"
-
-namespace mlir {
-namespace iree_compiler {
-namespace IREE {
-namespace HAL {
-
-// Calls linker tool to link objData and returns shared library blob.
-iree::StatusOr<std::string> linkLLVMAOTObjects(
- const std::string& linkerToolPath, const std::string& objData);
-// Use lld::elf::link for linking objData and returns shared library blob.
-iree::StatusOr<std::string> linkLLVMAOTObjectsWithLLDElf(
- const std::string& objData);
-
-} // namespace HAL
-} // namespace IREE
-} // namespace iree_compiler
-} // namespace mlir
-
-#endif // IREE_COMPILER_DIALECT_HAL_TARGET_LLVM_AOT_LLVMAOTTARGETLINKER_H_
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/LinkerTool.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/LinkerTool.cpp
new file mode 100644
index 0000000..083d281
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/LinkerTool.cpp
@@ -0,0 +1,110 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/compiler/Dialect/HAL/Target/LLVM/AOT/LinkerTool.h"
+
+#define DEBUG_TYPE "llvmaot-linker"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace HAL {
+
+// static
+Artifact Artifact::createTemporary(StringRef prefix, StringRef suffix) {
+ llvm::SmallString<32> filePath;
+ if (std::error_code error =
+ llvm::sys::fs::createTemporaryFile(prefix, suffix, filePath)) {
+ llvm::errs() << "failed to generate temporary file: " << error.message();
+ return {};
+ }
+ std::error_code error;
+ auto file = std::make_unique<llvm::ToolOutputFile>(filePath, error,
+ llvm::sys::fs::OF_None);
+ if (error) {
+ llvm::errs() << "failed to open temporary file '" << filePath
+ << "': " << error.message();
+ return {};
+ }
+ return {filePath.str().str(), std::move(file)};
+}
+
+// static
+Artifact Artifact::createVariant(StringRef basePath, StringRef suffix) {
+ SmallString<32> filePath(basePath);
+ llvm::sys::path::replace_extension(filePath, suffix);
+ std::error_code error;
+ auto file = std::make_unique<llvm::ToolOutputFile>(filePath, error,
+ llvm::sys::fs::OF_Append);
+ if (error) {
+ llvm::errs() << "failed to open temporary file '" << filePath
+ << "': " << error.message();
+ return {};
+ }
+ return {filePath.str().str(), std::move(file)};
+}
+
+Optional<std::vector<int8_t>> Artifact::read() const {
+ auto fileData = llvm::MemoryBuffer::getFile(path);
+ if (!fileData) {
+ llvm::errs() << "failed to load library output file '" << path << "'";
+ return llvm::None;
+ }
+ auto sourceBuffer = fileData.get()->getBuffer();
+ std::vector<int8_t> resultBuffer(sourceBuffer.size());
+ std::memcpy(resultBuffer.data(), sourceBuffer.data(), sourceBuffer.size());
+ return resultBuffer;
+}
+
+void Artifact::close() { outputFile->os().close(); }
+
+void Artifacts::keepAllFiles() {
+ if (libraryFile.outputFile) libraryFile.outputFile->keep();
+ if (debugFile.outputFile) debugFile.outputFile->keep();
+ for (auto &file : otherFiles) {
+ file.outputFile->keep();
+ }
+}
+
+std::string LinkerTool::getToolPath() const {
+ char *linkerPath = std::getenv("IREE_LLVMAOT_LINKER_PATH");
+ if (linkerPath) {
+ return std::string(linkerPath);
+ } else {
+ return "";
+ }
+}
+
+LogicalResult LinkerTool::runLinkCommand(const std::string &commandLine) {
+ LLVM_DEBUG(llvm::dbgs() << "Running linker command:\n" << commandLine);
+#if defined(_MSC_VER)
+ // It's easy to run afoul of quoting rules on Windows (such as when using
+ // spaces in the linker environment variable). See:
+ // https://stackoverflow.com/a/9965141
+ auto quotedCommandLine = "\"" + commandLine + "\"";
+ int exitCode = system(quotedCommandLine.c_str());
+#else
+ int exitCode = system(commandLine.c_str());
+#endif // _MSC_VER
+ if (exitCode == 0) return success();
+ llvm::errs() << "Linking failed; command line returned exit code " << exitCode
+ << ":\n\n"
+ << commandLine << "\n\n";
+ return failure();
+}
+
+} // namespace HAL
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/LinkerTool.h b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/LinkerTool.h
new file mode 100644
index 0000000..6f3c39a
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/LinkerTool.h
@@ -0,0 +1,117 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+#ifndef IREE_COMPILER_DIALECT_HAL_TARGET_LLVM_AOT_LINKERTOOL_H_
+#define IREE_COMPILER_DIALECT_HAL_TARGET_LLVM_AOT_LINKERTOOL_H_
+
+#include <string>
+
+#include "iree/compiler/Dialect/HAL/Target/LLVM/LLVMTargetOptions.h"
+#include "llvm/ADT/Triple.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Support/ToolOutputFile.h"
+#include "mlir/Support/LogicalResult.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace HAL {
+
+struct Artifact {
+ // Creates an output file path/container pair.
+ // By default the file will be deleted when the link completes; callers must
+ // use llvm::ToolOutputFile::keep() to prevent deletion upon success (or if
+ // leaving artifacts for debugging).
+ static Artifact createTemporary(StringRef prefix, StringRef suffix);
+
+ // Creates an output file derived from the given file's path with a new
+ // suffix.
+ static Artifact createVariant(StringRef basePath, StringRef suffix);
+
+ Artifact() = default;
+ Artifact(std::string path, std::unique_ptr<llvm::ToolOutputFile> outputFile)
+ : path(std::move(path)), outputFile(std::move(outputFile)) {}
+
+ std::string path;
+ std::unique_ptr<llvm::ToolOutputFile> outputFile;
+
+ // Reads the artifact file contents as bytes.
+ Optional<std::vector<int8_t>> read() const;
+
+ // Closes the ostream of the file while preserving the temporary entry on
+ // disk. Use this if files need to be modified by external tools that may
+ // require exclusive access.
+ void close();
+};
+
+struct Artifacts {
+ // File containing the linked library (DLL, ELF, etc).
+ Artifact libraryFile;
+
+ // Optional file containing associated debug information (if stored
+ // separately, such as PDB files).
+ Artifact debugFile;
+
+ // Other files associated with linking.
+ SmallVector<Artifact, 4> otherFiles;
+
+ // Keeps all of the artifacts around after linking completes. Useful for
+ // debugging.
+ void keepAllFiles();
+};
+
+// Base type for linker tools that can turn object files into shared objects.
+class LinkerTool {
+ public:
+ // Gets an instance of a linker tool for the given target options. This may
+ // be a completely different toolchain than that of the host.
+ static std::unique_ptr<LinkerTool> getForTarget(
+ llvm::Triple& targetTriple, LLVMTargetOptions& targetOptions);
+
+ explicit LinkerTool(llvm::Triple targetTriple,
+ LLVMTargetOptions targetOptions)
+ : targetTriple(std::move(targetTriple)),
+ targetOptions(std::move(targetOptions)) {}
+
+ virtual ~LinkerTool() = default;
+
+ // Returns the path to the linker tool binary.
+ virtual std::string getToolPath() const;
+
+ // Configures a module prior to compilation with any additional
+ // functions/exports it may need, such as shared object initializer functions.
+ virtual LogicalResult configureModule(
+ llvm::Module* llvmModule, ArrayRef<StringRef> entryPointNames) = 0;
+
+ // Links the given object files into a dynamically loadable library.
+ // The resulting library (and other associated artifacts) will be returned on
+ // success.
+ virtual Optional<Artifacts> linkDynamicLibrary(
+ StringRef libraryName, ArrayRef<Artifact> objectFiles) = 0;
+
+ protected:
+ // Runs the given command line on the shell, logging failures.
+ LogicalResult runLinkCommand(const std::string& commandLine);
+
+ llvm::Triple targetTriple;
+ LLVMTargetOptions targetOptions;
+};
+
+} // namespace HAL
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_DIALECT_HAL_TARGET_LLVM_AOT_LINKERTOOL_H_
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/internal/BUILD b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/internal/BUILD
index 6dadc26..bc8a6e0 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/internal/BUILD
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/internal/BUILD
@@ -19,11 +19,16 @@
)
cc_library(
- name = "LLVMAOTTargetLinker_internal",
- srcs = ["LLVMAOTTargetLinker.cpp"],
+ name = "LinkerTools_internal",
+ srcs = [
+ "LinkerTools.cpp",
+ "UnixLinkerTool.cpp",
+ "WindowsLinkerTool.cpp",
+ ],
deps = [
- "//iree/base:status",
- "//iree/compiler/Dialect/HAL/Target/LLVM/AOT:LLVMAOTTargetLinker_hdrs",
+ "//iree/compiler/Dialect/HAL/Target/LLVM/AOT:LinkerTool_hdrs",
+ "@llvm-project//llvm:Core",
"@llvm-project//llvm:Support",
+ "@llvm-project//mlir:Support",
],
)
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/internal/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/internal/CMakeLists.txt
index 01cb459..aafc4a5 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/internal/CMakeLists.txt
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/internal/CMakeLists.txt
@@ -16,12 +16,15 @@
iree_cc_library(
NAME
- LLVMAOTTargetLinker_internal
+ LinkerTools_internal
SRCS
- "LLVMAOTTargetLinker.cpp"
+ "LinkerTools.cpp"
+ "UnixLinkerTool.cpp"
+ "WindowsLinkerTool.cpp"
DEPS
+ LLVMCore
LLVMSupport
- iree::base::status
- iree::compiler::Dialect::HAL::Target::LLVM::AOT::LLVMAOTTargetLinker_hdrs
+ MLIRSupport
+ iree::compiler::Dialect::HAL::Target::LLVM::AOT::LinkerTool_hdrs
PUBLIC
)
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/internal/LLVMAOTTargetLinker.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/internal/LLVMAOTTargetLinker.cpp
deleted file mode 100644
index 8acd2d6..0000000
--- a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/internal/LLVMAOTTargetLinker.cpp
+++ /dev/null
@@ -1,79 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/Dialect/HAL/Target/LLVM/AOT/LLVMAOTTargetLinker.h"
-
-#include "iree/base/status.h"
-#include "llvm/Support/ToolOutputFile.h"
-
-namespace mlir {
-namespace iree_compiler {
-namespace IREE {
-namespace HAL {
-
-iree::StatusOr<std::string> linkLLVMAOTObjects(
- const std::string& linkerToolPath, const std::string& objData) {
- llvm::SmallString<32> objFilePath, dylibFilePath;
- if (std::error_code error = llvm::sys::fs::createTemporaryFile(
- "llvmaot_dylibs", "objfile", objFilePath)) {
- return iree::InternalErrorBuilder(IREE_LOC)
- << "Failed to generate temporary file for objfile : '"
- << error.message() << "'";
- }
- if (std::error_code error = llvm::sys::fs::createTemporaryFile(
- "llvmaot_dylibs", "dylibfile", dylibFilePath)) {
- return iree::InternalErrorBuilder(IREE_LOC)
- << "Failed to generate temporary file for dylib : '"
- << error.message() << "'";
- }
- std::error_code error;
- auto outputFile = std::make_unique<llvm::ToolOutputFile>(
- objFilePath, error, llvm::sys::fs::F_None);
- if (error) {
- return iree::InternalErrorBuilder(IREE_LOC)
- << "Failed to open temporary objfile '" << objFilePath.c_str()
- << "' for dylib : '" << error.message() << "'";
- }
-
- outputFile->os() << objData;
- outputFile->os().flush();
-
- auto linkingCmd =
- (linkerToolPath + " -shared " + objFilePath + " -o " + dylibFilePath)
- .str();
- int systemRet = system(linkingCmd.c_str());
- if (systemRet != 0) {
- return iree::InternalErrorBuilder(IREE_LOC)
- << linkingCmd << " failed with exit code " << systemRet;
- }
-
- auto dylibData = llvm::MemoryBuffer::getFile(dylibFilePath);
- if (!dylibData) {
- return iree::InternalErrorBuilder(IREE_LOC)
- << "Failed to read temporary dylib file '" << dylibFilePath.c_str()
- << "'";
- }
- return dylibData.get()->getBuffer().str();
-}
-
-iree::StatusOr<std::string> linkLLVMAOTObjectsWithLLDElf(
- const std::string& objData) {
- return iree::UnimplementedErrorBuilder(IREE_LOC)
- << "linkLLVMAOTObjectsWithLLD not implemented yet!";
-}
-
-} // namespace HAL
-} // namespace IREE
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/internal/LinkerTools.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/internal/LinkerTools.cpp
new file mode 100644
index 0000000..6fc4b93
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/internal/LinkerTools.cpp
@@ -0,0 +1,43 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/compiler/Dialect/HAL/Target/LLVM/AOT/LinkerTool.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace HAL {
+
+// TODO(benvanik): add other platforms:
+// createMacLinkerTool using ld64.lld
+// createWasmLinkerTool wasm-ld
+
+std::unique_ptr<LinkerTool> createUnixLinkerTool(
+ llvm::Triple &targetTriple, LLVMTargetOptions &targetOptions);
+std::unique_ptr<LinkerTool> createWindowsLinkerTool(
+ llvm::Triple &targetTriple, LLVMTargetOptions &targetOptions);
+
+// static
+std::unique_ptr<LinkerTool> LinkerTool::getForTarget(
+ llvm::Triple &targetTriple, LLVMTargetOptions &targetOptions) {
+ if (targetTriple.isOSWindows() || targetTriple.isWindowsMSVCEnvironment()) {
+ return createWindowsLinkerTool(targetTriple, targetOptions);
+ }
+ return createUnixLinkerTool(targetTriple, targetOptions);
+}
+
+} // namespace HAL
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/internal/UnixLinkerTool.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/internal/UnixLinkerTool.cpp
new file mode 100644
index 0000000..fcd1986
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/internal/UnixLinkerTool.cpp
@@ -0,0 +1,86 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/compiler/Dialect/HAL/Target/LLVM/AOT/LinkerTool.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/Support/FormatVariadic.h"
+
+#define DEBUG_TYPE "llvmaot-linker"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace HAL {
+
+// Unix linker (ld-like); for ELF files.
+class UnixLinkerTool : public LinkerTool {
+ public:
+ using LinkerTool::LinkerTool;
+
+ std::string getToolPath() const override {
+ auto toolPath = LinkerTool::getToolPath();
+ return toolPath.empty() ? "ld.lld" : toolPath;
+ }
+
+ LogicalResult configureModule(llvm::Module *llvmModule,
+ ArrayRef<StringRef> entryPointNames) override {
+ // Possibly a no-op in ELF files; needs to be verified.
+ return success();
+ }
+
+ Optional<Artifacts> linkDynamicLibrary(
+ StringRef libraryName, ArrayRef<Artifact> objectFiles) override {
+ Artifacts artifacts;
+
+ // Create the shared object name; if we only have a single input object we
+ // can just reuse that.
+ if (objectFiles.size() == 1) {
+ artifacts.libraryFile =
+ Artifact::createVariant(objectFiles.front().path, "so");
+ } else {
+ artifacts.libraryFile = Artifact::createTemporary(libraryName, "so");
+ }
+ artifacts.libraryFile.close();
+
+ SmallVector<std::string, 8> flags = {
+ getToolPath(),
+ "-shared",
+ "-o " + artifacts.libraryFile.path,
+ };
+
+ // TODO(ataei): add flags based on targetTriple.isAndroid(), like
+ // -static-libstdc++ (if this is needed, which it shouldn't be).
+
+ // Link all input objects. Note that we are not linking whole-archive as we
+ // want to allow dropping of unused codegen outputs.
+ for (auto &objectFile : objectFiles) {
+ flags.push_back(objectFile.path);
+ }
+
+ auto commandLine = llvm::join(flags, " ");
+ if (failed(runLinkCommand(commandLine))) return llvm::None;
+ return artifacts;
+ }
+};
+
+std::unique_ptr<LinkerTool> createUnixLinkerTool(
+ llvm::Triple &targetTriple, LLVMTargetOptions &targetOptions) {
+ return std::make_unique<UnixLinkerTool>(targetTriple, targetOptions);
+}
+
+} // namespace HAL
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/internal/WindowsLinkerTool.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/internal/WindowsLinkerTool.cpp
new file mode 100644
index 0000000..40632e3
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/internal/WindowsLinkerTool.cpp
@@ -0,0 +1,290 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/compiler/Dialect/HAL/Target/LLVM/AOT/LinkerTool.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/Support/FormatVariadic.h"
+
+#define DEBUG_TYPE "llvmaot-linker"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace HAL {
+
+// Windows linker (MSVC link.exe-like); for DLL files.
+class WindowsLinkerTool : public LinkerTool {
+ public:
+ using LinkerTool::LinkerTool;
+
+ std::string getToolPath() const override {
+ auto toolPath = LinkerTool::getToolPath();
+ return toolPath.empty() ? "lld-link" : toolPath;
+ }
+
+ LogicalResult configureModule(llvm::Module *llvmModule,
+ ArrayRef<StringRef> entryPointNames) override {
+ auto &ctx = llvmModule->getContext();
+
+ // Create a _DllMainCRTStartup replacement that does not initialize the CRT.
+ // This is required to prevent a bunch of CRT junk (locale, errno, TLS, etc)
+ // from getting emitted in such a way that it cannot be stripped by LTCG.
+ // Since we don't emit code using the CRT (beyond memset/memcpy) this is
+ // fine and can reduce binary sizes by 50-100KB.
+ //
+ // More info:
+ // https://docs.microsoft.com/en-us/cpp/build/run-time-library-behavior?view=vs-2019
+ {
+ auto dwordType = llvm::IntegerType::get(ctx, 32);
+ auto ptrType = llvm::PointerType::getUnqual(dwordType);
+ auto entry = cast<llvm::Function>(
+ llvmModule
+ ->getOrInsertFunction("iree_dll_main", dwordType, ptrType,
+ dwordType, ptrType)
+ .getCallee());
+ entry->setCallingConv(llvm::CallingConv::X86_StdCall);
+ entry->setDLLStorageClass(
+ llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass);
+ entry->setLinkage(llvm::GlobalValue::LinkageTypes::ExternalLinkage);
+ auto *block = llvm::BasicBlock::Create(ctx, "entry", entry);
+ llvm::IRBuilder<> builder(block);
+ auto one = llvm::ConstantInt::get(dwordType, 1, false);
+ builder.CreateRet(one);
+ }
+
+ // For now we ensure that our entry points are exported (via linker
+ // directives embedded in the object file) and in a compatible calling
+ // convention.
+ // TODO(benvanik): switch to executable libraries w/ internal functions.
+ for (auto entryPointName : entryPointNames) {
+ auto *entryPointFn = llvmModule->getFunction(entryPointName);
+ entryPointFn->setCallingConv(llvm::CallingConv::X86_StdCall);
+ entryPointFn->setDLLStorageClass(
+ llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass);
+ entryPointFn->setLinkage(
+ llvm::GlobalValue::LinkageTypes::ExternalLinkage);
+ entryPointFn->setVisibility(
+ llvm::GlobalValue::VisibilityTypes::DefaultVisibility);
+ entryPointFn->addFnAttr(llvm::Attribute::UWTable);
+ }
+
+ return success();
+ }
+
+ Optional<Artifacts> linkDynamicLibrary(
+ StringRef libraryName, ArrayRef<Artifact> objectFiles) override {
+ Artifacts artifacts;
+
+ // Create the shared object name; if we only have a single input object we
+ // can just reuse that.
+ if (objectFiles.size() == 1) {
+ artifacts.libraryFile =
+ Artifact::createVariant(objectFiles.front().path, "dll");
+ } else {
+ artifacts.libraryFile = Artifact::createTemporary(libraryName, "dll");
+ }
+
+ // link.exe doesn't like the files being opened. We don't use them as
+ // streams so close them all now before running the linker.
+ artifacts.libraryFile.close();
+
+ // We need a full path for the PDB and I hate strings in LLVM grumble.
+ SmallString<32> pdbPath(artifacts.libraryFile.path);
+ llvm::sys::path::replace_extension(pdbPath, "pdb");
+
+ SmallVector<std::string, 8> flags = {
+ getToolPath(),
+
+ // Useful when debugging linking/loading issues:
+ // "/verbose",
+
+ // https://docs.microsoft.com/en-us/cpp/build/reference/dll-build-a-dll?view=vs-2019
+ // Builds a DLL and exports functions with the dllexport storage class.
+ "/dll",
+
+ // Forces a fixed timestamp to ensure files are reproducable across
+ // builds. Undocumented but accepted by both link and lld-link.
+ // https://blog.conan.io/2019/09/02/Deterministic-builds-with-C-C++.html
+ "/Brepro",
+
+ // https://docs.microsoft.com/en-us/cpp/build/reference/nodefaultlib-ignore-libraries?view=vs-2019
+ // Ignore any libraries that are specified by the platform as we
+ // directly provide the ones we want.
+ "/nodefaultlib",
+
+ // https://docs.microsoft.com/en-us/cpp/build/reference/incremental-link-incrementally?view=vs-2019
+ // Disable incremental linking as we are only ever linking in one-shot
+ // mode to temp files. This avoids additional file padding and ordering
+ // restrictions that enable incremental linking. Our other options will
+ // prevent incremental linking in most cases, but it doesn't hurt to be
+ // explicit.
+ "/incremental:no",
+
+ // https://docs.microsoft.com/en-us/cpp/build/reference/guard-enable-guard-checks?view=vs-2019
+ // No control flow guard lookup (indirect branch verification).
+ "/guard:no",
+
+ // https://docs.microsoft.com/en-us/cpp/build/reference/safeseh-image-has-safe-exception-handlers?view=vs-2019
+ // We don't want exception unwind tables in our output.
+ "/safeseh:no",
+
+ // https://docs.microsoft.com/en-us/cpp/build/reference/entry-entry-point-symbol?view=vs-2019
+ // Use our entry point instead of the standard CRT one; ensures that we
+ // pull in no global state from the CRT.
+ "/entry:iree_dll_main",
+
+ // https://docs.microsoft.com/en-us/cpp/build/reference/debug-generate-debug-info?view=vs-2019
+ // Copies all PDB information into the final PDB so that we can use the
+ // same PDB across multiple machines.
+ "/debug:full",
+
+ // https://docs.microsoft.com/en-us/cpp/build/reference/pdb-use-program-database
+ // Generates the PDB file containing the debug information.
+ ("/pdb:" + pdbPath).str(),
+
+ // https://docs.microsoft.com/en-us/cpp/build/reference/pdbaltpath-use-alternate-pdb-path?view=vs-2019
+ // Forces the PDB we generate to be referenced in the DLL as just a
+ // relative path to the DLL itself. This allows us to move the PDBs
+ // along with the build DLLs across machines.
+ "/pdbaltpath:%_PDB%",
+
+ // https://docs.microsoft.com/en-us/cpp/build/reference/out-output-file-name?view=vs-2019
+ // Target for linker output. The base name of this path will be used for
+ // additional output files (like the map and pdb).
+ "/out:" + artifacts.libraryFile.path,
+ };
+
+ if (targetOptions.optLevel.getSpeedupLevel() >= 2 ||
+ targetOptions.optLevel.getSizeLevel() >= 2) {
+ // https://docs.microsoft.com/en-us/cpp/build/reference/opt-optimizations?view=vs-2019
+ // Enable all the fancy optimizations.
+ flags.push_back("/opt:ref,icf,lbr");
+ }
+
+ // SDK and MSVC paths.
+ // These rely on the environment variables provided by the
+ // vcvarsall or VsDevCmd ("Developer Command Prompt") scripts. They can also
+ // be manually be specified.
+ //
+ // We could also check to see if vswhere is installed and query that in the
+ // event of missing environment variables; that would eliminate the need for
+ // specifying things from for example IDEs that may not bring in the vcvars.
+ //
+ /* Example values:
+ UCRTVersion=10.0.18362.0
+ UniversalCRTSdkDir=C:\Program Files (x86)\Windows Kits\10\
+ VCToolsInstallDir=C:\Program Files (x86)\Microsoft Visual
+ Studio\2019\Preview\VC\Tools\MSVC\14.28.29304\
+ */
+ if (!getenv("VCToolsInstallDir") || !getenv("UniversalCRTSdkDir")) {
+ llvm::errs() << "required environment for lld-link/link not specified; "
+ "ensure you are building from a shell where "
+ "vcvarsall/VsDevCmd.bat/etc has been used";
+ return llvm::None;
+ }
+ const char *arch;
+ if (targetTriple.isARM() && targetTriple.isArch32Bit()) {
+ arch = "arm";
+ } else if (targetTriple.isARM()) {
+ arch = "arm64";
+ } else if (targetTriple.isX86() && targetTriple.isArch32Bit()) {
+ arch = "x86";
+ } else if (targetTriple.isX86()) {
+ arch = "x64";
+ } else {
+ llvm::errs() << "unsupported Windows target triple (no arch libs): "
+ << targetTriple.str();
+ return llvm::None;
+ }
+ flags.push_back(
+ llvm::formatv("/libpath:\"{0}\\lib\\{1}\"", "%VCToolsInstallDir%", arch)
+ .str());
+ flags.push_back(llvm::formatv("/libpath:\"{0}\\Lib\\{1}\\ucrt\\{2}\"",
+ "%UniversalCRTSdkDir%", "%UCRTVersion%", arch)
+ .str());
+ flags.push_back(llvm::formatv("/libpath:\"{0}\\Lib\\{1}\\um\\{2}\"",
+ "%UniversalCRTSdkDir%", "%UCRTVersion%", arch)
+ .str());
+
+ // We need to link against different libraries based on our configuration
+ // matrix (dynamic/static and debug/release).
+ int libIndex = 0;
+ if (targetOptions.optLevel.getSpeedupLevel() == 0) {
+ libIndex += 0; // debug
+ } else {
+ libIndex += 2; // release
+ }
+ libIndex += targetOptions.linkStatic ? 1 : 0;
+
+ // The required libraries for linking DLLs:
+ // https://docs.microsoft.com/en-us/cpp/c-runtime-library/crt-library-features?view=msvc-160
+ //
+ // NOTE: there are only static versions of msvcrt as it's the startup code.
+ static const char *kMSVCRTLibs[4] = {
+ /* debug/dynamic */ "msvcrtd.lib",
+ /* debug/static */ "msvcrtd.lib",
+ /* release/dynamic */ "msvcrt.lib",
+ /* release/static */ "msvcrt.lib",
+ };
+ static const char *kVCRuntimeLibs[4] = {
+ /* debug/dynamic */ "vcruntimed.lib",
+ /* debug/static */ "libvcruntimed.lib",
+ /* release/dynamic */ "vcruntime.lib",
+ /* release/static */ "libvcruntime.lib",
+ };
+ static const char *kUCRTLibs[4] = {
+ /* debug/dynamic */ "ucrtd.lib",
+ /* debug/static */ "libucrtd.lib",
+ /* release/dynamic */ "ucrt.lib",
+ /* release/static */ "libucrt.lib",
+ };
+ flags.push_back(kMSVCRTLibs[libIndex]);
+ flags.push_back(kVCRuntimeLibs[libIndex]);
+ flags.push_back(kUCRTLibs[libIndex]);
+ flags.push_back("kernel32.lib");
+
+ // Link all input objects. Note that we are not linking whole-archive as we
+ // want to allow dropping of unused codegen outputs.
+ for (auto &objectFile : objectFiles) {
+ flags.push_back(objectFile.path);
+ }
+
+ auto commandLine = llvm::join(flags, " ");
+ if (failed(runLinkCommand(commandLine))) return llvm::None;
+
+ // PDB file gets generated wtih the same path + .pdb.
+ artifacts.debugFile =
+ Artifact::createVariant(artifacts.libraryFile.path, "pdb");
+
+ // We currently discard some of the other file outputs (like the .exp
+ // listing the exported symbols) as we don't need them.
+ artifacts.otherFiles.push_back(
+ Artifact::createVariant(artifacts.libraryFile.path, "exp"));
+ artifacts.otherFiles.push_back(
+ Artifact::createVariant(artifacts.libraryFile.path, "lib"));
+
+ return artifacts;
+ }
+};
+
+std::unique_ptr<LinkerTool> createWindowsLinkerTool(
+ llvm::Triple &targetTriple, LLVMTargetOptions &targetOptions) {
+ return std::make_unique<WindowsLinkerTool>(targetTriple, targetOptions);
+}
+
+} // namespace HAL
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/BUILD b/iree/compiler/Dialect/HAL/Target/LLVM/BUILD
index 9c608c2..9040b96 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/BUILD
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/BUILD
@@ -80,6 +80,7 @@
"LLVMTargetOptions.h",
],
deps = [
+ "@llvm-project//llvm:MC",
"@llvm-project//llvm:Passes",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:Target",
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/LLVM/CMakeLists.txt
index 1f0f77f..416f32e 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/CMakeLists.txt
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/CMakeLists.txt
@@ -66,6 +66,7 @@
SRCS
"LLVMTargetOptions.cpp"
DEPS
+ LLVMMC
LLVMPasses
LLVMSupport
LLVMTarget
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMBaseTarget.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMBaseTarget.cpp
index 8efcf6c..4861dd3 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMBaseTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMBaseTarget.cpp
@@ -80,13 +80,37 @@
buildLLVMTransformPassPipeline(passManager);
}
+static FileLineColLoc findFirstFileLoc(Location baseLoc) {
+ if (auto loc = baseLoc.dyn_cast<FusedLoc>()) {
+ for (auto &childLoc : loc.getLocations()) {
+ auto childResult = findFirstFileLoc(childLoc);
+ if (childResult) return childResult;
+ }
+ } else if (auto loc = baseLoc.dyn_cast<FileLineColLoc>()) {
+ return loc;
+ }
+ return FileLineColLoc{};
+}
+
+static std::string guessModuleName(mlir::ModuleOp moduleOp) {
+ std::string moduleName =
+ moduleOp.getName().hasValue() ? moduleOp.getName().getValue().str() : "";
+ if (!moduleName.empty()) return moduleName;
+ FileLineColLoc loc = findFirstFileLoc(moduleOp.getLoc());
+ return llvm::sys::path::stem(loc.getFilename()).str();
+}
+
LogicalResult LLVMBaseTargetBackend::linkExecutables(mlir::ModuleOp moduleOp) {
OpBuilder builder = OpBuilder::atBlockBegin(moduleOp.getBody());
auto executableOps =
llvm::to_vector<8>(moduleOp.getOps<IREE::HAL::ExecutableOp>());
+ // Guess a module name, if needed, to make the output files readable.
+ auto moduleName = guessModuleName(moduleOp);
+
// Create our new "linked" hal.executable.
- std::string linkedExecutableName = llvm::formatv("linked_{0}", name());
+ std::string linkedExecutableName =
+ llvm::formatv("{0}_linked_{1}", moduleName, name());
auto linkedExecutableOp = builder.create<IREE::HAL::ExecutableOp>(
moduleOp.getLoc(), linkedExecutableName);
linkedExecutableOp.setPrivate();
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMIRPasses.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMIRPasses.cpp
index de66446..65cd442 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMIRPasses.cpp
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMIRPasses.cpp
@@ -31,16 +31,33 @@
namespace IREE {
namespace HAL {
+static llvm::CodeGenOpt::Level passBuilderOptLevelToCodeGenOptLevel(
+ const llvm::PassBuilder::OptimizationLevel &level) {
+ switch (level.getSpeedupLevel()) {
+ case 0:
+ return llvm::CodeGenOpt::None;
+ case 1:
+ return llvm::CodeGenOpt::Less;
+ case 2:
+ default:
+ return llvm::CodeGenOpt::Default;
+ case 3:
+ return llvm::CodeGenOpt::Aggressive;
+ }
+}
+
std::unique_ptr<llvm::TargetMachine> createTargetMachine(
const LLVMTargetOptions &targetOptions) {
std::string errorMessage;
auto target = llvm::TargetRegistry::lookupTarget(targetOptions.targetTriple,
errorMessage);
if (!target) return nullptr;
- // TODO(ataei): Once we have an AOT backend pass cpu and cpu-features
std::unique_ptr<llvm::TargetMachine> machine(target->createTargetMachine(
- targetOptions.targetTriple, "generic" /* cpu e.g k8*/,
- "" /* cpu features e.g avx512fma*/, targetOptions.options, {}));
+ targetOptions.targetTriple, targetOptions.targetCPU /* cpu e.g k8*/,
+ targetOptions.targetCPUFeatures /* cpu features e.g avx512fma*/,
+ targetOptions.options, {}, {},
+ passBuilderOptLevelToCodeGenOptLevel(targetOptions.optLevel),
+ /*JIT=*/false));
return machine;
}
@@ -68,10 +85,12 @@
passBuilder.registerLoopAnalyses(loopAnalysisManager);
passBuilder.crossRegisterProxies(loopAnalysisManager, functionAnalysisManager,
cGSCCAnalysisManager, moduleAnalysisManager);
- llvm::ModulePassManager modulePassManager;
- modulePassManager =
- passBuilder.buildPerModuleDefaultPipeline(options.optLevel);
- modulePassManager.run(*module, moduleAnalysisManager);
+ if (options.optLevel != llvm::PassBuilder::OptimizationLevel::O0) {
+ llvm::ModulePassManager modulePassManager;
+ modulePassManager =
+ passBuilder.buildPerModuleDefaultPipeline(options.optLevel);
+ modulePassManager.run(*module, moduleAnalysisManager);
+ }
if (llvm::verifyModule(*module)) return failure();
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMTargetOptions.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMTargetOptions.cpp
index e71beac..649dbf0 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMTargetOptions.cpp
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMTargetOptions.cpp
@@ -15,6 +15,7 @@
#include "iree/compiler/Dialect/HAL/Target/LLVM/LLVMTargetOptions.h"
#include "llvm/ADT/APFloat.h"
+#include "llvm/MC/SubtargetFeature.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Host.h"
#include "llvm/Target/TargetOptions.h"
@@ -26,17 +27,34 @@
LLVMTargetOptions getDefaultLLVMTargetOptions() {
LLVMTargetOptions targetOptions;
+
// Host target triple.
targetOptions.targetTriple = llvm::sys::getDefaultTargetTriple();
+ targetOptions.targetCPU = llvm::sys::getHostCPUName().str();
+ {
+ llvm::SubtargetFeatures features;
+ llvm::StringMap<bool> hostFeatures;
+ if (llvm::sys::getHostCPUFeatures(hostFeatures)) {
+ for (auto &feature : hostFeatures) {
+ features.AddFeature(feature.first(), feature.second);
+ }
+ }
+ targetOptions.targetCPUFeatures = features.getString();
+ }
+
// LLVM loop optimization options.
targetOptions.pipelineTuningOptions.LoopInterleaving = true;
targetOptions.pipelineTuningOptions.LoopVectorization = true;
targetOptions.pipelineTuningOptions.LoopUnrolling = true;
+
// LLVM SLP Auto vectorizer.
targetOptions.pipelineTuningOptions.SLPVectorization = true;
+
// LLVM -O3.
+ // TODO(benvanik): add an option for this.
targetOptions.optLevel = llvm::PassBuilder::OptimizationLevel::O3;
targetOptions.options.FloatABIType = llvm::FloatABI::Hard;
+
return targetOptions;
}
@@ -46,16 +64,57 @@
static llvm::cl::opt<std::string> clTargetTriple(
"iree-llvm-target-triple", llvm::cl::desc("LLVM target machine triple"),
llvm::cl::init(llvmTargetOptions.targetTriple));
- static llvm::cl::opt<bool> clSoftFloat(
- "iree-llvm-enable-msoft-float-abi",
+ static llvm::cl::opt<std::string> clTargetCPU(
+ "iree-llvm-target-cpu",
+ llvm::cl::desc(
+ "LLVM target machine CPU; use 'host' for your host native CPU"),
+ llvm::cl::init("generic"));
+ static llvm::cl::opt<std::string> clTargetCPUFeatures(
+ "iree-llvm-target-cpu-features",
+ llvm::cl::desc("LLVM target machine CPU features; use 'host' for your "
+ "host native CPU"),
+ llvm::cl::init(""));
+ llvmTargetOptions.targetTriple = clTargetTriple;
+ if (clTargetCPU != "host") {
+ llvmTargetOptions.targetCPU = clTargetCPU;
+ }
+ if (clTargetCPUFeatures != "host") {
+ llvmTargetOptions.targetCPUFeatures = clTargetCPUFeatures;
+ }
+
+ static llvm::cl::opt<llvm::FloatABI::ABIType> clTargetFloatABI(
+ "iree-llvm-target-float-abi",
llvm::cl::desc("LLVM target codegen enables soft float abi e.g "
"-mfloat-abi=softfp"),
- llvm::cl::init(false));
+ llvm::cl::init(llvmTargetOptions.options.FloatABIType),
+ llvm::cl::values(
+ clEnumValN(llvm::FloatABI::Default, "default", "Default (softfp)"),
+ clEnumValN(llvm::FloatABI::Soft, "soft",
+ "Software floating-point emulation"),
+ clEnumValN(llvm::FloatABI::Hard, "hard",
+ "Hardware floating-point instructions")));
+ llvmTargetOptions.options.FloatABIType = clTargetFloatABI;
- llvmTargetOptions.targetTriple = clTargetTriple;
- if (clSoftFloat) {
- llvmTargetOptions.options.FloatABIType = llvm::FloatABI::Soft;
- }
+ static llvm::cl::opt<bool> clDebugSymbols(
+ "iree-llvm-debug-symbols",
+ llvm::cl::desc("Generate and embed debug information (DWARF, PDB, etc)"),
+ llvm::cl::init(llvmTargetOptions.debugSymbols));
+ llvmTargetOptions.debugSymbols = clDebugSymbols;
+
+ static llvm::cl::opt<bool> clLinkStatic(
+ "iree-llvm-link-static",
+ llvm::cl::desc(
+ "Links system libraries into binaries statically to isolate them "
+ "from platform dependencies needed at runtime"),
+ llvm::cl::init(llvmTargetOptions.linkStatic));
+ llvmTargetOptions.linkStatic = clLinkStatic;
+
+ static llvm::cl::opt<bool> clKeepLinkerArtifacts(
+ "iree-llvm-keep-linker-artifacts",
+ llvm::cl::desc("Keep LLVM linker target artifacts (.so/.dll/etc)"),
+ llvm::cl::init(llvmTargetOptions.keepLinkerArtifacts));
+ llvmTargetOptions.keepLinkerArtifacts = clKeepLinkerArtifacts;
+
return llvmTargetOptions;
}
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMTargetOptions.h b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMTargetOptions.h
index 4893566..ba454b6 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMTargetOptions.h
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMTargetOptions.h
@@ -24,10 +24,28 @@
namespace HAL {
struct LLVMTargetOptions {
+ // Target machine configuration.
+ std::string targetTriple;
+ std::string targetCPU;
+ std::string targetCPUFeatures;
+
llvm::PipelineTuningOptions pipelineTuningOptions;
llvm::PassBuilder::OptimizationLevel optLevel;
llvm::TargetOptions options;
- std::string targetTriple;
+
+ // Include debug information in output files (PDB, DWARF, etc).
+ // Though this can be set independently from the optLevel (so -O3 with debug
+ // information is valid) it may significantly change the output program
+ // and benchmarking
+ bool debugSymbols = true;
+
+ // Link any required runtime libraries into the produced binaries statically.
+ // This increases resulting binary size but enables the binaries to be used on
+ // any machine without requiring matching system libraries to be installed.
+ bool linkStatic = false;
+
+ // True to keep linker artifacts for debugging.
+ bool keepLinkerArtifacts = false;
};
// Returns LLVMTargetOptions struct intialized with the
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/test/binaryop_test.mlir b/iree/compiler/Dialect/HAL/Target/LLVM/test/binary_op.mlir
similarity index 90%
rename from iree/compiler/Dialect/HAL/Target/LLVM/test/binaryop_test.mlir
rename to iree/compiler/Dialect/HAL/Target/LLVM/test/binary_op.mlir
index 0a66d94..41f1939 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/test/binaryop_test.mlir
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/test/binary_op.mlir
@@ -11,7 +11,7 @@
}
}
-// CHECK-LABEL: hal.executable @linked_llvm_ir
+// CHECK-LABEL: hal.executable @binary_op_linked_llvm_ir
// CHECK-DAG: hal.executable.binary attributes {
// CHECK-SAME: data = dense
// CHECK-SAME: format = 1280071245 : i32} {
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/test/matmul_op.mlir b/iree/compiler/Dialect/HAL/Target/LLVM/test/matmul_op.mlir
index 6744a70..1a48518 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/test/matmul_op.mlir
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/test/matmul_op.mlir
@@ -11,7 +11,7 @@
}
}
-// CHECK-LABEL: hal.executable @linked_llvm_ir
+// CHECK-LABEL: hal.executable @matmul_op_linked_llvm_ir
// CHECK-DAG: hal.executable.binary attributes {
// CHECK-SAME: data = dense
// CHECK-SAME: format = 1280071245 : i32} {
diff --git a/iree/hal/dylib/dylib_executable.cc b/iree/hal/dylib/dylib_executable.cc
index f3d79cf..3d3ff01 100644
--- a/iree/hal/dylib/dylib_executable.cc
+++ b/iree/hal/dylib/dylib_executable.cc
@@ -34,9 +34,13 @@
DyLibExecutable::~DyLibExecutable() {
IREE_TRACE_SCOPE0("DyLibExecutable::dtor");
- // TODO(benvanik): move to an atexit handler when tracing is enabled.
- // executable_library_.release();
+#if IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION
+ // Leak the library when tracing, since the profiler may still be reading it.
+ // TODO(benvanik): move to an atexit handler instead, verify with ASAN/MSAN
+ executable_library_.release();
+#else
executable_library_.reset();
+#endif // IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION
for (const auto& file_path : temp_file_paths_) {
file_io::DeleteFile(file_path).IgnoreError();
}
diff --git a/iree/test/e2e/regression/slice_add.mlir b/iree/test/e2e/regression/slice_add.mlir
new file mode 100644
index 0000000..762d979
--- /dev/null
+++ b/iree/test/e2e/regression/slice_add.mlir
@@ -0,0 +1,15 @@
+// RUN: iree-run-mlir -export-all -iree-hal-target-backends=vmla -function-input="3x4xi32=[[1,2,3,4],[5,6,7,8],[9,10,11,12]]" -function-input="1x2xi32=10" %s | IreeFileCheck %s
+// RUN: [[ $IREE_LLVMJIT_DISABLE == 1 ]] || (iree-run-mlir -export-all -iree-hal-target-backends=llvm-ir -function-input="3x4xi32=[[1,2,3,4],[5,6,7,8],[9,10,11,12]]" -function-input="1x2xi32=10" %s | IreeFileCheck %s)
+// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir -export-all -iree-hal-target-backends=vulkan-spirv -function-input="3x4xi32=[[1,2,3,4],[5,6,7,8],[9,10,11,12]]" -function-input="1x2xi32=10" %s | IreeFileCheck %s)
+
+// CHECK: EXEC @slice_stride_part
+// CHECK: 1x2xi32=[16 17]
+func @slice_stride_part(%arg0: tensor<3x4xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
+ %1 = "mhlo.slice"(%arg0) {
+ start_indices = dense<[1, 1]> : tensor<2xi64>,
+ limit_indices = dense<[2, 3]> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>
+ } : (tensor<3x4xi32>) -> tensor<1x2xi32>
+ %2 = mhlo.add %1, %arg1 : tensor<1x2xi32>
+ return %2 : tensor<1x2xi32>
+}
diff --git a/iree/vm/stack.h b/iree/vm/stack.h
index 6c50e7f..bf59de6 100644
--- a/iree/vm/stack.h
+++ b/iree/vm/stack.h
@@ -87,7 +87,7 @@
// code), etc.
iree_vm_source_offset_t pc;
- IREE_TRACE(iree_zone_id_t trace_zone)
+ IREE_TRACE(iree_zone_id_t trace_zone;)
} iree_vm_stack_frame_t;
// Returns the implementation-defined frame storage associated with |frame|.
diff --git a/scripts/update_e2e_coverage.py b/scripts/update_e2e_coverage.py
index 325cb28..fed8c12 100755
--- a/scripts/update_e2e_coverage.py
+++ b/scripts/update_e2e_coverage.py
@@ -28,13 +28,12 @@
TENSORFLOW_COVERAGE_DIR = 'tensorflow_coverage'
REFERENCE_BACKEND = 'tf'
-# Assumes that tests are expanded for the tf, iree_vmla, iree_llvmjit and
+# Assumes that tests are expanded for the tf, iree_vmla, and
# iree_vulkan backends.
BACKENDS_TO_TITLES = collections.OrderedDict([
('tf', 'tensorflow'),
('tflite', 'tflite'),
('iree_vmla', 'vmla'),
- ('iree_llvmjit', 'llvm-ir'),
('iree_vulkan', 'vulkan-spirv'),
])