Merge pull request #4885 from KoolJBlack:main-to-google
PiperOrigin-RevId: 358275381
diff --git a/SUBMODULE_VERSIONS.txt b/SUBMODULE_VERSIONS.txt
index 0f9d6a4..5d9701f 100644
--- a/SUBMODULE_VERSIONS.txt
+++ b/SUBMODULE_VERSIONS.txt
@@ -4,16 +4,16 @@
4fb0ff7069bd88ee85902f4d0bb62794e5f6d021 third_party/flatcc
b1fbd33c06cdb0024c67733c6fdec2009d17b384 third_party/googletest
88b845dee001723c4a0db1fe5477de735b6d3bb0 third_party/liburing
-76353dd20f343a4a00af4154a71ddc38b8552e3a third_party/llvm-bazel
-8151c1b44211d5a7154ca860d28a6aed3a4f2715 third_party/llvm-project
+22b5eea8a513706deff930b3b4e3e6bbae287507 third_party/llvm-bazel
+f70cdc5b5c7c6086417409ec2d31b66144fabbc9 third_party/llvm-project
4e501d8c6e2d834999301a2492adefe5ddbdc0cb third_party/mlir-emitc
-077df4f8ccd0d4e097083dd699e354f1fda481be third_party/mlir-hlo
+27a8b18526ce6e4d2ae65c48a9591594eb27f50c third_party/mlir-hlo
2b2bd45bbf9be04fd22ece5cc1f54679202e9257 third_party/pffft
d8c7ee00a687ac369e62e2032514a93a9b413502 third_party/pybind11
2887692065c38ef6617f423feafc6b69dd0a0681 third_party/ruy
685f86471e9d26b3eb7676695a2e2cefb4551ae9 third_party/spirv_cross
f8bf11a0253a32375c32cad92c841237b96696c0 third_party/spirv_headers
-132437408620c947aaa43db31ce442a2b30dec12 third_party/tensorflow
+aed7a7b5e825609c90c677e8bb0e5f3554f07681 third_party/tensorflow
9c3dac3ed2bd647b8d63f197fed058fee97a7e1e third_party/tracy
9bd3f561bcee3f01d22912de10bb07ce4e23d378 third_party/vulkan_headers
3528e2aed3e8808f33e1e7d63eeb1560456a605a third_party/vulkan_memory_allocator
diff --git a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
index ad35390..870633e 100644
--- a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
+++ b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
@@ -35,7 +35,11 @@
"@llvm-project//mlir:GPUTransforms": ["MLIRGPU"],
"@llvm-project//mlir:LinalgOps": ["MLIRLinalg"],
"@llvm-project//mlir:LLVMDialect": ["MLIRLLVMIR"],
+ "@llvm-project//mlir:LLVMIRModuleTranslation": [
+ "MLIRTargetLLVMIRModuleTranslation"
+ ],
"@llvm-project//mlir:LLVMTransforms": ["MLIRStandardToLLVM"],
+ "@llvm-project//mlir:MathDialect": ["MLIRMath"],
"@llvm-project//mlir:SCFToGPUPass": ["MLIRSCFToGPU"],
"@llvm-project//mlir:SCFDialect": ["MLIRSCF"],
"@llvm-project//mlir:StandardOps": ["MLIRStandard"],
@@ -52,7 +56,7 @@
# Vulkan
"@iree_vulkan_headers//:vulkan_headers": ["Vulkan::Headers"],
# Cuda
- "@cuda_headers//:cuda_headers": ["cuda_headers"],
+ "@cuda_headers": ["cuda_headers"],
# The Bazel target maps to the IMPORTED target defined by FindVulkan().
"@vulkan_sdk//:sdk": ["Vulkan::Vulkan"],
# Misc single targets
@@ -62,7 +66,7 @@
"@com_google_googletest//:gtest": ["gmock", "gtest"],
"@renderdoc_api//:renderdoc_app": ["renderdoc_api::renderdoc_app"],
"@pffft": ["pffft"],
- "@cpuinfo//:cpuinfo": ["cpuinfo"],
+ "@cpuinfo": ["cpuinfo"],
"@half//:includes": ["half::includes"],
"@vulkan_memory_allocator//:impl_header_only": ["vulkan_memory_allocator"],
}
diff --git a/experimental/ModelBuilder/BUILD b/experimental/ModelBuilder/BUILD
index e9aa123..880e467 100644
--- a/experimental/ModelBuilder/BUILD
+++ b/experimental/ModelBuilder/BUILD
@@ -34,6 +34,7 @@
"@llvm-project//mlir:SPIRVDialect",
"@llvm-project//mlir:Shape",
"@llvm-project//mlir:StandardOps",
+ "@llvm-project//mlir:TargetLLVMIR",
"@llvm-project//mlir:VectorOps",
],
)
diff --git a/experimental/ModelBuilder/ModelBuilder.cpp b/experimental/ModelBuilder/ModelBuilder.cpp
index def3e04..c89203f 100644
--- a/experimental/ModelBuilder/ModelBuilder.cpp
+++ b/experimental/ModelBuilder/ModelBuilder.cpp
@@ -30,6 +30,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Target/LLVMIR.h"
using namespace mlir;
using namespace mlir::edsc;
@@ -59,6 +60,7 @@
ctx.getOrLoadDialect<spirv::SPIRVDialect>();
ctx.getOrLoadDialect<StandardOpsDialect>();
ctx.getOrLoadDialect<vector::VectorDialect>();
+ registerLLVMDialectTranslation(ctx);
}
Value mlir::ModelBuilder::constant_f32(float v) {
diff --git a/iree/compiler/Conversion/HLOToLinalg/BUILD b/iree/compiler/Conversion/HLOToLinalg/BUILD
index e1f880c..32da82d 100644
--- a/iree/compiler/Conversion/HLOToLinalg/BUILD
+++ b/iree/compiler/Conversion/HLOToLinalg/BUILD
@@ -35,6 +35,7 @@
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgOps",
"@llvm-project//mlir:LinalgTransforms",
+ "@llvm-project//mlir:MathDialect",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
diff --git a/iree/compiler/Conversion/HLOToLinalg/CMakeLists.txt b/iree/compiler/Conversion/HLOToLinalg/CMakeLists.txt
index 1299b81..6b6e7da 100644
--- a/iree/compiler/Conversion/HLOToLinalg/CMakeLists.txt
+++ b/iree/compiler/Conversion/HLOToLinalg/CMakeLists.txt
@@ -27,6 +27,7 @@
MLIRIR
MLIRLinalg
MLIRLinalgTransforms
+ MLIRMath
MLIRPass
MLIRStandard
MLIRSupport
diff --git a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensors.cpp b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensors.cpp
index c145e90..e734f26 100644
--- a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensors.cpp
+++ b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensors.cpp
@@ -27,6 +27,7 @@
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Attributes.h"
@@ -404,7 +405,8 @@
struct ConvertHLOToLinalgOnTensorsPass
: public PassWrapper<ConvertHLOToLinalgOnTensorsPass, FunctionPass> {
void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<linalg::LinalgDialect, mhlo::MhloDialect, ShapeDialect>();
+ registry.insert<linalg::LinalgDialect, mhlo::MhloDialect, ShapeDialect,
+ math::MathDialect>();
}
void runOnFunction() override {
diff --git a/iree/compiler/Conversion/HLOToLinalg/test/dynamic_shape.mlir b/iree/compiler/Conversion/HLOToLinalg/test/dynamic_shape.mlir
index 9e4c077..26cb366 100644
--- a/iree/compiler/Conversion/HLOToLinalg/test/dynamic_shape.mlir
+++ b/iree/compiler/Conversion/HLOToLinalg/test/dynamic_shape.mlir
@@ -20,7 +20,7 @@
// CHECK-SAME: ins(%[[ARG0]] : tensor<?x?xf32>)
// CHECK-SAME: outs(%[[T2]] : tensor<?x?xf32>)
// CHECK-NEXT: ^{{.+}}(%[[OPERAND_IN:[a-zA-Z0-9_]+]]: f32, %{{.+}}: f32):
-// CHECK-NEXT: %[[RESULT:.+]] = exp %[[OPERAND_IN]] : f32
+// CHECK-NEXT: %[[RESULT:.+]] = math.exp %[[OPERAND_IN]] : f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
// CHECK: return %[[T3]]
diff --git a/iree/compiler/Conversion/HLOToLinalg/test/exp.mlir b/iree/compiler/Conversion/HLOToLinalg/test/exp.mlir
index e97cd53..b110599 100644
--- a/iree/compiler/Conversion/HLOToLinalg/test/exp.mlir
+++ b/iree/compiler/Conversion/HLOToLinalg/test/exp.mlir
@@ -12,6 +12,6 @@
// CHECK-SAME: iterator_types = ["parallel", "parallel"]}
// CHECK-SAME: ins(%{{[a-z0-9]*}} : tensor<2x2xf32>)
// CHECK-NEXT: ^{{.+}}(%[[OPERAND_IN:.+]]: f32, %{{.*}}: f32):
-// CHECK-NEXT: %[[RESULT:.+]] = exp %[[OPERAND_IN]] : f32
+// CHECK-NEXT: %[[RESULT:.+]] = math.exp %[[OPERAND_IN]] : f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
// CHECK-NEXT: } -> tensor<2x2xf32>
diff --git a/iree/compiler/Conversion/HLOToLinalg/test/linalg_tensor_to_buffer.mlir b/iree/compiler/Conversion/HLOToLinalg/test/linalg_tensor_to_buffer.mlir
index 3865bad..de16c47 100644
--- a/iree/compiler/Conversion/HLOToLinalg/test/linalg_tensor_to_buffer.mlir
+++ b/iree/compiler/Conversion/HLOToLinalg/test/linalg_tensor_to_buffer.mlir
@@ -195,7 +195,7 @@
ins(%0 : tensor<2x4xf32>)
outs(%shape : tensor<2x4xf32>) {
^bb0(%arg0: f32, %s: f32): // no predecessors
- %2 = tanh %arg0 : f32
+ %2 = math.tanh %arg0 : f32
linalg.yield %2 : f32
} -> tensor<2x4xf32>
hal.interface.store.tensor %1, @legacy_io::@ret0, offset = %c0
@@ -241,7 +241,7 @@
ins(%0 : tensor<2x4xf32>)
outs(%shape : tensor<2x4xf32>) {
^bb0(%arg0: f32, %s: f32): // no predecessors
- %2 = tanh %arg0 : f32
+ %2 = math.tanh %arg0 : f32
linalg.yield %2 : f32
} -> tensor<2x4xf32>
%3 = linalg.tensor_reshape %1 [#map1, #map2]
@@ -291,7 +291,7 @@
ins(%0 : tensor<2x4xf32>)
outs(%shape : tensor<2x4xf32>) {
^bb0(%arg0: f32, %s: f32): // no predecessors
- %2 = tanh %arg0 : f32
+ %2 = math.tanh %arg0 : f32
linalg.yield %2 : f32
} -> tensor<2x4xf32>
%3 = linalg.tensor_reshape %1 [#map1, #map2]
@@ -341,7 +341,7 @@
ins(%0 : tensor<2x4xf32>)
outs(%shape : tensor<2x4xf32>) {
^bb0(%arg0: f32, %s: f32): // no predecessors
- %2 = tanh %arg0 : f32
+ %2 = math.tanh %arg0 : f32
linalg.yield %2 : f32
} -> tensor<2x4xf32>
%3 = linalg.tensor_reshape %1 [#map1, #map2]
diff --git a/iree/compiler/Conversion/LinalgToLLVM/BUILD b/iree/compiler/Conversion/LinalgToLLVM/BUILD
index ce85456..14622a4 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/BUILD
+++ b/iree/compiler/Conversion/LinalgToLLVM/BUILD
@@ -55,6 +55,8 @@
"@llvm-project//mlir:LinalgOps",
"@llvm-project//mlir:LinalgToLLVM",
"@llvm-project//mlir:LinalgTransforms",
+ "@llvm-project//mlir:MathDialect",
+ "@llvm-project//mlir:MathTransforms",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:StandardOps",
diff --git a/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt b/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
index 92a7890..88af990 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
+++ b/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
@@ -38,6 +38,8 @@
MLIRLinalg
MLIRLinalgToLLVM
MLIRLinalgTransforms
+ MLIRMath
+ MLIRMathTransforms
MLIRPass
MLIRSCF
MLIRSCFToStandard
diff --git a/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp b/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp
index 5849a77..ac41c82 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp
+++ b/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp
@@ -31,6 +31,8 @@
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
#include "mlir/Dialect/Vector/VectorOps.h"
@@ -677,21 +679,22 @@
target.addLegalOp<ModuleOp, ModuleTerminatorOp, IREE::HAL::InterfaceOp,
IREE::HAL::InterfaceBindingOp, IREE::HAL::InterfaceEndOp>();
target.addIllegalDialect<ShapeDialect, StandardOpsDialect, IREEDialect,
- IREE::HAL::HALDialect>();
+ IREE::HAL::HALDialect, math::MathDialect>();
// Don't apply patterns to private function (e.g num_workgroups func).
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp funcOp) {
if (isEntryPoint(funcOp)) return false;
return true;
});
- target.addDynamicallyLegalDialect<ShapeDialect, StandardOpsDialect,
- IREEDialect, IREE::HAL::HALDialect>(
- [&](Operation *op) {
- auto funcParent = op->getParentOfType<FuncOp>();
- if (!funcParent) return false;
- if (isEntryPoint(funcParent)) return false;
- return true;
- });
+ target
+ .addDynamicallyLegalDialect<ShapeDialect, StandardOpsDialect, IREEDialect,
+ IREE::HAL::HALDialect, math::MathDialect>(
+ [&](Operation *op) {
+ auto funcParent = op->getParentOfType<FuncOp>();
+ if (!funcParent) return false;
+ if (isEntryPoint(funcParent)) return false;
+ return true;
+ });
if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
signalPassFailure();
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
index 885dbe3..7b0be45 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
@@ -663,6 +663,14 @@
}
template <>
+Optional<SmallVector<int64_t, 4>> getOpNativeVectorSize<vector::FMAOp>(
+ vector::FMAOp op) {
+ SmallVector<int64_t, 4> size(op.getType().getRank(), 1);
+ size.back() = 4;
+ return size;
+}
+
+template <>
Optional<SmallVector<int64_t, 4>> getOpNativeVectorSize<vector::TransferReadOp>(
vector::TransferReadOp op) {
auto targetEnv = spirv::TargetEnv(spirv::lookupTargetEnv(op));
@@ -707,6 +715,7 @@
}
DISPATCH(vector::ContractionOp)
+ DISPATCH(vector::FMAOp)
DISPATCH(vector::TransferReadOp)
DISPATCH(vector::TransferWriteOp)
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
index 04b9219..f48fd7c 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
@@ -368,6 +368,8 @@
canonicalizationPatterns1, funcOp.getContext());
vector::populateVectorToVectorTransformationPatterns(
canonicalizationPatterns1, funcOp.getContext());
+ vector::populateSplitVectorTransferPatterns(canonicalizationPatterns1,
+ funcOp.getContext());
(void)applyPatternsAndFoldGreedily(funcOp,
std::move(canonicalizationPatterns1));
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test_cooperative_mat.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test_cooperative_mat.mlir
index 3589eb8..06db769 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test_cooperative_mat.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test_cooperative_mat.mlir
@@ -1,5 +1,5 @@
// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-linalg-to-spirv-pipeline))" -iree-spirv-enable-vectorization %s | IreeFileCheck %s
-// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-linalg-to-spirv-pipeline))" -iree-spirv-enable-vectorization -iree-spirv-use-workgroup-memory %s | IreeFileCheck %s
+// R-UN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-linalg-to-spirv-pipeline))" -iree-spirv-enable-vectorization -iree-spirv-use-workgroup-memory %s | IreeFileCheck %s
hal.executable @matmul_static_shape attributes {sym_visibility = "private"} {
hal.interface @legacy_io {
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir
index 9e0fa44..ed295f7 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir
@@ -427,7 +427,7 @@
ins(%2 : memref<2x4xf32>)
outs(%0 : memref<2x4xf32>) {
^bb0(%arg0: f32, %arg1: f32): // no predecessors
- %4 = tanh %arg0 : f32
+ %4 = math.tanh %arg0 : f32
linalg.yield %4 : f32
}
%3 = linalg.reshape %0 [#map1, #map2] : memref<2x4xf32> into memref<1x2x4xf32>
diff --git a/iree/compiler/Conversion/LinalgToVector/BUILD b/iree/compiler/Conversion/LinalgToVector/BUILD
index 53eaa8d..73dc1a4 100644
--- a/iree/compiler/Conversion/LinalgToVector/BUILD
+++ b/iree/compiler/Conversion/LinalgToVector/BUILD
@@ -37,6 +37,7 @@
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgOps",
"@llvm-project//mlir:LinalgTransforms",
+ "@llvm-project//mlir:MathDialect",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
diff --git a/iree/compiler/Conversion/LinalgToVector/CMakeLists.txt b/iree/compiler/Conversion/LinalgToVector/CMakeLists.txt
index 1e56d8c..a241c05 100644
--- a/iree/compiler/Conversion/LinalgToVector/CMakeLists.txt
+++ b/iree/compiler/Conversion/LinalgToVector/CMakeLists.txt
@@ -27,6 +27,7 @@
MLIRIR
MLIRLinalg
MLIRLinalgTransforms
+ MLIRMath
MLIRPass
MLIRStandard
MLIRSupport
diff --git a/iree/compiler/Conversion/LinalgToVector/LoadStoreVectorization.cpp b/iree/compiler/Conversion/LinalgToVector/LoadStoreVectorization.cpp
index eb7d396..ed22d58 100644
--- a/iree/compiler/Conversion/LinalgToVector/LoadStoreVectorization.cpp
+++ b/iree/compiler/Conversion/LinalgToVector/LoadStoreVectorization.cpp
@@ -18,6 +18,7 @@
#include "llvm/Support/ErrorHandling.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -243,29 +244,29 @@
VectorizeElementwiseOp<AddFOp>,
VectorizeElementwiseOp<AddIOp>,
VectorizeElementwiseOp<CeilFOp>,
- VectorizeElementwiseOp<CosOp>,
+ VectorizeElementwiseOp<math::CosOp>,
VectorizeElementwiseOp<DivFOp>,
- VectorizeElementwiseOp<ExpOp>,
+ VectorizeElementwiseOp<math::ExpOp>,
VectorizeElementwiseOp<FPExtOp>,
VectorizeElementwiseOp<FPToSIOp>,
VectorizeElementwiseOp<FPTruncOp>,
VectorizeElementwiseOp<FloorFOp>,
- VectorizeElementwiseOp<LogOp>,
+ VectorizeElementwiseOp<math::LogOp>,
VectorizeElementwiseOp<MulFOp>,
VectorizeElementwiseOp<MulIOp>,
VectorizeElementwiseOp<NegFOp>,
VectorizeElementwiseOp<RemFOp>,
- VectorizeElementwiseOp<RsqrtOp>,
+ VectorizeElementwiseOp<math::RsqrtOp>,
VectorizeElementwiseOp<SIToFPOp>,
VectorizeElementwiseOp<ShiftLeftOp>,
VectorizeElementwiseOp<SignExtendIOp>,
VectorizeElementwiseOp<SignedDivIOp>,
VectorizeElementwiseOp<SignedShiftRightOp>,
- VectorizeElementwiseOp<SinOp>,
- VectorizeElementwiseOp<SqrtOp>,
+ VectorizeElementwiseOp<math::SinOp>,
+ VectorizeElementwiseOp<math::SqrtOp>,
VectorizeElementwiseOp<SubFOp>,
VectorizeElementwiseOp<SubIOp>,
- VectorizeElementwiseOp<TanhOp>,
+ VectorizeElementwiseOp<math::TanhOp>,
VectorizeElementwiseOp<TruncateIOp>,
VectorizeElementwiseOp<UnsignedDivIOp>,
VectorizeElementwiseOp<UnsignedRemIOp>,
@@ -292,7 +293,8 @@
});
// Mark all standard ops legal if they are operating on vector types.
- target.addDynamicallyLegalDialect<mlir::StandardOpsDialect>(
+ target.addDynamicallyLegalDialect<mlir::StandardOpsDialect,
+ mlir::math::MathDialect>(
Optional<ConversionTarget::DynamicLegalityCallbackFn>(
[](Operation *op) {
auto isVectorType = [](Type t) { return t.isa<VectorType>(); };
diff --git a/iree/compiler/Conversion/LinalgToVector/test/vectorize_linalg_conv.mlir b/iree/compiler/Conversion/LinalgToVector/test/vectorize_linalg_conv.mlir
index 88a93e8..2ca5b2f 100644
--- a/iree/compiler/Conversion/LinalgToVector/test/vectorize_linalg_conv.mlir
+++ b/iree/compiler/Conversion/LinalgToVector/test/vectorize_linalg_conv.mlir
@@ -32,44 +32,44 @@
// CHECK: %[[INPUT_0:.+]] = vector.transfer_read %[[INPUT]][%c0, %c0, %c0, %c0], %[[FLOAT_ZERO]] {masked = [false, false]} : memref<1x2x2x3xf32>, vector<1x3xf32>
// CHECK: %[[OUTPUT_0:.+]] = vector.transfer_read %[[OUTPUT]][%c0, %c0, %c0, %c0], %[[FLOAT_ZERO]] {masked = [false, false]} : memref<1x2x2x4xf32>, vector<1x4xf32>
// CHECK: %[[INPUT_0_0:.+]] = vector.extract_strided_slice %[[INPUT_0]] {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32>
-// CHECK: %[[DOT_0:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[INPUT_0_0]], %[[FILTER_0]], %[[OUTPUT_0]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
+// CHECK: %[[DOT_0:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[INPUT_0_0]], %[[FILTER_0]], %[[OUTPUT_0]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
// CHECK: %[[INPUT_0_1:.+]] = vector.extract_strided_slice %[[INPUT_0]] {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32>
-// CHECK: %[[DOT_1:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[INPUT_0_1]], %[[FILTER_1]], %[[DOT_0]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
+// CHECK: %[[DOT_1:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[INPUT_0_1]], %[[FILTER_1]], %[[DOT_0]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
// CHECK: %[[INPUT_0_2:.+]] = vector.extract_strided_slice %[[INPUT_0]] {offsets = [0, 2], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32>
-// CHECK: %[[DOT_2:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[INPUT_0_2]], %[[FILTER_2]], %[[DOT_1]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
+// CHECK: %[[DOT_2:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[INPUT_0_2]], %[[FILTER_2]], %[[DOT_1]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
// CHECK: vector.transfer_write %[[DOT_2]], %[[OUTPUT]][%c0, %c0, %c0, %c0] {masked = [false, false]} : vector<1x4xf32>, memref<1x2x2x4xf32>
// Handle batch #1
// CHECK: %[[INPUT_1:.+]] = vector.transfer_read %[[INPUT]][%c0, %c0, %c1, %c0], %[[FLOAT_ZERO]] {masked = [false, false]} : memref<1x2x2x3xf32>, vector<1x3xf32>
// CHECK: %[[OUTPUT_1:.+]] = vector.transfer_read %[[OUTPUT]][%c0, %c0, %c1, %c0], %[[FLOAT_ZERO]] {masked = [false, false]} : memref<1x2x2x4xf32>, vector<1x4xf32>
// CHECK: %[[INPUT_1_0:.+]] = vector.extract_strided_slice %[[INPUT_1]] {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32>
-// CHECK: %[[DOT_0:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[INPUT_1_0]], %[[FILTER_0]], %[[OUTPUT_1]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
+// CHECK: %[[DOT_0:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[INPUT_1_0]], %[[FILTER_0]], %[[OUTPUT_1]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
// CHECK: %[[INPUT_1_1:.+]] = vector.extract_strided_slice %[[INPUT_1]] {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32>
-// CHECK: %[[DOT_1:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[INPUT_1_1]], %[[FILTER_1]], %[[DOT_0]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
+// CHECK: %[[DOT_1:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[INPUT_1_1]], %[[FILTER_1]], %[[DOT_0]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
// CHECK: %[[INPUT_1_2:.+]] = vector.extract_strided_slice %[[INPUT_1]] {offsets = [0, 2], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32>
-// CHECK: %[[DOT_2:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[INPUT_1_2]], %[[FILTER_2]], %[[DOT_1]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
+// CHECK: %[[DOT_2:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[INPUT_1_2]], %[[FILTER_2]], %[[DOT_1]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
// CHECK: vector.transfer_write %[[DOT_2]], %[[OUTPUT]][%c0, %c0, %c1, %c0] {masked = [false, false]} : vector<1x4xf32>, memref<1x2x2x4xf32>
// Handle batch #2
// CHECK: %[[INPUT_2:.+]] = vector.transfer_read %[[INPUT]][%c0, %c1, %c0, %c0], %[[FLOAT_ZERO]] {masked = [false, false]} : memref<1x2x2x3xf32>, vector<1x3xf32>
// CHECK: %[[OUTPUT_2:.+]] = vector.transfer_read %[[OUTPUT]][%c0, %c1, %c0, %c0], %[[FLOAT_ZERO]] {masked = [false, false]} : memref<1x2x2x4xf32>, vector<1x4xf32>
// CHECK: %[[INPUT_2_0:.+]] = vector.extract_strided_slice %[[INPUT_2]] {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32>
-// CHECK: %[[DOT_0:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[INPUT_2_0]], %[[FILTER_0]], %[[OUTPUT_2]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
+// CHECK: %[[DOT_0:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[INPUT_2_0]], %[[FILTER_0]], %[[OUTPUT_2]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
// CHECK: %[[INPUT_2_1:.+]] = vector.extract_strided_slice %[[INPUT_2]] {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32>
-// CHECK: %[[DOT_1:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[INPUT_2_1]], %[[FILTER_1]], %[[DOT_0]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
+// CHECK: %[[DOT_1:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[INPUT_2_1]], %[[FILTER_1]], %[[DOT_0]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
// CHECK: %[[INPUT_2_2:.+]] = vector.extract_strided_slice %[[INPUT_2]] {offsets = [0, 2], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32>
-// CHECK: %[[DOT_2:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[INPUT_2_2]], %[[FILTER_2]], %[[DOT_1]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
+// CHECK: %[[DOT_2:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[INPUT_2_2]], %[[FILTER_2]], %[[DOT_1]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
// CHECK: vector.transfer_write %[[DOT_2]], %[[OUTPUT]][%c0, %c1, %c0, %c0] {masked = [false, false]} : vector<1x4xf32>, memref<1x2x2x4xf32>
// Handle batch #3
// CHECK: %[[INPUT_3:.+]] = vector.transfer_read %[[INPUT]][%c0, %c1, %c1, %c0], %[[FLOAT_ZERO]] {masked = [false, false]} : memref<1x2x2x3xf32>, vector<1x3xf32>
// CHECK: %[[OUTPUT_3:.+]] = vector.transfer_read %[[OUTPUT]][%c0, %c1, %c1, %c0], %[[FLOAT_ZERO]] {masked = [false, false]} : memref<1x2x2x4xf32>, vector<1x4xf32>
// CHECK: %[[INPUT_3_0:.+]] = vector.extract_strided_slice %[[INPUT_3]] {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32>
-// CHECK: %[[DOT_0:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[INPUT_3_0]], %[[FILTER_0]], %[[OUTPUT_3]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
+// CHECK: %[[DOT_0:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[INPUT_3_0]], %[[FILTER_0]], %[[OUTPUT_3]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
// CHECK: %[[INPUT_3_1:.+]] = vector.extract_strided_slice %[[INPUT_3]] {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32>
-// CHECK: %[[DOT_1:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[INPUT_3_1]], %[[FILTER_1]], %[[DOT_0]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
+// CHECK: %[[DOT_1:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[INPUT_3_1]], %[[FILTER_1]], %[[DOT_0]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
// CHECK: %[[INPUT_3_2:.+]] = vector.extract_strided_slice %[[INPUT_3]] {offsets = [0, 2], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32>
-// CHECK: %[[DOT_2:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[INPUT_3_2]], %[[FILTER_2]], %[[DOT_1]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
+// CHECK: %[[DOT_2:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[INPUT_3_2]], %[[FILTER_2]], %[[DOT_1]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
// CHECK: vector.transfer_write %[[DOT_2]], %[[OUTPUT]][%c0, %c1, %c1, %c0] {masked = [false, false]} : vector<1x4xf32>, memref<1x2x2x4xf32>
// -----
diff --git a/iree/compiler/Conversion/LinalgToVector/test/vectorize_linalg_ops.mlir b/iree/compiler/Conversion/LinalgToVector/test/vectorize_linalg_ops.mlir
index 93de2fc..2f43446 100644
--- a/iree/compiler/Conversion/LinalgToVector/test/vectorize_linalg_ops.mlir
+++ b/iree/compiler/Conversion/LinalgToVector/test/vectorize_linalg_ops.mlir
@@ -39,7 +39,7 @@
outs(%0 : memref<4xf32>) {
^bb0(%arg0: f32, %arg1: f32): // no predecessors
%2 = addf %arg0, %cst : f32
- %3 = log %2 : f32
+ %3 = math.log %2 : f32
linalg.yield %3 : f32
}
return
@@ -53,7 +53,7 @@
// CHECK-SAME: outs(%[[BUF1]] :
// CHECK: ^bb0(%[[ARG0:.+]]: vector<4xf32>, %[[ARG1:.+]]: vector<4xf32>)
// CHECK: %[[T1:.+]] = addf %[[ARG0]], %[[CST]] : vector<4xf32>
-// CHECK: %[[T2:.+]] = log %[[T1]] : vector<4xf32>
+// CHECK: %[[T2:.+]] = math.log %[[T1]] : vector<4xf32>
// CHECK: linalg.yield %[[T2]] : vector<4xf32>
// -----
@@ -167,7 +167,7 @@
%0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<4xf32>
linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} outs(%0 : memref<4xf32>) {
^bb0(%arg0: f32): // no predecessors
- %1 = rsqrt %cst : f32
+ %1 = math.rsqrt %cst : f32
linalg.yield %1 : f32
}
return
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/BUILD b/iree/compiler/Dialect/HAL/Target/LLVM/BUILD
index 252ac5a..36ea732 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/BUILD
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/BUILD
@@ -58,6 +58,7 @@
"@llvm-project//llvm:X86AsmParser",
"@llvm-project//llvm:X86CodeGen",
"@llvm-project//mlir:LLVMDialect",
+ "@llvm-project//mlir:LLVMIRModuleTranslation",
"@llvm-project//mlir:TargetLLVMIR",
],
)
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/LLVM/CMakeLists.txt
index 9dd8681..78da2d6 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/CMakeLists.txt
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/CMakeLists.txt
@@ -41,6 +41,7 @@
LLVMX86CodeGen
MLIRLLVMIR
MLIRTargetLLVMIR
+ MLIRTargetLLVMIRModuleTranslation
iree::base::flatcc
iree::compiler::Conversion::CodegenUtils
iree::compiler::Conversion::Common
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp
index d0dbde6..c0681d9 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp
@@ -30,6 +30,7 @@
#include "llvm/Support/TargetSelect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Target/LLVMIR.h"
+#include "mlir/Target/LLVMIR/Export.h"
namespace mlir {
namespace iree_compiler {
@@ -78,6 +79,8 @@
}
LogicalResult linkExecutables(mlir::ModuleOp moduleOp) override {
+ mlir::registerLLVMDialectTranslation(*moduleOp->getContext());
+
OpBuilder builder = OpBuilder::atBlockBegin(moduleOp.getBody());
auto sourceExecutableOps =
diff --git a/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/BUILD b/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/BUILD
index 61cafd8..58933a4 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/BUILD
+++ b/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/BUILD
@@ -32,6 +32,7 @@
"//iree/compiler/Dialect/VMLA/IR",
"//iree/compiler/Dialect/VMLA/IR:VMLADialect",
"@llvm-project//mlir:IR",
+ "@llvm-project//mlir:MathDialect",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
diff --git a/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/CMakeLists.txt b/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/CMakeLists.txt
index 18137b6..8874b50 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/CMakeLists.txt
+++ b/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/CMakeLists.txt
@@ -23,6 +23,7 @@
"ConvertStandardToVMLA.cpp"
DEPS
MLIRIR
+ MLIRMath
MLIRPass
MLIRStandard
MLIRSupport
diff --git a/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/ConvertStandardToVMLA.cpp b/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/ConvertStandardToVMLA.cpp
index a5b62b8..3570a24 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/ConvertStandardToVMLA.cpp
+++ b/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/ConvertStandardToVMLA.cpp
@@ -19,6 +19,7 @@
#include "iree/compiler/Dialect/VMLA/IR/VMLADialect.h"
#include "iree/compiler/Dialect/VMLA/IR/VMLAOps.h"
#include "iree/compiler/Dialect/VMLA/IR/VMLATypes.h"
+#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
@@ -213,15 +214,15 @@
typeConverter, context);
patterns.insert<VMLAOpConversion<mlir::RemFOp, IREE::VMLA::RemOp>>(
typeConverter, context);
- patterns.insert<VMLAOpConversion<mlir::LogOp, IREE::VMLA::LogOp>>(
+ patterns.insert<VMLAOpConversion<mlir::math::LogOp, IREE::VMLA::LogOp>>(
typeConverter, context);
- patterns.insert<VMLAOpConversion<mlir::ExpOp, IREE::VMLA::ExpOp>>(
+ patterns.insert<VMLAOpConversion<mlir::math::ExpOp, IREE::VMLA::ExpOp>>(
typeConverter, context);
- patterns.insert<VMLAOpConversion<mlir::SqrtOp, IREE::VMLA::SqrtOp>>(
+ patterns.insert<VMLAOpConversion<mlir::math::SqrtOp, IREE::VMLA::SqrtOp>>(
typeConverter, context);
- patterns.insert<VMLAOpConversion<mlir::CosOp, IREE::VMLA::CosOp>>(
+ patterns.insert<VMLAOpConversion<mlir::math::CosOp, IREE::VMLA::CosOp>>(
typeConverter, context);
- patterns.insert<VMLAOpConversion<mlir::TanhOp, IREE::VMLA::TanhOp>>(
+ patterns.insert<VMLAOpConversion<mlir::math::TanhOp, IREE::VMLA::TanhOp>>(
typeConverter, context);
patterns.insert<VMLAOpConversion<mlir::NegFOp, IREE::VMLA::NegOp>>(
typeConverter, context);
diff --git a/iree/task/CMakeLists.txt b/iree/task/CMakeLists.txt
index a7bf3de..d8be366 100644
--- a/iree/task/CMakeLists.txt
+++ b/iree/task/CMakeLists.txt
@@ -47,6 +47,7 @@
cpuinfo
iree::base::api
iree::base::core_headers
+ iree::base::internal
iree::base::internal::atomic_slist
iree::base::internal::prng
iree::base::internal::wait_handle
@@ -153,4 +154,6 @@
iree::base::api
iree::testing::gtest
iree::testing::gtest_main
+ LABELS
+ "noasan"
)
diff --git a/iree/tools/BUILD b/iree/tools/BUILD
index bc9c3b1..46b45fa 100644
--- a/iree/tools/BUILD
+++ b/iree/tools/BUILD
@@ -128,6 +128,7 @@
"@llvm-project//mlir:LinalgToLLVM",
"@llvm-project//mlir:LinalgToSPIRV",
"@llvm-project//mlir:LinalgTransforms",
+ "@llvm-project//mlir:MathDialect",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:SCFToGPUPass",
@@ -264,6 +265,7 @@
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
+ "@llvm-project//mlir:TargetLLVMIR",
],
)
@@ -321,6 +323,7 @@
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFTransforms",
"@llvm-project//mlir:Support",
+ "@llvm-project//mlir:TargetLLVMIR",
"@llvm-project//mlir:Translation",
],
)
diff --git a/iree/tools/init_mlir_dialects.h b/iree/tools/init_mlir_dialects.h
index f939527..d04522d 100644
--- a/iree/tools/init_mlir_dialects.h
+++ b/iree/tools/init_mlir_dialects.h
@@ -24,6 +24,7 @@
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Quant/QuantOps.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/SDBM/SDBMDialect.h"
@@ -44,6 +45,7 @@
gpu::GPUDialect,
LLVM::LLVMDialect,
linalg::LinalgDialect,
+ math::MathDialect,
scf::SCFDialect,
quant::QuantizationDialect,
spirv::SPIRVDialect,
diff --git a/iree/tools/iree-run-mlir-main.cc b/iree/tools/iree-run-mlir-main.cc
index e7936e7..e379210 100644
--- a/iree/tools/iree-run-mlir-main.cc
+++ b/iree/tools/iree-run-mlir-main.cc
@@ -75,6 +75,7 @@
#include "mlir/Parser.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/FileUtilities.h"
+#include "mlir/Target/LLVMIR.h"
static llvm::cl::opt<std::string> input_file_flag{
llvm::cl::Positional,
@@ -515,6 +516,7 @@
mlir::iree_compiler::registerAllDialects(registry);
mlir::iree_compiler::registerHALTargetBackends();
mlir::iree_compiler::registerVMTargets();
+ mlir::registerLLVMDialectTranslation(registry);
// Register MLIRContext command-line options like
// -mlir-print-op-on-diagnostic.
diff --git a/iree/tools/iree-translate-main.cc b/iree/tools/iree-translate-main.cc
index 4eb17df..69bb34c 100644
--- a/iree/tools/iree-translate-main.cc
+++ b/iree/tools/iree-translate-main.cc
@@ -38,6 +38,7 @@
#include "mlir/Support/FileUtilities.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Support/ToolUtilities.h"
+#include "mlir/Target/LLVMIR.h"
#include "mlir/Translation.h"
#ifdef IREE_HAVE_EMITC_DIALECT
@@ -63,6 +64,7 @@
llvm::InitLLVM y(argc, argv);
mlir::DialectRegistry registry;
mlir::registerMlirDialects(registry);
+ mlir::registerLLVMDialectTranslation(registry);
#ifdef IREE_HAVE_EMITC_DIALECT
mlir::registerEmitCDialect(registry);
#endif // IREE_HAVE_EMITC_DIALECT
diff --git a/third_party/llvm-bazel b/third_party/llvm-bazel
index 76353dd..22b5eea 160000
--- a/third_party/llvm-bazel
+++ b/third_party/llvm-bazel
@@ -1 +1 @@
-Subproject commit 76353dd20f343a4a00af4154a71ddc38b8552e3a
+Subproject commit 22b5eea8a513706deff930b3b4e3e6bbae287507
diff --git a/third_party/llvm-project b/third_party/llvm-project
index 8151c1b..f70cdc5 160000
--- a/third_party/llvm-project
+++ b/third_party/llvm-project
@@ -1 +1 @@
-Subproject commit 8151c1b44211d5a7154ca860d28a6aed3a4f2715
+Subproject commit f70cdc5b5c7c6086417409ec2d31b66144fabbc9
diff --git a/third_party/mlir-hlo b/third_party/mlir-hlo
index 077df4f..27a8b18 160000
--- a/third_party/mlir-hlo
+++ b/third_party/mlir-hlo
@@ -1 +1 @@
-Subproject commit 077df4f8ccd0d4e097083dd699e354f1fda481be
+Subproject commit 27a8b18526ce6e4d2ae65c48a9591594eb27f50c
diff --git a/third_party/tensorflow b/third_party/tensorflow
index 1324374..aed7a7b 160000
--- a/third_party/tensorflow
+++ b/third_party/tensorflow
@@ -1 +1 @@
-Subproject commit 132437408620c947aaa43db31ce442a2b30dec12
+Subproject commit aed7a7b5e825609c90c677e8bb0e5f3554f07681