Merge pull request #3497 from ScottTodd:main-to-google
PiperOrigin-RevId: 337550323
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 472cea2..9d5a461 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -46,7 +46,6 @@
option(IREE_BUILD_PYTHON_BINDINGS "Builds the IREE python bindings" OFF)
option(IREE_BUILD_EXPERIMENTAL "Builds experimental projects." OFF)
-#TODO: Make this functional!
set(IREE_HAL_DRIVERS_TO_BUILD "all"
CACHE STRING "Semicolon-separated list of HAL drivers to build, or \"all\".")
set(IREE_TARGET_BACKENDS_TO_BUILD "all"
diff --git a/integrations/tensorflow/compiler/dialect/tf_strings/conversion/test/convert_flow_to_hal.mlir b/integrations/tensorflow/compiler/dialect/tf_strings/conversion/test/convert_flow_to_hal.mlir
index 2a4c4b1..5664219 100644
--- a/integrations/tensorflow/compiler/dialect/tf_strings/conversion/test/convert_flow_to_hal.mlir
+++ b/integrations/tensorflow/compiler/dialect/tf_strings/conversion/test/convert_flow_to_hal.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-tf-opt --iree-convert-flow-to-hal %s --split-input-file | IreeFileCheck %s
+// RUN: iree-tf-opt --iree-convert-to-hal %s --split-input-file | IreeFileCheck %s
// CHECK-LABEL: @i32_to_string
func @i32_to_string(%arg0 : i32) -> !tf_strings.string {
diff --git a/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/test/convert_flow_to_hal.mlir b/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/test/convert_flow_to_hal.mlir
index aa07172..a1fc6ad 100644
--- a/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/test/convert_flow_to_hal.mlir
+++ b/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/test/convert_flow_to_hal.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-tf-opt <%s -iree-convert-flow-to-hal -split-input-file | IreeFileCheck %s
+// RUN: iree-tf-opt <%s -iree-convert-to-hal -split-input-file | IreeFileCheck %s
// CHECK-LABEL: func @Reserve(%arg0: !hal.buffer, %arg1: !hal.buffer) -> !tensorlist.list {
func @Reserve(%arg0: tensor<0xi32>, %arg1: tensor<i32>) -> !tf_tensorlist.list{
diff --git a/integrations/tensorflow/e2e/iree_e2e_cartesian_product_test_suite.bzl b/integrations/tensorflow/e2e/iree_e2e_cartesian_product_test_suite.bzl
index e7a4fe5..0af8da8 100644
--- a/integrations/tensorflow/e2e/iree_e2e_cartesian_product_test_suite.bzl
+++ b/integrations/tensorflow/e2e/iree_e2e_cartesian_product_test_suite.bzl
@@ -141,6 +141,12 @@
for failing_configuration in failing_configurations:
failing_configuration = _normalize_dictionary(failing_configuration)
+ for key in failing_configuration:
+ if key not in flags_to_values:
+ fail("Encountered unexpected key \"{}\" ".format(key) +
+ "in a failing configuration. Expected one of " +
+ "{}.".format(list(flags_to_values.keys())))
+
# If a flag isn't specified in the failing configuration, assume it
# is failing for all values of that flag.
for key, values in flags_to_values.items():
diff --git a/integrations/tensorflow/e2e/keras/BUILD b/integrations/tensorflow/e2e/keras/BUILD
index 78098b1..0643127 100644
--- a/integrations/tensorflow/e2e/keras/BUILD
+++ b/integrations/tensorflow/e2e/keras/BUILD
@@ -173,7 +173,8 @@
# All models with runtime shorter than ResNet50.
"MobileNet", # Max: Vulkan 61.0s
"MobileNetV2", # Max: LLVM 96.3s
- "ResNet50", # Max: LLVM 145.6s
+ # TODO(#3489): started having vulkan diffs after #3424; investigate.
+ # "ResNet50", # Max: LLVM 145.6s
"VGG16", # Max: LLVM 89.5s
"VGG19", # Max: LLVM 94.7s
],
@@ -294,14 +295,13 @@
"InceptionResNetV2",
"InceptionV3",
],
- "datasets": ["imagenet"],
- "backends": [
+ "target_backends": [
"iree_vulkan",
],
},
{
# Failing llvm and vulkan:
- "models": [
+ "model": [
"NASNetLarge",
"NASNetMobile",
"ResNet50V2",
diff --git a/iree/base/threading.c b/iree/base/threading.c
index 2ca8cdb..0b3a49f 100644
--- a/iree/base/threading.c
+++ b/iree/base/threading.c
@@ -24,8 +24,8 @@
#include <intrin.h>
#endif // IREE_COMPILER_MSVC
-int iree_strncpy_s(char* restrict dest, size_t destsz, const char* restrict src,
- size_t count) {
+int iree_strncpy_s(char* IREE_RESTRICT dest, size_t destsz,
+ const char* IREE_RESTRICT src, size_t count) {
#if defined(IREE_COMPILER_MSVC) || defined(__STDC_LIB_EXT1__)
return strncpy_s(dest, destsz, src, count);
#else
diff --git a/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp b/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
index a282eb0..df4579e 100644
--- a/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
+++ b/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
@@ -156,7 +156,10 @@
auto variableOp = dyn_cast_or_null<VariableOp>(
SymbolTable::lookupNearestSymbolFrom(*this, variable()));
if (!variableOp) return {};
- if (variableOp.is_mutable()) {
+ if (variableOp.getAttr("noinline")) {
+ // Inlining of the constant has been disabled.
+ return {};
+ } else if (variableOp.is_mutable()) {
// We can't inline mutable variables as they may be changed at any time.
// There may still be other folders/canonicalizers that can help (such as
// store-forwarding).
diff --git a/iree/compiler/Dialect/Flow/IR/test/variable_folding.mlir b/iree/compiler/Dialect/Flow/IR/test/variable_folding.mlir
index 3343677..09c62b3 100644
--- a/iree/compiler/Dialect/Flow/IR/test/variable_folding.mlir
+++ b/iree/compiler/Dialect/Flow/IR/test/variable_folding.mlir
@@ -32,6 +32,16 @@
// -----
+flow.variable @v_const dense<1.0> : tensor<8xf32> attributes {noinline}
+// CHECK-LABEL: @no_fold_noinline_immutable_const
+func @no_fold_noinline_immutable_const() -> tensor<8xf32> {
+ // CHECK-NEXT: = flow.variable.load @v_const : tensor<8xf32>
+ %0 = flow.variable.load @v_const : tensor<8xf32>
+ return %0 : tensor<8xf32>
+}
+
+// -----
+
flow.variable @v_nop mutable : tensor<4xi32>
// CHECK-LABEL: @nop_load_store
func @nop_load_store() {
diff --git a/iree/compiler/Dialect/Flow/Transforms/BUILD b/iree/compiler/Dialect/Flow/Transforms/BUILD
index e3ce6a7..c21ed77 100644
--- a/iree/compiler/Dialect/Flow/Transforms/BUILD
+++ b/iree/compiler/Dialect/Flow/Transforms/BUILD
@@ -34,6 +34,7 @@
"MaterializeExportedReflection.cpp",
"MergeExportedReflection.cpp",
"OutlineDispatchRegions.cpp",
+ "OutlineLargeConstantsPass.cpp",
"Passes.cpp",
"PrePostPartitioningConversion.cpp",
"RematerializeDispatchConstants.cpp",
diff --git a/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt b/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
index 020450c..d8c8698 100644
--- a/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
+++ b/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
@@ -34,6 +34,7 @@
"MaterializeExportedReflection.cpp"
"MergeExportedReflection.cpp"
"OutlineDispatchRegions.cpp"
+ "OutlineLargeConstantsPass.cpp"
"Passes.cpp"
"PrePostPartitioningConversion.cpp"
"RematerializeDispatchConstants.cpp"
diff --git a/iree/compiler/Dialect/Flow/Transforms/OutlineLargeConstantsPass.cpp b/iree/compiler/Dialect/Flow/Transforms/OutlineLargeConstantsPass.cpp
new file mode 100644
index 0000000..5f2d2cd
--- /dev/null
+++ b/iree/compiler/Dialect/Flow/Transforms/OutlineLargeConstantsPass.cpp
@@ -0,0 +1,131 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <utility>
+
+#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace Flow {
+
+// NOTE: a total guess :) this feels like about the most per-dispatch-buffer
+// data we'd want to embed in the command buffer.
+// TODO(benvanik): make a pass option so users can override.
+static constexpr size_t kMinLargeConstantSize = 256;
+
+// Returns true if |constantOp| is large enough to be considered for pooling.
+// Some constants are small enough that inlining them into the ringbuffer is
+// more efficient and fewer bindings.
+static bool isConstantLarge(ConstantOp constantOp) {
+ auto type = constantOp.getType();
+ if (auto shapedType = type.dyn_cast<RankedTensorType>()) {
+ size_t unpackedByteLength =
+ (shapedType.getNumElements() * shapedType.getElementTypeBitWidth()) / 8;
+ if (unpackedByteLength >= kMinLargeConstantSize) {
+ return true;
+ }
+ }
+ return false;
+}
+
+// Returns a list of all large constants in the module.
+// Only walks top-level functions and ops to avoid pulling constants out of
+// executables.
+static std::vector<ConstantOp> findLargeConstantsInModule(ModuleOp moduleOp) {
+ std::vector<ConstantOp> largeConstantOps;
+ for (auto funcOp : moduleOp.getOps<FuncOp>()) {
+ for (auto &block : funcOp.getBlocks()) {
+ for (auto constantOp : block.getOps<ConstantOp>()) {
+ if (isConstantLarge(constantOp)) {
+ largeConstantOps.push_back(constantOp);
+ }
+ }
+ }
+ }
+ return largeConstantOps;
+}
+
+class OutlineLargeConstantsPass
+ : public PassWrapper<OutlineLargeConstantsPass, OperationPass<ModuleOp>> {
+ public:
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<IREE::Flow::FlowDialect>();
+ }
+
+ void runOnOperation() override {
+ auto moduleOp = getOperation();
+
+ // For name uniquing.
+ SymbolTable moduleSymbols(moduleOp);
+ std::string baseName = "_large_const_";
+ int uniqueId = 0;
+
+ // Create all top-level flow.variables from large constants in the module.
+ OpBuilder moduleBuilder(&moduleOp.getBody()->front());
+ std::vector<std::pair<ConstantOp, IREE::Flow::VariableOp>> replacements;
+ for (auto &largeConstantOp : findLargeConstantsInModule(moduleOp)) {
+ std::string name;
+ do {
+ name = baseName + std::to_string(uniqueId++);
+ } while (moduleSymbols.lookup(name) != nullptr);
+ auto variableOp = moduleBuilder.create<IREE::Flow::VariableOp>(
+ largeConstantOp.getLoc(), name, /*isMutable=*/false,
+ largeConstantOp.getType(), largeConstantOp.getValue());
+ SymbolTable::setSymbolVisibility(variableOp,
+ SymbolTable::Visibility::Private);
+ replacements.emplace_back(largeConstantOp, variableOp);
+
+ // Prevent the variable from being re-inlined if the canonicalizer runs.
+ // By the time we've outlined things here we are sure we want them
+ // outlined even if the user runs an arbitrary number of passes between
+ // now and when we may use that information (HAL constant pooling, etc).
+ variableOp.setAttr("noinline", moduleBuilder.getUnitAttr());
+ }
+
+ // Replace all of the constants with lookups for the new variables.
+ for (auto pair : replacements) {
+ auto constantOp = pair.first;
+ auto variableOp = pair.second;
+ OpBuilder builder(moduleOp.getContext());
+ builder.setInsertionPoint(constantOp);
+ auto lookupOp = builder.create<IREE::Flow::VariableLoadOp>(
+ constantOp.getLoc(), constantOp.getType(), variableOp.getName());
+ constantOp.getResult().replaceAllUsesWith(lookupOp);
+ constantOp.erase();
+ }
+ }
+};
+
+std::unique_ptr<OperationPass<ModuleOp>> createOutlineLargeConstantsPass() {
+ return std::make_unique<OutlineLargeConstantsPass>(); // NOLINT
+}
+
+static PassRegistration<OutlineLargeConstantsPass> pass(
+ "iree-flow-outline-large-constants",
+ "Outlines large tensor constants into flow.variables at the module level.");
+
+} // namespace Flow
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
index 636ddd4..c2f821f 100644
--- a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
@@ -196,6 +196,10 @@
passManager.addNestedPass<FuncOp>(createCanonicalizerPass());
passManager.addNestedPass<FuncOp>(createCSEPass());
+ // Prior to leaving the pipeline we need to clean things up for following
+ // layers. These transforms may be undone by subsequent CSE/folding passes.
+ passManager.addPass(createOutlineLargeConstantsPass());
+
// Symbol DCE any remaining variables/functions that are now no longer
// required.
passManager.addPass(createSymbolDCEPass());
diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.h b/iree/compiler/Dialect/Flow/Transforms/Passes.h
index 0ce0551..0e432cb 100644
--- a/iree/compiler/Dialect/Flow/Transforms/Passes.h
+++ b/iree/compiler/Dialect/Flow/Transforms/Passes.h
@@ -123,6 +123,9 @@
// TODO(benvanik): pass to dedupe similar executables (by making dynamically
// shaped, adjusting types, etc).
+// Outlines large tensor constants into flow.variables at the module level.
+std::unique_ptr<OperationPass<ModuleOp>> createOutlineLargeConstantsPass();
+
//===----------------------------------------------------------------------===//
// Stream Formation and Folding
//===----------------------------------------------------------------------===//
@@ -158,6 +161,7 @@
createFoldCompatibleDispatchRegionsPass();
createRematerializeDispatchConstantsPass();
createOutlineDispatchRegionsPass();
+ createOutlineLargeConstantsPass();
createFormStreamsPass();
createHoistUnstreamableOpsPass();
}
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/outline_large_constants.mlir b/iree/compiler/Dialect/Flow/Transforms/test/outline_large_constants.mlir
new file mode 100644
index 0000000..e08392f
--- /dev/null
+++ b/iree/compiler/Dialect/Flow/Transforms/test/outline_large_constants.mlir
@@ -0,0 +1,10 @@
+// RUN: iree-opt -split-input-file -iree-flow-outline-large-constants %s | IreeFileCheck %s
+
+// CHECK: flow.variable @[[LARGE_VARIABLE:.+]] dense<1.200000e+00> : tensor<512x128xf32>
+func @fn1() -> (tensor<2xf32>, tensor<512x128xf32>) {
+ // CHECK-DAG: %[[SMALL_VALUE:.+]] = constant dense<{{.+}}> : tensor<2xf32>
+ %cst_0 = constant dense<[0.0287729427, 0.0297581609]> : tensor<2xf32>
+ // CHECK-DAG: %[[LARGE_VALUE:.+]] = flow.variable.load @[[LARGE_VARIABLE]] : tensor<512x128xf32>
+ %cst_1 = constant dense<1.2> : tensor<512x128xf32>
+ return %cst_0, %cst_1 : tensor<2xf32>, tensor<512x128xf32>
+}
diff --git a/iree/compiler/Dialect/HAL/Conversion/BUILD b/iree/compiler/Dialect/HAL/Conversion/BUILD
index 260c837..f435b85 100644
--- a/iree/compiler/Dialect/HAL/Conversion/BUILD
+++ b/iree/compiler/Dialect/HAL/Conversion/BUILD
@@ -31,7 +31,6 @@
],
deps = [
"//iree/compiler/Dialect/HAL/IR",
- "//iree/compiler/Dialect/HAL/IR:HALDialect",
"//iree/compiler/Dialect/HAL/Utils",
"//iree/compiler/Dialect/IREE/IR",
"//iree/compiler/Dialect/Shape/IR",
@@ -47,7 +46,6 @@
"Passes.h",
],
deps = [
- "//iree/compiler/Dialect/HAL/Conversion/FlowToHAL",
"//iree/compiler/Dialect/HAL/Conversion/HALToVM",
"@llvm-project//mlir:Pass",
],
diff --git a/iree/compiler/Dialect/HAL/Conversion/CMakeLists.txt b/iree/compiler/Dialect/HAL/Conversion/CMakeLists.txt
index 2858ada..ffcb261 100644
--- a/iree/compiler/Dialect/HAL/Conversion/CMakeLists.txt
+++ b/iree/compiler/Dialect/HAL/Conversion/CMakeLists.txt
@@ -29,7 +29,6 @@
MLIRStandard
MLIRTransforms
iree::compiler::Dialect::HAL::IR
- iree::compiler::Dialect::HAL::IR::HALDialect
iree::compiler::Dialect::HAL::Utils
iree::compiler::Dialect::IREE::IR
iree::compiler::Dialect::Shape::IR
@@ -43,7 +42,6 @@
"Passes.h"
DEPS
MLIRPass
- iree::compiler::Dialect::HAL::Conversion::FlowToHAL
iree::compiler::Dialect::HAL::Conversion::HALToVM
PUBLIC
)
diff --git a/iree/compiler/Dialect/HAL/Conversion/ConversionTarget.cpp b/iree/compiler/Dialect/HAL/Conversion/ConversionTarget.cpp
index 2cd0cf3..30eb695 100644
--- a/iree/compiler/Dialect/HAL/Conversion/ConversionTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Conversion/ConversionTarget.cpp
@@ -34,11 +34,15 @@
// illegal types.
markUnknownOpDynamicallyLegal();
- // The HAL dialect expects both standard ops and the HAL ops (in case some
- // conversion has already happened).
- addLegalDialect<StandardOpsDialect>();
- addLegalOp<ModuleOp, ModuleTerminatorOp>();
- addLegalDialect<IREE::HAL::HALDialect>();
+ // The HAL dialect allows hal ops as input as we may be running on partially
+ // processed files or may have already lowered some constructs (like constant
+ // pools).
+ addLegalDialect("hal");
+
+ // We don't care about the contents of a HAL executable: it may have any kind
+ // of dialect and type usage.
+ addLegalOp<IREE::HAL::ExecutableOp>();
+ markOpRecursivelyLegal<IREE::HAL::ExecutableOp>();
// There are a variety of patterns which convert std.dim and std.rank ops
// to corresponding HAL ops. All should be eliminated.
@@ -49,18 +53,6 @@
addDynamicallyLegalOp<Shape::TieShapeOp>([&](Shape::TieShapeOp op) {
return typeConverter.isLegal(op.result().getType());
});
-
- // We don't care about the contents of a HAL executable: it may have any kind
- // of dialect and type usage.
- addLegalOp<IREE::HAL::ExecutableOp>();
- markOpRecursivelyLegal<IREE::HAL::ExecutableOp>();
-
- addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
- return typeConverter.isSignatureLegal(op.getType()) &&
- typeConverter.isLegal(&op.getBody());
- });
- addDynamicallyLegalOp<ConstantOp>(
- [&](ConstantOp op) { return typeConverter.isLegal(op.getType()); });
}
bool HALConversionTarget::isDynamicallyLegal(Operation *op) const {
diff --git a/iree/compiler/Dialect/HAL/Conversion/ConversionTarget.h b/iree/compiler/Dialect/HAL/Conversion/ConversionTarget.h
index b210d10..2140d42 100644
--- a/iree/compiler/Dialect/HAL/Conversion/ConversionTarget.h
+++ b/iree/compiler/Dialect/HAL/Conversion/ConversionTarget.h
@@ -15,7 +15,6 @@
#ifndef IREE_COMPILER_DIALECT_HAL_CONVERSION_CONVERSIONTARGET_H_
#define IREE_COMPILER_DIALECT_HAL_CONVERSION_CONVERSIONTARGET_H_
-#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/MLIRContext.h"
diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/BUILD b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/BUILD
index f0bc225..c6da78b 100644
--- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/BUILD
+++ b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/BUILD
@@ -24,7 +24,6 @@
"ConvertFlowToHAL.cpp",
"ConvertShapeQueryOps.cpp",
"ConvertStreamOps.cpp",
- "ConvertStructuralOps.cpp",
"ConvertTensorOps.cpp",
"ConvertVariableOps.cpp",
],
@@ -38,8 +37,6 @@
"//iree/compiler/Dialect/HAL/IR:HALDialect",
"//iree/compiler/Dialect/HAL/Target",
"//iree/compiler/Dialect/HAL/Utils",
- "//iree/compiler/Dialect/IREE/Conversion:ConvertToHAL",
- "//iree/compiler/Dialect/IREE/Conversion:PreserveCompilerHints",
"//iree/compiler/Dialect/IREE/IR",
"//iree/compiler/Dialect/Shape/IR",
"@llvm-project//llvm:Support",
diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/CMakeLists.txt b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/CMakeLists.txt
index f58d6df..4845e08 100644
--- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/CMakeLists.txt
+++ b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/CMakeLists.txt
@@ -23,7 +23,6 @@
"ConvertFlowToHAL.cpp"
"ConvertShapeQueryOps.cpp"
"ConvertStreamOps.cpp"
- "ConvertStructuralOps.cpp"
"ConvertTensorOps.cpp"
"ConvertVariableOps.cpp"
DEPS
@@ -38,8 +37,6 @@
iree::compiler::Dialect::HAL::IR::HALDialect
iree::compiler::Dialect::HAL::Target
iree::compiler::Dialect::HAL::Utils
- iree::compiler::Dialect::IREE::Conversion::ConvertToHAL
- iree::compiler::Dialect::IREE::Conversion::PreserveCompilerHints
iree::compiler::Dialect::IREE::IR
iree::compiler::Dialect::Shape::IR
PUBLIC
diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertFlowToHAL.cpp b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertFlowToHAL.cpp
index c91168f..cec7ace 100644
--- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertFlowToHAL.cpp
+++ b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertFlowToHAL.cpp
@@ -15,25 +15,7 @@
#include "iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertFlowToHAL.h"
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
-#include "iree/compiler/Dialect/HAL/Conversion/ConversionDialectInterface.h"
-#include "iree/compiler/Dialect/HAL/Conversion/ConversionTarget.h"
-#include "iree/compiler/Dialect/HAL/Conversion/TypeConverter.h"
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
-#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
-#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
-#include "iree/compiler/Dialect/IREE/Conversion/ConvertToHAL.h"
-#include "iree/compiler/Dialect/IREE/Conversion/PreserveCompilerHints.h"
-#include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
-#include "iree/compiler/Dialect/IREE/IR/IREETypes.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/Function.h"
-#include "mlir/IR/Matchers.h"
-#include "mlir/IR/Module.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
@@ -44,11 +26,6 @@
OwningRewritePatternList &patterns,
TypeConverter &converter);
-// Populates only the structural (module/function/etc) conversion patterns.
-void populateFlowStructuralToHALPatterns(MLIRContext *context,
- OwningRewritePatternList &patterns,
- TypeConverter &converter);
-
// Populates only the flow.tensor.* conversion patterns.
void populateFlowTensorToHALPatterns(MLIRContext *context,
OwningRewritePatternList &patterns,
@@ -64,69 +41,21 @@
OwningRewritePatternList &patterns,
TypeConverter &converter);
-namespace {
-
-// A pass converting the IREE flow dialect into the IREE HAL dialect.
-class ConvertFlowToHALPass
- : public PassWrapper<ConvertFlowToHALPass, OperationPass<ModuleOp>> {
- public:
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<IREE::HAL::HALDialect>();
- }
-
- void runOnOperation() override {
- auto *context = &getContext();
-
- SmallVector<const HALConversionDialectInterface *, 4> conversionInterfaces;
- // Gather all interfaces from registered dialects.
- // These will perform the tensor->buffer mapping for their ops.
- for (auto *dialect : context->getLoadedDialects()) {
- if (auto *conversionInterface =
- dialect
- ->getRegisteredInterface<HALConversionDialectInterface>()) {
- conversionInterfaces.emplace_back(conversionInterface);
- }
- }
- HALTypeConverter typeConverter(conversionInterfaces);
- HALConversionTarget target(context, typeConverter);
- target.addIllegalDialect<IREE::Flow::FlowDialect>();
-
- OwningRewritePatternList patterns;
- populateFlowStreamToHALPatterns(context, patterns, typeConverter);
- populateFlowStructuralToHALPatterns(context, patterns, typeConverter);
- populateFlowTensorToHALPatterns(context, patterns, typeConverter);
- populateFlowVariableToHALPatterns(context, patterns, typeConverter);
- populateHalBufferViewShapePatterns(context, patterns, typeConverter);
- populateIREEToHALPatterns(context, patterns);
- setupIREEToHALLegality(context, target);
- populatePreserveCompilerHintsPatterns(context, patterns);
- setupCompilerHintsLegality(context, target, typeConverter);
-
- // Gather all HAL dialect conversion patterns from custom dialects.
- // These will perform the tensor->buffer mapping for their ops.
- for (auto *conversionInterface : conversionInterfaces) {
- conversionInterface->setupConversionTarget(target, patterns,
- typeConverter);
- }
-
- // NOTE: we allow ops that we don't know about to allow custom dialects
- // that don't need anything HAL-specific to pass through. This is handled by
- // the fallback type legality support of the
- if (failed(applyPartialConversion(getOperation(), target, patterns))) {
- return signalPassFailure();
- }
- }
-};
-
-} // namespace
-
-std::unique_ptr<OperationPass<ModuleOp>> createConvertFlowToHALPass() {
- return std::make_unique<ConvertFlowToHALPass>(); // NOLINT
+void setupFlowToHALLegality(MLIRContext *context,
+ ConversionTarget &conversionTarget,
+ TypeConverter &typeConverter) {
+ conversionTarget.addIllegalDialect<IREE::Flow::FlowDialect>();
}
-static PassRegistration<ConvertFlowToHALPass> pass(
- "iree-convert-flow-to-hal",
- "Convert input flow ops to the IREE HAL dialect");
+// Populates conversion patterns for Flow->HAL.
+void populateFlowToHALPatterns(MLIRContext *context,
+ OwningRewritePatternList &patterns,
+ TypeConverter &typeConverter) {
+ populateFlowStreamToHALPatterns(context, patterns, typeConverter);
+ populateFlowTensorToHALPatterns(context, patterns, typeConverter);
+ populateFlowVariableToHALPatterns(context, patterns, typeConverter);
+ populateHalBufferViewShapePatterns(context, patterns, typeConverter);
+}
} // namespace iree_compiler
} // namespace mlir
diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertFlowToHAL.h b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertFlowToHAL.h
index e74941d..1be74d2 100644
--- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertFlowToHAL.h
+++ b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertFlowToHAL.h
@@ -21,8 +21,16 @@
namespace mlir {
namespace iree_compiler {
-// Converts flow streams to command buffer recording ops.
-std::unique_ptr<OperationPass<ModuleOp>> createConvertFlowToHALPass();
+// Adds op legality rules to |conversionTarget| to ensure all incoming flow ops
+// are removed during Flow->HAL lowering.
+void setupFlowToHALLegality(MLIRContext *context,
+ ConversionTarget &conversionTarget,
+ TypeConverter &typeConverter);
+
+// Populates conversion patterns for Flow->HAL.
+void populateFlowToHALPatterns(MLIRContext *context,
+ OwningRewritePatternList &patterns,
+ TypeConverter &typeConverter);
} // namespace iree_compiler
} // namespace mlir
diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertStructuralOps.cpp b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertStructuralOps.cpp
deleted file mode 100644
index 6e560af..0000000
--- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertStructuralOps.cpp
+++ /dev/null
@@ -1,87 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
-#include "iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertFlowToHAL.h"
-#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
-#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
-#include "iree/compiler/Dialect/IREE/IR/IREETypes.h"
-#include "llvm/ADT/DenseMap.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/BlockAndValueMapping.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/Function.h"
-#include "mlir/IR/Matchers.h"
-#include "mlir/IR/Module.h"
-#include "mlir/Transforms/DialectConversion.h"
-
-namespace mlir {
-namespace iree_compiler {
-namespace {
-
-class FuncOpSignatureConversion : public OpConversionPattern<mlir::FuncOp> {
- public:
- FuncOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter)
- : OpConversionPattern(ctx), converter(converter) {}
-
- LogicalResult matchAndRewrite(
- mlir::FuncOp funcOp, llvm::ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- // Convert the input signature types.
- // TODO(benvanik): dynamic shapes by passing in tensor dynamic dims.
- auto originalType = funcOp.getType();
- TypeConverter::SignatureConversion newSignature(
- originalType.getNumInputs());
- for (auto argType : llvm::enumerate(originalType.getInputs())) {
- if (failed(converter.convertSignatureArg(argType.index(), argType.value(),
- newSignature))) {
- return failure();
- }
- }
- SmallVector<Type, 4> newResultTypes;
- if (failed(converter.convertTypes(originalType.getResults(),
- newResultTypes))) {
- return failure();
- }
-
- // Replace function.
- auto newFuncOp = rewriter.cloneWithoutRegions(funcOp);
- newFuncOp.getBlocks().clear();
- rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
- newFuncOp.end());
- newFuncOp.setType(rewriter.getFunctionType(newSignature.getConvertedTypes(),
- newResultTypes));
- if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), converter,
- &newSignature)))
- return failure();
-
- rewriter.eraseOp(funcOp);
- return success();
- }
-
- private:
- TypeConverter &converter;
-};
-
-} // namespace
-
-void populateFlowStructuralToHALPatterns(MLIRContext *context,
- OwningRewritePatternList &patterns,
- TypeConverter &converter) {
- patterns.insert<FuncOpSignatureConversion>(context, converter);
-}
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertVariableOps.cpp b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertVariableOps.cpp
index 40053dd..195b7a0 100644
--- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertVariableOps.cpp
+++ b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertVariableOps.cpp
@@ -70,6 +70,7 @@
// an appropriate HAL Buffer-based initializer.
if (auto initialValueElements =
variableOp.initial_valueAttr().dyn_cast_or_null<ElementsAttr>()) {
+ rewriter.setInsertionPointAfter(variableOp);
auto initializerFunc = createInitializerFromImmediate(
variableOp, initialValueElements, rewriter);
initializer = initializerFunc.getName();
diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/constant_ops.mlir b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/constant_ops.mlir
new file mode 100644
index 0000000..fa4fe64
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/constant_ops.mlir
@@ -0,0 +1,65 @@
+// RUN: iree-opt -split-input-file -iree-convert-to-hal -verify-diagnostics %s | IreeFileCheck %s
+
+// CHECK-LABEL: hal.variable @var_i32 mutable : !hal.buffer
+flow.variable @var_i32 mutable : tensor<i32>
+func @fn() {
+ // CHECK: %[[V:.+]] = hal.variable.load @var_i32 : !hal.buffer
+ %0 = flow.variable.load @var_i32 : tensor<i32>
+ // CHECK-NEXT: hal.variable.store %[[V]], @var_i32 : !hal.buffer
+ flow.variable.store %0, @var_i32 : tensor<i32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: hal.variable @var_i1 mutable : !hal.buffer
+flow.variable @var_i1 mutable : tensor<i1>
+func @fn() {
+ // CHECK: %[[V:.+]] = hal.variable.load @var_i1 : !hal.buffer
+ %0 = flow.variable.load @var_i1 : tensor<i1>
+ // CHECK-NEXT: hal.variable.store %[[V]], @var_i1 : !hal.buffer
+ flow.variable.store %0, @var_i1 : tensor<i1>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: hal.variable @var_indirect mutable : !hal.buffer
+flow.variable @var_indirect mutable : tensor<i32>
+func @fn() {
+ // CHECK: %[[ADDR:.+]] = hal.variable.address @var_indirect
+ %0 = flow.variable.address @var_indirect : !iree.ptr<tensor<i32>>
+ // CHECK-NEXT: %[[VALUE:.+]] = hal.variable.load.indirect %[[ADDR]]
+ %1 = flow.variable.load.indirect %0 : !iree.ptr<tensor<i32>> -> tensor<i32>
+ // CHECK-NEXT: hal.variable.store.indirect %[[VALUE]], %[[ADDR]]
+ flow.variable.store.indirect %1, %0 : tensor<i32> -> !iree.ptr<tensor<i32>>
+ return
+}
+
+// -----
+
+// Checks that an initializer function is generated, used and operates on
+// a hal.buffer (versus tensor).
+// CHECK: hal.variable @var_with_tensor_default
+// CHECK-SAME: init(@__var_with_tensor_default_initializer)
+// CHECK-SAME: : !hal.buffer
+// CHECK-LABEL: func @__var_with_tensor_default_initializer() -> !hal.buffer
+flow.variable @var_with_tensor_default mutable dense<0.000000e+00> : tensor<f32>
+func @fn() {
+ %0 = flow.variable.load @var_with_tensor_default : tensor<f32>
+ flow.variable.store %0, @var_with_tensor_default : tensor<f32>
+ return
+}
+
+// -----
+
+// TODO(b/145839814): It should not be possible to produce a name collision
+// expected-error @+3 {{redefinition of symbol named '__var_with_initializer_initializer'}}
+// expected-note @+1 {{see existing symbol definition here}}
+func @__var_with_initializer_initializer() -> ()
+flow.variable @var_with_initializer mutable dense<0.000000e+00> : tensor<f32>
+func @fn() {
+ %0 = flow.variable.load @var_with_initializer : tensor<f32>
+ flow.variable.store %0, @var_with_initializer : tensor<f32>
+ return
+}
diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/stream_ops.mlir b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/stream_ops.mlir
index 154c8e9..72922dc 100644
--- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/stream_ops.mlir
+++ b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/stream_ops.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt -split-input-file -iree-convert-flow-to-hal -canonicalize %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -iree-convert-to-hal -canonicalize %s | IreeFileCheck %s
hal.executable @ex0 {
hal.interface @interface {
@@ -26,7 +26,7 @@
// CHECK-NEXT: hal.command_buffer.begin %[[CMD]]
%0 = flow.ex.stream.fragment(%arg1 = %cst : index, %arg2 = %arg0 : tensor<128xf32>) -> tensor<128xf32> {
// CHECK-DAG: %[[EXE_LAYOUT:.+]] = hal.executable_layout.lookup
- // CHECK: hal.command_buffer.push_descriptor_set %[[CMD]], %[[EXE_LAYOUT]], set=0, bindings=[0 = (%arg0, %c0, %sz_3), 1 = (%buffer_1, %c0, %sz_4)]
+ // CHECK: hal.command_buffer.push_descriptor_set %[[CMD]], %[[EXE_LAYOUT]], set=0, bindings=[0 = (%arg0, %c0, %c512), 1 = (%[[TMP_BUF]], %c0, %c512)]
// CHECK: hal.command_buffer.dispatch.symbol {{.+}}, @ex0::@vmla::@entry0, workgroup_xyz
// CHECK: hal.command_buffer.execution_barrier
%1 = flow.dispatch @ex0::@entry0[%arg1 : index](%arg2) : (tensor<128xf32>) -> tensor<128xf32>
@@ -47,7 +47,6 @@
// CHECK-LABEL: @tensorUpdate
// CHECK-SAME: (%[[UBUF:.+]]:{{.+}}, %[[TBUF:.+]]:{{.+}})
func @tensorUpdate(%arg0 : tensor<1x1x10xf32>, %arg1 : tensor<5x1x10xf32>) -> tensor<5x1x10xf32> {
- // CHECK: %[[C0:.+]] = constant 0
%c4 = constant 4 : index
%c1 = constant 1 : index
// CHECK: %[[RET_BUF:.+]] = hal.allocator.allocate
@@ -56,11 +55,9 @@
%0 = flow.ex.stream.fragment(%arg2 = %arg0 : tensor<1x1x10xf32>, %arg3 = %arg1 : tensor<5x1x10xf32>, %arg4 = %c4 : index, %arg5 = %c1 : index) -> tensor<5x1x10xf32> {
// TODO(laurenzo): Update these checks to be more precise. The regexes can
// match too much, masking issues.
- // CHECK: %[[UOFF:.+]], %[[ULEN:.+]] = hal.allocator.compute_range %{{.+}}
- // CHECK: %[[TLEN:.+]] = hal.allocator.compute_size %{{.+}}
- // CHECK-NEXT: hal.command_buffer.copy_buffer %[[CMD]], %[[TBUF]], %[[C0]], %[[RET_BUF]], %[[C0]], %[[TLEN]]
+ // CHECK-NEXT: hal.command_buffer.copy_buffer %[[CMD]], %[[TBUF]], %c0, %[[RET_BUF]], %c0, %c200
// CHECK: hal.command_buffer.execution_barrier
- // CHECK-NEXT: hal.command_buffer.copy_buffer %[[CMD]], %[[UBUF]], %[[C0]], %[[RET_BUF]], %[[UOFF]], %[[ULEN]]
+ // CHECK-NEXT: hal.command_buffer.copy_buffer %[[CMD]], %[[UBUF]], %c0, %[[RET_BUF]], %c204, %c40
%1 = flow.tensor.update %arg2, %arg3[%arg4, %arg5, %arg5] : tensor<1x1x10xf32> -> tensor<5x1x10xf32>
flow.return %1 : tensor<5x1x10xf32>
}
@@ -89,14 +86,13 @@
// CHECK-LABEL: func @dispatchWithShapeTies
// CHECK-SAME: (%[[T:.+]]:{{.+}}, %[[BS:.+]]:{{.+}})
func @dispatchWithShapeTies(%arg0: tensor<?x128xf32>, %bs : index) -> tensor<?x128xf32> {
- // CHECK: %[[C128:.+]] = constant 128
%cst = constant 128 : index
// Verify that size computation derives from the passed dynamic index.
- // CHECK: hal.allocator.compute_size %allocator, shape = [%[[BS]], %[[C128]]], element_type = 50331680
+ // CHECK-DAG: %[[BS4:.+]] = muli %[[BS]], %c4 : index
+ // CHECK-DAG: = muli %[[BS4]], %c128 : index
// Verify that an i32 is pushed.
// CHECK: %[[CAST_BS:.+]] = index_cast %[[BS]] : index to i32
// CHECK: hal.command_buffer.push_constants %[[UNUSED0:.+]], %[[UNUSED1:.+]], offset = 0, values = [%[[CAST_BS]]] : i32
- // CHECK: %[[ALLOCATOR0:.+]] = hal.buffer.allocator %[[T]] : !hal.allocator
// Note that multiple dispatches in the stream verifies that transient
// allocation is covering all ops.
%0 = flow.ex.stream.fragment(%arg1 = %cst : index, %arg2 = %arg0 : tensor<?x128xf32>, %arg3 = %bs : index) -> tensor<?x128xf32> {
diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/tensor_ops.mlir b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/tensor_ops.mlir
index 89bd590..4723331 100644
--- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/tensor_ops.mlir
+++ b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/tensor_ops.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt -split-input-file -iree-convert-flow-to-hal %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -iree-convert-to-hal %s | IreeFileCheck %s
// CHECK-LABEL: @constantTensor
func @constantTensor() {
diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/variable_ops.mlir b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/variable_ops.mlir
index 2c46543..284199c 100644
--- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/variable_ops.mlir
+++ b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/variable_ops.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt -split-input-file -iree-convert-flow-to-hal -verify-diagnostics %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -iree-convert-to-hal -verify-diagnostics %s | IreeFileCheck %s
// CHECK-LABEL: hal.variable @var_i32 mutable : !hal.buffer
flow.variable @var_i32 mutable : tensor<i32>
@@ -39,10 +39,10 @@
// -----
// Checks that an initializer function is generated, used and operates on
// a hal.buffer (versus tensor).
-// CHECK-LABEL: func @__var_with_tensor_initializer_initializer() -> !hal.buffer
// CHECK: hal.variable @var_with_tensor_initializer
// CHECK-SAME: init(@__var_with_tensor_initializer_initializer)
// CHECK-SAME: : !hal.buffer
+// CHECK-LABEL: func @__var_with_tensor_initializer_initializer() -> !hal.buffer
flow.variable @var_with_tensor_initializer mutable dense<0.000000e+00> : tensor<f32>
func @fn() {
%0 = flow.variable.load @var_with_tensor_initializer : tensor<f32>
diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToHAL/BUILD b/iree/compiler/Dialect/HAL/Conversion/HALToHAL/BUILD
new file mode 100644
index 0000000..853e495
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Conversion/HALToHAL/BUILD
@@ -0,0 +1,41 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "HALToHAL",
+ srcs = [
+ "ConvertConstantOps.cpp",
+ "ConvertHALToHAL.cpp",
+ ],
+ hdrs = [
+ "ConvertHALToHAL.h",
+ ],
+ deps = [
+ "//iree/compiler/Dialect/HAL/Conversion",
+ "//iree/compiler/Dialect/HAL/IR",
+ "//iree/compiler/Dialect/HAL/Utils",
+ "//iree/compiler/Dialect/IREE/IR",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:StandardOps",
+ "@llvm-project//mlir:Transforms",
+ ],
+)
diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToHAL/CMakeLists.txt b/iree/compiler/Dialect/HAL/Conversion/HALToHAL/CMakeLists.txt
new file mode 100644
index 0000000..10816c8
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Conversion/HALToHAL/CMakeLists.txt
@@ -0,0 +1,36 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+iree_add_all_subdirs()
+
+iree_cc_library(
+ NAME
+ HALToHAL
+ HDRS
+ "ConvertHALToHAL.h"
+ SRCS
+ "ConvertConstantOps.cpp"
+ "ConvertHALToHAL.cpp"
+ DEPS
+ LLVMSupport
+ MLIRIR
+ MLIRPass
+ MLIRStandard
+ MLIRTransforms
+ iree::compiler::Dialect::HAL::Conversion
+ iree::compiler::Dialect::HAL::IR
+ iree::compiler::Dialect::HAL::Utils
+ iree::compiler::Dialect::IREE::IR
+ PUBLIC
+)
diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToHAL/ConvertConstantOps.cpp b/iree/compiler/Dialect/HAL/Conversion/HALToHAL/ConvertConstantOps.cpp
new file mode 100644
index 0000000..68516a5
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Conversion/HALToHAL/ConvertConstantOps.cpp
@@ -0,0 +1,56 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
+#include "iree/compiler/Dialect/IREE/IR/IREETypes.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace {
+
+class ConstantSubspanConversion
+ : public OpConversionPattern<IREE::HAL::ConstantSubspanOp> {
+ public:
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult matchAndRewrite(
+ IREE::HAL::ConstantSubspanOp op, llvm::ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto bufferValue = rewriter.createOrFold<IREE::HAL::VariableLoadOp>(
+ op.getLoc(), IREE::HAL::BufferType::get(rewriter.getContext()),
+ op.runtime_buffer().getLeafReference());
+ auto offsetValue = rewriter.createOrFold<mlir::ConstantOp>(
+ op.getLoc(), op.runtime_range().offsetAttr());
+ auto lengthValue = rewriter.createOrFold<mlir::ConstantOp>(
+ op.getLoc(), op.runtime_range().lengthAttr());
+ rewriter.replaceOpWithNewOp<IREE::HAL::BufferSubspanOp>(
+ op, IREE::HAL::BufferType::get(rewriter.getContext()), bufferValue,
+ offsetValue, lengthValue);
+ return success();
+ }
+};
+
+} // namespace
+
+void populateHALConstantToHALPatterns(MLIRContext *context,
+ OwningRewritePatternList &patterns,
+ TypeConverter &typeConverter) {
+ patterns.insert<ConstantSubspanConversion>(typeConverter, context);
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToHAL/ConvertHALToHAL.cpp b/iree/compiler/Dialect/HAL/Conversion/HALToHAL/ConvertHALToHAL.cpp
new file mode 100644
index 0000000..17ec9dd
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Conversion/HALToHAL/ConvertHALToHAL.cpp
@@ -0,0 +1,41 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/compiler/Dialect/HAL/Conversion/HALToHAL/ConvertHALToHAL.h"
+
+#include "iree/compiler/Dialect/HAL/Conversion/TypeConverter.h"
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+// Populates only the hal.constant.* conversion patterns.
+void populateHALConstantToHALPatterns(MLIRContext *context,
+ OwningRewritePatternList &patterns,
+ TypeConverter &converter);
+
+void setupHALToHALLegality(MLIRContext *context,
+ ConversionTarget &conversionTarget,
+ TypeConverter &typeConverter) {
+ conversionTarget.addIllegalOp<IREE::HAL::ConstantSubspanOp>();
+}
+
+void populateHALToHALPatterns(MLIRContext *context,
+ OwningRewritePatternList &patterns,
+ TypeConverter &typeConverter) {
+ populateHALConstantToHALPatterns(context, patterns, typeConverter);
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToHAL/ConvertHALToHAL.h b/iree/compiler/Dialect/HAL/Conversion/HALToHAL/ConvertHALToHAL.h
new file mode 100644
index 0000000..9b3470d
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Conversion/HALToHAL/ConvertHALToHAL.h
@@ -0,0 +1,39 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_COMPILER_DIALECT_HAL_CONVERSION_HALTOHAL_CONVERTHALTOHAL_H_
+#define IREE_COMPILER_DIALECT_HAL_CONVERSION_HALTOHAL_CONVERTHALTOHAL_H_
+
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+// Adds op legality rules to |conversionTarget| to ensure all incoming HAL
+// pseudo ops are removed during HAL->HAL lowering.
+void setupHALToHALLegality(MLIRContext *context,
+ ConversionTarget &conversionTarget,
+ TypeConverter &typeConverter);
+
+// Populates conversion patterns for HAL->HAL (pseudo ops, etc).
+void populateHALToHALPatterns(MLIRContext *context,
+ OwningRewritePatternList &patterns,
+ TypeConverter &typeConverter);
+
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_DIALECT_HAL_CONVERSION_HALTOHAL_CONVERTHALTOHAL_H_
diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/BUILD b/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/BUILD
new file mode 100644
index 0000000..b780b11
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/BUILD
@@ -0,0 +1,30 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("//iree:lit_test.bzl", "iree_lit_test_suite")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_lit_test_suite(
+ name = "lit",
+ srcs = glob(["*.mlir"]),
+ data = [
+ "//iree/tools:IreeFileCheck",
+ "//iree/tools:iree-opt",
+ ],
+)
diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/CMakeLists.txt b/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/CMakeLists.txt
new file mode 100644
index 0000000..fcc538b
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/CMakeLists.txt
@@ -0,0 +1,26 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+iree_add_all_subdirs()
+
+file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir)
+iree_lit_test_suite(
+ NAME
+ lit
+ SRCS
+ "${_GLOB_X_MLIR}"
+ DATA
+ iree::tools::IreeFileCheck
+ iree::tools::iree-opt
+)
diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/constant_ops.mlir b/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/constant_ops.mlir
new file mode 100644
index 0000000..1e43084
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/constant_ops.mlir
@@ -0,0 +1,12 @@
+// RUN: iree-opt -split-input-file -iree-convert-to-hal -verify-diagnostics %s | IreeFileCheck %s
+
+// CHECK-LABEL: func @constant_subspan
+func @constant_subspan() {
+ // CHECK-DAG: [[BUFFER:%.+]] = hal.variable.load @pool_buffer : !hal.buffer
+ // CHECK-DAG: [[OFFSET:%.+]] = constant 123 : index
+ // CHECK-DAG: [[LENGTH:%.+]] = constant 16 : index
+ // CHECK-NEXT: = hal.buffer.subspan [[BUFFER]], [[OFFSET]], [[LENGTH]] : !hal.buffer
+ %cst0 = hal.constant.subspan @pool_buffer[#hal.byte_range<123, 16>] : tensor<4xf32>
+ return
+}
+hal.variable @pool_buffer : !hal.buffer
diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToVM/BUILD b/iree/compiler/Dialect/HAL/Conversion/HALToVM/BUILD
index 72c80e7..87259ce 100644
--- a/iree/compiler/Dialect/HAL/Conversion/HALToVM/BUILD
+++ b/iree/compiler/Dialect/HAL/Conversion/HALToVM/BUILD
@@ -25,6 +25,7 @@
"ConvertBufferOps.cpp",
"ConvertBufferViewOps.cpp",
"ConvertCommandBufferOps.cpp",
+ "ConvertConstantOps.cpp",
"ConvertControlFlowOps.cpp",
"ConvertDeviceOps.cpp",
"ConvertExecutableOps.cpp",
diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToVM/CMakeLists.txt b/iree/compiler/Dialect/HAL/Conversion/HALToVM/CMakeLists.txt
index f12e7c0..9b6143a 100644
--- a/iree/compiler/Dialect/HAL/Conversion/HALToVM/CMakeLists.txt
+++ b/iree/compiler/Dialect/HAL/Conversion/HALToVM/CMakeLists.txt
@@ -24,6 +24,7 @@
"ConvertBufferOps.cpp"
"ConvertBufferViewOps.cpp"
"ConvertCommandBufferOps.cpp"
+ "ConvertConstantOps.cpp"
"ConvertControlFlowOps.cpp"
"ConvertDeviceOps.cpp"
"ConvertExecutableOps.cpp"
diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertAllocatorOps.cpp b/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertAllocatorOps.cpp
index 105d845..1363dc6 100644
--- a/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertAllocatorOps.cpp
+++ b/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertAllocatorOps.cpp
@@ -22,82 +22,36 @@
namespace iree_compiler {
namespace {
-class AllocatorAllocateConstOpConversion
- : public OpConversionPattern<IREE::HAL::AllocatorAllocateConstOp> {
+class AllocatorMapOpConversion
+ : public OpConversionPattern<IREE::HAL::AllocatorMapOp> {
public:
- AllocatorAllocateConstOpConversion(MLIRContext *context,
- SymbolTable &importSymbols,
- TypeConverter &typeConverter,
- StringRef importName)
- : OpConversionPattern(context) {
- importOp = importSymbols.lookup<IREE::VM::ImportOp>(importName);
- assert(importOp);
+ AllocatorMapOpConversion(TypeConverter &typeConverter, MLIRContext *context,
+ SymbolTable &importSymbols)
+ : OpConversionPattern(typeConverter, context) {
+ wrapByteBufferImportOp = importSymbols.lookup<IREE::VM::ImportOp>(
+ "hal.allocator.wrap.byte_buffer");
+ assert(wrapByteBufferImportOp);
}
LogicalResult matchAndRewrite(
- IREE::HAL::AllocatorAllocateConstOp op, llvm::ArrayRef<Value> operands,
+ IREE::HAL::AllocatorMapOp op, llvm::ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- // Encode constant data into a rodata segment. These will eventually get
- // deduped and combined.
- auto ip = rewriter.saveInsertionPoint();
- auto parentFuncOp = op.getParentOfType<IREE::VM::FuncOp>();
- rewriter.setInsertionPoint(parentFuncOp);
- auto constName = (parentFuncOp.getName() + "_const_" +
- std::to_string(allocateUniqueId(parentFuncOp)))
- .str();
- auto rodataOp =
- rewriter.create<IREE::VM::RodataOp>(op.getLoc(), constName, op.value());
- rewriter.restoreInsertionPoint(ip);
- auto loadRodataOp =
- rewriter.create<IREE::VM::ConstRefRodataOp>(op.getLoc(), rodataOp);
-
- IREE::HAL::AllocatorAllocateConstOp::Adaptor opAdaptor(operands);
- auto shape = IREE::HAL::getStaticShapeDims(op.getLoc(),
- op.value().getType(), rewriter);
- SmallVector<Value, 8> callOperands = {
- opAdaptor.allocator(),
- rewriter.create<mlir::ConstantOp>(
- op.getLoc(), rewriter.getI32IntegerAttr(
- static_cast<int32_t>(op.memory_types()))),
- rewriter.create<mlir::ConstantOp>(
- op.getLoc(), rewriter.getI32IntegerAttr(
- static_cast<int32_t>(op.buffer_usage()))),
- };
- callOperands.append(shape.begin(), shape.end());
- callOperands.push_back(rewriter.create<mlir::ConstantOp>(
- op.getLoc(),
- IREE::HAL::getElementTypeAttr(op.value().getType().getElementType())));
- callOperands.push_back(loadRodataOp.getResult());
- SmallVector<int16_t, 6> segmentSizes = {
- /*allocator=*/-1,
- /*memory_types=*/-1,
- /*buffer_usage=*/-1,
- /*shape=*/static_cast<int16_t>(shape.size()),
- /*element_type=*/-1,
- /*value=*/-1,
- };
-
- auto importType = importOp.getType();
- rewriter.replaceOpWithNewOp<IREE::VM::CallVariadicOp>(
- op, rewriter.getSymbolRefAttr(importOp), importType.getResults(),
- segmentSizes, importType.getInputs(), callOperands);
+ IREE::HAL::AllocatorMapOp::Adaptor opAdaptor(operands);
+ rewriter.replaceOpWithNewOp<IREE::VM::CallOp>(
+ op, wrapByteBufferImportOp.getName(),
+ ArrayRef<Type>{getTypeConverter()->convertType(op.getType())},
+ ArrayRef<Value>{opAdaptor.allocator(),
+ rewriter.createOrFold<IREE::VM::ConstI32Op>(
+ op.getLoc(), op.memory_typesAttr()),
+ rewriter.createOrFold<IREE::VM::ConstI32Op>(
+ op.getLoc(), op.buffer_usageAttr()),
+ opAdaptor.source(), opAdaptor.offset(),
+ opAdaptor.length()});
return success();
}
private:
- // TODO(b/145839814): find a name that's unique or make the rewriter support
- // assigning unique names.
- int allocateUniqueId(Operation *context) const {
- if (uniqueContext != context) {
- uniqueContext = context;
- uniqueCounter = 0;
- }
- return uniqueCounter++;
- }
- mutable Operation *uniqueContext = nullptr;
- mutable int uniqueCounter = 0;
-
- mutable IREE::VM::ImportOp importOp;
+ mutable IREE::VM::ImportOp wrapByteBufferImportOp;
};
} // namespace
@@ -106,16 +60,10 @@
SymbolTable &importSymbols,
TypeConverter &typeConverter,
OwningRewritePatternList &patterns) {
- patterns.insert<VMImportOpConversion<IREE::HAL::AllocatorComputeSizeOp>>(
- context, importSymbols, typeConverter, "hal.allocator.compute_size");
- patterns.insert<VMImportOpConversion<IREE::HAL::AllocatorComputeOffsetOp>>(
- context, importSymbols, typeConverter, "hal.allocator.compute_offset");
- patterns.insert<VMImportOpConversion<IREE::HAL::AllocatorComputeRangeOp>>(
- context, importSymbols, typeConverter, "hal.allocator.compute_range");
patterns.insert<VMImportOpConversion<IREE::HAL::AllocatorAllocateOp>>(
context, importSymbols, typeConverter, "hal.allocator.allocate");
- patterns.insert<AllocatorAllocateConstOpConversion>(
- context, importSymbols, typeConverter, "hal.allocator.allocate.const");
+ patterns.insert<AllocatorMapOpConversion>(typeConverter, context,
+ importSymbols);
}
} // namespace iree_compiler
diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertConstantOps.cpp b/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertConstantOps.cpp
new file mode 100644
index 0000000..e1e4460
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertConstantOps.cpp
@@ -0,0 +1,74 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
+#include "iree/compiler/Dialect/IREE/IR/IREETypes.h"
+#include "iree/compiler/Dialect/VM/IR/VMOps.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace {
+
+class ConstantPoolOpConversion
+ : public OpConversionPattern<IREE::HAL::ConstantPoolOp> {
+ public:
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult matchAndRewrite(
+ IREE::HAL::ConstantPoolOp op, llvm::ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ for (auto storageOp : op.getOps<IREE::HAL::ConstantStorageOp>()) {
+ auto rodataName = (op.sym_name() + storageOp.sym_name()).str();
+ auto rodataOp = rewriter.create<IREE::VM::RodataOp>(
+ storageOp.getLoc(), rodataName, storageOp.value());
+ SymbolTable::setSymbolVisibility(rodataOp,
+ SymbolTable::Visibility::Private);
+ }
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
+class ConstantStorageLookupOpConversion
+ : public OpConversionPattern<IREE::HAL::ConstantStorageLookupOp> {
+ public:
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult matchAndRewrite(
+ IREE::HAL::ConstantStorageLookupOp op, llvm::ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ // I don't like this, but I can't figure out what to do.
+ // Matches the logic above.
+ auto rodataName =
+ (op.constant().getRootReference() + op.constant().getLeafReference())
+ .str();
+ rewriter.replaceOpWithNewOp<IREE::VM::ConstRefRodataOp>(op, rodataName);
+ return success();
+ }
+};
+
+} // namespace
+
+void populateHALConstantToVMPatterns(MLIRContext *context,
+ SymbolTable &importSymbols,
+ TypeConverter &typeConverter,
+ OwningRewritePatternList &patterns) {
+ patterns.insert<ConstantPoolOpConversion, ConstantStorageLookupOpConversion>(
+ typeConverter, context);
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertHALToVM.cpp b/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertHALToVM.cpp
index e2a3677..759e7fc 100644
--- a/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertHALToVM.cpp
+++ b/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertHALToVM.cpp
@@ -49,6 +49,10 @@
extern void populateHALCommandBufferToVMPatterns(
MLIRContext *context, SymbolTable &importSymbols,
TypeConverter &typeConverter, OwningRewritePatternList &patterns);
+extern void populateHALConstantToVMPatterns(MLIRContext *context,
+ SymbolTable &importSymbols,
+ TypeConverter &typeConverter,
+ OwningRewritePatternList &patterns);
extern void populateHALControlFlowToVMPatterns(
MLIRContext *context, SymbolTable &importSymbols,
TypeConverter &typeConverter, OwningRewritePatternList &patterns);
@@ -81,6 +85,8 @@
patterns);
populateHALCommandBufferToVMPatterns(context, importSymbols, typeConverter,
patterns);
+ populateHALConstantToVMPatterns(context, importSymbols, typeConverter,
+ patterns);
populateHALControlFlowToVMPatterns(context, importSymbols, typeConverter,
patterns);
populateHALDeviceToVMPatterns(context, importSymbols, typeConverter,
diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/allocator_ops.mlir b/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/allocator_ops.mlir
index d00d4f0..2541eb8 100644
--- a/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/allocator_ops.mlir
+++ b/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/allocator_ops.mlir
@@ -1,9 +1,10 @@
-// RUN: iree-opt -split-input-file -iree-convert-hal-to-vm %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -canonicalize -iree-convert-hal-to-vm %s | IreeFileCheck %s
-// CHECK-LABEL: @allocatorComputeSize
-func @allocatorComputeSize(%arg0 : !hal.allocator) -> index {
+// CHECK-LABEL: @allocatorComputeSizeFoldsAway
+func @allocatorComputeSizeFoldsAway(%arg0 : !hal.allocator) -> index {
+ // CHECK: %c4194304 = vm.const.i32 4194304 : i32
+ // CHECK-NOT: hal.allocator.compute_size
%c1024 = constant 1024 : index
- // CHECK: %0 = vm.call.variadic @hal.allocator.compute_size(%arg0, [%c1024, %c1024], %c32) : (!vm.ref<!hal.allocator>, i32 ..., i32) -> i32
%0 = hal.allocator.compute_size %arg0, shape=[%c1024, %c1024], element_type=32
return %0 : index
}
@@ -20,11 +21,11 @@
// -----
-// CHECK: vm.rodata @allocatorAllocateConst_const_0 dense<123> : tensor<4x4xi32>
-// CHECK-LABEL: func @allocatorAllocateConst
-func @allocatorAllocateConst(%arg0 : !hal.allocator) -> !hal.buffer {
- // CHECK: %allocatorAllocateConst_const_0 = vm.const.ref.rodata @allocatorAllocateConst_const_0 : !vm.ref<!iree.byte_buffer>
- // CHECK: %ref = vm.call.variadic @hal.allocator.allocate.const(%arg0, %c6, %c2, [%c4, %c4_0], %c16777248, %allocatorAllocateConst_const_0) : (!vm.ref<!hal.allocator>, i32, i32, i32 ..., i32, !vm.ref<!iree.byte_buffer>) -> !vm.ref<!hal.buffer>
- %buffer = hal.allocator.allocate.const %arg0, "HostVisible|HostCoherent", "Transfer" : !hal.buffer = dense<123> : tensor<4x4xi32>
+// CHECK-LABEL: func @allocatorMapByteBuffer
+func @allocatorMapByteBuffer(%arg0 : !hal.allocator, %arg1 : !iree.byte_buffer) -> !hal.buffer {
+ %offset = constant 128 : index
+ %length = constant 256 : index
+ // CHECK: = vm.call @hal.allocator.wrap.byte_buffer(%arg0, %c6, %c2, %arg1, %c128, %c256) : (!vm.ref<!hal.allocator>, i32, i32, !vm.ref<!iree.byte_buffer>, i32, i32) -> !vm.ref<!hal.buffer>
+ %buffer = hal.allocator.map %arg0, "HostVisible|HostCoherent", "Transfer", %arg1[%offset, %length] : !iree.byte_buffer -> !hal.buffer
return %buffer : !hal.buffer
}
diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/constant_ops.mlir b/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/constant_ops.mlir
new file mode 100644
index 0000000..17dae9c
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/constant_ops.mlir
@@ -0,0 +1,52 @@
+// RUN: iree-opt -split-input-file -iree-convert-hal-to-vm %s | IreeFileCheck %s
+
+// CHECK: vm.rodata @pool_storage0 dense<[102, 102, 6, 64, -51, -52, 76, 64, -102, -103, -119, 64, -51, -52, -84, 64]> : vector<16xi8>
+// CHECK: vm.rodata @pool_storage1 dense<[6, 7, 8, 0]> : vector<4xi8>
+hal.constant_pool @pool attributes {buffer_constraints = #hal.buffer_constraints<max_allocation_size = 1073741824, min_buffer_offset_alignment = 32, max_buffer_range = 134217728, min_buffer_range_alignment = 4>} {
+ hal.constant_pool.span @cst0 : tensor<4xf32> = @_storage0[#hal.byte_range<0, 16>] -> @pool_storage0_buffer[#hal.byte_range<0, 16>]
+ hal.constant_pool.span @cst1 : tensor<3xi8> = @_storage1[#hal.byte_range<0, 3>] -> @pool_storage1_buffer[#hal.byte_range<0, 3>]
+ hal.constant_pool.splat @cst2 = dense<1.000000e+00> : tensor<1xf32> -> @pool_splats[#hal.byte_range<0, 4>]
+ hal.constant_pool.splat @cst3 = dense<1234567890> : tensor<8xi32> -> @pool_splats[#hal.byte_range<32, 32>]
+ hal.constant_storage @_storage0 = dense<[102, 102, 6, 64, -51, -52, 76, 64, -102, -103, -119, 64, -51, -52, -84, 64]> : vector<16xi8>
+ hal.constant_storage @_storage1 = dense<[6, 7, 8, 0]> : vector<4xi8>
+}
+
+// CHECK: vm.global.ref @pool_storage0_buffer init(@pool_storage0_buffer_initializer) : !vm.ref<!hal.buffer>
+hal.variable @pool_storage0_buffer init(@pool_storage0_buffer_initializer) : !hal.buffer attributes {sym_visibility = "private"}
+// CHECK: vm.func @pool_storage0_buffer_initializer() -> !vm.ref<!hal.buffer>
+func @pool_storage0_buffer_initializer() -> !hal.buffer attributes {sym_visibility = "private"} {
+ %c0 = constant 0 : index
+ %c16 = constant 16 : index
+ %dev = hal.ex.shared_device : !hal.device
+ %allocator = hal.device.allocator %dev : !hal.allocator
+ // CHECK: [[STORAGE_REF:%.+]] = vm.const.ref.rodata @pool_storage0 : !vm.ref<!iree.byte_buffer>
+ %storage = hal.constant_storage.lookup @pool::@_storage0 : !iree.byte_buffer
+ // CHECK: = vm.call @hal.allocator.wrap.byte_buffer({{.+}}, %c22, %c15, [[STORAGE_REF]], %zero, %c16)
+ %mapped = hal.allocator.map %allocator, "HostVisible|HostCoherent|DeviceVisible", "Constant|Transfer|Mapping|Dispatch", %storage[%c0, %c16] : !iree.byte_buffer -> !hal.buffer
+ return %mapped : !hal.buffer
+}
+
+// CHECK: vm.global.ref @pool_storage1_buffer init(@pool_storage1_buffer_initializer) : !vm.ref<!hal.buffer>
+hal.variable @pool_storage1_buffer init(@pool_storage1_buffer_initializer) : !hal.buffer attributes {sym_visibility = "private"}
+func @pool_storage1_buffer_initializer() -> !hal.buffer attributes {sym_visibility = "private"}
+
+// CHECK: vm.global.ref @pool_splats init(@pool_splats_initializer) : !vm.ref<!hal.buffer>
+hal.variable @pool_splats init(@pool_splats_initializer) : !hal.buffer attributes {sym_visibility = "private"}
+// CHECK: vm.func @pool_splats_initializer() -> !vm.ref<!hal.buffer>
+func @pool_splats_initializer() -> !hal.buffer attributes {sym_visibility = "private"} {
+ %c64 = constant 64 : index
+ %c0 = constant 0 : index
+ %c4 = constant 4 : index
+ %c1065353216_i32 = constant 1065353216 : i32
+ %c32 = constant 32 : index
+ %c1234567890_i32 = constant 1234567890 : i32
+ %dev = hal.ex.shared_device : !hal.device
+ %allocator = hal.device.allocator %dev : !hal.allocator
+ // CHECK: [[BUFFER:%.+]] = vm.call @hal.allocator.allocate({{.+}}, %c50, %c15, %c64)
+ %buffer = hal.allocator.allocate %allocator, "HostVisible|DeviceVisible|DeviceLocal", "Constant|Transfer|Mapping|Dispatch", %c64 : !hal.buffer
+ // CHECK: vm.call @hal.buffer.fill([[BUFFER]], %zero, %c4, %c1065353216)
+ hal.buffer.fill %buffer, %c0, %c4, %c1065353216_i32
+ // CHECK: vm.call @hal.buffer.fill([[BUFFER]], %c32, %c32, %c1234567890)
+ hal.buffer.fill %buffer, %c32, %c32, %c1234567890_i32
+ return %buffer : !hal.buffer
+}
diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/variable_ops.mlir b/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/variable_ops.mlir
index 5ff3018..acaf1c1 100644
--- a/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/variable_ops.mlir
+++ b/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/variable_ops.mlir
@@ -1,7 +1,7 @@
// RUN: iree-opt -split-input-file -iree-convert-hal-to-vm %s | IreeFileCheck %s
// CHECK: vm.global.i32 @v_initialized_const 4 : i32
-hal.variable @v_initialized_const 4 : i32
+hal.variable @v_initialized_const = 4 : i32
// -----
diff --git a/iree/compiler/Dialect/HAL/Conversion/IREEToHAL/BUILD b/iree/compiler/Dialect/HAL/Conversion/IREEToHAL/BUILD
new file mode 100644
index 0000000..9cbaaa5
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Conversion/IREEToHAL/BUILD
@@ -0,0 +1,36 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "IREEToHAL",
+ srcs = [
+ "ConvertIREEToHAL.cpp",
+ ],
+ hdrs = [
+ "ConvertIREEToHAL.h",
+ ],
+ deps = [
+ "//iree/compiler/Dialect/HAL/IR",
+ "//iree/compiler/Dialect/IREE/IR",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:StandardOps",
+ "@llvm-project//mlir:Transforms",
+ ],
+)
diff --git a/iree/compiler/Dialect/HAL/Conversion/IREEToHAL/CMakeLists.txt b/iree/compiler/Dialect/HAL/Conversion/IREEToHAL/CMakeLists.txt
new file mode 100644
index 0000000..21f3b76
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Conversion/IREEToHAL/CMakeLists.txt
@@ -0,0 +1,31 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+iree_add_all_subdirs()
+
+iree_cc_library(
+ NAME
+ IREEToHAL
+ HDRS
+ "ConvertIREEToHAL.h"
+ SRCS
+ "ConvertIREEToHAL.cpp"
+ DEPS
+ MLIRIR
+ MLIRStandard
+ MLIRTransforms
+ iree::compiler::Dialect::HAL::IR
+ iree::compiler::Dialect::IREE::IR
+ PUBLIC
+)
diff --git a/iree/compiler/Dialect/IREE/Conversion/ConvertToHAL.cpp b/iree/compiler/Dialect/HAL/Conversion/IREEToHAL/ConvertIREEToHAL.cpp
similarity index 96%
rename from iree/compiler/Dialect/IREE/Conversion/ConvertToHAL.cpp
rename to iree/compiler/Dialect/HAL/Conversion/IREEToHAL/ConvertIREEToHAL.cpp
index 6696ca0..dfa3657 100644
--- a/iree/compiler/Dialect/IREE/Conversion/ConvertToHAL.cpp
+++ b/iree/compiler/Dialect/HAL/Conversion/IREEToHAL/ConvertIREEToHAL.cpp
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "iree/compiler/Dialect/IREE/Conversion/ConvertToHAL.h"
+#include "iree/compiler/Dialect/HAL/Conversion/IREEToHAL/ConvertIREEToHAL.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
diff --git a/iree/compiler/Dialect/IREE/Conversion/ConvertToHAL.h b/iree/compiler/Dialect/HAL/Conversion/IREEToHAL/ConvertIREEToHAL.h
similarity index 86%
rename from iree/compiler/Dialect/IREE/Conversion/ConvertToHAL.h
rename to iree/compiler/Dialect/HAL/Conversion/IREEToHAL/ConvertIREEToHAL.h
index 7f57558..d14eba6 100644
--- a/iree/compiler/Dialect/IREE/Conversion/ConvertToHAL.h
+++ b/iree/compiler/Dialect/HAL/Conversion/IREEToHAL/ConvertIREEToHAL.h
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#ifndef IREE_COMPILER_DIALECT_IREE_CONVERSION_CONVERTTOHAL_H_
-#define IREE_COMPILER_DIALECT_IREE_CONVERSION_CONVERTTOHAL_H_
+#ifndef IREE_COMPILER_DIALECT_HAL_CONVERSION_CONVERTIREETOHAL_H_
+#define IREE_COMPILER_DIALECT_HAL_CONVERSION_CONVERTIREETOHAL_H_
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/PatternMatch.h"
@@ -36,4 +36,4 @@
} // namespace iree_compiler
} // namespace mlir
-#endif // IREE_COMPILER_DIALECT_IREE_CONVERSION_CONVERTTOHAL_H_
+#endif // IREE_COMPILER_DIALECT_HAL_CONVERSION_CONVERTIREETOHAL_H_
diff --git a/iree/compiler/Dialect/HAL/Conversion/IREEToHAL/test/BUILD b/iree/compiler/Dialect/HAL/Conversion/IREEToHAL/test/BUILD
new file mode 100644
index 0000000..b780b11
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Conversion/IREEToHAL/test/BUILD
@@ -0,0 +1,30 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("//iree:lit_test.bzl", "iree_lit_test_suite")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_lit_test_suite(
+ name = "lit",
+ srcs = glob(["*.mlir"]),
+ data = [
+ "//iree/tools:IreeFileCheck",
+ "//iree/tools:iree-opt",
+ ],
+)
diff --git a/iree/compiler/Dialect/HAL/Conversion/IREEToHAL/test/CMakeLists.txt b/iree/compiler/Dialect/HAL/Conversion/IREEToHAL/test/CMakeLists.txt
new file mode 100644
index 0000000..fcc538b
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Conversion/IREEToHAL/test/CMakeLists.txt
@@ -0,0 +1,26 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+iree_add_all_subdirs()
+
+file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir)
+iree_lit_test_suite(
+ NAME
+ lit
+ SRCS
+ "${_GLOB_X_MLIR}"
+ DATA
+ iree::tools::IreeFileCheck
+ iree::tools::iree-opt
+)
diff --git a/iree/compiler/Dialect/HAL/Conversion/IREEToHAL/test/shape_constants.mlir b/iree/compiler/Dialect/HAL/Conversion/IREEToHAL/test/shape_constants.mlir
new file mode 100644
index 0000000..a6af406
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Conversion/IREEToHAL/test/shape_constants.mlir
@@ -0,0 +1,11 @@
+// RUN: iree-opt -iree-convert-to-hal %s --split-input-file | IreeFileCheck %s
+
+// CHECK-LABEL: @dynamic_shape_constant
+func @dynamic_shape_constant() {
+ // CHECK: %dev = hal.ex.shared_device
+ // CHECK: %allocator = hal.device.allocator %dev
+ // CHECK: %view = hal.buffer_view.const %allocator, "HostVisible|DeviceVisible|DeviceLocal", "Constant|Transfer|Mapping|Dispatch" : !hal.buffer_view = dense<2> : tensor<2xi32>
+ // CHECK: %[[RES:.+]] = iree.do_not_optimize(%view) : !hal.buffer_view
+ %c = iree.dynamic_shape_constant dense<2> : tensor<2xi32> -> tensor<?xi32>
+ return
+}
diff --git a/iree/compiler/Dialect/HAL/Conversion/Passes.h b/iree/compiler/Dialect/HAL/Conversion/Passes.h
index a78c312..6dacbd3 100644
--- a/iree/compiler/Dialect/HAL/Conversion/Passes.h
+++ b/iree/compiler/Dialect/HAL/Conversion/Passes.h
@@ -23,11 +23,9 @@
std::unique_ptr<OperationPass<ModuleOp>> createConvertHALToVMPass(
IREE::VM::TargetOptions targetOptions);
-std::unique_ptr<OperationPass<ModuleOp>> createConvertFlowToHALPass();
inline void registerHALConversionPasses() {
createConvertHALToVMPass(IREE::VM::getTargetOptionsFromFlags());
- createConvertFlowToHALPass();
}
} // namespace iree_compiler
diff --git a/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/BUILD b/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/BUILD
new file mode 100644
index 0000000..89b7980
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/BUILD
@@ -0,0 +1,42 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "StandardToHAL",
+ srcs = [
+ "ConvertStandardToHAL.cpp",
+ "ConvertStructuralOps.cpp",
+ ],
+ hdrs = [
+ "ConvertStandardToHAL.h",
+ ],
+ deps = [
+ "//iree/compiler/Dialect/HAL/Conversion",
+ "//iree/compiler/Dialect/HAL/IR",
+ "//iree/compiler/Dialect/HAL/IR:HALDialect",
+ "//iree/compiler/Dialect/HAL/Target",
+ "//iree/compiler/Dialect/HAL/Utils",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:StandardOps",
+ "@llvm-project//mlir:Transforms",
+ ],
+)
diff --git a/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/CMakeLists.txt b/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/CMakeLists.txt
new file mode 100644
index 0000000..b4c195c
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/CMakeLists.txt
@@ -0,0 +1,37 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+iree_add_all_subdirs()
+
+iree_cc_library(
+ NAME
+ StandardToHAL
+ HDRS
+ "ConvertStandardToHAL.h"
+ SRCS
+ "ConvertStandardToHAL.cpp"
+ "ConvertStructuralOps.cpp"
+ DEPS
+ LLVMSupport
+ MLIRIR
+ MLIRPass
+ MLIRStandard
+ MLIRTransforms
+ iree::compiler::Dialect::HAL::Conversion
+ iree::compiler::Dialect::HAL::IR
+ iree::compiler::Dialect::HAL::IR::HALDialect
+ iree::compiler::Dialect::HAL::Target
+ iree::compiler::Dialect::HAL::Utils
+ PUBLIC
+)
diff --git a/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/ConvertStandardToHAL.cpp b/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/ConvertStandardToHAL.cpp
new file mode 100644
index 0000000..3785ddb
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/ConvertStandardToHAL.cpp
@@ -0,0 +1,50 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/compiler/Dialect/HAL/Conversion/StandardToHAL/ConvertStandardToHAL.h"
+
+#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+void populateStandardStructuralToHALPatterns(MLIRContext *context,
+ OwningRewritePatternList &patterns,
+ TypeConverter &converter);
+
+void setupStandardToHALLegality(MLIRContext *context,
+ ConversionTarget &conversionTarget,
+ TypeConverter &typeConverter) {
+ conversionTarget.addLegalOp<mlir::ModuleOp, mlir::ModuleTerminatorOp>();
+
+ // We need to rewrite certain types on operands/results so use the default
+ // dynamic legality checker to force any ops using such types to run through
+ // our patterns.
+ conversionTarget.addDynamicallyLegalDialect<mlir::StandardOpsDialect>();
+ conversionTarget.addDynamicallyLegalOp<mlir::FuncOp>([&](mlir::FuncOp op) {
+ return typeConverter.isSignatureLegal(op.getType()) &&
+ typeConverter.isLegal(&op.getBody());
+ });
+}
+
+void populateStandardToHALPatterns(MLIRContext *context,
+ OwningRewritePatternList &patterns,
+ TypeConverter &typeConverter) {
+ populateStandardStructuralToHALPatterns(context, patterns, typeConverter);
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/ConvertStandardToHAL.h b/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/ConvertStandardToHAL.h
new file mode 100644
index 0000000..dd2e555
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/ConvertStandardToHAL.h
@@ -0,0 +1,38 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_COMPILER_DIALECT_HAL_CONVERSION_STANDARDTOHAL_CONVERTSTANDARDTOHAL_H_
+#define IREE_COMPILER_DIALECT_HAL_CONVERSION_STANDARDTOHAL_CONVERTSTANDARDTOHAL_H_
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+// Adds op legality rules to |conversionTarget| to ensure all incoming std ops
+// are removed during ->HAL lowering.
+void setupStandardToHALLegality(MLIRContext *context,
+ ConversionTarget &conversionTarget,
+ TypeConverter &typeConverter);
+
+// Populates conversion patterns for std->HAL.
+void populateStandardToHALPatterns(MLIRContext *context,
+ OwningRewritePatternList &patterns,
+ TypeConverter &typeConverter);
+
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_DIALECT_HAL_CONVERSION_STANDARDTOHAL_CONVERTSTANDARDTOHAL_H_
diff --git a/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/ConvertStructuralOps.cpp b/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/ConvertStructuralOps.cpp
new file mode 100644
index 0000000..565805f
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/ConvertStructuralOps.cpp
@@ -0,0 +1,128 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/compiler/Dialect/HAL/Conversion/StandardToHAL/ConvertStandardToHAL.h"
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
+#include "llvm/ADT/DenseMap.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/Module.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace {
+
+class FuncOpSignatureConversion : public OpConversionPattern<mlir::FuncOp> {
+ public:
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult matchAndRewrite(
+ mlir::FuncOp funcOp, llvm::ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto &typeConverter = *getTypeConverter();
+
+ // Convert the input signature types.
+ // TODO(benvanik): dynamic shapes by passing in tensor dynamic dims.
+ auto originalType = funcOp.getType();
+ TypeConverter::SignatureConversion newSignature(
+ originalType.getNumInputs());
+ for (auto argType : llvm::enumerate(originalType.getInputs())) {
+ if (failed(typeConverter.convertSignatureArg(
+ argType.index(), argType.value(), newSignature))) {
+ return failure();
+ }
+ }
+ SmallVector<Type, 4> newResultTypes;
+ if (failed(typeConverter.convertTypes(originalType.getResults(),
+ newResultTypes))) {
+ return failure();
+ }
+
+ // Replace function.
+ auto newFuncOp = rewriter.cloneWithoutRegions(funcOp);
+ newFuncOp.getBlocks().clear();
+ rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
+ newFuncOp.end());
+ newFuncOp.setType(rewriter.getFunctionType(newSignature.getConvertedTypes(),
+ newResultTypes));
+ if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter,
+ &newSignature))) {
+ return failure();
+ }
+
+ rewriter.eraseOp(funcOp);
+ return success();
+ }
+};
+
+class BranchOpConversion : public OpConversionPattern<mlir::BranchOp> {
+ public:
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult matchAndRewrite(
+ mlir::BranchOp op, llvm::ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ mlir::BranchOpAdaptor adaptor(operands);
+ rewriter.replaceOpWithNewOp<mlir::BranchOp>(op, op.dest(),
+ adaptor.destOperands());
+ return success();
+ }
+};
+
+class CondBranchOpConversion : public OpConversionPattern<mlir::CondBranchOp> {
+ public:
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult matchAndRewrite(
+ mlir::CondBranchOp op, llvm::ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ mlir::CondBranchOpAdaptor adaptor(operands,
+ op.getOperation()->getAttrDictionary());
+ rewriter.replaceOpWithNewOp<mlir::CondBranchOp>(
+ op, adaptor.condition(), op.trueDest(), adaptor.trueDestOperands(),
+ op.falseDest(), op.falseDestOperands());
+ return success();
+ }
+};
+
+class ReturnOpConversion : public OpConversionPattern<mlir::ReturnOp> {
+ public:
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult matchAndRewrite(
+ mlir::ReturnOp returnOp, llvm::ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<mlir::ReturnOp>(returnOp, operands);
+ return success();
+ }
+};
+
+} // namespace
+
+void populateStandardStructuralToHALPatterns(MLIRContext *context,
+ OwningRewritePatternList &patterns,
+ TypeConverter &converter) {
+ patterns.insert<FuncOpSignatureConversion, BranchOpConversion,
+ CondBranchOpConversion, ReturnOpConversion>(converter,
+ context);
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/test/BUILD b/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/test/BUILD
new file mode 100644
index 0000000..b780b11
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/test/BUILD
@@ -0,0 +1,30 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("//iree:lit_test.bzl", "iree_lit_test_suite")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_lit_test_suite(
+ name = "lit",
+ srcs = glob(["*.mlir"]),
+ data = [
+ "//iree/tools:IreeFileCheck",
+ "//iree/tools:iree-opt",
+ ],
+)
diff --git a/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/test/CMakeLists.txt b/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/test/CMakeLists.txt
new file mode 100644
index 0000000..fcc538b
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/test/CMakeLists.txt
@@ -0,0 +1,26 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+iree_add_all_subdirs()
+
+file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir)
+iree_lit_test_suite(
+ NAME
+ lit
+ SRCS
+ "${_GLOB_X_MLIR}"
+ DATA
+ iree::tools::IreeFileCheck
+ iree::tools::iree-opt
+)
diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/structural_ops.mlir b/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/test/structural_ops.mlir
similarity index 81%
rename from iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/structural_ops.mlir
rename to iree/compiler/Dialect/HAL/Conversion/StandardToHAL/test/structural_ops.mlir
index b5ead37..da580f7 100644
--- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/structural_ops.mlir
+++ b/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/test/structural_ops.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt -split-input-file -iree-convert-flow-to-hal %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -iree-convert-to-hal %s | IreeFileCheck %s
// CHECK-LABEL: func @tensorIO(%arg0: !hal.buffer) -> !hal.buffer
func @tensorIO(%arg0 : tensor<1x1xi32>) -> tensor<1x1xi32> {
diff --git a/iree/compiler/Dialect/HAL/IR/BUILD b/iree/compiler/Dialect/HAL/IR/BUILD
index 0ecb83b..09ca578 100644
--- a/iree/compiler/Dialect/HAL/IR/BUILD
+++ b/iree/compiler/Dialect/HAL/IR/BUILD
@@ -71,7 +71,9 @@
deps = [
":IR",
"//iree/compiler/Dialect/HAL:hal_imports",
+ "//iree/compiler/Dialect/HAL/Conversion/HALToHAL",
"//iree/compiler/Dialect/HAL/Conversion/HALToVM",
+ "//iree/compiler/Dialect/IREE/IR",
"//iree/compiler/Dialect/VM/Conversion",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
diff --git a/iree/compiler/Dialect/HAL/IR/CMakeLists.txt b/iree/compiler/Dialect/HAL/IR/CMakeLists.txt
index 15dd12b..2202d6c 100644
--- a/iree/compiler/Dialect/HAL/IR/CMakeLists.txt
+++ b/iree/compiler/Dialect/HAL/IR/CMakeLists.txt
@@ -60,8 +60,10 @@
MLIRParser
MLIRStandard
MLIRTransformUtils
+ iree::compiler::Dialect::HAL::Conversion::HALToHAL
iree::compiler::Dialect::HAL::Conversion::HALToVM
iree::compiler::Dialect::HAL::hal_imports
+ iree::compiler::Dialect::IREE::IR
iree::compiler::Dialect::VM::Conversion
PUBLIC
)
diff --git a/iree/compiler/Dialect/HAL/IR/HALBase.td b/iree/compiler/Dialect/HAL/IR/HALBase.td
index bfa4698..ecbf55f 100644
--- a/iree/compiler/Dialect/HAL/IR/HALBase.td
+++ b/iree/compiler/Dialect/HAL/IR/HALBase.td
@@ -50,6 +50,16 @@
// HAL enums
//===----------------------------------------------------------------------===//
+def HAL_MemoryModel_Unified : I32EnumAttrCase<"Unified", 0>;
+def HAL_MemoryModel_Discrete : I32EnumAttrCase<"Discrete", 1>;
+def HAL_MemoryModelAttr :
+ I32EnumAttr<"MemoryModel", "IREE HAL MemoryModel", [
+ HAL_MemoryModel_Unified,
+ HAL_MemoryModel_Discrete,
+ ]> {
+ let cppNamespace = "::mlir::iree_compiler::IREE::HAL";
+}
+
def HAL_MemoryType_None : BitEnumAttrCase<"None", 0x0000>;
def HAL_MemoryType_Transient : BitEnumAttrCase<"Transient", 0x0001>;
def HAL_MemoryType_HostVisible : BitEnumAttrCase<"HostVisible", 0x0002>;
@@ -489,6 +499,34 @@
// HAL structs
//===----------------------------------------------------------------------===//
+def HAL_BufferConstraintsAttr :
+ IREE_StructAttr<"buffer_constraints", "BufferConstraintsAttr", HAL_Dialect, [
+ // The maximum size of a memory allocation that can be created, even if
+ // there is more space available in the heap.
+ IREE_StructFieldAttr<"max_allocation_size", HAL_DeviceSizeAttr>,
+ // The minimum required alignment, in bytes, for offsets used in runtime
+ // buffer bindings for target backends. Offset values (both dynamic and
+ // static) must be an integer multiple of this limit.
+ IREE_StructFieldAttr<"min_buffer_offset_alignment", HAL_DeviceSizeAttr>,
+ // The maximum value that can be specified for size ranges of buffer
+ // bindings. The underlying allocation may be larger than this but only
+ // up to this amount will be visible to kernels.
+ IREE_StructFieldAttr<"max_buffer_range", HAL_DeviceSizeAttr>,
+ // The minimum required alignment, in bytes, for size ranges of buffer
+ // bindings.
+ IREE_StructFieldAttr<"min_buffer_range_alignment", HAL_DeviceSizeAttr>,
+ ]> {
+ let cppNamespace = "mlir::iree_compiler::IREE::HAL";
+}
+
+def HAL_ByteRangeAttr :
+ IREE_StructAttr<"byte_range", "ByteRangeAttr", HAL_Dialect, [
+ IREE_StructFieldAttr<"offset", HAL_DeviceSizeAttr>,
+ IREE_StructFieldAttr<"length", HAL_DeviceSizeAttr>,
+ ]> {
+ let cppNamespace = "mlir::iree_compiler::IREE::HAL";
+}
+
def HAL_MemoryBarrier : NamedTupleOf<[
NamedTupleElement<0, "source_scope", I32>,
NamedTupleElement<1, "target_scope", I32>
@@ -568,6 +606,13 @@
let cppNamespace = "mlir::iree_compiler::IREE::HAL";
}
+def HAL_DeviceMatchMemoryModelAttr : IREE_StructAttr<
+ "device.match.memory_model", "DeviceMatchMemoryModelAttr", HAL_Dialect, [
+ IREE_StructFieldAttr<"memory_model", HAL_MemoryModelAttr>,
+ ]> {
+ let cppNamespace = "mlir::iree_compiler::IREE::HAL";
+}
+
//===----------------------------------------------------------------------===//
// Base HAL op classes
//===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/HAL/IR/HALDialect.cpp b/iree/compiler/Dialect/HAL/IR/HALDialect.cpp
index a29d4f1..e2863d8 100644
--- a/iree/compiler/Dialect/HAL/IR/HALDialect.cpp
+++ b/iree/compiler/Dialect/HAL/IR/HALDialect.cpp
@@ -14,13 +14,16 @@
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
+#include "iree/compiler/Dialect/HAL/Conversion/HALToHAL/ConvertHALToHAL.h"
#include "iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertHALToVM.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
#include "iree/compiler/Dialect/HAL/hal.imports.h"
+#include "iree/compiler/Dialect/IREE/IR/IREEDialect.h"
#include "iree/compiler/Dialect/VM/Conversion/ConversionDialectInterface.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/SourceMgr.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Parser.h"
@@ -63,6 +66,8 @@
void populateVMConversionPatterns(
SymbolTable &importSymbols, OwningRewritePatternList &patterns,
TypeConverter &typeConverter) const override {
+ populateHALToHALPatterns(getDialect()->getContext(), patterns,
+ typeConverter);
populateHALToVMPatterns(getDialect()->getContext(), importSymbols, patterns,
typeConverter);
}
@@ -80,10 +85,13 @@
HALDialect::HALDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context, TypeID::get<HALDialect>()) {
+ context->loadDialect<IREEDialect>();
+
addInterfaces<HALInlinerInterface, HALToVMConversionInterface>();
- addAttributes<DescriptorSetLayoutBindingAttr, MatchAlwaysAttr, MatchAnyAttr,
- MatchAllAttr, DeviceMatchIDAttr>();
+ addAttributes<BufferConstraintsAttr, ByteRangeAttr,
+ DescriptorSetLayoutBindingAttr, MatchAlwaysAttr, MatchAnyAttr,
+ MatchAllAttr, DeviceMatchIDAttr, DeviceMatchMemoryModelAttr>();
addTypes<AllocatorType, BufferType, BufferViewType, CommandBufferType,
DescriptorSetType, DescriptorSetLayoutType, DeviceType, EventType,
@@ -104,7 +112,11 @@
Type type) const {
StringRef attrKind;
if (failed(parser.parseKeyword(&attrKind))) return {};
- if (attrKind == DescriptorSetLayoutBindingAttr::getKindName()) {
+ if (attrKind == BufferConstraintsAttr::getKindName()) {
+ return BufferConstraintsAttr::parse(parser);
+ } else if (attrKind == ByteRangeAttr::getKindName()) {
+ return ByteRangeAttr::parse(parser);
+ } else if (attrKind == DescriptorSetLayoutBindingAttr::getKindName()) {
return DescriptorSetLayoutBindingAttr::parse(parser);
} else if (attrKind == MatchAlwaysAttr::getKindName()) {
return MatchAlwaysAttr::parse(parser);
@@ -114,6 +126,8 @@
return MatchAllAttr::parse(parser);
} else if (attrKind == DeviceMatchIDAttr::getKindName()) {
return DeviceMatchIDAttr::parse(parser);
+ } else if (attrKind == DeviceMatchMemoryModelAttr::getKindName()) {
+ return DeviceMatchMemoryModelAttr::parse(parser);
}
parser.emitError(parser.getNameLoc())
<< "unknown HAL attribute: " << attrKind;
@@ -122,8 +136,9 @@
void HALDialect::printAttribute(Attribute attr, DialectAsmPrinter &p) const {
TypeSwitch<Attribute>(attr)
- .Case<DescriptorSetLayoutBindingAttr, MatchAlwaysAttr, MatchAnyAttr,
- MatchAllAttr, DeviceMatchIDAttr>(
+ .Case<BufferConstraintsAttr, ByteRangeAttr,
+ DescriptorSetLayoutBindingAttr, MatchAlwaysAttr, MatchAnyAttr,
+ MatchAllAttr, DeviceMatchIDAttr, DeviceMatchMemoryModelAttr>(
[&](auto typedAttr) { typedAttr.print(p); })
.Default(
[](Attribute) { llvm_unreachable("unhandled HAL attribute kind"); });
@@ -192,6 +207,21 @@
}
}
+//===----------------------------------------------------------------------===//
+// Dialect hooks
+//===----------------------------------------------------------------------===//
+
+Operation *HALDialect::materializeConstant(OpBuilder &builder, Attribute value,
+ Type type, Location loc) {
+ if (type.isa<IndexType>()) {
+ // Some folders materialize raw index types, which just become std
+ // constants.
+ return builder.create<mlir::ConstantIndexOp>(
+ loc, value.cast<IntegerAttr>().getValue().getSExtValue());
+ }
+ return nullptr;
+}
+
} // namespace HAL
} // namespace IREE
} // namespace iree_compiler
diff --git a/iree/compiler/Dialect/HAL/IR/HALDialect.h b/iree/compiler/Dialect/HAL/IR/HALDialect.h
index b6e086b..a99d8bd 100644
--- a/iree/compiler/Dialect/HAL/IR/HALDialect.h
+++ b/iree/compiler/Dialect/HAL/IR/HALDialect.h
@@ -33,6 +33,9 @@
Type parseType(DialectAsmParser &parser) const override;
void printType(Type type, DialectAsmPrinter &p) const override;
+
+ Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type,
+ Location loc) override;
};
} // namespace HAL
diff --git a/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp b/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp
index 626fc01..66749ce 100644
--- a/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp
+++ b/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp
@@ -13,6 +13,7 @@
// limitations under the License.
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
#include "llvm/ADT/StringExtras.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
@@ -30,7 +31,7 @@
namespace HAL {
//===----------------------------------------------------------------------===//
-// Variables
+// hal.variable.*
//===----------------------------------------------------------------------===//
namespace {
@@ -155,12 +156,178 @@
}
//===----------------------------------------------------------------------===//
-// iree::hal::Buffer
+// hal.allocator.*
+//===----------------------------------------------------------------------===//
+
+// Computes the element count of a possibly-dynamic shaped tensor.
+static Value getElementCount(Location loc, Value baseValue,
+ ValueRange shapeDims, OpBuilder &builder) {
+ Value value = baseValue;
+ for (auto dim : shapeDims) {
+ value = builder.createOrFold<mlir::MulIOp>(loc, value, dim);
+ }
+ return value;
+}
+
+namespace {
+
+/// Expands hal.allocator.compute_size to IR performing the math.
+struct ExpandAllocatorComputeSizeOp
+ : public OpRewritePattern<AllocatorComputeSizeOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(AllocatorComputeSizeOp op,
+ PatternRewriter &rewriter) const override {
+ // TODO(benvanik): use buffer constraints for alignment.
+ BufferConstraintsAdaptor bufferConstraints(op.getLoc(), op.allocator());
+
+ auto elementSize = rewriter.createOrFold<mlir::ConstantIndexOp>(
+ op.getLoc(), getElementByteCount(op.element_typeAttr()));
+ auto byteSize =
+ getElementCount(op.getLoc(), elementSize, op.shape(), rewriter);
+
+ rewriter.replaceOp(op, {byteSize});
+ return success();
+ }
+};
+
+} // namespace
+
+void AllocatorComputeSizeOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ results.insert<ExpandAllocatorComputeSizeOp>(context);
+}
+
+namespace {
+
+/// Expands hal.allocator.compute_offset to IR performing the math.
+struct ExpandAllocatorComputeOffsetOp
+ : public OpRewritePattern<AllocatorComputeOffsetOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(AllocatorComputeOffsetOp op,
+ PatternRewriter &rewriter) const override {
+ // TODO(benvanik): use buffer constraints.
+ BufferConstraintsAdaptor bufferConstraints(op.getLoc(), op.allocator());
+
+ auto offset = rewriter.createOrFold<mlir::ConstantIndexOp>(op.getLoc(), 0);
+ for (size_t i = 0; i < op.indices().size(); ++i) {
+ // TODO(benvanik): check error case in debug builds.
+ // if (indices[i] >= shape[i]) {
+ // return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+ // "index[%zu] out of bounds: %d >= %d", i,
+ // indices[i], shape[i]);
+ // }
+ auto axisOffset = op.indices()[i];
+ for (size_t j = i + 1; j < op.shape().size(); ++j) {
+ axisOffset = rewriter.createOrFold<mlir::MulIOp>(
+ op.getLoc(), axisOffset, op.shape()[j]);
+ }
+ offset =
+ rewriter.createOrFold<mlir::AddIOp>(op.getLoc(), offset, axisOffset);
+ }
+ auto elementSize = rewriter.createOrFold<mlir::ConstantIndexOp>(
+ op.getLoc(), getElementByteCount(op.element_typeAttr()));
+ auto byteOffset =
+ rewriter.createOrFold<mlir::MulIOp>(op.getLoc(), offset, elementSize);
+
+ rewriter.replaceOp(op, {byteOffset});
+ return success();
+ }
+};
+
+} // namespace
+
+void AllocatorComputeOffsetOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ results.insert<ExpandAllocatorComputeOffsetOp>(context);
+}
+
+namespace {
+
+/// Expands hal.allocator.compute_range to IR performing the math.
+struct ExpandAllocatorComputeRangeOp
+ : public OpRewritePattern<AllocatorComputeRangeOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(AllocatorComputeRangeOp op,
+ PatternRewriter &rewriter) const override {
+ // TODO(benvanik): use buffer constraints.
+ BufferConstraintsAdaptor bufferConstraints(op.getLoc(), op.allocator());
+
+ SmallVector<Value, 6> endIndices(op.shape().size());
+ auto one = rewriter.createOrFold<mlir::ConstantIndexOp>(op.getLoc(), 1);
+ for (size_t i = 0; i < endIndices.size(); ++i) {
+ endIndices[i] = rewriter.createOrFold<mlir::SubIOp>(
+ op.getLoc(),
+ rewriter.createOrFold<mlir::AddIOp>(op.getLoc(), op.indices()[i],
+ op.lengths()[i]),
+ one);
+ }
+
+ auto startByteOffset = rewriter.createOrFold<AllocatorComputeOffsetOp>(
+ op.getLoc(), rewriter.getIndexType(), op.allocator(), op.shape(),
+ op.element_typeAttr(), op.indices());
+ auto endByteOffset = rewriter.createOrFold<AllocatorComputeOffsetOp>(
+ op.getLoc(), rewriter.getIndexType(), op.allocator(), op.shape(),
+ op.element_typeAttr(), endIndices);
+
+ auto elementSize = rewriter.createOrFold<mlir::ConstantIndexOp>(
+ op.getLoc(), getElementByteCount(op.element_typeAttr()));
+ auto offsetLength = rewriter.createOrFold<mlir::AddIOp>(
+ op.getLoc(),
+ rewriter.createOrFold<mlir::SubIOp>(op.getLoc(), endByteOffset,
+ startByteOffset),
+ elementSize);
+
+ rewriter.replaceOp(op, {startByteOffset, offsetLength});
+ return success();
+ }
+};
+
+} // namespace
+
+void AllocatorComputeRangeOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ results.insert<ExpandAllocatorComputeRangeOp>(context);
+}
+
+namespace {
+
+/// Expands hal.allocator.allocate.const to an allocation and data write.
+struct ExpandAllocatorAllocateConstOp
+ : public OpRewritePattern<AllocatorAllocateConstOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(AllocatorAllocateConstOp op,
+ PatternRewriter &rewriter) const override {
+ auto hostBuffer = rewriter.createOrFold<IREE::ByteBufferConstantOp>(
+ op.getLoc(), IREE::ByteBufferType::get(rewriter.getContext()),
+ op.value());
+ auto zero = rewriter.createOrFold<mlir::ConstantIndexOp>(op.getLoc(), 0);
+ auto neg1 = rewriter.createOrFold<mlir::ConstantIndexOp>(op.getLoc(), -1);
+ auto deviceBuffer = rewriter.createOrFold<AllocatorMapOp>(
+ op.getLoc(), op.allocator(), op.memory_types(), op.buffer_usage(),
+ hostBuffer, zero, neg1);
+ rewriter.replaceOp(op, {deviceBuffer});
+ return success();
+ }
+};
+
+} // namespace
+
+void AllocatorAllocateConstOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ results.insert<ExpandAllocatorAllocateConstOp>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// hal.buffer.*
//===----------------------------------------------------------------------===//
namespace {
-/// Skips a hal.buffer_view.buffer accessor when the buffer view was created in
+/// Skips a hal.buffer.allocator accessor when the buffer view was created in
/// the same scope and we know the origin buffer.
struct SkipBufferAllocatorOp : public OpRewritePattern<BufferAllocatorOp> {
using OpRewritePattern<BufferAllocatorOp>::OpRewritePattern;
@@ -175,6 +342,11 @@
op.buffer().getDefiningOp())) {
rewriter.replaceOp(op, allocateOp.allocator());
return success();
+ } else if (auto subspanOp = dyn_cast_or_null<BufferSubspanOp>(
+ op.buffer().getDefiningOp())) {
+ rewriter.replaceOpWithNewOp<BufferAllocatorOp>(op,
+ subspanOp.source_buffer());
+ return success();
}
return failure();
}
@@ -188,7 +360,7 @@
}
//===----------------------------------------------------------------------===//
-// iree::hal::BufferView
+// hal.buffer_view.*
//===----------------------------------------------------------------------===//
namespace {
@@ -256,7 +428,7 @@
}
//===----------------------------------------------------------------------===//
-// iree::hal::CommandBuffer
+// hal.command_buffer.*
//===----------------------------------------------------------------------===//
namespace {
@@ -285,6 +457,88 @@
results.insert<SkipCommandBufferDeviceOp>(context);
}
+namespace {
+
+/// Folds hal.buffer.subspans into push descriptor bindings.
+/// The binding range is always equal to or a subset of the subspan.
+struct FoldCommandBufferPushDescriptorSetBufferSubspan
+ : public OpRewritePattern<CommandBufferPushDescriptorSetOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(CommandBufferPushDescriptorSetOp op,
+ PatternRewriter &rewriter) const override {
+ auto ip = rewriter.saveInsertionPoint();
+ rewriter.setInsertionPoint(op);
+ bool needsUpdate = false;
+ auto bindingBuffers = llvm::to_vector<4>(op.binding_buffers());
+ auto bindingOffsets = llvm::to_vector<4>(op.binding_offsets());
+ for (size_t i = 0; i < bindingBuffers.size(); ++i) {
+ auto *definingOp = bindingBuffers[i].getDefiningOp();
+ if (!definingOp) continue;
+ if (auto subspanOp = dyn_cast<BufferSubspanOp>(definingOp)) {
+ needsUpdate = true;
+ bindingBuffers[i] = subspanOp.source_buffer();
+ bindingOffsets[i] = rewriter.createOrFold<mlir::AddIOp>(
+ subspanOp.getLoc(), subspanOp.source_offset(), bindingOffsets[i]);
+ }
+ }
+ rewriter.restoreInsertionPoint(ip);
+ if (!needsUpdate) return failure();
+ rewriter.updateRootInPlace(op, [&]() {
+ auto mutableBindingBuffers = op.binding_buffersMutable();
+ mutableBindingBuffers.clear();
+ mutableBindingBuffers.append(bindingBuffers);
+ auto mutableBindingOffsets = op.binding_offsetsMutable();
+ mutableBindingOffsets.clear();
+ mutableBindingOffsets.append(bindingOffsets);
+ });
+ return success();
+ }
+};
+
+} // namespace
+
+void CommandBufferPushDescriptorSetOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ results.insert<FoldCommandBufferPushDescriptorSetBufferSubspan>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// hal.constant_pool.*
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+// Resolves hal.constant.buffer ops to their runtime hal.variable buffer.
+struct ResolveConstantPoolLoadToRuntimeBuffer
+ : public OpRewritePattern<ConstantPoolLoadOp> {
+ using OpRewritePattern<ConstantPoolLoadOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(ConstantPoolLoadOp op,
+ PatternRewriter &rewriter) const override {
+ auto *constOp = SymbolTable::lookupNearestSymbolFrom(op, op.constant());
+ SymbolRefAttr runtimeBufferSymRef;
+ ByteRangeAttr runtimeBufferRange;
+ if (auto spanOp = dyn_cast<ConstantPoolSpanOp>(constOp)) {
+ runtimeBufferSymRef = spanOp.runtime_bufferAttr();
+ runtimeBufferRange = spanOp.runtime_rangeAttr();
+ } else if (auto splatOp = dyn_cast<ConstantPoolSplatOp>(constOp)) {
+ runtimeBufferSymRef = splatOp.runtime_bufferAttr();
+ runtimeBufferRange = splatOp.runtime_rangeAttr();
+ }
+ if (!runtimeBufferSymRef || !runtimeBufferRange) return failure();
+ rewriter.replaceOpWithNewOp<IREE::HAL::ConstantSubspanOp>(
+ op, op.getType(), runtimeBufferSymRef, runtimeBufferRange);
+ return success();
+ }
+};
+
+} // namespace
+
+void ConstantPoolLoadOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ results.insert<ResolveConstantPoolLoadToRuntimeBuffer>(context);
+}
+
//===----------------------------------------------------------------------===//
// hal.device.switch
//===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/iree/compiler/Dialect/HAL/IR/HALOps.cpp
index bc5e21f..fc90e0a 100644
--- a/iree/compiler/Dialect/HAL/IR/HALOps.cpp
+++ b/iree/compiler/Dialect/HAL/IR/HALOps.cpp
@@ -138,24 +138,34 @@
}
}
- if (failed(parser.parseOptionalColon())) {
+ if (failed(parser.parseOptionalAttrDictWithKeyword(result->attributes))) {
+ return failure();
+ }
+
+ Type type;
+ if (succeeded(parser.parseOptionalEqual())) {
+ // @foo = 4 : i32
Attribute initialValueAttr;
if (failed(parser.parseAttribute(initialValueAttr, "initial_value",
result->attributes))) {
return failure();
}
- result->addAttribute("type", TypeAttr::get(initialValueAttr.getType()));
+ type = initialValueAttr.getType();
} else {
- Type type;
- if (failed(parser.parseType(type))) {
+ // @foo : index = 4 : i32
+ if (failed(parser.parseColonType(type)) ||
+ failed(parser.parseOptionalAttrDictWithKeyword(result->attributes))) {
return failure();
}
- result->addAttribute("type", TypeAttr::get(type));
+ if (succeeded(parser.parseOptionalEqual())) {
+ Attribute initialValueAttr;
+ if (failed(parser.parseAttribute(initialValueAttr, "initial_value",
+ result->attributes))) {
+ return failure();
+ }
+ }
}
-
- if (failed(parser.parseOptionalAttrDictWithKeyword(result->attributes))) {
- return failure();
- }
+ result->addAttribute("type", TypeAttr::get(type));
return success();
}
@@ -171,10 +181,11 @@
p.printSymbolName(op.initializer().getValue());
p << ')';
}
- if (op.initial_value().hasValue()) {
- p << ' ';
- p.printAttribute(op.initial_value().getValue());
+ if (op.initial_value().hasValue() &&
+ op.type() == op.initial_value().getValue().getType()) {
+ // @foo = 4 : i32
} else {
+ // @foo : index = 4 : i32
p << " : ";
p.printType(op.type());
}
@@ -185,6 +196,10 @@
"initializer",
"initial_value",
});
+ if (op.initial_value().hasValue()) {
+ p << " = ";
+ p.printAttribute(op.initial_value().getValue());
+ }
}
static LogicalResult verifyVariableOp(VariableOp op) {
@@ -208,14 +223,6 @@
<< " is " << op.type() << " but initializer function "
<< initializerOp.getName() << " is " << initializerOp.getType();
}
- } else if (op.initial_value().hasValue()) {
- // Ensure the value is something we can store in the variable
- if (!isVariableTypeCompatible(op.type(), op.initial_value()->getType())) {
- return op.emitOpError()
- << "initial value type mismatch; variable " << op.sym_name()
- << " is " << op.type() << " but initial value provided is "
- << op.initial_value()->getType();
- }
}
return success();
}
@@ -438,6 +445,28 @@
}
//===----------------------------------------------------------------------===//
+// hal.allocator.map
+//===----------------------------------------------------------------------===//
+
+void AllocatorMapOp::build(OpBuilder &builder, OperationState &state,
+ Value allocator,
+ IREE::HAL::MemoryTypeBitfield memoryTypes,
+ IREE::HAL::BufferUsageBitfield bufferUsage,
+ Value source, Value offset, Value length) {
+ state.addOperands({allocator, source, offset, length});
+ state.addAttribute("memory_types", builder.getI32IntegerAttr(
+ static_cast<int32_t>(memoryTypes)));
+ state.addAttribute("buffer_usage", builder.getI32IntegerAttr(
+ static_cast<int32_t>(bufferUsage)));
+ state.addTypes({BufferType::get(builder.getContext())});
+}
+
+void AllocatorMapOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(result(), "mapped");
+}
+
+//===----------------------------------------------------------------------===//
// hal.buffer.allocator
//===----------------------------------------------------------------------===//
@@ -735,6 +764,9 @@
auto i32Type = parser.getBuilder().getIntegerType(32);
auto indexType = parser.getBuilder().getIndexType();
SmallVector<Attribute, 4> bindingAttrs;
+ SmallVector<Value, 4> bindingBuffers;
+ SmallVector<Value, 4> bindingOffsets;
+ SmallVector<Value, 4> bindingLengths;
do {
IntegerAttr bindingAttr;
NamedAttrList attrList;
@@ -746,15 +778,15 @@
failed(parser.parseEqual()) || failed(parser.parseLParen()) ||
failed(parser.parseOperand(buffer)) ||
failed(parser.resolveOperand(
- buffer, BufferType::get(result->getContext()), result->operands)) ||
+ buffer, BufferType::get(result->getContext()), bindingBuffers)) ||
failed(parser.parseComma()) ||
failed(parser.parseOperand(bufferOffset)) ||
failed(
- parser.resolveOperand(bufferOffset, indexType, result->operands)) ||
+ parser.resolveOperand(bufferOffset, indexType, bindingOffsets)) ||
failed(parser.parseComma()) ||
failed(parser.parseOperand(bufferLength)) ||
failed(
- parser.resolveOperand(bufferLength, indexType, result->operands)) ||
+ parser.resolveOperand(bufferLength, indexType, bindingLengths)) ||
failed(parser.parseRParen())) {
return failure();
}
@@ -762,6 +794,9 @@
} while (succeeded(parser.parseOptionalComma()));
result->addAttribute("bindings",
parser.getBuilder().getArrayAttr(bindingAttrs));
+ result->addOperands(bindingBuffers);
+ result->addOperands(bindingOffsets);
+ result->addOperands(bindingLengths);
return success();
}
@@ -870,6 +905,75 @@
}
//===----------------------------------------------------------------------===//
+// hal.constant_pool
+//===----------------------------------------------------------------------===//
+
+void ConstantPoolOp::build(OpBuilder &builder, OperationState &state,
+ StringRef name,
+ BufferConstraintsAttr bufferConstraints) {
+ ensureTerminator(*state.addRegion(), builder, state.location);
+ state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
+ builder.getStringAttr(name));
+ state.addAttribute("buffer_constraints", bufferConstraints);
+}
+
+static ParseResult parseConstantPoolOp(OpAsmParser &parser,
+ OperationState *result) {
+ StringAttr nameAttr;
+ if (failed(parser.parseSymbolName(nameAttr,
+ mlir::SymbolTable::getSymbolAttrName(),
+ result->attributes)) ||
+ failed(parser.parseOptionalAttrDictWithKeyword(result->attributes))) {
+ return failure();
+ }
+
+ // Parse the module body.
+ auto *body = result->addRegion();
+ if (failed(parser.parseRegion(*body, llvm::None, llvm::None))) {
+ return failure();
+ }
+
+ // Ensure that this module has a valid terminator.
+ ConstantPoolOp::ensureTerminator(*body, parser.getBuilder(),
+ result->location);
+ return success();
+}
+
+static void printConstantPoolOp(OpAsmPrinter &p, ConstantPoolOp op) {
+ p << op.getOperationName() << ' ';
+ p.printSymbolName(op.sym_name());
+ p.printOptionalAttrDictWithKeyword(
+ op.getAttrs(),
+ /*elidedAttrs=*/{mlir::SymbolTable::getSymbolAttrName()});
+ p.printRegion(op.body(), /*printEntryBlockArgs=*/false,
+ /*printBlockTerminators=*/false);
+}
+
+//===----------------------------------------------------------------------===//
+// hal.constant_pool.load
+//===----------------------------------------------------------------------===//
+
+void ConstantPoolLoadOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
+ setNameFn(result(), "const");
+}
+
+//===----------------------------------------------------------------------===//
+// hal.constant_storage.lookup
+//===----------------------------------------------------------------------===//
+
+void ConstantStorageLookupOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
+ setNameFn(result(), "storage");
+}
+
+//===----------------------------------------------------------------------===//
+// hal.constant.subspan
+//===----------------------------------------------------------------------===//
+
+void ConstantSubspanOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
+ setNameFn(result(), "const_span");
+}
+
+//===----------------------------------------------------------------------===//
// hal.descriptor_set.create
//===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/HAL/IR/HALOps.td b/iree/compiler/Dialect/HAL/IR/HALOps.td
index 1541a46..e95d5be 100644
--- a/iree/compiler/Dialect/HAL/IR/HALOps.td
+++ b/iree/compiler/Dialect/HAL/IR/HALOps.td
@@ -341,13 +341,14 @@
$element_type attr-dict
}];
- let skipDefaultBuilders = 1;
let builders = [
OpBuilder<[{
OpBuilder &builder, OperationState &state, Value allocator,
ValueRange shape, int32_t elementSize
}]>,
];
+
+ let hasCanonicalizer = 1;
}
def HAL_AllocatorComputeOffsetOp : HAL_PureOp<"allocator.compute_offset", [
@@ -376,13 +377,14 @@
$element_type `,` `indices` `=` `[` $indices `]` attr-dict
}];
- let skipDefaultBuilders = 1;
let builders = [
OpBuilder<[{
OpBuilder &builder, OperationState &state, Value allocator,
ValueRange shape, int32_t elementType, ValueRange indices
}]>,
];
+
+ let hasCanonicalizer = 1;
}
def HAL_AllocatorComputeRangeOp : HAL_PureOp<"allocator.compute_range", [
@@ -414,7 +416,6 @@
$lengths `]` attr-dict
}];
- let skipDefaultBuilders = 1;
let builders = [
OpBuilder<[{
OpBuilder &builder, OperationState &state, Value allocator,
@@ -422,6 +423,8 @@
ValueRange lengths
}]>,
];
+
+ let hasCanonicalizer = 1;
}
def HAL_AllocatorAllocateOp : HAL_Op<"allocator.allocate", [
@@ -495,6 +498,51 @@
ElementsAttr value
}]>,
];
+
+ let hasCanonicalizer = 1;
+}
+
+def HAL_AllocatorMapOp : HAL_Op<"allocator.map", [
+ DeclareOpInterfaceMethods<OpAsmOpInterface>,
+ ]> {
+ let summary = [{allocator-supported host buffer wrapping operation}];
+ let description = [{
+ Wraps a !hal.buffer around host read-only memory backed by the given byte
+ buffer. The returned buffer may be host-only and not directly usable on
+ devices.
+ }];
+
+ let arguments = (ins
+ HAL_Allocator:$allocator,
+ HAL_MemoryTypeBitfieldAttr:$memory_types,
+ HAL_BufferUsageBitfieldAttr:$buffer_usage,
+ // TODO(benvanik): support other types (and mutable buffers).
+ ByteBufferType:$source,
+ HAL_DeviceSize:$offset,
+ HAL_DeviceSize:$length
+ );
+ let results = (outs
+ HAL_Buffer:$result
+ );
+
+ let assemblyFormat = [{
+ $allocator `,` $memory_types `,` $buffer_usage `,`
+ $source `[` $offset `,` $length `]` attr-dict-with-keyword
+ `:` type($source) `->` type($result)
+ }];
+
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilder<[{
+ OpBuilder &builder, OperationState &state,
+ Value allocator,
+ IREE::HAL::MemoryTypeBitfield memoryTypes,
+ IREE::HAL::BufferUsageBitfield bufferUsage,
+ Value source,
+ Value offset,
+ Value length
+ }]>,
+ ];
}
//===----------------------------------------------------------------------===//
@@ -1148,6 +1196,8 @@
ArrayRef<DescriptorSetBindingValue> bindings
}]>,
];
+
+ let hasCanonicalizer = 1;
}
def HAL_CommandBufferBindDescriptorSetOp :
@@ -1304,6 +1354,231 @@
}
//===----------------------------------------------------------------------===//
+// Constant pooling
+//===----------------------------------------------------------------------===//
+
+def HAL_ConstantPoolOp : HAL_Op<"constant_pool", [
+ IsolatedFromAbove,
+ SingleBlockImplicitTerminator<"IREE::HAL::ConstantPoolEndOp">,
+ Symbol,
+ SymbolTable,
+ ]> {
+ let summary = [{pool of constants with similar lifetimes}];
+ let description = [{
+ A pool of constants that share a similiar lifetime and that should be stored
+ together both in the source files and at runtime. By logically grouping
+ constants by their frequency and locality of access we can reduce the number
+ of bindings required on hal.interface by sourcing constants from the same
+ buffer. We can also optimize module loading by mapping or DMA transferring
+ constant data (based on device).
+ }];
+
+ let arguments = (ins
+ StrAttr:$sym_name,
+ HAL_BufferConstraintsAttr:$buffer_constraints
+ );
+
+ let regions = (region SizedRegion<1>:$body);
+
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilder<[{
+ OpBuilder &builder, OperationState &state, StringRef name,
+ BufferConstraintsAttr bufferConstraints
+ }]>,
+ ];
+
+ let extraClassDeclaration = [{
+ Block& getBlock() { return body().front(); }
+ }];
+}
+
+def HAL_ConstantPoolEndOp : HAL_Op<"constant_pool_end", [
+ HasParent<"IREE::HAL::ConstantPoolOp">,
+ Terminator,
+ ]> {
+ let summary = [{terminator pseudo-op for the constant pool op}];
+ let assemblyFormat = "attr-dict-with-keyword";
+}
+
+def HAL_ConstantPoolValueOp : HAL_Op<"constant_pool.value", [
+ Symbol,
+ HasParent<"IREE::HAL::ConstantPoolOp">,
+ ]> {
+ let summary = [{constant value within a parent constant pool}];
+ let description = [{
+ Represents a constant value as part of a constant pool containing constants
+ with a similar lifetime.
+ }];
+
+ let arguments = (ins
+ SymbolNameAttr:$sym_name,
+ ElementsAttr:$value
+ );
+
+ let assemblyFormat = [{
+ $sym_name attr-dict `=` $value
+ }];
+}
+
+def HAL_ConstantPoolSpanOp : HAL_Op<"constant_pool.span", [
+ Symbol,
+ HasParent<"IREE::HAL::ConstantPoolOp">,
+ ]> {
+ let summary = [{constant span within a parent storage block}];
+ let description = [{
+ Represents a constant stored within a hal.constant_pool. Provides a
+ symbol that can be used to reference the constant data as a stored range
+ within the module file.
+ }];
+
+ let arguments = (ins
+ SymbolNameAttr:$sym_name,
+ TypeAttr:$tensor_type,
+ SymbolRefAttr:$storage_buffer,
+ HAL_ByteRangeAttr:$storage_range,
+ OptionalAttr<SymbolRefAttr>:$runtime_buffer,
+ OptionalAttr<HAL_ByteRangeAttr>:$runtime_range
+ );
+
+ let assemblyFormat = [{
+ $sym_name `:` $tensor_type attr-dict
+ `=` $storage_buffer `[` $storage_range `]`
+ (`->` $runtime_buffer^ `[` $runtime_range `]`)?
+ }];
+}
+
+def HAL_ConstantPoolSplatOp : HAL_Op<"constant_pool.splat", [
+ Symbol,
+ HasParent<"IREE::HAL::ConstantPoolOp">,
+ ]> {
+ let summary = [{constant splat within a parent storage block}];
+ let description = [{
+ Represents a splatted constant that has no representation in the storage
+ but must be represented at runtime as splatted 4-byte value.
+ }];
+
+ let arguments = (ins
+ SymbolNameAttr:$sym_name,
+ ElementsAttr:$value,
+ OptionalAttr<SymbolRefAttr>:$runtime_buffer,
+ OptionalAttr<HAL_ByteRangeAttr>:$runtime_range
+ );
+
+ let assemblyFormat = [{
+ $sym_name attr-dict `=` $value
+ (`->` $runtime_buffer^ `[` $runtime_range `]`)?
+ }];
+}
+
+def HAL_ConstantPoolLoadOp : HAL_PureOp<"constant_pool.load", [
+ DeclareOpInterfaceMethods<OpAsmOpInterface>,
+ ]> {
+ let summary = [{constant pool tensor load pseudo-op}];
+ let description = [{
+ Used during conversion to provide a placeholder for a globally cached and
+ possibly lazy-initialized compile-time constants. Will be replaced with a
+ direct variable access during transformation.
+ }];
+
+ let arguments = (ins
+ SymbolRefAttr:$constant
+ );
+ let results = (outs
+ TypeAlias<AnyRankedTensor>:$result
+ );
+
+ let assemblyFormat = "$constant attr-dict `:` type($result)";
+
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilder<[{
+ OpBuilder &builder, OperationState &state,
+ Type resultType, SymbolRefAttr constant
+ }], [{
+ state.addTypes({resultType});
+ state.addAttribute("constant", constant);
+ }]>,
+ ];
+
+ let hasCanonicalizer = 1;
+}
+
+def HAL_ConstantStorageOp : HAL_Op<"constant_storage", [
+ Symbol,
+ ]> {
+ let summary = [{constant data storage block}];
+ let description = [{
+ Represents a packed constant storage buffer meeting the buffer constraints
+ placed on the parent pool. Referenced by other constant pool ops.
+ }];
+
+ let arguments = (ins
+ SymbolNameAttr:$sym_name,
+ ElementsAttr:$value
+ );
+
+ let assemblyFormat = [{
+ $sym_name attr-dict-with-keyword `=` $value
+ }];
+}
+
+def HAL_ConstantStorageLookupOp :
+ HAL_PureOp<"constant_storage.lookup", [
+ DeclareOpInterfaceMethods<OpAsmOpInterface>,
+ ]> {
+ let summary = [{constant storage byte buffer accessor}];
+ let description = [{
+ Returns the read-only host byte buffer storing the constant data.
+ }];
+
+ let arguments = (ins
+ SymbolRefAttr:$constant
+ );
+ let results = (outs
+ ByteBufferType:$result
+ );
+
+ let assemblyFormat = [{
+ $constant `:` type($result) attr-dict
+ }];
+}
+
+def HAL_ConstantSubspanOp : HAL_PureOp<"constant.subspan", [
+ DeclareOpInterfaceMethods<OpAsmOpInterface>,
+ ]> {
+ let summary = [{runtime constant buffer lookup pseudo-op}];
+ let description = [{
+ Used during conversion to resolve a runtime representation of a constant as
+ a tensor backed by a buffer range.
+ }];
+
+ let arguments = (ins
+ SymbolRefAttr:$runtime_buffer,
+ HAL_ByteRangeAttr:$runtime_range
+ );
+ let results = (outs
+ AnyRankedTensor:$result
+ );
+
+ let assemblyFormat = [{
+ $runtime_buffer `[` $runtime_range `]` `:` type($result) attr-dict
+ }];
+
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilder<[{
+ OpBuilder &builder, OperationState &state,
+ Type resultType, SymbolRefAttr runtimeBuffer, ByteRangeAttr runtimeRange
+ }], [{
+ state.addTypes({resultType});
+ state.addAttribute("runtime_buffer", runtimeBuffer);
+ state.addAttribute("runtime_range", runtimeRange);
+ }]>,
+ ];
+}
+
+//===----------------------------------------------------------------------===//
// iree::hal::DescriptorSet
//===----------------------------------------------------------------------===//
@@ -1569,6 +1844,32 @@
}];
}
+def HAL_DeviceMatchMemoryModelOp : HAL_PureOp<"device.match.memory_model"> {
+ let summary = [{returns true if the device memory model matches the value}];
+ let description = [{
+ Compares the device's memory model against the specified model.
+ This can be used to conditionally evaluate device-specific code when the
+ device is not known at compile-time.
+
+ ```mlir
+ %is_match = hal.device.match.memory_model %device, memory_model = "Unified" : (!hal.device) -> i1
+ ```
+ }];
+
+ let arguments = (ins
+ HAL_Device:$device,
+ HAL_MemoryModelAttr:$model
+ );
+ let results = (outs
+ I1:$result
+ );
+
+ let assemblyFormat = [{
+ $device `,` `model` `=` `[` $model `]` attr-dict
+ `:` `(` type($device) `)` `->` type($result)
+ }];
+}
+
//===----------------------------------------------------------------------===//
// iree::hal::Executable
//===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/HAL/IR/HALTypes.cpp b/iree/compiler/Dialect/HAL/IR/HALTypes.cpp
index 9d10e92..8a16770 100644
--- a/iree/compiler/Dialect/HAL/IR/HALTypes.cpp
+++ b/iree/compiler/Dialect/HAL/IR/HALTypes.cpp
@@ -15,6 +15,7 @@
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
#include "llvm/ADT/StringExtras.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
@@ -98,11 +99,190 @@
elementType.getValue());
}
+size_t getElementBitCount(IntegerAttr elementType) {
+ return static_cast<size_t>((elementType.getValue().getZExtValue()) & 0xFF);
+}
+
+size_t getElementByteCount(IntegerAttr elementType) {
+ return (getElementBitCount(elementType) + 8 - 1) / 8;
+}
+
+//===----------------------------------------------------------------------===//
+// Struct types
+//===----------------------------------------------------------------------===//
+
+BufferConstraintsAttr intersectBufferConstraints(BufferConstraintsAttr lhs,
+ BufferConstraintsAttr rhs) {
+ Builder b(lhs.getContext());
+ return BufferConstraintsAttr::get(
+ b.getIndexAttr(std::min(lhs.max_allocation_size().getSExtValue(),
+ rhs.max_allocation_size().getSExtValue())),
+ b.getIndexAttr(
+ std::max(lhs.min_buffer_offset_alignment().getSExtValue(),
+ rhs.min_buffer_offset_alignment().getSExtValue())),
+ b.getIndexAttr(std::min(lhs.max_buffer_range().getSExtValue(),
+ rhs.max_buffer_range().getSExtValue())),
+ b.getIndexAttr(
+ std::max(lhs.min_buffer_range_alignment().getSExtValue(),
+ rhs.min_buffer_range_alignment().getSExtValue())));
+}
+
+// TODO(benvanik): runtime buffer constraint queries from the allocator.
+// We can add folders for those when the allocator is strongly-typed with
+// #hal.buffer_constraints and otherwise leave them for runtime queries.
+BufferConstraintsAdaptor::BufferConstraintsAdaptor(Location loc,
+ Value allocator)
+ : loc_(loc), allocator_(allocator) {
+ // Picked to represent what we kind of want on CPU today.
+ uint64_t maxAllocationSize = 1 * 1024 * 1024 * 1024ull;
+ uint64_t minBufferOffsetAlignment = 16ull;
+ uint64_t maxBufferRange = 1 * 1024 * 1024 * 1024ull;
+ uint64_t minBufferRangeAlignment = 16ull;
+ Builder b(loc.getContext());
+ bufferConstraints_ = BufferConstraintsAttr::get(
+ b.getIndexAttr(maxAllocationSize),
+ b.getIndexAttr(minBufferOffsetAlignment), b.getIndexAttr(maxBufferRange),
+ b.getIndexAttr(minBufferRangeAlignment));
+}
+
+Value BufferConstraintsAdaptor::getMaxAllocationSize(OpBuilder &builder) {
+ return builder.createOrFold<mlir::ConstantOp>(
+ loc_, bufferConstraints_.max_allocation_sizeAttr());
+}
+
+Value BufferConstraintsAdaptor::getMinBufferOffsetAlignment(
+ OpBuilder &builder) {
+ return builder.createOrFold<mlir::ConstantOp>(
+ loc_, bufferConstraints_.min_buffer_offset_alignmentAttr());
+}
+
+Value BufferConstraintsAdaptor::getMaxBufferRange(OpBuilder &builder) {
+ return builder.createOrFold<mlir::ConstantOp>(
+ loc_, bufferConstraints_.max_buffer_rangeAttr());
+}
+
+Value BufferConstraintsAdaptor::getMinBufferRangeAlignment(OpBuilder &builder) {
+ return builder.createOrFold<mlir::ConstantOp>(
+ loc_, bufferConstraints_.min_buffer_range_alignmentAttr());
+}
+
//===----------------------------------------------------------------------===//
// Attribute printing and parsing
//===----------------------------------------------------------------------===//
// static
+Attribute BufferConstraintsAttr::parse(DialectAsmParser &p) {
+ auto b = p.getBuilder();
+ if (failed(p.parseLess())) return {};
+
+ IntegerAttr maxAllocationSizeAttr;
+ IntegerAttr minBufferOffsetAlignmentAttr;
+ IntegerAttr maxBufferRangeAttr;
+ IntegerAttr minBufferRangeAlignmentAttr;
+ if (failed(p.parseKeyword("max_allocation_size")) || failed(p.parseEqual()) ||
+ failed(p.parseAttribute(maxAllocationSizeAttr, b.getIndexType())) ||
+ failed(p.parseComma()) ||
+ failed(p.parseKeyword("min_buffer_offset_alignment")) ||
+ failed(p.parseEqual()) ||
+ failed(
+ p.parseAttribute(minBufferOffsetAlignmentAttr, b.getIndexType())) ||
+ failed(p.parseComma()) || failed(p.parseKeyword("max_buffer_range")) ||
+ failed(p.parseEqual()) ||
+ failed(p.parseAttribute(maxBufferRangeAttr, b.getIndexType())) ||
+ failed(p.parseComma()) ||
+ failed(p.parseKeyword("min_buffer_range_alignment")) ||
+ failed(p.parseEqual()) ||
+ failed(p.parseAttribute(minBufferRangeAlignmentAttr, b.getIndexType()))) {
+ return {};
+ }
+
+ if (failed(p.parseGreater())) return {};
+ return BufferConstraintsAttr::get(
+ maxAllocationSizeAttr, minBufferOffsetAlignmentAttr, maxBufferRangeAttr,
+ minBufferRangeAlignmentAttr);
+}
+
+void BufferConstraintsAttr::print(DialectAsmPrinter &p) const {
+ auto &os = p.getStream();
+ os << getKindName() << "<";
+ os << "max_allocation_size = " << max_allocation_size() << ", ";
+ os << "min_buffer_offset_alignment = " << min_buffer_offset_alignment()
+ << ", ";
+ os << "max_buffer_range = " << max_buffer_range() << ", ";
+ os << "min_buffer_range_alignment = " << min_buffer_range_alignment();
+ os << ">";
+}
+
+// static
+Attribute ByteRangeAttr::parse(DialectAsmParser &p) {
+ auto b = p.getBuilder();
+ if (failed(p.parseLess())) return {};
+
+ // TODO(benvanik): support the range syntax; the dialect asm parser fights
+ // with it though by checking for proper []/() nesting.
+
+ // Try first the range style: byte_range<[start..end)>
+ bool startInclusive;
+ if (succeeded(p.parseOptionalLSquare())) { // [...
+ startInclusive = true;
+ } else if (succeeded(p.parseOptionalLParen())) { // (...
+ startInclusive = false;
+ } else {
+ // byte_range<offset, length>
+ IntegerAttr offsetAttr;
+ IntegerAttr lengthAttr;
+ if (failed(p.parseAttribute(offsetAttr, b.getIndexType())) ||
+ failed(p.parseComma()) ||
+ failed(p.parseAttribute(lengthAttr, b.getIndexType())) ||
+ failed(p.parseGreater())) {
+ return {};
+ }
+ return get(offsetAttr, lengthAttr);
+ }
+
+ IntegerAttr startAttr;
+ IntegerAttr endAttr;
+ if (failed(p.parseAttribute(startAttr, b.getIndexType())) ||
+ failed(p.parseKeyword("to")) ||
+ failed(p.parseAttribute(endAttr, b.getIndexType()))) {
+ return {};
+ }
+
+ bool endInclusive;
+ if (succeeded(p.parseOptionalRSquare())) { // ...]
+ endInclusive = true;
+ } else if (succeeded(p.parseOptionalRParen())) { // ...)
+ endInclusive = false;
+ } else {
+ p.emitError(p.getCurrentLocation()) << "expected ] or ) to end range";
+ return {};
+ }
+
+ if (failed(p.parseGreater())) return {};
+
+ startAttr = startInclusive
+ ? startAttr
+ : b.getIndexAttr((startAttr.getValue() + 1).getSExtValue());
+ endAttr = endInclusive
+ ? endAttr
+ : b.getIndexAttr((endAttr.getValue() - 1).getSExtValue());
+
+ IntegerAttr offsetAttr = startAttr;
+ IntegerAttr lengthAttr = b.getIndexAttr(
+ (endAttr.getValue() - startAttr.getValue()).getSExtValue());
+ return get(offsetAttr, lengthAttr);
+}
+
+void ByteRangeAttr::print(DialectAsmPrinter &p) const {
+ auto &os = p.getStream();
+ os << getKindName() << "<";
+ os << offset();
+ os << ", ";
+ os << length();
+ os << ">";
+}
+
+// static
Attribute DescriptorSetLayoutBindingAttr::parse(DialectAsmParser &p) {
auto b = p.getBuilder();
IntegerAttr bindingAttr;
@@ -204,6 +384,24 @@
os << "\">";
}
+// static
+Attribute DeviceMatchMemoryModelAttr::parse(DialectAsmParser &p) {
+ IntegerAttr memoryModelAttr;
+ if (failed(p.parseLess()) ||
+ failed(parseEnumAttr<MemoryModel>(p, "memory_model", memoryModelAttr)) ||
+ failed(p.parseGreater())) {
+ return {};
+ }
+ return get(memoryModelAttr);
+}
+
+void DeviceMatchMemoryModelAttr::print(DialectAsmPrinter &p) const {
+ auto &os = p.getStream();
+ os << getKindName() << "<\"";
+ os << stringifyMemoryModel(memory_model());
+ os << "\">";
+}
+
#include "iree/compiler/Dialect/HAL/IR/HALOpInterface.cpp.inc"
} // namespace HAL
diff --git a/iree/compiler/Dialect/HAL/IR/HALTypes.h b/iree/compiler/Dialect/HAL/IR/HALTypes.h
index b1f8cee..6a1b1bd 100644
--- a/iree/compiler/Dialect/HAL/IR/HALTypes.h
+++ b/iree/compiler/Dialect/HAL/IR/HALTypes.h
@@ -52,8 +52,14 @@
// Returns an attribute with the MLIR element type or {}.
IntegerAttr getElementTypeAttr(Type type);
+// Returns the total bit count of elements of the given type.
+size_t getElementBitCount(IntegerAttr elementType);
+
+// Returns the rounded-up byte count of elements of the given type.
+size_t getElementByteCount(IntegerAttr elementType);
+
//===----------------------------------------------------------------------===//
-// RefObject types
+// Object types
//===----------------------------------------------------------------------===//
class AllocatorType : public Type::TypeBase<AllocatorType, Type, TypeStorage> {
@@ -133,6 +139,28 @@
// Struct types
//===----------------------------------------------------------------------===//
+// Returns the intersection (most conservative) constraints |lhs| ∩ |rhs|.
+BufferConstraintsAttr intersectBufferConstraints(BufferConstraintsAttr lhs,
+ BufferConstraintsAttr rhs);
+
+// TODO(benvanik): runtime buffer constraint queries from the allocator.
+// We can add folders for those when the allocator is strongly-typed with
+// #hal.buffer_constraints and otherwise leave them for runtime queries.
+class BufferConstraintsAdaptor {
+ public:
+ BufferConstraintsAdaptor(Location loc, Value allocator);
+
+ Value getMaxAllocationSize(OpBuilder &builder);
+ Value getMinBufferOffsetAlignment(OpBuilder &builder);
+ Value getMaxBufferRange(OpBuilder &builder);
+ Value getMinBufferRangeAlignment(OpBuilder &builder);
+
+ private:
+ Location loc_;
+ Value allocator_;
+ BufferConstraintsAttr bufferConstraints_;
+};
+
class BufferBarrierType {
public:
static TupleType get(MLIRContext *context) {
diff --git a/iree/compiler/Dialect/HAL/IR/test/allocator_ops.mlir b/iree/compiler/Dialect/HAL/IR/test/allocator_ops.mlir
index 85a4e07..d6c9885 100644
--- a/iree/compiler/Dialect/HAL/IR/test/allocator_ops.mlir
+++ b/iree/compiler/Dialect/HAL/IR/test/allocator_ops.mlir
@@ -62,3 +62,22 @@
// CHECK-NEXT: return %[[CB]]
return %buffer : !hal.buffer
}
+
+// -----
+
+// CHECK-LABEL: @allocator_map_byte_buffer
+func @allocator_map_byte_buffer() -> !hal.buffer {
+ // CHECK-DAG: [[SOURCE:%.+]] = "test_hal.immutable_data"
+ %source = "test_hal.immutable_data"() : () -> !iree.byte_buffer
+ // CHECK-DAG: [[OFFSET:%.+]] = "test_hal.offset"
+ %offset = "test_hal.offset"() : () -> index
+ // CHECK-DAG: [[LENGTH:%.+]] = "test_hal.length"
+ %length = "test_hal.length"() : () -> index
+ // CHECK-DAG: [[AL:%.+]] = "test_hal.allocator"
+ %allocator = "test_hal.allocator"() : () -> !hal.allocator
+ // CHECK: = hal.allocator.map [[AL]], "HostVisible|HostCoherent", "Transfer", [[SOURCE]][
+ // CHECK-SAME: [[OFFSET]], [[LENGTH]]
+ // CHECK-SAME: ] : !iree.byte_buffer -> !hal.buffer
+ %buffer = hal.allocator.map %allocator, "HostVisible|HostCoherent", "Transfer", %source[%offset, %length] : !iree.byte_buffer -> !hal.buffer
+ return %buffer : !hal.buffer
+}
diff --git a/iree/compiler/Dialect/HAL/IR/test/attributes.mlir b/iree/compiler/Dialect/HAL/IR/test/attributes.mlir
index e74781b..155c222 100644
--- a/iree/compiler/Dialect/HAL/IR/test/attributes.mlir
+++ b/iree/compiler/Dialect/HAL/IR/test/attributes.mlir
@@ -1,6 +1,15 @@
// RUN: iree-opt -split-input-file %s | iree-opt -split-input-file | IreeFileCheck %s
-"some.foo"() {
+// CHECK-LABEL: byte_range.offset_length
+"byte_range.offset_length"() {
+ // CHECK: br = #hal.byte_range<123, 456>
+ br = #hal.byte_range<123, 456>
+} : () -> ()
+
+// -----
+
+// CHECK-LABEL: descriptor_set_layout_binding.basic
+"descriptor_set_layout_binding.basic"() {
// CHECK: dslb = #hal.descriptor_set_layout_binding<0, "StorageBuffer", "Read|MayAlias">
dslb = #hal.descriptor_set_layout_binding<0, "StorageBuffer", "Read|MayAlias">
} : () -> ()
diff --git a/iree/compiler/Dialect/HAL/IR/test/buffer_folding.mlir b/iree/compiler/Dialect/HAL/IR/test/buffer_folding.mlir
index 656ad7a..257df6d 100644
--- a/iree/compiler/Dialect/HAL/IR/test/buffer_folding.mlir
+++ b/iree/compiler/Dialect/HAL/IR/test/buffer_folding.mlir
@@ -12,3 +12,19 @@
// CHECK: return %[[AL]]
return %1 : !hal.allocator
}
+
+// -----
+
+// CHECK-LABEL: @skip_subspan_buffer_allocator
+func @skip_subspan_buffer_allocator() -> !hal.allocator {
+ %c0 = constant 0 : index
+ %c184 = constant 184 : index
+ %c384 = constant 384 : index
+ // CHECK-DAG: %[[AL:.+]] = "test_hal.allocator"
+ %allocator = "test_hal.allocator"() : () -> !hal.allocator
+ %source_buffer = hal.allocator.allocate %allocator, "HostVisible|HostCoherent", "Transfer", %c384 : !hal.buffer
+ %span_buffer = hal.buffer.subspan %source_buffer, %c0, %c184 : !hal.buffer
+ %1 = hal.buffer.allocator %span_buffer : !hal.allocator
+ // CHECK: return %[[AL]]
+ return %1 : !hal.allocator
+}
diff --git a/iree/compiler/Dialect/HAL/IR/test/buffer_view_folding.mlir b/iree/compiler/Dialect/HAL/IR/test/buffer_view_folding.mlir
index 54d51dd..2d1d209 100644
--- a/iree/compiler/Dialect/HAL/IR/test/buffer_view_folding.mlir
+++ b/iree/compiler/Dialect/HAL/IR/test/buffer_view_folding.mlir
@@ -1,14 +1,13 @@
-// Tests folding and canonicalization of HAL buffer view ops.
-
// RUN: iree-opt -split-input-file -canonicalize %s | iree-opt -split-input-file | IreeFileCheck %s
// CHECK-LABEL: @expand_buffer_view_const
func @expand_buffer_view_const() -> !hal.buffer_view {
%0 = "test_hal.allocator"() : () -> !hal.allocator
- // CHECK: %[[BUFFER:.+]] = hal.allocator.allocate.const %0, "HostVisible|HostCoherent", "Transfer" : !hal.buffer = dense<[4, 1, 2]> : tensor<3xi32>
- // CHECK-NEXT: %[[VIEW:.+]] = hal.buffer_view.create %[[BUFFER]], shape = [%c3], element_type = 16777248 : !hal.buffer_view
+ // CHECK: [[CONST:%.+]] = iree.byte_buffer.constant : !iree.byte_buffer = dense<[4, 1, 2]> : tensor<3xi32>
+ // CHECK-NEXT: [[BUFFER:%.+]] = hal.allocator.map {{.+}}, "HostVisible|HostCoherent", "Transfer", [[CONST]][%c0, %c-1] : !iree.byte_buffer -> !hal.buffer
+ // CHECK-NEXT: [[VIEW:%.+]] = hal.buffer_view.create [[BUFFER]], shape = [%c3], element_type = 16777248 : !hal.buffer_view
%view = hal.buffer_view.const %0, "HostVisible|HostCoherent", "Transfer" : !hal.buffer_view = dense<[4, 1, 2]> : tensor<3xi32>
- // CHECK-NEXT: return %[[VIEW]]
+ // CHECK-NEXT: return [[VIEW]]
return %view : !hal.buffer_view
}
diff --git a/iree/compiler/Dialect/HAL/IR/test/command_buffer_folding.mlir b/iree/compiler/Dialect/HAL/IR/test/command_buffer_folding.mlir
index a004780..6ede3e1 100644
--- a/iree/compiler/Dialect/HAL/IR/test/command_buffer_folding.mlir
+++ b/iree/compiler/Dialect/HAL/IR/test/command_buffer_folding.mlir
@@ -1,5 +1,3 @@
-// Tests folding and canonicalization of HAL command buffer ops.
-
// RUN: iree-opt -split-input-file -canonicalize %s | iree-opt -split-input-file | IreeFileCheck %s
// CHECK-LABEL: @skip_command_buffer_device
@@ -14,3 +12,34 @@
return %exe : !hal.executable
}
+
+// -----
+
+// CHECK-LABEL: @fold_buffer_subspan_into_push_descriptor_set
+// CHECK-SAME: [[BASE_BUFFER:%[a-z0-9]+]]: !hal.buffer
+func @fold_buffer_subspan_into_push_descriptor_set(
+ %cmd : !hal.command_buffer,
+ %layout : !hal.executable_layout,
+ %buffer : !hal.buffer
+ ) {
+ %c0 = constant 0 : index
+ %c4 = constant 4 : index
+ %c4096 = constant 4096 : index
+ %c8000 = constant 8000 : index
+ %c262140 = constant 262140 : index
+ %c262144 = constant 262144 : index
+ %subspan = hal.buffer.subspan %buffer, %c4096, %c262144 : !hal.buffer
+ // CHECK: hal.command_buffer.push_descriptor_set {{.+}}, bindings=[
+ hal.command_buffer.push_descriptor_set %cmd, %layout, set=0, bindings=[
+ // 0 + 4096:
+ // CHECK-SAME: 0 = ([[BASE_BUFFER]], %c4096, %c8000)
+ 0 = (%subspan, %c0, %c8000),
+ // 4096 + 4:
+ // CHECK-SAME: 1 = ([[BASE_BUFFER]], %c4100, %c262140)
+ 1 = (%subspan, %c4, %c262140),
+ // No change:
+ // CHECK-SAME: 2 = ([[BASE_BUFFER]], %c4096, %c262144)
+ 2 = (%buffer, %c4096, %c262144)
+ ]
+ return
+}
diff --git a/iree/compiler/Dialect/HAL/IR/test/constant_folding.mlir b/iree/compiler/Dialect/HAL/IR/test/constant_folding.mlir
new file mode 100644
index 0000000..5d73945
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/IR/test/constant_folding.mlir
@@ -0,0 +1,20 @@
+// RUN: iree-opt -split-input-file -canonicalize %s | iree-opt -split-input-file | IreeFileCheck %s
+
+hal.constant_pool @pool attributes {
+ buffer_constraints = #hal.buffer_constraints<max_allocation_size = 1073741824,
+ min_buffer_offset_alignment = 32,
+ max_buffer_range = 134217728,
+ min_buffer_range_alignment = 4>
+ } {
+ hal.constant_pool.span @cst_span : tensor<3xi8> = @_storage1[#hal.byte_range<0, 3>] -> @pool_storage1_buffer[#hal.byte_range<0, 3>]
+ hal.constant_pool.splat @cst_splat = dense<1.000000e+00> : tensor<1xf32> -> @pool_splats[#hal.byte_range<0, 4>]
+}
+
+// CHECK-LABEL: func @pools_identified
+func @pools_identified() -> (tensor<2x3xf32>, tensor<3x2xf32>) {
+ // CHECK-NEXT: = hal.constant.subspan @pool_storage1_buffer[#hal.byte_range<0, 3>] : tensor<2x3xf32>
+ %cst0 = hal.constant_pool.load @pool::@cst_span : tensor<2x3xf32>
+ // CHECK-NEXT: = hal.constant.subspan @pool_splats[#hal.byte_range<0, 4>] : tensor<3x2xf32>
+ %cst1 = hal.constant_pool.load @pool::@cst_splat : tensor<3x2xf32>
+ return %cst0, %cst1 : tensor<2x3xf32>, tensor<3x2xf32>
+}
diff --git a/iree/compiler/Dialect/HAL/IR/test/constant_ops.mlir b/iree/compiler/Dialect/HAL/IR/test/constant_ops.mlir
new file mode 100644
index 0000000..6c26f70
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/IR/test/constant_ops.mlir
@@ -0,0 +1,86 @@
+// RUN: iree-opt -split-input-file %s | iree-opt -split-input-file | IreeFileCheck %s
+
+// CHECK-LABEL: hal.constant_pool @pool0
+hal.constant_pool @pool0 attributes {
+ buffer_constraints = #hal.buffer_constraints<max_allocation_size = 1073741824,
+ min_buffer_offset_alignment = 32,
+ max_buffer_range = 134217728,
+ min_buffer_range_alignment = 4>
+ } {
+ // CHECK-NEXT: hal.constant_pool.value @cst0 = dense<0.{{.+}}> : tensor<2x3xf32>
+ hal.constant_pool.value @cst0 = dense<0.0> : tensor<2x3xf32>
+ // CHECK-NEXT: hal.constant_pool.value @cst1 = dense<1.{{.+}}> : tensor<3x2xf32>
+ hal.constant_pool.value @cst1 = dense<1.0> : tensor<3x2xf32>
+}
+
+// CHECK-LABEL: func @pools_identified()
+func @pools_identified() -> (tensor<2x3xf32>, tensor<3x2xf32>) {
+ // CHECK-NEXT: = hal.constant_pool.load @pool0::@cst0 : tensor<2x3xf32>
+ %cst0 = hal.constant_pool.load @pool0::@cst0 : tensor<2x3xf32>
+ // CHECK-NEXT: = hal.constant_pool.load @pool0::@cst1 : tensor<3x2xf32>
+ %cst1 = hal.constant_pool.load @pool0::@cst1 : tensor<3x2xf32>
+ return %cst0, %cst1 : tensor<2x3xf32>, tensor<3x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: hal.constant_pool @storage_allocated
+hal.constant_pool @storage_allocated attributes {
+ buffer_constraints = #hal.buffer_constraints<max_allocation_size = 1073741824,
+ min_buffer_offset_alignment = 32,
+ max_buffer_range = 134217728,
+ min_buffer_range_alignment = 4>
+ } {
+ // CHECK-NEXT: hal.constant_pool.span @cst0 : tensor<2x3xf32> = @_storage[#hal.byte_range<0, 1024>]
+ hal.constant_pool.span @cst0 : tensor<2x3xf32> = @_storage[#hal.byte_range<0, 1024>]
+ // CHECK-NEXT: hal.constant_pool.span @cst1 : tensor<3x2xf32> = @_storage[#hal.byte_range<1024, 1024>]
+ hal.constant_pool.span @cst1 : tensor<3x2xf32> = @_storage[#hal.byte_range<1024, 1024>]
+ // CHECK-NEXT: hal.constant_pool.splat @cst2 = dense<1.000000e+00> : tensor<1xf32>
+ hal.constant_pool.splat @cst2 = dense<1.000000e+00> : tensor<1xf32>
+ // CHECK-NEXT: hal.constant_storage @_storage = dense<1> : vector<768xi8>
+ hal.constant_storage @_storage = dense<1> : vector<768xi8>
+}
+
+// -----
+
+// CHECK-LABEL: hal.constant_pool @pool
+// CHECK-SAME: buffer_constraints = #hal.buffer_constraints
+hal.constant_pool @pool attributes {
+ buffer_constraints = #hal.buffer_constraints<max_allocation_size = 1073741824,
+ min_buffer_offset_alignment = 32,
+ max_buffer_range = 134217728,
+ min_buffer_range_alignment = 4>
+ } {
+ // CHECK-NEXT: hal.constant_pool.span @cst0 : tensor<4xf32> = @_storage0[#hal.byte_range<0, 16>] -> @pool_storage0_buffer[#hal.byte_range<0, 16>]
+ hal.constant_pool.span @cst0 : tensor<4xf32> = @_storage0[#hal.byte_range<0, 16>] -> @pool_storage0_buffer[#hal.byte_range<0, 16>]
+ // CHECK-NEXT: hal.constant_pool.span @cst1 : tensor<3xi8> = @_storage1[#hal.byte_range<0, 3>] -> @pool_storage1_buffer[#hal.byte_range<0, 3>]
+ hal.constant_pool.span @cst1 : tensor<3xi8> = @_storage1[#hal.byte_range<0, 3>] -> @pool_storage1_buffer[#hal.byte_range<0, 3>]
+ // CHECK-NEXT: hal.constant_pool.splat @cst2 = dense<1.000000e+00> : tensor<1xf32> -> @pool_splats[#hal.byte_range<0, 4>]
+ hal.constant_pool.splat @cst2 = dense<1.000000e+00> : tensor<1xf32> -> @pool_splats[#hal.byte_range<0, 4>]
+ // CHECK-NEXT: hal.constant_pool.splat @cst3 = dense<1234567890> : tensor<8xi32> -> @pool_splats[#hal.byte_range<32, 32>]
+ hal.constant_pool.splat @cst3 = dense<1234567890> : tensor<8xi32> -> @pool_splats[#hal.byte_range<32, 32>]
+ // CHECK-NEXT: hal.constant_storage @_storage0 = dense<[102, 102, 6, 64, -51, -52, 76, 64, -102, -103, -119, 64, -51, -52, -84, 64]> : vector<16xi8>
+ hal.constant_storage @_storage0 = dense<[102, 102, 6, 64, -51, -52, 76, 64, -102, -103, -119, 64, -51, -52, -84, 64]> : vector<16xi8>
+ // CHECK-NEXT: hal.constant_storage @_storage1 = dense<[6, 7, 8, 0]> : vector<4xi8>
+ hal.constant_storage @_storage1 = dense<[6, 7, 8, 0]> : vector<4xi8>
+}
+
+// CHECK: func @storage_lookup
+func @storage_lookup() {
+ // CHECK-NEXT: = hal.constant_storage.lookup @pool::@_storage1 : !iree.byte_buffer
+ %storage = hal.constant_storage.lookup @pool::@_storage1 : !iree.byte_buffer
+ return
+}
+
+// -----
+
+// CHECK: hal.variable @storage0_buffer0 : !hal.buffer
+hal.variable @storage0_buffer0 : !hal.buffer
+// CHECK-LABEL: func @runtime_buffer_subspan()
+func @runtime_buffer_subspan() {
+ // CHECK-NEXT: = hal.constant.subspan @storage0_buffer0[#hal.byte_range<0, 1024>] : tensor<4xf32>
+ %cst0 = hal.constant.subspan @storage0_buffer0[#hal.byte_range<0, 1024>] : tensor<4xf32>
+ // CHECK-NEXT: = hal.constant.subspan @storage0_buffer0[#hal.byte_range<1024, 2048>] : tensor<4xf32>
+ %cst1 = hal.constant.subspan @storage0_buffer0[#hal.byte_range<1024, 2048>] : tensor<4xf32>
+ return
+}
diff --git a/iree/compiler/Dialect/HAL/IR/test/variable_folding.mlir b/iree/compiler/Dialect/HAL/IR/test/variable_folding.mlir
index 487a7da..887006f 100644
--- a/iree/compiler/Dialect/HAL/IR/test/variable_folding.mlir
+++ b/iree/compiler/Dialect/HAL/IR/test/variable_folding.mlir
@@ -2,7 +2,7 @@
// RUN: iree-opt -split-input-file -canonicalize %s | iree-opt -split-input-file | IreeFileCheck %s
-// CHECK: hal.variable @v_initialized 4 : i32
+// CHECK: hal.variable @v_initialized = 4 : i32
hal.variable @v_initialized init(@initializer) : i32
func @initializer() -> i32 {
%0 = constant 4 : i32
diff --git a/iree/compiler/Dialect/HAL/IR/test/variable_ops.mlir b/iree/compiler/Dialect/HAL/IR/test/variable_ops.mlir
index 46a25e0..2eba9a5 100644
--- a/iree/compiler/Dialect/HAL/IR/test/variable_ops.mlir
+++ b/iree/compiler/Dialect/HAL/IR/test/variable_ops.mlir
@@ -9,8 +9,14 @@
// -----
-// CHECK: hal.variable @v_initialized_const 4 : i32
-hal.variable @v_initialized_const 4 : i32
+// CHECK: hal.variable @v_initialized_const0 = 4 : i32
+hal.variable @v_initialized_const0 = 4 : i32
+
+// CHECK: hal.variable @v_initialized_const1 = 40 : i32
+hal.variable @v_initialized_const1 : i32 = 40 : i32
+
+// CHECK: hal.variable @v_initialized_const2 : i32 = 40 : i64
+hal.variable @v_initialized_const2 : i32 = 40 : i64
// -----
diff --git a/iree/compiler/Dialect/HAL/Target/TargetBackend.cpp b/iree/compiler/Dialect/HAL/Target/TargetBackend.cpp
index 882c7aa..b2037e6 100644
--- a/iree/compiler/Dialect/HAL/Target/TargetBackend.cpp
+++ b/iree/compiler/Dialect/HAL/Target/TargetBackend.cpp
@@ -71,6 +71,26 @@
return false;
}
+// static
+BufferConstraintsAttr TargetBackend::makeDefaultBufferConstraints(
+ MLIRContext *context) {
+ // Picked to represent what we kind of want on CPU today.
+ uint64_t maxAllocationSize = 1 * 1024 * 1024 * 1024ull;
+ uint64_t minBufferOffsetAlignment = 16ull;
+ uint64_t maxBufferRange = 1 * 1024 * 1024 * 1024ull;
+ uint64_t minBufferRangeAlignment = 16ull;
+ Builder b(context);
+ return BufferConstraintsAttr::get(b.getIndexAttr(maxAllocationSize),
+ b.getIndexAttr(minBufferOffsetAlignment),
+ b.getIndexAttr(maxBufferRange),
+ b.getIndexAttr(minBufferRangeAlignment));
+}
+
+BufferConstraintsAttr TargetBackend::queryBufferConstraints(
+ MLIRContext *context) {
+ return makeDefaultBufferConstraints(context);
+}
+
void TargetBackend::declareTargetOps(IREE::Flow::ExecutableOp sourceOp,
IREE::HAL::ExecutableOp executableOp) {
OpBuilder targetBuilder(&executableOp.getBlock().back());
diff --git a/iree/compiler/Dialect/HAL/Target/TargetBackend.h b/iree/compiler/Dialect/HAL/Target/TargetBackend.h
index 1cd4b4f..4f03b80 100644
--- a/iree/compiler/Dialect/HAL/Target/TargetBackend.h
+++ b/iree/compiler/Dialect/HAL/Target/TargetBackend.h
@@ -116,6 +116,10 @@
// 'foo-10?' matches: 'foo-101', 'foo-102'
static bool matchPattern(StringRef value, StringRef pattern);
+ // Returns a generic host-like set of constraints.
+ static BufferConstraintsAttr makeDefaultBufferConstraints(
+ MLIRContext *context);
+
virtual ~TargetBackend() = default;
// Returns a name for the backend used to differentiate between other targets.
@@ -124,6 +128,21 @@
// call to matchPattern. For example, 'vulkan-v1.1' or 'vmla*'.
virtual std::string filter_pattern() const = 0;
+ // Register dependent dialects for the TargetBackend.
+ // Mirrors the method on mlir::Pass of the same name. A TargetBackend is
+ // expected to register the dialects it will create entities for (Operations,
+ // Types, Attributes), other than dialects that exist in the input. These are
+ // the dialects that will be used in |declareTargetOps| and
+ // |buildTranslationPassPipeline|.
+ // TODO(#1036): We might be able to get rid of this with dynamic pass
+ // registration.
+ virtual void getDependentDialects(DialectRegistry ®istry) const {}
+
+ // Queries for compile-time known buffer constraints.
+ // These should conservatively represent the min/max values even if the
+ // backend may support others at runtime.
+ virtual BufferConstraintsAttr queryBufferConstraints(MLIRContext *context);
+
// Creates an interface representing the bindings and push constants required
// to dispatch the executable. Interfaces used across backends and executables
// will be deduplicated to reduce code size and runtime overhead and being
@@ -152,16 +171,6 @@
virtual void declareTargetOps(IREE::Flow::ExecutableOp sourceOp,
IREE::HAL::ExecutableOp executableOp);
- // Register dependent dialects for the TargetBackend.
- // Mirrors the method on mlir::Pass of the same name. A TargetBackend is
- // expected to register the dialects it will create entities for (Operations,
- // Types, Attributes), other than dialects that exist in the input. These are
- // the dialects that will be used in |declareTargetOps| and
- // |buildTranslationPassPipeline|.
- // TODO(#1036): We might be able to get rid of this with dynamic pass
- // registration.
- virtual void getDependentDialects(DialectRegistry ®istry) const {}
-
// Captured state from the point at which a dispatch is to be recorded.
struct DispatchState {
// The original flow.dispatch op.
diff --git a/iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.cpp b/iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.cpp
index 752a6fd..ab3c3a4 100644
--- a/iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.cpp
@@ -41,7 +41,7 @@
namespace {
-bool AreInterfacesEquivalent(IREE::HAL::InterfaceOp lhs,
+bool areInterfacesEquivalent(IREE::HAL::InterfaceOp lhs,
IREE::HAL::InterfaceOp rhs) {
auto lhsBindings = lhs.getBlock().getOps<IREE::HAL::InterfaceBindingOp>();
auto rhsBindings = rhs.getBlock().getOps<IREE::HAL::InterfaceBindingOp>();
@@ -125,7 +125,8 @@
// }
OpBuilder builder = OpBuilder::atBlockBegin(moduleOp.getBody());
- auto executableOps = moduleOp.getOps<IREE::HAL::ExecutableOp>();
+ auto executableOps =
+ llvm::to_vector<8>(moduleOp.getOps<IREE::HAL::ExecutableOp>());
// Create our new "linked" hal.executable.
auto linkedExecutableOp = builder.create<IREE::HAL::ExecutableOp>(
@@ -143,9 +144,14 @@
auto linkedVmModuleOp =
builder.create<IREE::VM::ModuleOp>(moduleOp.getLoc(), "linked_module");
- int executablesLinked = 0;
llvm::SmallVector<IREE::HAL::InterfaceOp, 4> interfaceOps;
int nextEntryPointOrdinal = 0;
+ DenseMap<StringRef, Operation *> symbolMap;
+ DenseMap<Attribute, Attribute> entryPointRefReplacements;
+ auto linkedExecutableBuilder =
+ OpBuilder::atBlockBegin(linkedExecutableOp.getBody());
+ auto linkedTargetBuilder =
+ OpBuilder::atBlockBegin(linkedTargetOp.getBody());
for (auto executableOp : executableOps) {
auto targetOps = llvm::to_vector<4>(
executableOp.getOps<IREE::HAL::ExecutableTargetOp>());
@@ -157,103 +163,112 @@
IREE::HAL::InterfaceOp interfaceOpForExecutable;
for (auto interfaceOp : interfaceOps) {
- if (AreInterfacesEquivalent(interfaceOp,
+ if (areInterfacesEquivalent(interfaceOp,
executableOp.getFirstInterfaceOp())) {
interfaceOpForExecutable = interfaceOp;
+ break;
}
}
if (!interfaceOpForExecutable) {
- builder.setInsertionPoint(linkedTargetOp);
- interfaceOpForExecutable = dyn_cast<IREE::HAL::InterfaceOp>(
- builder.clone(*executableOp.getFirstInterfaceOp()));
+ interfaceOpForExecutable =
+ dyn_cast<IREE::HAL::InterfaceOp>(linkedExecutableBuilder.clone(
+ *executableOp.getFirstInterfaceOp()));
interfaceOpForExecutable.setName(
llvm::formatv("legacy_io_{0}", interfaceOps.size()).str());
interfaceOps.push_back(interfaceOpForExecutable);
}
- // Clone entry point ops, remapping ordinals and updating symbol refs.
- builder.setInsertionPoint(linkedModuleOp);
+ // Clone entry point ops and queue remapping ordinals and updating
+ // symbol refs.
for (auto entryPointOp :
targetOp.getOps<IREE::HAL::ExecutableEntryPointOp>()) {
auto newEntryPointOp =
- builder.create<IREE::HAL::ExecutableEntryPointOp>(
+ linkedTargetBuilder.create<IREE::HAL::ExecutableEntryPointOp>(
entryPointOp.getLoc(), entryPointOp.sym_nameAttr(),
builder.getI32IntegerAttr(nextEntryPointOrdinal++),
builder.getSymbolRefAttr(interfaceOpForExecutable.getName()),
entryPointOp.signatureAttr());
- // Update references to @executable::@target::@entry symbols.
- // SymbolTable::replaceAllSymbolUses only looks at root symbols,
- // which we can't blindly replace (other targets will map to other
- // linked executables).
- auto executableUses =
- SymbolTable::getSymbolUses(executableOp, moduleOp);
- if (!executableUses.hasValue()) continue;
- for (auto executableUse : executableUses.getValue()) {
- auto executableUser = executableUse.getUser();
- // Only process symbols for this @target::@entry.
- auto nestedRefs =
- executableUse.getSymbolRef().getNestedReferences();
- if (nestedRefs.size() != 2 ||
- nestedRefs[0].getValue() != targetOp.sym_name() ||
- nestedRefs[1].getValue() != entryPointOp.sym_name()) {
- continue;
- }
- if (auto dispatchOp =
- dyn_cast<IREE::HAL::CommandBufferDispatchSymbolOp>(
- executableUser)) {
- // New nested reference to the linked exe/target/entry.
- StringRef newExecutableOpSymName =
- linkedExecutableOp
- .getAttrOfType<StringAttr>(
- SymbolTable::getSymbolAttrName())
- .getValue();
- auto newSymbolRefAttr = builder.getSymbolRefAttr(
- newExecutableOpSymName,
- {builder.getSymbolRefAttr(linkedTargetOp),
- builder.getSymbolRefAttr(newEntryPointOp)});
- dispatchOp.setAttr("entry_point", newSymbolRefAttr);
- }
- }
+ // Add to replacement table for fixing up dispatch calls referencing
+ // this entry point.
+ auto oldSymbolRefAttr = builder.getSymbolRefAttr(
+ executableOp.getName(), {builder.getSymbolRefAttr(targetOp),
+ builder.getSymbolRefAttr(entryPointOp)});
+ auto newSymbolRefAttr = builder.getSymbolRefAttr(
+ linkedExecutableOp.getName(),
+ {builder.getSymbolRefAttr(linkedTargetOp),
+ builder.getSymbolRefAttr(newEntryPointOp)});
+ entryPointRefReplacements[oldSymbolRefAttr] = newSymbolRefAttr;
}
- // Clone vm.module ops, including their contents.
+ // Merge the existing vm.module op into the new linked vm.module op.
auto vmModuleOps =
targetOp.getInnerModule().getOps<IREE::VM::ModuleOp>();
if (vmModuleOps.empty()) {
return targetOp.getInnerModule().emitError()
<< "target's outer module does not contain a vm.module op";
}
- auto vmModuleOp = *vmModuleOps.begin();
- builder.setInsertionPoint(&linkedVmModuleOp.getBlock().back());
- // Use a SymbolTable to guard against inserting duplicate symbols.
- SymbolTable symbolTable(linkedVmModuleOp.getOperation());
+ mergeModuleInto(*vmModuleOps.begin(), linkedVmModuleOp, symbolMap);
- for (auto &op : vmModuleOp.getBody()->getOperations()) {
- if (auto terminatorOp = dyn_cast<IREE::VM::ModuleTerminatorOp>(op)) {
- continue;
- }
- if (op.hasTrait<SymbolOpInterface::Trait>() &&
- symbolTable.lookup(dyn_cast<SymbolOpInterface>(op).getName())) {
- continue;
- }
- builder.clone(op);
- }
-
- // Now that we're done cloning its ops, delete the original target op.
targetOp.erase();
+ }
- executablesLinked++;
+ if (executableOp.getOps<IREE::HAL::ExecutableTargetOp>().empty()) {
+ executableOp.erase();
}
}
- if (executablesLinked == 0) {
+ // Update references to @executable::@target::@entry symbols.
+ replaceEntryPointUses(moduleOp, entryPointRefReplacements);
+
+ // Remove if we didn't add anything.
+ if (linkedTargetOp.getOps<IREE::HAL::ExecutableEntryPointOp>().empty()) {
+ linkedTargetOp.erase();
linkedExecutableOp.erase();
}
return success();
}
+ // Destructively merges |sourceModuleOp| into |targetModuleOp|.
+ // |targetSymbolTable| is updated with the new symbols.
+ void mergeModuleInto(IREE::VM::ModuleOp sourceModuleOp,
+ IREE::VM::ModuleOp targetModuleOp,
+ DenseMap<StringRef, Operation *> &targetSymbolMap) {
+ auto allOps = llvm::to_vector<8>(llvm::map_range(
+ sourceModuleOp.getBlock(), [&](Operation &op) { return &op; }));
+ for (auto &op : allOps) {
+ if (op->isKnownTerminator()) continue;
+ if (auto symbolInterface = dyn_cast<SymbolOpInterface>(op)) {
+ if (targetSymbolMap.count(symbolInterface.getName())) {
+ // TODO(scotttodd): compare ops to ensure we aren't copying different
+ // things with the same name.
+ continue;
+ }
+ targetSymbolMap[symbolInterface.getName()] = op;
+ }
+ op->moveBefore(&targetModuleOp.getBlock().back());
+ }
+
+ // Now that we're done cloning its ops, delete the original target op.
+ sourceModuleOp.erase();
+ }
+
+ // Replaces each usage of an entry point with its original symbol name with a
+ // new symbol name.
+ void replaceEntryPointUses(
+ mlir::ModuleOp moduleOp,
+ const DenseMap<Attribute, Attribute> &replacements) {
+ for (auto funcOp : moduleOp.getOps<mlir::FuncOp>()) {
+ funcOp.walk([&](IREE::HAL::CommandBufferDispatchSymbolOp dispatchOp) {
+ auto it = replacements.find(dispatchOp.entry_point());
+ if (it != replacements.end()) {
+ dispatchOp.entry_pointAttr(it->second.cast<SymbolRefAttr>());
+ }
+ });
+ }
+ }
+
LogicalResult serializeExecutable(IREE::HAL::ExecutableTargetOp targetOp,
OpBuilder &executableBuilder) override {
// Serialize the VM module to bytes.
diff --git a/iree/compiler/Dialect/HAL/Target/VMLA/test/i1_types.mlir b/iree/compiler/Dialect/HAL/Target/VMLA/test/i1_types.mlir
index f9be39e..429db63 100644
--- a/iree/compiler/Dialect/HAL/Target/VMLA/test/i1_types.mlir
+++ b/iree/compiler/Dialect/HAL/Target/VMLA/test/i1_types.mlir
@@ -3,7 +3,7 @@
// CHECK-LABEL: @i1_op_usage(%arg0: !hal.buffer) -> !hal.buffer
func @i1_op_usage(%arg0: tensor<4xi1>) -> tensor<4xi1> {
%c4 = constant 4 : index
- // CHECK: hal.allocator.allocate.const {{.+}} dense<[1, 0, 1, 0]> : tensor<4xi8>
+ // CHECK: %0 = iree.byte_buffer.constant : !iree.byte_buffer = dense<[1, 0, 1, 0]> : tensor<4xi8>
%cst = constant dense<[true, false, true, false]> : tensor<4xi1>
%0 = flow.ex.stream.fragment(%arg1 = %c4 : index, %arg2 = %arg0 : tensor<4xi1>, %arg3 = %cst : tensor<4xi1>) -> tensor<4xi1> {
%1 = flow.dispatch @i1_op_usage_ex_dispatch_0::@i1_op_usage_ex_dispatch_0[%arg1 : index](%arg2, %arg3) : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1>
diff --git a/iree/compiler/Dialect/HAL/Target/VMLA/test/smoketest.mlir b/iree/compiler/Dialect/HAL/Target/VMLA/test/smoketest.mlir
index 939f06a..b4fb233 100644
--- a/iree/compiler/Dialect/HAL/Target/VMLA/test/smoketest.mlir
+++ b/iree/compiler/Dialect/HAL/Target/VMLA/test/smoketest.mlir
@@ -106,7 +106,7 @@
// CHECK-NEXT: hal.executable.entry_point @reduction_ex_dispatch_0 attributes {interface = @legacy_io_0, ordinal = 0 : i32, signature = (tensor<4x8xf32>) -> tensor<4xf32>}
// CHECK-NEXT: module {
// CHECK-NEXT: vm.module @linked_module {
-// CHECK-NEXT: vm.rodata @reduction_ex_dispatch_0_const_0 dense<0.000000e+00> : tensor<f32>
+// CHECK-NEXT: vm.rodata @reduction_ex_dispatch_0_const dense<0.000000e+00> : tensor<1xf32>
// CHECK-NEXT: vm.func @reduction_ex_dispatch_0(%arg0: !vm.ref<!vmla.interface>, %arg1: i32, %arg2: i32, %arg3: i32) {
// CHECK-NEXT: %zero = vm.const.i32.zero : i32
// CHECK-NEXT: %c128 = vm.const.i32 128 : i32
@@ -114,14 +114,16 @@
// CHECK-NEXT: %c4 = vm.const.i32 4 : i32
// CHECK-NEXT: %c8 = vm.const.i32 8 : i32
// CHECK-NEXT: %c1 = vm.const.i32 1 : i32
-// CHECK-NEXT: %reduction_ex_dispatch_0_const_0 = vm.const.ref.rodata @reduction_ex_dispatch_0_const_0 : !vm.ref<!iree.byte_buffer>
-// CHECK-NEXT: %ref = vm.call @vmla.buffer.const(%reduction_ex_dispatch_0_const_0) : (!vm.ref<!iree.byte_buffer>) -> !vm.ref<!vmla.buffer>
-// CHECK-NEXT: %ref_0 = vm.call @vmla.interface.binding(%arg0, %zero, %zero) : (!vm.ref<!vmla.interface>, i32, i32) -> !vm.ref<!vmla.buffer>
-// CHECK-NEXT: %ref_1 = vm.call @vmla.buffer.view(%ref_0, %zero, %c128) : (!vm.ref<!vmla.buffer>, i32, i32) -> !vm.ref<!vmla.buffer>
-// CHECK-NEXT: %ref_2 = vm.call @vmla.buffer.alloc(%c16) : (i32) -> !vm.ref<!vmla.buffer>
-// CHECK-NEXT: vm.call.variadic @vmla.reduce.sum.f32(%ref_1, [%c4, %c8], %ref, [], %c1, %ref_2, [%c4]) : (!vm.ref<!vmla.buffer>, i32 ..., !vm.ref<!vmla.buffer>, i32 ..., i32, !vm.ref<!vmla.buffer>, i32 ...)
-// CHECK-NEXT: %ref_3 = vm.call @vmla.interface.binding(%arg0, %zero, %c1) : (!vm.ref<!vmla.interface>, i32, i32) -> !vm.ref<!vmla.buffer>
-// CHECK-NEXT: vm.call @vmla.buffer.copy(%ref_2, %zero, %ref_3, %zero, %c16) : (!vm.ref<!vmla.buffer>, i32, !vm.ref<!vmla.buffer>, i32, i32) -> ()
+// CHECK-NEXT: %reduction_ex_dispatch_0_const = vm.const.ref.rodata @reduction_ex_dispatch_0_const : !vm.ref<!iree.byte_buffer>
+// CHECK-NEXT: %ref = vm.call @vmla.buffer.const(%reduction_ex_dispatch_0_const) : (!vm.ref<!iree.byte_buffer>) -> !vm.ref<!vmla.buffer>
+// CHECK-NEXT: %ref_0 = vm.call @vmla.buffer.alloc(%c4) : (i32) -> !vm.ref<!vmla.buffer>
+// CHECK-NEXT: vm.call @vmla.buffer.fill(%ref, %ref_0) : (!vm.ref<!vmla.buffer>, !vm.ref<!vmla.buffer>) -> ()
+// CHECK-NEXT: %ref_1 = vm.call @vmla.interface.binding(%arg0, %zero, %zero) : (!vm.ref<!vmla.interface>, i32, i32) -> !vm.ref<!vmla.buffer>
+// CHECK-NEXT: %ref_2 = vm.call @vmla.buffer.view(%ref_1, %zero, %c128) : (!vm.ref<!vmla.buffer>, i32, i32) -> !vm.ref<!vmla.buffer>
+// CHECK-NEXT: %ref_3 = vm.call @vmla.buffer.alloc(%c16) : (i32) -> !vm.ref<!vmla.buffer>
+// CHECK-NEXT: vm.call.variadic @vmla.reduce.sum.f32(%ref_2, [%c4, %c8], %ref_0, [], %c1, %ref_3, [%c4]) : (!vm.ref<!vmla.buffer>, i32 ..., !vm.ref<!vmla.buffer>, i32 ..., i32, !vm.ref<!vmla.buffer>, i32 ...)
+// CHECK-NEXT: %ref_4 = vm.call @vmla.interface.binding(%arg0, %zero, %c1) : (!vm.ref<!vmla.interface>, i32, i32) -> !vm.ref<!vmla.buffer>
+// CHECK-NEXT: vm.call @vmla.buffer.copy(%ref_3, %zero, %ref_4, %zero, %c16) : (!vm.ref<!vmla.buffer>, i32, !vm.ref<!vmla.buffer>, i32, i32) -> ()
// CHECK-NEXT: vm.return
// CHECK-NEXT: }
// CHECK-NEXT: vm.export @reduction_ex_dispatch_0
@@ -130,4 +132,5 @@
// CHECK-NEXT: vm.import @vmla.buffer.alloc(%byte_length : i32) -> !vm.ref<!vmla.buffer>
// CHECK-NEXT: vm.import @vmla.buffer.view(%src : !vm.ref<!vmla.buffer>, %byte_offset : i32, %byte_length : i32) -> !vm.ref<!vmla.buffer>
// CHECK-NEXT: vm.import @vmla.buffer.copy(%src : !vm.ref<!vmla.buffer>, %src_byte_offset : i32, %dst : !vm.ref<!vmla.buffer>, %dst_byte_offset : i32, %byte_length : i32)
+// CHECK-NEXT: vm.import @vmla.buffer.fill(%value : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
// CHECK-NEXT: vm.import @vmla.reduce.sum.f32(%src : !vm.ref<!vmla.buffer>, %src_shape : i32 ..., %init : !vm.ref<!vmla.buffer>, %init_shape : i32 ..., %dimension : i32, %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...)
diff --git a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
index 23bde9c..91de450c 100644
--- a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
@@ -144,6 +144,22 @@
// clang-format on
}
+ BufferConstraintsAttr queryBufferConstraints(MLIRContext *context) override {
+ // Picked from here to start:
+ // https://vulkan.gpuinfo.org/displaydevicelimit.php?name=minStorageBufferOffsetAlignment&platform=android
+ // https://vulkan.gpuinfo.org/displaydevicelimit.php?name=maxStorageBufferRange&platform=android
+ // We should instead be querying the vulkan environment attributes.
+ uint64_t maxAllocationSize = 1 * 1024 * 1024 * 1024ull;
+ uint64_t minBufferOffsetAlignment = 256ull;
+ uint64_t maxBufferRange = 128 * 1024 * 1024ull;
+ uint64_t minBufferRangeAlignment = 16ull;
+ Builder b(context);
+ return BufferConstraintsAttr::get(b.getIndexAttr(maxAllocationSize),
+ b.getIndexAttr(minBufferOffsetAlignment),
+ b.getIndexAttr(maxBufferRange),
+ b.getIndexAttr(minBufferRangeAlignment));
+ }
+
void declareTargetOps(IREE::Flow::ExecutableOp sourceOp,
IREE::HAL::ExecutableOp executableOp) override {
spirv::TargetEnvAttr spvTargetEnv =
diff --git a/iree/compiler/Dialect/HAL/Transforms/BUILD b/iree/compiler/Dialect/HAL/Transforms/BUILD
index c507f0c..d240022 100644
--- a/iree/compiler/Dialect/HAL/Transforms/BUILD
+++ b/iree/compiler/Dialect/HAL/Transforms/BUILD
@@ -21,11 +21,15 @@
cc_library(
name = "Transforms",
srcs = [
+ "ConvertToHAL.cpp",
+ "IdentifyConstantPools.cpp",
"InlineDeviceSwitches.cpp",
"LinkExecutables.cpp",
+ "MaterializeConstantPoolBuffers.cpp",
"MaterializeInterfaces.cpp",
"MaterializeResourceCaches.cpp",
"MemoizeDeviceQueries.cpp",
+ "PackConstantPoolStorage.cpp",
"Passes.cpp",
"PublicAbiGeneration.cpp",
"ResolveEntryPointOrdinals.cpp",
@@ -38,12 +42,19 @@
deps = [
"//iree/base:signature_mangle",
"//iree/compiler/Dialect/Flow/IR",
+ "//iree/compiler/Dialect/HAL/Conversion",
"//iree/compiler/Dialect/HAL/Conversion/FlowToHAL",
+ "//iree/compiler/Dialect/HAL/Conversion/HALToHAL",
+ "//iree/compiler/Dialect/HAL/Conversion/IREEToHAL",
+ "//iree/compiler/Dialect/HAL/Conversion/StandardToHAL",
"//iree/compiler/Dialect/HAL/IR",
"//iree/compiler/Dialect/HAL/IR:HALDialect",
"//iree/compiler/Dialect/HAL/Target",
"//iree/compiler/Dialect/HAL/Utils",
+ "//iree/compiler/Dialect/IREE/Conversion:PreserveCompilerHints",
"//iree/compiler/Dialect/IREE/IR",
+ "//iree/compiler/Dialect/IREE/Transforms",
+ "//iree/compiler/Dialect/Shape/IR",
"//iree/compiler/Dialect/Shape/Transforms",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:Support",
diff --git a/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt b/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt
index 91087a6..ed48d8a 100644
--- a/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt
+++ b/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt
@@ -20,11 +20,15 @@
HDRS
"Passes.h"
SRCS
+ "ConvertToHAL.cpp"
+ "IdentifyConstantPools.cpp"
"InlineDeviceSwitches.cpp"
"LinkExecutables.cpp"
+ "MaterializeConstantPoolBuffers.cpp"
"MaterializeInterfaces.cpp"
"MaterializeResourceCaches.cpp"
"MemoizeDeviceQueries.cpp"
+ "PackConstantPoolStorage.cpp"
"Passes.cpp"
"PublicAbiGeneration.cpp"
"ResolveEntryPointOrdinals.cpp"
@@ -40,12 +44,19 @@
absl::strings
iree::base::signature_mangle
iree::compiler::Dialect::Flow::IR
+ iree::compiler::Dialect::HAL::Conversion
iree::compiler::Dialect::HAL::Conversion::FlowToHAL
+ iree::compiler::Dialect::HAL::Conversion::HALToHAL
+ iree::compiler::Dialect::HAL::Conversion::IREEToHAL
+ iree::compiler::Dialect::HAL::Conversion::StandardToHAL
iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::HAL::IR::HALDialect
iree::compiler::Dialect::HAL::Target
iree::compiler::Dialect::HAL::Utils
+ iree::compiler::Dialect::IREE::Conversion::PreserveCompilerHints
iree::compiler::Dialect::IREE::IR
+ iree::compiler::Dialect::IREE::Transforms
+ iree::compiler::Dialect::Shape::IR
iree::compiler::Dialect::Shape::Transforms
PUBLIC
)
diff --git a/iree/compiler/Dialect/HAL/Transforms/ConvertToHAL.cpp b/iree/compiler/Dialect/HAL/Transforms/ConvertToHAL.cpp
new file mode 100644
index 0000000..340ff06
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Transforms/ConvertToHAL.cpp
@@ -0,0 +1,122 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
+#include "iree/compiler/Dialect/HAL/Conversion/ConversionDialectInterface.h"
+#include "iree/compiler/Dialect/HAL/Conversion/ConversionTarget.h"
+#include "iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertFlowToHAL.h"
+#include "iree/compiler/Dialect/HAL/Conversion/HALToHAL/ConvertHALToHAL.h"
+#include "iree/compiler/Dialect/HAL/Conversion/IREEToHAL/ConvertIREEToHAL.h"
+#include "iree/compiler/Dialect/HAL/Conversion/StandardToHAL/ConvertStandardToHAL.h"
+#include "iree/compiler/Dialect/HAL/Conversion/TypeConverter.h"
+#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
+#include "iree/compiler/Dialect/IREE/Conversion/PreserveCompilerHints.h"
+#include "iree/compiler/Dialect/IREE/IR/IREEDialect.h"
+#include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
+#include "iree/compiler/Dialect/IREE/IR/IREETypes.h"
+#include "iree/compiler/Dialect/IREE/Transforms/Passes.h"
+#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace HAL {
+namespace {
+
+// A pass converting the IREE flow dialect into the IREE HAL dialect.
+class ConvertToHALPass
+ : public PassWrapper<ConvertToHALPass, OperationPass<ModuleOp>> {
+ public:
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<IREEDialect>();
+ registry.insert<HALDialect>();
+ registry.insert<StandardOpsDialect>();
+ }
+
+ void runOnOperation() override {
+ auto *context = &getContext();
+
+ // Gather all interfaces from registered dialects.
+ // These will perform the tensor->buffer mapping for their ops.
+ SmallVector<const HALConversionDialectInterface *, 4> conversionInterfaces;
+ for (auto *dialect : context->getLoadedDialects()) {
+ if (auto *conversionInterface =
+ dialect
+ ->getRegisteredInterface<HALConversionDialectInterface>()) {
+ conversionInterfaces.emplace_back(conversionInterface);
+ }
+ }
+
+ HALTypeConverter typeConverter(conversionInterfaces);
+ HALConversionTarget conversionTarget(context, typeConverter);
+
+ OwningRewritePatternList patterns;
+
+ setupIREEToHALLegality(context, conversionTarget);
+ populateIREEToHALPatterns(context, patterns);
+
+ setupCompilerHintsLegality(context, conversionTarget, typeConverter);
+ populatePreserveCompilerHintsPatterns(context, patterns);
+
+ setupStandardToHALLegality(context, conversionTarget, typeConverter);
+ populateStandardToHALPatterns(context, patterns, typeConverter);
+
+ setupFlowToHALLegality(context, conversionTarget, typeConverter);
+ populateFlowToHALPatterns(context, patterns, typeConverter);
+
+ setupHALToHALLegality(context, conversionTarget, typeConverter);
+ populateHALToHALPatterns(context, patterns, typeConverter);
+
+ // Gather all HAL dialect conversion patterns from custom dialects.
+ // These will perform the tensor->buffer mapping for their ops.
+ for (auto *conversionInterface : conversionInterfaces) {
+ conversionInterface->setupConversionTarget(conversionTarget, patterns,
+ typeConverter);
+ }
+
+ // NOTE: we allow ops that we don't know about to allow custom dialects
+ // that don't need anything HAL-specific to pass through. This is handled by
+ // the fallback type legality support of the
+ if (failed(applyPartialConversion(getOperation(), conversionTarget,
+ patterns))) {
+ return signalPassFailure();
+ }
+ }
+};
+
+} // namespace
+
+std::unique_ptr<OperationPass<ModuleOp>> createConvertToHALPass() {
+ return std::make_unique<ConvertToHALPass>(); // NOLINT
+}
+
+static PassRegistration<ConvertToHALPass> pass(
+ "iree-convert-to-hal",
+ "Convert input flow/std/etc dialects to the IREE HAL dialect.");
+
+} // namespace HAL
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Dialect/HAL/Transforms/IdentifyConstantPools.cpp b/iree/compiler/Dialect/HAL/Transforms/IdentifyConstantPools.cpp
new file mode 100644
index 0000000..cb31f0d
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Transforms/IdentifyConstantPools.cpp
@@ -0,0 +1,318 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <utility>
+
+#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "iree/compiler/Dialect/HAL/Target/TargetBackend.h"
+#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
+#include "iree/compiler/Dialect/HAL/Transforms/Passes.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace HAL {
+
+class IdentifyConstantPoolsPass
+ : public PassWrapper<IdentifyConstantPoolsPass, OperationPass<ModuleOp>> {
+ public:
+ IdentifyConstantPoolsPass() : targetOptions_(getTargetOptionsFromFlags()) {}
+ explicit IdentifyConstantPoolsPass(TargetOptions targetOptions)
+ : targetOptions_(targetOptions) {}
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<mlir::StandardOpsDialect>();
+ registry.insert<IREE::Flow::FlowDialect>();
+ registry.insert<IREE::HAL::HALDialect>();
+ }
+
+ void runOnOperation() override {
+ auto moduleOp = getOperation();
+
+ // Gather constant variables. We assume that prior passes/pipelines have
+ // hoisted anything worth pooling to flow.variables at the module scope.
+ // We expect that immutable variables have already been de-duped and that
+ // mutable variables that remain may have identical initializers.
+ SmallVector<IREE::Flow::VariableOp, 16> mutableOps;
+ SmallVector<IREE::Flow::VariableOp, 16> immutableOps;
+ for (auto variableOp : moduleOp.getOps<IREE::Flow::VariableOp>()) {
+ if (!variableOp.initial_value().hasValue()) continue;
+ auto variableType = variableOp.type().dyn_cast<RankedTensorType>();
+ if (!variableType) continue;
+ if (variableOp.is_mutable()) {
+ mutableOps.push_back(variableOp);
+ } else {
+ immutableOps.push_back(variableOp);
+ }
+ }
+ if (mutableOps.empty() && immutableOps.empty()) {
+ return;
+ }
+
+ // Derive buffer constraints based on target backends.
+ auto bufferConstraints = computeConservativeBufferConstraints(
+ targetOptions_, moduleOp.getContext());
+ if (!bufferConstraints) {
+ moduleOp.emitWarning() << "no target backends provided buffer "
+ "constraints; falling back to host default";
+ bufferConstraints =
+ TargetBackend::makeDefaultBufferConstraints(moduleOp.getContext());
+ }
+
+ SymbolTable moduleSymbolTable(moduleOp);
+ auto moduleBuilder = OpBuilder::atBlockBegin(moduleOp.getBody());
+ auto variableUsages = gatherVariableUsages(moduleOp);
+
+ // Process the mutable ops where each constant is only used as an
+ // initializer. The lifetime of these is short as we only use them to
+ // populate the initial variable buffer contents.
+ makeConstantPool("_const_pool_init", mutableOps, bufferConstraints,
+ variableUsages, moduleOp, moduleSymbolTable,
+ moduleBuilder);
+
+ // Process the immutable ops where the same buffer will be used for the
+ // lifetime of the module.
+ makeConstantPool("_const_pool", immutableOps, bufferConstraints,
+ variableUsages, moduleOp, moduleSymbolTable,
+ moduleBuilder);
+
+ // NOTE: pools now contain the values but they are in an undefined order.
+ // We should have following passes that reorder the values to cluster them
+ // by usage time locality so that there's a better chance of them landing
+ // in the same runtime buffers and prefetched mapped storage pages.
+ }
+
+ private:
+ enum VariableUsage {
+ kAddress = 1 << 0,
+ kLoad = 1 << 1,
+ };
+
+ // Gathers information about the usage of all variables in the module.
+ DenseMap<StringRef, VariableUsage> gatherVariableUsages(
+ mlir::ModuleOp moduleOp) {
+ DenseMap<StringRef, VariableUsage> uses;
+ for (auto funcOp : moduleOp.getOps<mlir::FuncOp>()) {
+ funcOp.walk([&](Operation *op) {
+ if (auto addressOp = dyn_cast<IREE::Flow::VariableAddressOp>(op)) {
+ auto it = uses.find(addressOp.variable());
+ if (it == uses.end()) {
+ uses[addressOp.variable()] = VariableUsage::kAddress;
+ } else {
+ uses[addressOp.variable()] = static_cast<VariableUsage>(
+ it->second | VariableUsage::kAddress);
+ }
+ } else if (auto loadOp = dyn_cast<IREE::Flow::VariableLoadOp>(op)) {
+ auto it = uses.find(loadOp.variable());
+ if (it == uses.end()) {
+ uses[loadOp.variable()] = VariableUsage::kLoad;
+ } else {
+ uses[loadOp.variable()] =
+ static_cast<VariableUsage>(it->second | VariableUsage::kLoad);
+ }
+ }
+ });
+ }
+ return uses;
+ }
+
+ // Tries to find the min/max constraints on buffers across all target
+ // backends. This should really be done per pool based on the usage of the
+ // constants (if pool 0 is used by device A and pool 1 is used by device B
+ // then they should not need to have matching constraints).
+ BufferConstraintsAttr computeConservativeBufferConstraints(
+ const TargetOptions &targetOptions, MLIRContext *context) {
+ auto targetBackends = matchTargetBackends(targetOptions.targets);
+ BufferConstraintsAttr attr = {};
+ for (auto &targetBackend : targetBackends) {
+ if (attr) {
+ attr = intersectBufferConstraints(
+ attr, targetBackend->queryBufferConstraints(context));
+ } else {
+ attr = targetBackend->queryBufferConstraints(context);
+ }
+ }
+ return attr;
+ }
+
+ // Makes a new hal.constant_pool containing the values of the given
+ // variable ops. The variables will be erased and all variable loads will be
+ // replaced with constant loads. Returns the constant pool, if it was created.
+ Optional<ConstantPoolOp> makeConstantPool(
+ StringRef poolName, ArrayRef<IREE::Flow::VariableOp> variableOps,
+ BufferConstraintsAttr bufferConstraints,
+ DenseMap<StringRef, VariableUsage> &variableUsages,
+ mlir::ModuleOp moduleOp, SymbolTable &moduleSymbolTable,
+ OpBuilder &moduleBuilder) {
+ // Create the pool to be filled with constant values.
+ auto poolOp = OpBuilder(moduleBuilder.getContext())
+ .create<ConstantPoolOp>(moduleBuilder.getUnknownLoc(),
+ poolName, bufferConstraints);
+ moduleSymbolTable.insert(poolOp, moduleBuilder.getInsertionPoint());
+ SymbolTable::setSymbolVisibility(poolOp, SymbolTable::Visibility::Private);
+
+ // Replace each variable and keep track of the mapping from variable->value.
+ // This allows us to do one run through the module to replace usages as a
+ // post-processing step.
+ DenseMap<StringRef, IREE::HAL::ConstantPoolValueOp> constantReplacements;
+ SmallVector<Operation *, 4> deadOps;
+ auto poolBuilder = OpBuilder::atBlockBegin(poolOp.getBody());
+ for (auto variableOp : variableOps) {
+ // Grab the constant value from the variable that we'll be pooling.
+ auto value = variableOp.initial_value()
+ .getValue()
+ .dyn_cast_or_null<ElementsAttr>();
+ assert(value && "value precondition not met: must be elements attr");
+
+ // Create the constant in the pool.
+ auto valueOp = poolBuilder.create<ConstantPoolValueOp>(
+ variableOp.getLoc(), variableOp.getName(), value);
+ SymbolTable::setSymbolVisibility(valueOp,
+ SymbolTable::Visibility::Nested);
+
+ // If the variable is an immutable constant and used in compatible
+ // ways we can turn them into constant loads instead. These will avoid
+ // the additional runtime overhead of variable lifetime tracking and
+ // allow further optimizations at use sites where we know the values
+ // come from constant memory.
+ auto variableUsage = variableUsages[variableOp.getName()];
+ if (!variableOp.is_mutable() &&
+ (variableUsage & VariableUsage::kAddress)) {
+ variableOp.emitWarning() << "variable is used indirectly; currently "
+ "unsupported for constant pooling";
+ continue;
+ }
+
+ if (!variableOp.is_mutable()) {
+ // Replace all loads of the variable with loads of the constant.
+ // We do the actual replacement in a post-processing step so we don't
+ // modify the IR during this loop.
+ constantReplacements[variableOp.getName()] = valueOp;
+ deadOps.push_back(variableOp);
+ } else {
+ // Build an initializer function to populate the variable with the
+ // constant value on startup.
+ changeToVariableInitializerFunc(variableOp, valueOp);
+ }
+ }
+
+ // Remove the pool if it didn't end up with any constants.
+ if (poolOp.getBody()->front().isKnownTerminator()) {
+ poolOp.erase();
+ return None;
+ }
+
+ // Process pending usage replacements.
+ replaceConstantVariableLoads(moduleOp, constantReplacements);
+
+ // Cleanup any inlined variables we no longer need after replacement.
+ for (auto deadOp : deadOps) {
+ deadOp->erase();
+ }
+
+ return poolOp;
+ }
+
+ // Constructs a function that can be used as an initializer for a variable
+ // and inserts it by the variable op in the module.
+ FuncOp changeToVariableInitializerFunc(
+ IREE::Flow::VariableOp variableOp,
+ IREE::HAL::ConstantPoolValueOp valueOp) {
+ // Create the function and make the variable point to it for init.
+ OpBuilder moduleBuilder(variableOp.getContext());
+ moduleBuilder.setInsertionPointAfter(variableOp);
+ auto initializerName = (variableOp.getName() + "_initializer").str();
+ auto initializerFunc = moduleBuilder.create<FuncOp>(
+ variableOp.getLoc(), initializerName,
+ moduleBuilder.getFunctionType({}, {variableOp.type()}));
+ SymbolTable::setSymbolVisibility(initializerFunc,
+ SymbolTable::Visibility::Private);
+ variableOp.removeAttr("initial_value");
+ variableOp.setAttr("initializer",
+ moduleBuilder.getSymbolRefAttr(initializerFunc));
+
+ // Emit a constant load that will later on be turned into a runtime buffer
+ // reference.
+ auto funcBuilder = OpBuilder::atBlockBegin(initializerFunc.addEntryBlock());
+ auto constValue = funcBuilder.createOrFold<IREE::HAL::ConstantPoolLoadOp>(
+ variableOp.getLoc(), variableOp.type(),
+ funcBuilder.getSymbolRefAttr(
+ valueOp.getParentOfType<ConstantPoolOp>().getName(),
+ {funcBuilder.getSymbolRefAttr(valueOp)}));
+ funcBuilder.create<mlir::ReturnOp>(variableOp.getLoc(), constValue);
+
+ return initializerFunc;
+ }
+
+ // Replaces uses of each variable with references to the constant pool value.
+ void replaceConstantVariableLoads(
+ mlir::ModuleOp moduleOp,
+ DenseMap<StringRef, IREE::HAL::ConstantPoolValueOp> &replacements) {
+ SmallVector<
+ std::pair<IREE::Flow::VariableLoadOp, IREE::HAL::ConstantPoolValueOp>,
+ 8>
+ loadValues;
+ for (auto funcOp : moduleOp.getOps<mlir::FuncOp>()) {
+ funcOp.walk([&](IREE::Flow::VariableLoadOp loadOp) {
+ auto replacement = replacements.find(loadOp.variable());
+ if (replacement != replacements.end()) {
+ loadValues.push_back(std::make_pair(loadOp, replacement->second));
+ }
+ });
+ }
+ for (auto &loadValue : loadValues) {
+ replaceVariableLoadWithConstantLoad(loadValue.first, loadValue.second);
+ }
+ }
+
+ // Replaces a flow.variable.load with a hal.constant_pool.load of a pooled
+ // value.
+ void replaceVariableLoadWithConstantLoad(
+ IREE::Flow::VariableLoadOp variableLoadOp, ConstantPoolValueOp valueOp) {
+ OpBuilder builder(variableLoadOp);
+ auto loadOp = builder.create<ConstantPoolLoadOp>(
+ variableLoadOp.getLoc(), variableLoadOp.getType(),
+ builder.getSymbolRefAttr(
+ valueOp.getParentOfType<ConstantPoolOp>().getName(),
+ {builder.getSymbolRefAttr(valueOp)}));
+ variableLoadOp.replaceAllUsesWith(loadOp.result());
+ variableLoadOp.erase();
+ }
+
+ TargetOptions targetOptions_;
+};
+
+std::unique_ptr<OperationPass<ModuleOp>> createIdentifyConstantPoolsPass(
+ TargetOptions targetOptions) {
+ return std::make_unique<IdentifyConstantPoolsPass>(targetOptions);
+}
+
+static PassRegistration<IdentifyConstantPoolsPass> pass(
+ "iree-hal-identify-constant-pools",
+ "Combines constant variables into one or more hal.constant_pools based on "
+ "usage semantics.");
+
+} // namespace HAL
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Dialect/HAL/Transforms/InlineDeviceSwitches.cpp b/iree/compiler/Dialect/HAL/Transforms/InlineDeviceSwitches.cpp
index 1491178..52c2b27 100644
--- a/iree/compiler/Dialect/HAL/Transforms/InlineDeviceSwitches.cpp
+++ b/iree/compiler/Dialect/HAL/Transforms/InlineDeviceSwitches.cpp
@@ -73,6 +73,13 @@
// #hal.device.match.id<"pattern"> -> hal.device.match.id
return funcBuilder.createOrFold<IREE::HAL::DeviceMatchIDOp>(
loc, funcBuilder.getI1Type(), device, matchAttr.patternAttr());
+ } else if (auto matchAttr =
+ conditionAttr
+ .dyn_cast<IREE::HAL::DeviceMatchMemoryModelAttr>()) {
+ // #hal.device.match.memory_model<"Unified"> ->
+ // hal.device.match.memory_model
+ return funcBuilder.createOrFold<IREE::HAL::DeviceMatchMemoryModelOp>(
+ loc, funcBuilder.getI1Type(), device, matchAttr.memory_modelAttr());
}
llvm_unreachable("unhandled condition expression attribute");
return {};
diff --git a/iree/compiler/Dialect/HAL/Transforms/MaterializeConstantPoolBuffers.cpp b/iree/compiler/Dialect/HAL/Transforms/MaterializeConstantPoolBuffers.cpp
new file mode 100644
index 0000000..d5e6b16
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Transforms/MaterializeConstantPoolBuffers.cpp
@@ -0,0 +1,310 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <utility>
+
+#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "iree/compiler/Dialect/HAL/Transforms/Passes.h"
+#include "iree/compiler/Dialect/HAL/Utils/TypeUtils.h"
+#include "iree/compiler/Dialect/IREE/IR/IREEDialect.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace HAL {
+
+class MaterializeConstantPoolBuffersPass
+ : public PassWrapper<MaterializeConstantPoolBuffersPass,
+ OperationPass<ModuleOp>> {
+ public:
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<mlir::StandardOpsDialect>();
+ registry.insert<IREEDialect>();
+ registry.insert<IREE::HAL::HALDialect>();
+ }
+
+ void runOnOperation() override {
+ auto moduleOp = getOperation();
+
+ // Today we simply materialize a !hal.buffer variable for each storage
+ // buffer and initialize it in a naive way. Really we should be aggregating
+ // all constant pools and issuing a command buffer on a DMA queue to upload
+ // everything but this is a start that works on unified memory systems ok.
+ // TODO(benvanik): command buffer-based DMA/uploads.
+ //
+ // We also handle the specific constant types directly in this file. Instead
+ // we could synthesize pseudo ops (hal.constant.populate %buffer, ...) that
+ // then were rewritten to the logic during conversion. This would let us
+ // more easily add new types (including things like target-specific constant
+ // types).
+ SymbolTable moduleSymbolTable(moduleOp);
+ auto poolOps = llvm::to_vector<4>(moduleOp.getOps<ConstantPoolOp>());
+ for (auto poolOp : poolOps) {
+ auto insertionPoint = ++Block::iterator(poolOp);
+
+ // 1:1 storage to runtime buffers.
+ for (auto storageOp : poolOp.getOps<ConstantStorageOp>()) {
+ makeStorageBufferRuntimeVariable(poolOp, storageOp, moduleSymbolTable,
+ insertionPoint);
+ }
+
+ // We currently put all splats on their own so that we are always able to
+ // map the storage buffers above as read-only.
+ auto splatOps = llvm::to_vector<4>(poolOp.getOps<ConstantPoolSplatOp>());
+ if (!splatOps.empty()) {
+ makeSplatRuntimeVariable(poolOp, splatOps, moduleSymbolTable,
+ insertionPoint);
+ }
+ }
+ }
+
+ private:
+ // Creates a runtime buffer into which the storage buffer will be mapped or
+ // uploaded.
+ void makeStorageBufferRuntimeVariable(ConstantPoolOp poolOp,
+ ConstantStorageOp storageOp,
+ SymbolTable &moduleSymbolTable,
+ Block::iterator insertionPoint) {
+ auto *context = poolOp.getContext();
+ auto variableName =
+ (poolOp.getName() + storageOp.getName() + "_buffer").str();
+ auto variableOp = OpBuilder(context).create<IREE::HAL::VariableOp>(
+ storageOp.getLoc(), variableName, /*isMutable=*/false,
+ IREE::HAL::BufferType::get(context));
+ moduleSymbolTable.insert(variableOp, insertionPoint);
+ SymbolTable::setSymbolVisibility(variableOp,
+ SymbolTable::Visibility::Private);
+
+ // Find all the spans in the pool that map into this storage buffer so that
+ // we can update them with their runtime offsets. Note that since we are
+ // uploading 1:1 today all the offsets are the same as their storage ones.
+ auto variableSymRef = SymbolRefAttr::get(variableOp.getName(), context);
+ for (auto spanOp : poolOp.getOps<ConstantPoolSpanOp>()) {
+ if (spanOp.storage_buffer().getLeafReference() != storageOp.getName()) {
+ continue;
+ }
+ spanOp.runtime_bufferAttr(variableSymRef);
+ spanOp.runtime_rangeAttr(spanOp.storage_range());
+ }
+
+ auto initializerFunc = makeStorageBufferRuntimeInitializerFunc(
+ variableOp.getName(), storageOp, poolOp.buffer_constraints());
+ moduleSymbolTable.insert(initializerFunc, insertionPoint);
+ variableOp.initializerAttr(
+ SymbolRefAttr::get(initializerFunc.getName(), context));
+ }
+
+ // Creates an initializer function that unpacks the given storage op into a
+ // new buffer.
+ FuncOp makeStorageBufferRuntimeInitializerFunc(
+ StringRef variableName, ConstantStorageOp storageOp,
+ BufferConstraintsAttr bufferConstraints) {
+ auto *context = storageOp.getContext();
+ OpBuilder builder(context);
+ auto initializerName = (variableName + "_initializer").str();
+ auto initializerFunc = FuncOp::create(
+ storageOp.getLoc(), initializerName,
+ builder.getFunctionType({}, {IREE::HAL::BufferType::get(context)}));
+ SymbolTable::setSymbolVisibility(initializerFunc,
+ SymbolTable::Visibility::Private);
+
+ auto funcBuilder = OpBuilder::atBlockBegin(initializerFunc.addEntryBlock());
+
+ // HACK: use default allocator.
+ auto deviceValue = funcBuilder.createOrFold<IREE::HAL::ExSharedDeviceOp>(
+ storageOp.getLoc());
+ auto allocatorValue =
+ funcBuilder.createOrFold<IREE::HAL::DeviceAllocatorOp>(
+ storageOp.getLoc(), deviceValue);
+
+ // Today we always map the buffer directly. We should be using a device
+ // switch to schedule the upload if needed.
+ // TODO(benvanik): allocate based on usage tracking.
+ auto memoryTypes = IREE::HAL::MemoryTypeBitfield::HostLocal |
+ IREE::HAL::MemoryTypeBitfield::DeviceVisible;
+ auto bufferUsage = IREE::HAL::BufferUsageBitfield::Constant |
+ IREE::HAL::BufferUsageBitfield::All;
+ auto sourceValue =
+ funcBuilder.createOrFold<IREE::HAL::ConstantStorageLookupOp>(
+ storageOp.getLoc(), IREE::ByteBufferType::get(context),
+ funcBuilder.getSymbolRefAttr(
+ storageOp.getParentOfType<ConstantPoolOp>().getName(),
+ {funcBuilder.getSymbolRefAttr(storageOp)}));
+ auto offsetValue =
+ funcBuilder.createOrFold<mlir::ConstantIndexOp>(storageOp.getLoc(), 0);
+ uint64_t runtimeLength =
+ align(storageOp.value().getNumElements(),
+ bufferConstraints.min_buffer_range_alignment());
+ auto lengthValue = funcBuilder.createOrFold<mlir::ConstantIndexOp>(
+ storageOp.getLoc(), runtimeLength);
+ auto bufferValue = funcBuilder.createOrFold<IREE::HAL::AllocatorMapOp>(
+ storageOp.getLoc(), allocatorValue, memoryTypes, bufferUsage,
+ sourceValue, offsetValue, lengthValue);
+ funcBuilder.create<mlir::ReturnOp>(storageOp.getLoc(), bufferValue);
+
+ return initializerFunc;
+ }
+
+ // Creates a runtime buffer for the given constant pool splats and constructs
+ // its initializer to fill the contents.
+ void makeSplatRuntimeVariable(ConstantPoolOp poolOp,
+ ArrayRef<ConstantPoolSplatOp> splatOps,
+ SymbolTable &moduleSymbolTable,
+ Block::iterator insertionPoint) {
+ auto *context = poolOp.getContext();
+ auto variableLoc = FusedLoc::get(
+ llvm::to_vector<8>(llvm::map_range(
+ splatOps,
+ [](ConstantPoolSplatOp splatOp) { return splatOp.getLoc(); })),
+ context);
+ auto variableName = (poolOp.getName() + "_splats").str();
+ auto variableOp = OpBuilder(context).create<IREE::HAL::VariableOp>(
+ variableLoc, variableName, /*isMutable=*/false,
+ IREE::HAL::BufferType::get(context));
+ moduleSymbolTable.insert(variableOp, insertionPoint);
+ SymbolTable::setSymbolVisibility(variableOp,
+ SymbolTable::Visibility::Private);
+
+ // Compute the ranges for all the splats at runtime and the required buffer
+ // size based on the constraints provided.
+ auto bufferConstraints = poolOp.buffer_constraints();
+ auto variableSymRef = SymbolRefAttr::get(variableOp.getName(), context);
+ uint64_t bufferLength = 0;
+ for (auto splatOp : poolOp.getOps<ConstantPoolSplatOp>()) {
+ uint64_t splatOffset =
+ align(bufferLength, bufferConstraints.min_buffer_offset_alignment());
+ uint64_t unpaddedLength =
+ getRoundedElementByteWidth(
+ splatOp.value().getType().getElementType()) *
+ splatOp.value().getNumElements();
+ uint64_t splatLength =
+ align(unpaddedLength, bufferConstraints.min_buffer_range_alignment());
+ splatOp.runtime_bufferAttr(variableSymRef);
+ splatOp.runtime_rangeAttr(ByteRangeAttr::get(
+ APInt(64, splatOffset), APInt(64, splatLength), context));
+ bufferLength = splatOffset + splatLength;
+ }
+
+ // TODO(benvanik): if we spill here we'll need to create more buffers. We
+ // could flip this loop inside out and first allocate the splats.
+ if (bufferLength > bufferConstraints.max_buffer_range().getZExtValue()) {
+ variableOp.emitError()
+ << "constant splat buffer length " << bufferLength
+ << " spills max buffer range of "
+ << bufferConstraints.max_buffer_range().getZExtValue()
+ << " - contents may not be accessible at runtime";
+ }
+
+ auto initializerFunc = makeSplatRuntimeInitializerFunc(
+ variableOp.getLoc(), variableOp.getName(), splatOps, bufferLength);
+ moduleSymbolTable.insert(initializerFunc, insertionPoint);
+ variableOp.initializerAttr(
+ SymbolRefAttr::get(initializerFunc.getName(), context));
+ }
+
+ // Creates an initializer function that allocates the runtime buffer and
+ // splats the values into it.
+ FuncOp makeSplatRuntimeInitializerFunc(Location variableLoc,
+ StringRef variableName,
+ ArrayRef<ConstantPoolSplatOp> splatOps,
+ uint64_t bufferLength) {
+ auto *context = variableLoc.getContext();
+ OpBuilder builder(context);
+ auto initializerName = (variableName + "_initializer").str();
+ auto initializerFunc = FuncOp::create(
+ variableLoc, initializerName,
+ builder.getFunctionType({}, {IREE::HAL::BufferType::get(context)}));
+ SymbolTable::setSymbolVisibility(initializerFunc,
+ SymbolTable::Visibility::Private);
+
+ auto funcBuilder = OpBuilder::atBlockBegin(initializerFunc.addEntryBlock());
+
+ // HACK: use default allocator.
+ auto deviceValue =
+ funcBuilder.createOrFold<IREE::HAL::ExSharedDeviceOp>(variableLoc);
+ auto allocatorValue =
+ funcBuilder.createOrFold<IREE::HAL::DeviceAllocatorOp>(variableLoc,
+ deviceValue);
+
+ // Allocate buffer with empty contents.
+ // TODO(benvanik): allocate based on usage tracking.
+ auto memoryTypes = IREE::HAL::MemoryTypeBitfield::DeviceLocal |
+ IREE::HAL::MemoryTypeBitfield::HostVisible;
+ auto bufferUsage = IREE::HAL::BufferUsageBitfield::Constant |
+ IREE::HAL::BufferUsageBitfield::All;
+ auto allocationSizeValue = funcBuilder.createOrFold<mlir::ConstantIndexOp>(
+ variableLoc, bufferLength);
+ auto bufferValue = funcBuilder.createOrFold<IREE::HAL::AllocatorAllocateOp>(
+ variableLoc, allocatorValue, memoryTypes, bufferUsage,
+ allocationSizeValue);
+
+ // Fill the buffer (memset).
+ // TODO(benvanik): do this via a command buffer/DMA to keep host moving.
+ for (auto splatOp : splatOps) {
+ auto offsetValue = funcBuilder.createOrFold<mlir::ConstantOp>(
+ splatOp.getLoc(), splatOp.runtime_rangeAttr().offsetAttr());
+ auto lengthValue = funcBuilder.createOrFold<mlir::ConstantOp>(
+ splatOp.getLoc(), splatOp.runtime_rangeAttr().lengthAttr());
+ uint32_t pattern = makePatternFromSplatValue(
+ splatOp.value().cast<SplatElementsAttr>().getSplatValue());
+ auto patternValue = funcBuilder.createOrFold<mlir::ConstantIntOp>(
+ variableLoc, static_cast<int64_t>(pattern), 32);
+ funcBuilder.create<IREE::HAL::BufferFillOp>(splatOp.getLoc(), bufferValue,
+ offsetValue, lengthValue,
+ patternValue);
+ }
+
+ funcBuilder.create<mlir::ReturnOp>(variableLoc, bufferValue);
+
+ return initializerFunc;
+ }
+
+ // Makes a 4-byte pattern from a splat value for use at runtime.
+ // Asserts if the pattern cannot be constructed (not 4-byte compatible). TBD.
+ // TODO(benvanik): support 8-byte fill patterns (via fallback executable).
+ uint32_t makePatternFromSplatValue(Attribute elementAttr) {
+ assert(elementAttr.getType().getIntOrFloatBitWidth() <= 32); // i64/f64 TBD
+ if (auto intAttr = elementAttr.dyn_cast<IntegerAttr>()) {
+ return static_cast<uint32_t>(
+ APInt::getSplat(32, intAttr.getValue()).getZExtValue());
+ } else if (auto fltAttr = elementAttr.dyn_cast<FloatAttr>()) {
+ return static_cast<uint32_t>(
+ APInt::getSplat(32, fltAttr.getValue().bitcastToAPInt())
+ .getZExtValue());
+ }
+ assert(false && "unsupported splat type");
+ return 0;
+ }
+};
+
+std::unique_ptr<OperationPass<ModuleOp>>
+createMaterializeConstantPoolBuffersPass() {
+ return std::make_unique<MaterializeConstantPoolBuffersPass>();
+}
+
+static PassRegistration<MaterializeConstantPoolBuffersPass> pass(
+ "iree-hal-materialize-constant-pool-buffers",
+ "Materializes runtime buffers for constant pools.");
+
+} // namespace HAL
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Dialect/HAL/Transforms/MemoizeDeviceQueries.cpp b/iree/compiler/Dialect/HAL/Transforms/MemoizeDeviceQueries.cpp
index 3e343eb..7a90486 100644
--- a/iree/compiler/Dialect/HAL/Transforms/MemoizeDeviceQueries.cpp
+++ b/iree/compiler/Dialect/HAL/Transforms/MemoizeDeviceQueries.cpp
@@ -75,11 +75,13 @@
moduleBuilder.getFunctionType({}, {moduleBuilder.getI1Type()}));
SymbolTable::setSymbolVisibility(initializerOp,
SymbolTable::Visibility::Private);
+ moduleBuilder.setInsertionPoint(initializerOp);
auto variableOp = moduleBuilder.create<IREE::HAL::VariableOp>(
fusedLoc, variableName,
/*isMutable=*/false, initializerOp);
SymbolTable::setSymbolVisibility(variableOp,
SymbolTable::Visibility::Private);
+ moduleBuilder.setInsertionPointAfter(initializerOp);
auto funcBuilder = OpBuilder::atBlockBegin(initializerOp.addEntryBlock());
auto device =
diff --git a/iree/compiler/Dialect/HAL/Transforms/PackConstantPoolStorage.cpp b/iree/compiler/Dialect/HAL/Transforms/PackConstantPoolStorage.cpp
new file mode 100644
index 0000000..ae2bc46
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Transforms/PackConstantPoolStorage.cpp
@@ -0,0 +1,294 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <utility>
+
+#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "iree/compiler/Dialect/HAL/Transforms/Passes.h"
+#include "iree/compiler/Dialect/HAL/Utils/TypeUtils.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace HAL {
+
+class PackConstantPoolStoragePass
+ : public PassWrapper<PackConstantPoolStoragePass,
+ OperationPass<ConstantPoolOp>> {
+ public:
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<IREE::HAL::HALDialect>();
+ }
+
+ void runOnOperation() override {
+ auto poolOp = getOperation();
+ auto bufferConstraints = poolOp.buffer_constraints();
+ if (failed(packConstantPool(poolOp, bufferConstraints))) {
+ signalPassFailure();
+ return;
+ }
+ }
+
+ private:
+ // Packs all constant values within |poolOp| into storage buffers.
+ // Zero or more top-level module byte buffers will be inserted.
+ // Safe to call on constant pools that have already been packed; only newly
+ // inserted constant values will get packed and they will be placed into
+ // new buffers.
+ //
+ // New module-level operations will be inserted before |moduleInsertionPoint|.
+ LogicalResult packConstantPool(ConstantPoolOp poolOp,
+ BufferConstraintsAttr bufferConstraints) {
+ // We only pack values either into storage (dense, real data) or represent
+ // them as values that will be filled at runtime (splatted values).
+ SmallVector<ConstantPoolValueOp, 8> denseValueOps;
+ SmallVector<ConstantPoolValueOp, 8> splatValueOps;
+ poolOp.walk([&](ConstantPoolValueOp valueOp) {
+ if (auto splatAttr = valueOp.value().dyn_cast<SplatElementsAttr>()) {
+ splatValueOps.push_back(valueOp);
+ } else if (auto denseAttr =
+ valueOp.value().dyn_cast<DenseElementsAttr>()) {
+ denseValueOps.push_back(valueOp);
+ }
+ });
+
+ // Create splat values that future passes handling the runtime work will
+ // use to splat the element value directly into memory.
+ for (auto splatValueOp : splatValueOps) {
+ OpBuilder builder(poolOp.getContext());
+ builder.setInsertionPointAfter(splatValueOp);
+ auto splatOp = builder.create<ConstantPoolSplatOp>(
+ splatValueOp.getLoc(), splatValueOp.getName(), splatValueOp.value(),
+ SymbolRefAttr{}, ByteRangeAttr{});
+ SymbolTable::setSymbolVisibility(splatOp,
+ SymbolTable::Visibility::Nested);
+ splatValueOp.erase();
+ }
+
+ // Perform the packing of dense values to compute the storage buffers we
+ // will need and where each value will be placed.
+ auto storageBuffers = computePackingMap(denseValueOps, bufferConstraints,
+ poolOp.getContext());
+ if (storageBuffers.empty()) return success();
+
+ // Create the storage buffer variables.
+ SymbolTable poolSymbolTable(poolOp);
+ for (auto storageBuffer : storageBuffers) {
+ auto storageBufferLoc = storageBuffer.loc.hasValue()
+ ? storageBuffer.loc.getValue()
+ : UnknownLoc::get(poolOp.getContext());
+ auto storageBufferOp =
+ OpBuilder(poolOp.getContext())
+ .create<ConstantStorageOp>(storageBufferLoc, "_storage",
+ storageBuffer.data);
+ poolSymbolTable.insert(storageBufferOp);
+ SymbolTable::setSymbolVisibility(storageBufferOp,
+ SymbolTable::Visibility::Nested);
+
+ // TODO(benvanik): specify alignment attribute for file serialization
+ // (minStorageBufferOffsetAlignment) and get vm.rodata handling it.
+
+ // Replace each constant value with a span referencing the storage
+ // buffers.
+ for (auto constantSpan : storageBuffer.spans) {
+ auto valueOp = constantSpan.valueOp;
+ OpBuilder poolBuilder(poolOp.getContext());
+ poolBuilder.setInsertionPointAfter(valueOp);
+ auto spanOp = poolBuilder.create<ConstantPoolSpanOp>(
+ valueOp.getLoc(), valueOp.getName(),
+ TypeAttr::get(valueOp.value().getType()),
+ poolBuilder.getSymbolRefAttr(storageBufferOp),
+ ByteRangeAttr::get(APInt(64, constantSpan.offset),
+ APInt(64, constantSpan.length),
+ poolOp.getContext()),
+ SymbolRefAttr{}, ByteRangeAttr{});
+ SymbolTable::setSymbolVisibility(spanOp,
+ SymbolTable::Visibility::Nested);
+ valueOp.erase();
+ }
+ }
+
+ return success();
+ }
+
+ struct ConstantSpan {
+ // Original value op this span represents.
+ ConstantPoolValueOp valueOp;
+ // Byte offset within the storage buffer.
+ uint64_t offset = 0;
+ // Length of the valid data when padded out.
+ // This is only accounting for the padding of the valid data itself and not
+ // any additional padding for other spans within the buffer (like start
+ // offset alignment).
+ uint64_t length = 0;
+ };
+
+ struct StorageBuffer {
+ // Total size in bytes (including padding).
+ uint64_t totalSize = 0;
+ // Fused location of all spans that make up this storage buffer.
+ Optional<Location> loc;
+ // Constant spans packed into this buffer.
+ SmallVector<ConstantSpan, 8> spans;
+ // Packed byte data that must be embedded in the final module.
+ // It must be written with an alignment as required by the constraints.
+ ElementsAttr data;
+ };
+
+ // Returns zero or more storage buffers and the spans values map into.
+ // Assume that |valueOps| have been ordered by prior passes and that order may
+ // have some performance-sensitivity (constants are grouped by
+ // locality/lifetime/etc).
+ SmallVector<StorageBuffer, 8> computePackingMap(
+ ArrayRef<ConstantPoolValueOp> valueOps,
+ BufferConstraintsAttr bufferConstraints, MLIRContext *context) {
+ // This is literally all my brain has brain for right now. The ideal here is
+ // that we have a basic static (and ideally profile-guided) sorting pass
+ // that keeps constant values that are accessed sorted together.
+ //
+ // We want good spatial locality as being in the same storage buffer means
+ // that constants are likelier to be pulled into memory together (by disk
+ // prefetcher pulling in mapped pages, TLB cache being hot, etc). We want
+ // good temporal locality because then we have a higher chance of the
+ // constants being placed into the same runtime buffer and that reduces the
+ // amount of buffer swapping/bindings we need to manage when recording
+ // commands.
+ //
+ // <story time> Funnily enough, the same reasons we care here (and this same
+ // algorithm) are the same that console developers encountered in the early
+ // days of CD-ROMs; inserting padding to force alignment on block
+ // boundaries, ensuring that temporally related content was together even if
+ // it meant repeating things, etc - like always physically duplicate the
+ // music near each level that uses it and potentially even interleave those
+ // together in block order on the disc, as being able to stream the music
+ // and still seek to blocks in level content was worth the % of lost space.
+ // You could listen for how well-optimized a game was by the noise level of
+ // the read head! (incidentally, same case too with tapes and floppies,
+ // however the space limitations were almost always the top concern there -
+ // it wasn't until CD-ROM and beyond that there was enough space to shuffle
+ // things around and waste on silly things like loading times).
+ //
+ // Here it's all descriptor sets and mapped pages but same thing pretty
+ // much, and passes earlier on may duplicate constants in the pool if it
+ // means they can improve locality at runtime. This pass doesn't dedupe and
+ // just sticks to packing for that reason.
+
+ // Build a list of buffers and spans (append to current or spill to new).
+ auto storageBuffers =
+ bucketValuesIntoStorageBuffers(valueOps, bufferConstraints);
+
+ // Pack each storage buffer bucket into a single data blob.
+ for (auto &storageBuffer : storageBuffers) {
+ packStorageBufferData(storageBuffer, context);
+ }
+
+ return storageBuffers;
+ }
+
+ // Buckets |valueOps| into one or more storage buffers based on
+ // |bufferConstraints|.
+ SmallVector<StorageBuffer, 8> bucketValuesIntoStorageBuffers(
+ ArrayRef<ConstantPoolValueOp> valueOps,
+ BufferConstraintsAttr bufferConstraints) {
+ // TODO(benvanik): replace with a better strategy (best-fit, etc).
+ SmallVector<StorageBuffer, 8> storageBuffers;
+ storageBuffers.push_back({});
+ StorageBuffer *currentBuffer = &storageBuffers.back();
+ for (auto valueOp : valueOps) {
+ uint64_t offset = align(currentBuffer->totalSize,
+ bufferConstraints.min_buffer_offset_alignment());
+ uint64_t unpaddedLength =
+ valueOp.value().cast<DenseElementsAttr>().getRawData().size();
+ uint64_t paddedLength =
+ align(unpaddedLength, bufferConstraints.min_buffer_range_alignment());
+ if (offset + unpaddedLength >
+ bufferConstraints.max_allocation_size().getZExtValue()) {
+ // Spilling buffer; make a new one.
+ storageBuffers.push_back({});
+ currentBuffer = &storageBuffers.back();
+ offset = 0;
+ }
+ currentBuffer->spans.push_back({valueOp, offset, unpaddedLength});
+ currentBuffer->totalSize =
+ std::max(currentBuffer->totalSize, offset + paddedLength);
+ }
+ if (storageBuffers.back().spans.empty()) {
+ storageBuffers.pop_back();
+ }
+ return storageBuffers;
+ }
+
+ // Packs all span data into a single data attribute we can tag on the buffer.
+ // The data produced will contain all spans at the specified offsets with no
+ // additional padding.
+ //
+ // NOTE: data can overlap so do not assume that the order between spans
+ // is contiguous or always increasing! Always seek!
+ void packStorageBufferData(StorageBuffer &storageBuffer,
+ MLIRContext *context) {
+ // The constants get rolled into the buffer. This neat bit of info would
+ // be useful if we wanted to map back a module size through data blobs.
+ // With buffer <-> constant it's possible to build a tree map of
+ // contributions in the source. TBD ;)
+ storageBuffer.loc = FusedLoc::get(
+ llvm::to_vector<8>(llvm::map_range(
+ storageBuffer.spans,
+ [](ConstantSpan &span) { return span.valueOp.getLoc(); })),
+ context);
+
+ // TODO(#3354): replace this with an #iree.composite_buffer attribute or
+ // something so we can reuse the uniqued storage for each constant and just
+ // reference them with the (offset, length) byte range. Otherwise we are
+ // re-uniquing the new constant (and the old ones will likely be around in
+ // various forms at least transiently) meaning that we are potentially
+ // doubling the size of all constants in memory.
+
+ // Construct the buffer in memory.
+ std::vector<char> buffer(storageBuffer.totalSize);
+ for (auto &constantSpan : storageBuffer.spans) {
+ // NOTE: we know the data is dense because we've already filtered out
+ // any splats; we really would not want to be writing splats to a file.
+ auto sourceData = constantSpan.valueOp.value().cast<DenseElementsAttr>();
+ auto rawData = sourceData.getRawData();
+ llvm::copy(rawData, buffer.begin() + constantSpan.offset);
+ }
+ storageBuffer.data = DenseElementsAttr::getFromRawBuffer(
+ VectorType::get({static_cast<int64_t>(storageBuffer.totalSize)},
+ IntegerType::get(8, context)),
+ buffer,
+ /*isSplatBuffer=*/false);
+ }
+};
+
+std::unique_ptr<OperationPass<ConstantPoolOp>>
+createPackConstantPoolStoragePass() {
+ return std::make_unique<PackConstantPoolStoragePass>();
+}
+
+static PassRegistration<PackConstantPoolStoragePass> pass(
+ "iree-hal-pack-constant-pool-storage",
+ "Packs all constants in a hal.constant_pool into their possibly "
+ "target-dependent storage formats.");
+
+} // namespace HAL
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Dialect/HAL/Transforms/Passes.cpp b/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
index fc05535..b89d649 100644
--- a/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
+++ b/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
@@ -16,7 +16,6 @@
#include <memory>
-#include "iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertFlowToHAL.h"
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
#include "mlir/Pass/PassRegistry.h"
@@ -44,10 +43,20 @@
} // namespace
void buildHALTransformPassPipeline(OpPassManager &passManager,
- TargetOptions targetOptions,
+ const TargetOptions &targetOptions,
const TransformOptions &transformOptions) {
passManager.addPass(createCanonicalizerPass());
+ // Handle large constants (weights/params/etc) first so that we can use the
+ // resulting constant pools to determine the interfaces.
+ passManager.addPass(createIdentifyConstantPoolsPass(targetOptions));
+ passManager.addPass(createPackConstantPoolStoragePass());
+ passManager.addPass(createMaterializeConstantPoolBuffersPass());
+ passManager.addPass(createCanonicalizerPass());
+ passManager.addPass(createSymbolDCEPass());
+
+ // Each executable needs a hal.interface to specify how the host and device
+ // comminucate across the ABI boundary.
passManager.addPass(createMaterializeInterfacesPass(targetOptions));
// TODO(#1036): when dynamic pass registration is supported we can just
@@ -56,7 +65,8 @@
// this pass.
passManager.addPass(createTranslateExecutablesPass(targetOptions));
- passManager.addPass(createConvertFlowToHALPass());
+ // Convert supported input dialects (std, flow, etc) into the HAL dialect.
+ passManager.addPass(createConvertToHALPass());
// Phase ordering note: Before this pass, functions signatures will be based
// on explicit shape types (such as ranked_shape). After this pass, these
@@ -113,7 +123,7 @@
}
void buildHALTransformPassPipeline(OpPassManager &passManager,
- TargetOptions targetOptions) {
+ const TargetOptions &targetOptions) {
TransformOptions transformOptions;
buildHALTransformPassPipeline(passManager, targetOptions, transformOptions);
}
diff --git a/iree/compiler/Dialect/HAL/Transforms/Passes.h b/iree/compiler/Dialect/HAL/Transforms/Passes.h
index 3cd5de7..80dd59d 100644
--- a/iree/compiler/Dialect/HAL/Transforms/Passes.h
+++ b/iree/compiler/Dialect/HAL/Transforms/Passes.h
@@ -45,11 +45,18 @@
// buildHALTransformPassPipeline & run
// <run conversion from HAL to vm/etc>
void buildHALTransformPassPipeline(OpPassManager &passManager,
- TargetOptions targetOptions);
+ const TargetOptions &targetOptions);
void registerHALTransformPassPipeline();
//===----------------------------------------------------------------------===//
+// Conversion
+//===----------------------------------------------------------------------===//
+
+// Convert input flow/std/etc dialects to the IREE HAL dialect.
+std::unique_ptr<OperationPass<ModuleOp>> createConvertToHALPass();
+
+//===----------------------------------------------------------------------===//
// Device management
//===----------------------------------------------------------------------===//
@@ -96,6 +103,20 @@
// Resource initialization, caching, and optimization
//===----------------------------------------------------------------------===//
+// Combines constant variables into one or more hal.constant_pools based on
+// usage semantics.
+std::unique_ptr<OperationPass<ModuleOp>> createIdentifyConstantPoolsPass(
+ TargetOptions targetOptions);
+
+// Packs all constant data in a hal.constant_pool into their storage formats
+// and maps them with hal.constant_pool.span.
+std::unique_ptr<OperationPass<ConstantPoolOp>>
+createPackConstantPoolStoragePass();
+
+// Materializes runtime buffers for constant pools.
+std::unique_ptr<OperationPass<ModuleOp>>
+createMaterializeConstantPoolBuffersPass();
+
// Finds all resource lookups (such as hal.executable.lookup), materializes
// their cache storage and initialization, and rewrites the lookups to
// references.
@@ -109,6 +130,7 @@
inline void registerHALPasses() {
registerHALTransformPassPipeline();
auto executableOptions = getTargetOptionsFromFlags();
+ createConvertToHALPass();
createInlineDeviceSwitchesPass();
createMemoizeDeviceQueriesPass();
createMaterializeInterfacesPass(executableOptions);
@@ -117,6 +139,9 @@
createResolveEntryPointOrdinalsPass();
createSerializeExecutablesPass(executableOptions);
createPublicABIGenerationPass();
+ createIdentifyConstantPoolsPass(executableOptions);
+ createPackConstantPoolStoragePass();
+ createMaterializeConstantPoolBuffersPass();
createMaterializeResourceCachesPass(executableOptions);
}
diff --git a/iree/compiler/Dialect/HAL/Transforms/test/identify_constant_pools.mlir b/iree/compiler/Dialect/HAL/Transforms/test/identify_constant_pools.mlir
new file mode 100644
index 0000000..3c58678
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Transforms/test/identify_constant_pools.mlir
@@ -0,0 +1,53 @@
+// RUN: iree-opt -split-input-file -iree-hal-identify-constant-pools -iree-hal-target-backends=vmla -iree-hal-target-backends=vulkan-spirv %s | IreeFileCheck %s
+
+// CHECK: hal.constant_pool @_const_pool attributes
+// CHECK-SAME: buffer_constraints = #hal.buffer_constraints<max_allocation_size = 1073741824, min_buffer_offset_alignment = 256, max_buffer_range = 134217728, min_buffer_range_alignment = 16>
+// CHECK-NEXT: hal.constant_pool.value @cst0 {{.+}} = dense<1.000000e+00> : tensor<1xf32>
+flow.variable @cst0 dense<1.000000e+00> : tensor<1xf32>
+// CHECK-NEXT: hal.constant_pool.value @cst1 {{.+}} = dense<[2.100000e+00, 3.200000e+00, 4.300000e+00, 5.400000e+00]> : tensor<4xf32>
+flow.variable @cst1 dense<[2.1, 3.2, 4.3, 5.4]> : tensor<4xf32>
+// CHECK-NEXT: hal.constant_pool.value @cst2 {{.+}} = dense<[6, 7, 8]> : tensor<3xi8>
+flow.variable @cst2 dense<[6, 7, 8]> : tensor<3xi8>
+
+// CHECK-LABEL: func @immutable_variables
+func @immutable_variables() -> (tensor<1xf32>, tensor<4xf32>, tensor<3xi8>) {
+ // CHECK-NEXT: = hal.constant_pool.load @_const_pool::@cst0 : tensor<1xf32>
+ %cst0 = flow.variable.load @cst0 : tensor<1xf32>
+ // CHECK-NEXT: = hal.constant_pool.load @_const_pool::@cst1 : tensor<4xf32>
+ %cst1 = flow.variable.load @cst1 : tensor<4xf32>
+ // CHECK-NEXT: = hal.constant_pool.load @_const_pool::@cst2 : tensor<3xi8>
+ %cst2 = flow.variable.load @cst2 : tensor<3xi8>
+ return %cst0, %cst1, %cst2 : tensor<1xf32>, tensor<4xf32>, tensor<3xi8>
+}
+
+// -----
+
+// CHECK: hal.constant_pool @_const_pool_init
+// CHECK-NEXT: hal.constant_pool.value @variable_0 {{.+}} = dense<3.000000e+00> : tensor<128xf32>
+
+// CHECK: flow.variable @variable_0 mutable init(@variable_0_initializer)
+flow.variable @variable_0 mutable dense<3.0> : tensor<128xf32>
+// CHECK-NEXT: func @variable_0_initializer() -> tensor<128xf32>
+// CHECK-NEXT: [[CONST:%.+]] = hal.constant_pool.load @_const_pool_init::@variable_0 : tensor<128xf32>
+// CHECK-NEXT: return [[CONST]] : tensor<128xf32>
+// CHECK-NEXT: }
+
+// CHECK-LABEL: func @mutable_variables
+func @mutable_variables() -> tensor<128xf32> {
+ // CHECK: flow.variable.load @variable_0
+ %var_0 = flow.variable.load @variable_0 : tensor<128xf32>
+ return %var_0 : tensor<128xf32>
+}
+
+// -----
+
+// NOTE: indirect variable accesses not currently supported.
+// CHECK: flow.variable @_large_const_0
+flow.variable @_large_const_0 dense<3.0> : tensor<128xf32>
+func @skip_indirect_variables() -> (tensor<128xf32>) {
+ // CHECK: flow.variable.address
+ %0 = flow.variable.address @_large_const_0 : !iree.ptr<tensor<128xf32>>
+ // CHECK: flow.variable.load.indirect
+ %1 = flow.variable.load.indirect %0 : !iree.ptr<tensor<128xf32>> -> tensor<128xf32>
+ return %1 : tensor<128xf32>
+}
diff --git a/iree/compiler/Dialect/HAL/Transforms/test/materialize_constant_pool_buffers.mlir b/iree/compiler/Dialect/HAL/Transforms/test/materialize_constant_pool_buffers.mlir
new file mode 100644
index 0000000..54b5ae2
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Transforms/test/materialize_constant_pool_buffers.mlir
@@ -0,0 +1,61 @@
+// RUN: iree-opt -split-input-file -iree-hal-materialize-constant-pool-buffers %s | IreeFileCheck %s
+
+// CHECK-LABEL: hal.constant_pool @dense_variable_init
+hal.constant_pool @dense_variable_init attributes {buffer_constraints = #hal.buffer_constraints<max_allocation_size = 1073741824, min_buffer_offset_alignment = 32, max_buffer_range = 134217728, min_buffer_range_alignment = 4>} {
+ // CHECK-NEXT: @cst0 {{.+}} -> @dense_variable_init_storage_buffer[#hal.byte_range<0, 512>]
+ hal.constant_pool.span @cst0 : tensor<128xf32> = @_storage[#hal.byte_range<0, 512>]
+ // CHECK-NEXT: @cst1 {{.+}} -> @dense_variable_init_storage_buffer[#hal.byte_range<512, 256>]
+ hal.constant_pool.span @cst1 : tensor<64xf32> = @_storage[#hal.byte_range<512, 256>]
+ hal.constant_storage @_storage = dense<1> : vector<768xi8>
+}
+
+// CHECK: hal.variable @dense_variable_init_storage_buffer init(@dense_variable_init_storage_buffer_initializer) : !hal.buffer
+// CHECK-NEXT: func @dense_variable_init_storage_buffer_initializer() -> !hal.buffer
+// CHECK: [[STORAGE:%.+]] = hal.constant_storage.lookup @dense_variable_init::@_storage : !iree.byte_buffer
+// CHECK: = hal.allocator.map {{.+}} [[STORAGE]][%c0, %c768] : !iree.byte_buffer -> !hal.buffer
+
+// -----
+
+// CHECK-LABEL: hal.constant_pool @splat_variable_init
+hal.constant_pool @splat_variable_init attributes {buffer_constraints = #hal.buffer_constraints<max_allocation_size = 1073741824, min_buffer_offset_alignment = 32, max_buffer_range = 134217728, min_buffer_range_alignment = 4>} {
+ // CHECK-NEXT: @cst0 {{.+}} -> @splat_variable_init_splats[#hal.byte_range<0, 4>]
+ hal.constant_pool.splat @cst0 = dense<1.000000e+00> : tensor<1xf32>
+ // CHECK-NEXT: @cst1 {{.+}} -> @splat_variable_init_splats[#hal.byte_range<32, 32>]
+ hal.constant_pool.splat @cst1 = dense<1234567890> : tensor<8xi32>
+}
+
+// CHECK: hal.variable @splat_variable_init_splats init(@splat_variable_init_splats_initializer) : !hal.buffer
+// CHECK-NEXT: func @splat_variable_init_splats_initializer() -> !hal.buffer
+// CHECK: [[BUFFER:%.+]] = hal.allocator.allocate {{.+}} %c64 : !hal.buffer
+// CHECK: hal.buffer.fill [[BUFFER]], %c0, %c4, %c1065353216_i32
+// CHECK: hal.buffer.fill [[BUFFER]], %c32, %c32_0, %c1234567890_i32
+
+// -----
+
+// CHECK-LABEL: hal.constant_pool @pool
+hal.constant_pool @pool attributes {buffer_constraints = #hal.buffer_constraints<max_allocation_size = 1073741824, min_buffer_offset_alignment = 32, max_buffer_range = 134217728, min_buffer_range_alignment = 4>} {
+ // CHECK-NEXT: @cst0 {{.+}} -> @pool_storage0_buffer[#hal.byte_range<0, 16>]
+ hal.constant_pool.span @cst0 : tensor<4xf32> = @_storage0[#hal.byte_range<0, 16>]
+ // CHECK-NEXT: @cst1 {{.+}} -> @pool_storage1_buffer[#hal.byte_range<0, 3>]
+ hal.constant_pool.span @cst1 : tensor<3xi8> = @_storage1[#hal.byte_range<0, 3>]
+ // CHECK-NEXT: @cst2 {{.+}} -> @pool_splats[#hal.byte_range<0, 4>]
+ hal.constant_pool.splat @cst2 = dense<1.000000e+00> : tensor<1xf32>
+ // CHECK-NEXT: @cst3 {{.+}} -> @pool_splats[#hal.byte_range<32, 32>]
+ hal.constant_pool.splat @cst3 = dense<1234567890> : tensor<8xi32>
+ hal.constant_storage @_storage0 = dense<[102, 102, 6, 64, -51, -52, 76, 64, -102, -103, -119, 64, -51, -52, -84, 64]> : vector<16xi8>
+ hal.constant_storage @_storage1 = dense<[6, 7, 8, 0]> : vector<4xi8>
+}
+
+// CHECK: hal.variable @pool_storage0_buffer init(@pool_storage0_buffer_initializer) : !hal.buffer
+// CHECK-NEXT: func @pool_storage0_buffer_initializer() -> !hal.buffer
+// CHECK: [[STORAGE:%.+]] = hal.constant_storage.lookup @pool::@_storage0 : !iree.byte_buffer
+// CHECK: = hal.allocator.map {{.+}} [[STORAGE]][%c0, %c16] : !iree.byte_buffer -> !hal.buffer
+
+// CHECK: hal.variable @pool_storage1_buffer init(@pool_storage1_buffer_initializer) : !hal.buffer
+// CHECK-NEXT: func @pool_storage1_buffer_initializer() -> !hal.buffer
+
+// CHECK: hal.variable @pool_splats init(@pool_splats_initializer) : !hal.buffer
+// CHECK-NEXT: func @pool_splats_initializer() -> !hal.buffer
+// CHECK: [[BUFFER:%.+]] = hal.allocator.allocate %allocator, "HostVisible|DeviceVisible|DeviceLocal", "Constant|Transfer|Mapping|Dispatch", %c64 : !hal.buffer
+// CHECK: hal.buffer.fill [[BUFFER]], %c0, %c4, %c1065353216_i32
+// CHECK: hal.buffer.fill [[BUFFER]], %c32, %c32_0, %c1234567890_i32
diff --git a/iree/compiler/Dialect/HAL/Transforms/test/memoize_device_queries.mlir b/iree/compiler/Dialect/HAL/Transforms/test/memoize_device_queries.mlir
index de7f576..a081a4c 100644
--- a/iree/compiler/Dialect/HAL/Transforms/test/memoize_device_queries.mlir
+++ b/iree/compiler/Dialect/HAL/Transforms/test/memoize_device_queries.mlir
@@ -1,15 +1,15 @@
// RUN: iree-opt -split-input-file -iree-hal-memoize-device-queries %s | IreeFileCheck %s
+// CHECK: hal.variable @_device_match_id_0 init(@_device_match_id_0_initializer) : i1
// CHECK: func @_device_match_id_0_initializer() -> i1
// CHECK-NEXT: %[[DEVICE:.+]] = hal.ex.shared_device : !hal.device
// CHECK-NEXT: %[[IS_MATCH:.+]] = hal.device.match.id %[[DEVICE]], pattern = ["vulkan-v1.?-*"] : (!hal.device) -> i1
// CHECK-NEXT: return %[[IS_MATCH]] : i1
-// CHECK: hal.variable @_device_match_id_0 init(@_device_match_id_0_initializer) : i1
// CHECK: hal.variable @_device_match_id_1
// CHECK: hal.variable @_device_match_id_2
-// CHECK-LABEL: @device_matchers
+// CHECK-LABEL: func @device_matchers
func @device_matchers(%device : !hal.device) {
// CHECK-NEXT: = hal.variable.load @_device_match_id_0 : i1
%0 = hal.device.match.id %device, pattern = ["vulkan-v1.?-*"] : (!hal.device) -> i1
diff --git a/iree/compiler/Dialect/HAL/Transforms/test/pack_constant_pool_storage.mlir b/iree/compiler/Dialect/HAL/Transforms/test/pack_constant_pool_storage.mlir
new file mode 100644
index 0000000..469a94c
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Transforms/test/pack_constant_pool_storage.mlir
@@ -0,0 +1,36 @@
+// RUN: iree-opt -split-input-file -iree-hal-pack-constant-pool-storage %s | IreeFileCheck %s
+
+// CHECK-LABEL: hal.constant_pool @pool
+hal.constant_pool @pool attributes {
+ buffer_constraints = #hal.buffer_constraints<max_allocation_size = 1073741824,
+ min_buffer_offset_alignment = 32,
+ max_buffer_range = 134217728,
+ min_buffer_range_alignment = 4>
+ } {
+ // CHECK-DAG: hal.constant_pool.splat @cst0 {{.+}} = dense<1.000000e+00> : tensor<1xf32>
+ hal.constant_pool.value @cst0 = dense<1.000000e+00> : tensor<1xf32>
+ // CHECK-DAG: hal.constant_pool.span @cst1 : tensor<4xf32> {{.+}} = @_storage[#hal.byte_range<0, 16>]
+ hal.constant_pool.value @cst1 = dense<[2.1, 3.2, 4.3, 5.4]> : tensor<4xf32>
+ // CHECK-DAG: hal.constant_pool.span @cst2 : tensor<3xi8> {{.+}} = @_storage[#hal.byte_range<32, 3>]
+ hal.constant_pool.value @cst2 = dense<[6, 7, 8]> : tensor<3xi8>
+
+ // CHECK: hal.constant_storage @_storage {{.+}} = dense<[102, 102, 6, 64, -51, -52, 76, 64, -102, -103, -119, 64, -51, -52, -84, 64, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 7, 8, 0]> : vector<36xi8>
+}
+
+// -----
+
+// CHECK-LABEL: hal.constant_pool @multi_storage
+hal.constant_pool @multi_storage attributes {
+ buffer_constraints = #hal.buffer_constraints<max_allocation_size = 18,
+ min_buffer_offset_alignment = 1,
+ max_buffer_range = 134217728,
+ min_buffer_range_alignment = 1>
+ } {
+ // CHECK-DAG: hal.constant_pool.span @cst0 : tensor<4xf32> {{.+}} = @_storage[#hal.byte_range<0, 16>]
+ hal.constant_pool.value @cst0 = dense<[2.1, 3.2, 4.3, 5.4]> : tensor<4xf32>
+ // CHECK-DAG: hal.constant_pool.span @cst1 : tensor<3xi8> {{.+}} = @_storage_0[#hal.byte_range<0, 3>]
+ hal.constant_pool.value @cst1 = dense<[6, 7, 8]> : tensor<3xi8>
+
+ // CHECK-NEXT: hal.constant_storage @_storage {{.+}} = dense<[102, 102, 6, 64, -51, -52, 76, 64, -102, -103, -119, 64, -51, -52, -84, 64]> : vector<16xi8>
+ // CHECK-NEXT: hal.constant_storage @_storage_0 {{.+}} = dense<[6, 7, 8]> : vector<3xi8>
+}
diff --git a/iree/compiler/Dialect/HAL/Utils/TypeUtils.h b/iree/compiler/Dialect/HAL/Utils/TypeUtils.h
index dbc960d..9c628a7 100644
--- a/iree/compiler/Dialect/HAL/Utils/TypeUtils.h
+++ b/iree/compiler/Dialect/HAL/Utils/TypeUtils.h
@@ -26,6 +26,14 @@
namespace IREE {
namespace HAL {
+// Aligns |value| to |alignment|, rounding up if needed.
+static inline uint64_t align(uint64_t value, uint64_t alignment) {
+ return (value + (alignment - 1)) & ~(alignment - 1);
+}
+static inline uint64_t align(uint64_t value, const APInt &alignment) {
+ return align(value, alignment.getZExtValue());
+}
+
// Returns the number of bytes an element of the given type occupies
// post-conversion. For example, the size of i1 would be '1 byte'.
int32_t getRoundedElementByteWidth(Type type);
diff --git a/iree/compiler/Dialect/HAL/hal.imports.mlir b/iree/compiler/Dialect/HAL/hal.imports.mlir
index eff8b57..f4aed4e 100644
--- a/iree/compiler/Dialect/HAL/hal.imports.mlir
+++ b/iree/compiler/Dialect/HAL/hal.imports.mlir
@@ -20,33 +20,6 @@
// iree::hal::Allocator
//===----------------------------------------------------------------------===//
-// Computes the byte size required for a buffer of the given shape and type.
-vm.import @allocator.compute_size(
- %allocator : !vm.ref<!hal.allocator>,
- %shape : i32 ...,
- %element_type : i32
-) -> i32
-attributes {nosideeffects}
-
-// Computes an element byte offset within a buffer.
-vm.import @allocator.compute_offset(
- %allocator : !vm.ref<!hal.allocator>,
- %shape : i32 ...,
- %element_type : i32,
- %indices : i32 ...
-) -> i32
-attributes {nosideeffects}
-
-// Computes a byte range within a buffer for one or more elements.
-vm.import @allocator.compute_range(
- %allocator : !vm.ref<!hal.allocator>,
- %shape : i32 ...,
- %element_type : i32,
- %indices : i32 ...,
- %lengths : i32 ...
-) -> (i32, i32)
-attributes {nosideeffects}
-
// Allocates a buffer from the allocator.
vm.import @allocator.allocate(
%allocator : !vm.ref<!hal.allocator>,
@@ -55,14 +28,15 @@
%allocation_size : i32
) -> !vm.ref<!hal.buffer>
-// Allocates a buffer from the allocator with the given constant contents.
-vm.import @allocator.allocate.const(
+// Wraps a subrange of a read-only host memory buffer.
+// Host mapping must be supported by the allocator.
+vm.import @allocator.wrap.byte_buffer(
%allocator : !vm.ref<!hal.allocator>,
%memory_types : i32,
%buffer_usage : i32,
- %shape : i32 ...,
- %element_type : i32,
- %value : !vm.ref<!iree.byte_buffer>
+ %source : !vm.ref<!iree.byte_buffer>,
+ %offset : i32,
+ %length : i32
) -> !vm.ref<!hal.buffer>
//===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/IREE/Conversion/BUILD b/iree/compiler/Dialect/IREE/Conversion/BUILD
index 5aa54ef..85dbc15 100644
--- a/iree/compiler/Dialect/IREE/Conversion/BUILD
+++ b/iree/compiler/Dialect/IREE/Conversion/BUILD
@@ -19,23 +19,6 @@
)
cc_library(
- name = "ConvertToHAL",
- srcs = [
- "ConvertToHAL.cpp",
- ],
- hdrs = [
- "ConvertToHAL.h",
- ],
- deps = [
- "//iree/compiler/Dialect/HAL/IR",
- "//iree/compiler/Dialect/IREE/IR",
- "@llvm-project//mlir:IR",
- "@llvm-project//mlir:StandardOps",
- "@llvm-project//mlir:Transforms",
- ],
-)
-
-cc_library(
name = "PreserveCompilerHints",
srcs = [
"PreserveCompilerHints.cpp",
diff --git a/iree/compiler/Dialect/IREE/Conversion/CMakeLists.txt b/iree/compiler/Dialect/IREE/Conversion/CMakeLists.txt
index 9d07055..604cf81 100644
--- a/iree/compiler/Dialect/IREE/Conversion/CMakeLists.txt
+++ b/iree/compiler/Dialect/IREE/Conversion/CMakeLists.txt
@@ -16,22 +16,6 @@
iree_cc_library(
NAME
- ConvertToHAL
- HDRS
- "ConvertToHAL.h"
- SRCS
- "ConvertToHAL.cpp"
- DEPS
- MLIRIR
- MLIRStandard
- MLIRTransforms
- iree::compiler::Dialect::HAL::IR
- iree::compiler::Dialect::IREE::IR
- PUBLIC
-)
-
-iree_cc_library(
- NAME
PreserveCompilerHints
HDRS
"PreserveCompilerHints.h"
diff --git a/iree/compiler/Dialect/IREE/Conversion/test/convert_flow_to_hal.mlir b/iree/compiler/Dialect/IREE/Conversion/test/convert_flow_to_hal.mlir
index 0d6d278..21e086e 100644
--- a/iree/compiler/Dialect/IREE/Conversion/test/convert_flow_to_hal.mlir
+++ b/iree/compiler/Dialect/IREE/Conversion/test/convert_flow_to_hal.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --iree-convert-flow-to-hal %s --split-input-file | IreeFileCheck %s
+// RUN: iree-opt -iree-convert-to-hal %s --split-input-file | IreeFileCheck %s
// CHECK-LABEL: @preserve_compiler_hints
func @preserve_compiler_hints() {
@@ -11,15 +11,3 @@
iree.do_not_optimize(%c) : i32
return
}
-
-// -----
-
-// CHECK-LABEL: @dynamic_shape_constant
-func @dynamic_shape_constant() {
- // CHECK: %dev = hal.ex.shared_device
- // CHECK: %allocator = hal.device.allocator %dev
- // CHECK: %view = hal.buffer_view.const %allocator, "HostVisible|DeviceVisible|DeviceLocal", "Constant|Transfer|Mapping|Dispatch" : !hal.buffer_view = dense<2> : tensor<2xi32>
- // CHECK: %[[RES:.+]] = iree.do_not_optimize(%view) : !hal.buffer_view
- %c = iree.dynamic_shape_constant dense<2> : tensor<2xi32> -> tensor<?xi32>
- return
-}
diff --git a/iree/compiler/Dialect/IREE/IR/IREEOps.cpp b/iree/compiler/Dialect/IREE/IR/IREEOps.cpp
index bf48b9f..822e631 100644
--- a/iree/compiler/Dialect/IREE/IR/IREEOps.cpp
+++ b/iree/compiler/Dialect/IREE/IR/IREEOps.cpp
@@ -14,6 +14,7 @@
#include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
+#include "iree/compiler/Dialect/IREE/IR/IREETypes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/SMLoc.h"
diff --git a/iree/compiler/Dialect/IREE/IR/IREEOps.td b/iree/compiler/Dialect/IREE/IR/IREEOps.td
index 10973e5..fee13f7 100644
--- a/iree/compiler/Dialect/IREE/IR/IREEOps.td
+++ b/iree/compiler/Dialect/IREE/IR/IREEOps.td
@@ -32,6 +32,30 @@
IREE_Op<mnemonic, !listconcat(traits, [NoSideEffect])>;
//===----------------------------------------------------------------------===//
+// Byte buffers and host data
+//===----------------------------------------------------------------------===//
+
+def IREE_ByteBufferConstantOp : IREE_PureOp<"byte_buffer.constant"> {
+ let summary = "constant host-side byte buffer";
+ let description = [{
+ Defines a compile-time byte buffer based on the given attribute value.
+ The attribute will be serialized into the canonical IREE format for the
+ chosen host target.
+ }];
+
+ let arguments = (ins
+ ElementsAttr:$value
+ );
+ let results = (outs
+ ByteBufferType:$result
+ );
+
+ let assemblyFormat = [{
+ attr-dict `:` type($result) `=` $value
+ }];
+}
+
+//===----------------------------------------------------------------------===//
// Executable ABI
//===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/IREE/IR/test/byte_buffer_ops.mlir b/iree/compiler/Dialect/IREE/IR/test/byte_buffer_ops.mlir
new file mode 100644
index 0000000..27a58de
--- /dev/null
+++ b/iree/compiler/Dialect/IREE/IR/test/byte_buffer_ops.mlir
@@ -0,0 +1,8 @@
+// RUN: iree-opt -split-input-file %s | iree-opt -split-input-file | IreeFileCheck %s
+
+// CHECK-LABEL: @byte_buffer_constant
+func @byte_buffer_constant() -> !iree.byte_buffer {
+ // CHECK: = iree.byte_buffer.constant : !iree.byte_buffer = dense<[1, 2, 3]> : tensor<3xi32>
+ %0 = iree.byte_buffer.constant : !iree.byte_buffer = dense<[1, 2, 3]> : tensor<3xi32>
+ return %0 : !iree.byte_buffer
+}
diff --git a/iree/compiler/Dialect/VM/Conversion/IREEToVM/ConvertIREEToVM.cpp b/iree/compiler/Dialect/VM/Conversion/IREEToVM/ConvertIREEToVM.cpp
index ddeab10..fdeedc5 100644
--- a/iree/compiler/Dialect/VM/Conversion/IREEToVM/ConvertIREEToVM.cpp
+++ b/iree/compiler/Dialect/VM/Conversion/IREEToVM/ConvertIREEToVM.cpp
@@ -28,9 +28,29 @@
namespace mlir {
namespace iree_compiler {
-
namespace {
+//===----------------------------------------------------------------------===//
+// iree.byte_buffer.*
+//===----------------------------------------------------------------------===//
+
+class ByteBufferConstantOpConversion
+ : public OpConversionPattern<IREE::ByteBufferConstantOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult matchAndRewrite(
+ IREE::ByteBufferConstantOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<IREE::VM::RodataInlineOp>(
+ op, IREE::VM::RefType::get(op.getType()), op.value());
+ return success();
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// Compiler hints
+//===----------------------------------------------------------------------===//
+
class UnreachableOpConversion
: public OpConversionPattern<IREE::UnreachableOp> {
using OpConversionPattern::OpConversionPattern;
@@ -52,6 +72,7 @@
void populateIREEToVMPatterns(MLIRContext *context,
OwningRewritePatternList &patterns) {
+ patterns.insert<ByteBufferConstantOpConversion>(context);
patterns.insert<UnreachableOpConversion>(context);
}
diff --git a/iree/compiler/Dialect/VM/Conversion/IREEToVM/test/byte_buffer_ops.mlir b/iree/compiler/Dialect/VM/Conversion/IREEToVM/test/byte_buffer_ops.mlir
new file mode 100644
index 0000000..28ba88e
--- /dev/null
+++ b/iree/compiler/Dialect/VM/Conversion/IREEToVM/test/byte_buffer_ops.mlir
@@ -0,0 +1,13 @@
+// RUN: iree-opt -split-input-file -iree-vm-conversion %s | IreeFileCheck %s
+
+// CHECK-LABEL: @byte_buffer_constant
+module @byte_buffer_constant {
+module {
+ // CHECK: vm.func @my_fn
+ func @my_fn() {
+ // CHECK-NEXT: = vm.rodata.inline : !vm.ref<!iree.byte_buffer> = dense<[1, 2, 3]> : tensor<3xi32>
+ %0 = iree.byte_buffer.constant : !iree.byte_buffer = dense<[1, 2, 3]> : tensor<3xi32>
+ return
+ }
+}
+}
diff --git a/iree/compiler/Dialect/VM/Conversion/TypeConverter.cpp b/iree/compiler/Dialect/VM/Conversion/TypeConverter.cpp
index 2edcf88..a36fe2a 100644
--- a/iree/compiler/Dialect/VM/Conversion/TypeConverter.cpp
+++ b/iree/compiler/Dialect/VM/Conversion/TypeConverter.cpp
@@ -77,6 +77,12 @@
return IntegerType::get(targetOptions_.indexBits, indexType.getContext());
});
+ // Vectors are used for arbitrary byte storage.
+ addConversion([](VectorType vectorType) -> Optional<Type> {
+ return IREE::VM::RefType::get(
+ IREE::ByteBufferType::get(vectorType.getContext()));
+ });
+
// Convert ranked shape types (expanding all dims).
addConversion([this](Shape::RankedShapeType rankedShape,
SmallVectorImpl<Type> &results) {
diff --git a/iree/compiler/Dialect/VM/IR/VMOps.td b/iree/compiler/Dialect/VM/IR/VMOps.td
index d613c8b..b2a433c 100644
--- a/iree/compiler/Dialect/VM/IR/VMOps.td
+++ b/iree/compiler/Dialect/VM/IR/VMOps.td
@@ -819,6 +819,26 @@
let verifier = [{ return verify$cppClass(*this); }];
}
+
+def VM_RodataInlineOp : VM_PureOp<"rodata.inline", [
+ VM_PseudoOp,
+ ]> {
+ let summary = [{inlined constant rodata}];
+ let description = [{
+ vm.rodata that can be embedded inline in functions.
+ }];
+
+ let arguments = (ins
+ ElementsAttr:$value
+ );
+
+ let results = (outs
+ VM_RefOf<ByteBufferType>:$result
+ );
+
+ let assemblyFormat = "attr-dict `:` type($result) `=` $value";
+}
+
//===----------------------------------------------------------------------===//
// Lists
//===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/VM/IR/test/const_ops.mlir b/iree/compiler/Dialect/VM/IR/test/const_ops.mlir
index 13bbe5e..821c8d9 100644
--- a/iree/compiler/Dialect/VM/IR/test/const_ops.mlir
+++ b/iree/compiler/Dialect/VM/IR/test/const_ops.mlir
@@ -44,3 +44,14 @@
vm.return %buf0 : !vm.ref<!iree.byte_buffer>
}
}
+
+// -----
+
+vm.module @my_module {
+ // CHECK-LABEL: @inlined_rodata
+ vm.func @inlined_rodata() -> !vm.ref<!iree.byte_buffer> {
+ // CHECK-NEXT: = vm.rodata.inline : !vm.ref<!iree.byte_buffer> = dense<[0, 1, 2]> : tensor<3xi8>
+ %0 = vm.rodata.inline : !vm.ref<!iree.byte_buffer> = dense<[0, 1, 2]> : tensor<3xi8>
+ vm.return %0 : !vm.ref<!iree.byte_buffer>
+ }
+}
diff --git a/iree/compiler/Dialect/VM/Transforms/BUILD b/iree/compiler/Dialect/VM/Transforms/BUILD
index 1edfd36..db938f0 100644
--- a/iree/compiler/Dialect/VM/Transforms/BUILD
+++ b/iree/compiler/Dialect/VM/Transforms/BUILD
@@ -23,6 +23,7 @@
srcs = [
"Conversion.cpp",
"GlobalInitialization.cpp",
+ "HoistInlinedRodata.cpp",
"MarkPublicSymbolsExported.cpp",
"OrdinalAllocation.cpp",
"Passes.cpp",
diff --git a/iree/compiler/Dialect/VM/Transforms/CMakeLists.txt b/iree/compiler/Dialect/VM/Transforms/CMakeLists.txt
index 413d33c..5704913 100644
--- a/iree/compiler/Dialect/VM/Transforms/CMakeLists.txt
+++ b/iree/compiler/Dialect/VM/Transforms/CMakeLists.txt
@@ -22,6 +22,7 @@
SRCS
"Conversion.cpp"
"GlobalInitialization.cpp"
+ "HoistInlinedRodata.cpp"
"MarkPublicSymbolsExported.cpp"
"OrdinalAllocation.cpp"
"Passes.cpp"
diff --git a/iree/compiler/Dialect/VM/Transforms/HoistInlinedRodata.cpp b/iree/compiler/Dialect/VM/Transforms/HoistInlinedRodata.cpp
new file mode 100644
index 0000000..1c8204f
--- /dev/null
+++ b/iree/compiler/Dialect/VM/Transforms/HoistInlinedRodata.cpp
@@ -0,0 +1,95 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <utility>
+
+#include "iree/compiler/Dialect/IREE/IR/IREEDialect.h"
+#include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
+#include "iree/compiler/Dialect/VM/IR/VMDialect.h"
+#include "iree/compiler/Dialect/VM/IR/VMOps.h"
+#include "iree/compiler/Dialect/VM/IR/VMTypes.h"
+#include "iree/compiler/Dialect/VM/Transforms/Passes.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace VM {
+
+class HoistInlinedRodataPass
+ : public PassWrapper<HoistInlinedRodataPass,
+ OperationPass<IREE::VM::ModuleOp>> {
+ public:
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<IREEDialect>();
+ registry.insert<IREE::VM::VMDialect>();
+ }
+
+ void runOnOperation() override {
+ auto moduleOp = getOperation();
+ SymbolTable moduleSymbolTable(moduleOp);
+
+ // Find all inline byte buffers in the module.
+ auto funcOps = llvm::to_vector<4>(moduleOp.getOps<IREE::VM::FuncOp>());
+ for (auto funcOp : funcOps) {
+ auto inlineOps =
+ llvm::to_vector<4>(funcOp.getOps<IREE::VM::RodataInlineOp>());
+ if (inlineOps.empty()) continue;
+
+ OpBuilder moduleBuilder(moduleOp.getContext());
+ moduleBuilder.setInsertionPoint(funcOp);
+ for (auto inlineOp : inlineOps) {
+ auto rodataOp =
+ OpBuilder(moduleOp.getContext())
+ .create<IREE::VM::RodataOp>(inlineOp.getLoc(),
+ (funcOp.getName() + "_const").str(),
+ inlineOp.value());
+ moduleSymbolTable.insert(rodataOp, moduleBuilder.getInsertionPoint());
+ SymbolTable::setSymbolVisibility(rodataOp,
+ SymbolTable::Visibility::Private);
+ replaceInlineOpWithRodataRef(inlineOp, rodataOp);
+ }
+ }
+ }
+
+ private:
+ // Replaces a vm.rodata.inline op with a vm.const.ref.rodata op that
+ // references the module-level |rodataOp|.
+ void replaceInlineOpWithRodataRef(IREE::VM::RodataInlineOp inlineOp,
+ IREE::VM::RodataOp rodataOp) {
+ OpBuilder builder(inlineOp);
+ auto refOp =
+ builder.create<IREE::VM::ConstRefRodataOp>(inlineOp.getLoc(), rodataOp);
+ inlineOp.replaceAllUsesWith(refOp.value());
+ inlineOp.erase();
+ }
+};
+
+std::unique_ptr<OperationPass<IREE::VM::ModuleOp>>
+createHoistInlinedRodataPass() {
+ return std::make_unique<HoistInlinedRodataPass>();
+}
+
+static PassRegistration<HoistInlinedRodataPass> pass(
+ "iree-vm-hoist-inlined-rodata",
+ "Hoists inline iree.byte_buffer values to module-level constant storage.");
+
+} // namespace VM
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Dialect/VM/Transforms/Passes.cpp b/iree/compiler/Dialect/VM/Transforms/Passes.cpp
index 1f84c7e..b72ba7d 100644
--- a/iree/compiler/Dialect/VM/Transforms/Passes.cpp
+++ b/iree/compiler/Dialect/VM/Transforms/Passes.cpp
@@ -30,6 +30,7 @@
TargetOptions targetOptions) {
passManager.addPass(createCanonicalizerPass());
passManager.addPass(createConversionPass(targetOptions));
+ passManager.addPass(createHoistInlinedRodataPass());
passManager.addPass(createGlobalInitializationPass());
passManager.addPass(createInlinerPass());
passManager.addPass(createCSEPass());
diff --git a/iree/compiler/Dialect/VM/Transforms/Passes.h b/iree/compiler/Dialect/VM/Transforms/Passes.h
index 85e3cee..a54fc05 100644
--- a/iree/compiler/Dialect/VM/Transforms/Passes.h
+++ b/iree/compiler/Dialect/VM/Transforms/Passes.h
@@ -62,7 +62,14 @@
TargetOptions targetOptions);
//===----------------------------------------------------------------------===//
-// Module Analysis and Assignment
+// Module layout
+//===----------------------------------------------------------------------===//
+
+std::unique_ptr<OperationPass<IREE::VM::ModuleOp>>
+createHoistInlinedRodataPass();
+
+//===----------------------------------------------------------------------===//
+// Module analysis and ordinal assignment
//===----------------------------------------------------------------------===//
// Gathers all module-level global init/deinit functions into single locations
@@ -90,6 +97,7 @@
auto targetOptions = getTargetOptionsFromFlags();
registerVMTransformPassPipeline();
createConversionPass(targetOptions);
+ createHoistInlinedRodataPass();
createGlobalInitializationPass();
createOrdinalAllocationPass();
}
diff --git a/iree/compiler/Dialect/VM/Transforms/test/hoist_inlined_rodata.mlir b/iree/compiler/Dialect/VM/Transforms/test/hoist_inlined_rodata.mlir
new file mode 100644
index 0000000..f949e43
--- /dev/null
+++ b/iree/compiler/Dialect/VM/Transforms/test/hoist_inlined_rodata.mlir
@@ -0,0 +1,11 @@
+// RUN: iree-opt -split-input-file -iree-vm-hoist-inlined-rodata %s | IreeFileCheck %s
+
+vm.module @module {
+ // CHECK: vm.rodata @fn_const dense<[1, 2, 3]> : tensor<3xi32>
+ // CHECK-LABEL: vm.func @fn
+ vm.func @fn() {
+ // CHECK: = vm.const.ref.rodata @fn_const : !vm.ref<!iree.byte_buffer>
+ %0 = vm.rodata.inline : !vm.ref<!iree.byte_buffer> = dense<[1, 2, 3]> : tensor<3xi32>
+ vm.return
+ }
+}
diff --git a/iree/compiler/Dialect/VMLA/Conversion/BUILD b/iree/compiler/Dialect/VMLA/Conversion/BUILD
index db7c68d..6db8a30 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/BUILD
+++ b/iree/compiler/Dialect/VMLA/Conversion/BUILD
@@ -32,7 +32,6 @@
"//iree/compiler/Dialect/IREE/IR",
"//iree/compiler/Dialect/Shape/IR",
"//iree/compiler/Dialect/VMLA/IR",
- "//iree/compiler/Dialect/VMLA/IR:VMLADialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Transforms",
diff --git a/iree/compiler/Dialect/VMLA/Conversion/CMakeLists.txt b/iree/compiler/Dialect/VMLA/Conversion/CMakeLists.txt
index a9ec0ad..52391f5 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/CMakeLists.txt
+++ b/iree/compiler/Dialect/VMLA/Conversion/CMakeLists.txt
@@ -30,6 +30,5 @@
iree::compiler::Dialect::IREE::IR
iree::compiler::Dialect::Shape::IR
iree::compiler::Dialect::VMLA::IR
- iree::compiler::Dialect::VMLA::IR::VMLADialect
PUBLIC
)
diff --git a/iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.cpp b/iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.cpp
index 6bd35bb..651612c 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.cpp
+++ b/iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.cpp
@@ -38,7 +38,7 @@
// The VMLA dialect expects both standard ops and the VMLA ops (in case some
// conversion has already happened).
addLegalOp<ModuleOp, ModuleTerminatorOp>();
- addLegalDialect<IREE::VMLA::VMLADialect>();
+ addLegalDialect("vmla");
// Pseudo-ops are illegal.
// If we end up with a lot of these, consider using an "is pseudo" trait.
addIllegalOp<IREE::VMLA::BatchMatMulPseudoOp>();
diff --git a/iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.h b/iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.h
index af1a561..f97a236 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.h
+++ b/iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.h
@@ -15,7 +15,6 @@
#ifndef IREE_COMPILER_DIALECT_VMLA_CONVERSION_CONVERSIONTARGET_H_
#define IREE_COMPILER_DIALECT_VMLA_CONVERSION_CONVERSIONTARGET_H_
-#include "iree/compiler/Dialect/VMLA/IR/VMLADialect.h"
#include "iree/compiler/Dialect/VMLA/IR/VMLATypes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
diff --git a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/BUILD b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/BUILD
index 40c585d..9f61f7c 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/BUILD
+++ b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/BUILD
@@ -33,6 +33,7 @@
"//iree/compiler/Dialect/VM/Conversion/StandardToVM",
"//iree/compiler/Dialect/VM/IR",
"//iree/compiler/Dialect/VMLA:vmla_imports",
+ "//iree/compiler/Dialect/VMLA/Conversion",
"//iree/compiler/Dialect/VMLA/IR",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
diff --git a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/CMakeLists.txt b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/CMakeLists.txt
index 6c82226..e90f8e3 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/CMakeLists.txt
+++ b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/CMakeLists.txt
@@ -31,6 +31,7 @@
iree::compiler::Dialect::VM::Conversion
iree::compiler::Dialect::VM::Conversion::StandardToVM
iree::compiler::Dialect::VM::IR
+ iree::compiler::Dialect::VMLA::Conversion
iree::compiler::Dialect::VMLA::IR
iree::compiler::Dialect::VMLA::vmla_imports
PUBLIC
diff --git a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp
index 1b66485..895a8e5 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp
+++ b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp
@@ -22,6 +22,7 @@
#include "iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.h"
#include "iree/compiler/Dialect/VM/Conversion/TypeConverter.h"
#include "iree/compiler/Dialect/VM/IR/VMOps.h"
+#include "iree/compiler/Dialect/VMLA/Conversion/TypeConverter.h"
#include "iree/compiler/Dialect/VMLA/IR/VMLAOps.h"
#include "iree/compiler/Dialect/VMLA/IR/VMLATypes.h"
#include "iree/compiler/Dialect/VMLA/vmla.imports.h"
@@ -183,24 +184,37 @@
LogicalResult matchAndRewrite(
IREE::VMLA::ConstantOp op, llvm::ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- // Encode constant data into a rodata segment. These will eventually get
- // deduped and combined.
- auto ip = rewriter.saveInsertionPoint();
- auto parentFuncOp = op.getParentOfType<IREE::VM::FuncOp>();
- rewriter.setInsertionPoint(parentFuncOp);
- auto constName = (parentFuncOp.getName() + "_const_" +
- std::to_string(allocateUniqueId(parentFuncOp)))
- .str();
- auto rodataOp =
- rewriter.create<IREE::VM::RodataOp>(op.getLoc(), constName, op.value());
- rewriter.restoreInsertionPoint(ip);
- auto loadRodataOp =
- rewriter.create<IREE::VM::ConstRefRodataOp>(op.getLoc(), rodataOp);
-
- // Dereference constant data.
- rewriter.replaceOpWithNewOp<IREE::VMLA::BufferConstOp>(
- op, IREE::VMLA::BufferType::get(op.getContext()),
- loadRodataOp.getResult());
+ if (auto splatAttr = op.value().dyn_cast<SplatElementsAttr>()) {
+ // Encode just a single splat element and use a buffer fill.
+ auto rodataValue = rewriter.createOrFold<IREE::VM::RodataInlineOp>(
+ op.getLoc(),
+ IREE::VM::RefType::get(IREE::ByteBufferType::get(op.getContext())),
+ DenseElementsAttr::get(
+ RankedTensorType::get({1}, splatAttr.getSplatValue().getType()),
+ splatAttr.getSplatValue()));
+ auto fillValue = rewriter.createOrFold<IREE::VMLA::BufferConstOp>(
+ op.getLoc(), IREE::VMLA::BufferType::get(op.getContext()),
+ rodataValue);
+ auto bufferLengthValue = rewriter.createOrFold<mlir::ConstantIndexOp>(
+ op.getLoc(), splatAttr.getType().cast<ShapedType>().getNumElements() *
+ VMLATypeConverter::getRoundedElementByteWidth(
+ splatAttr.getSplatValue().getType()));
+ auto bufferValue = rewriter.createOrFold<IREE::VMLA::BufferAllocOp>(
+ op.getLoc(), IREE::VMLA::BufferType::get(op.getContext()),
+ bufferLengthValue);
+ rewriter.create<IREE::VMLA::BufferFillOp>(op.getLoc(), fillValue,
+ bufferValue);
+ rewriter.replaceOp(op, bufferValue);
+ } else {
+ // Encode constant data into a rodata segment. These will eventually get
+ // deduped and combined.
+ auto rodataValue = rewriter.createOrFold<IREE::VM::RodataInlineOp>(
+ op.getLoc(),
+ IREE::VM::RefType::get(IREE::ByteBufferType::get(op.getContext())),
+ op.value());
+ rewriter.replaceOpWithNewOp<IREE::VMLA::BufferConstOp>(
+ op, IREE::VMLA::BufferType::get(op.getContext()), rodataValue);
+ }
return success();
}
diff --git a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/test/constant_ops.mlir b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/test/constant_ops.mlir
new file mode 100644
index 0000000..bad91b4
--- /dev/null
+++ b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/test/constant_ops.mlir
@@ -0,0 +1,23 @@
+// RUN: iree-opt -split-input-file -iree-convert-vmla-to-vm -cse %s | IreeFileCheck %s
+
+// CHECK-LABEL: vm.func @denseConstant
+func @denseConstant() -> !vmla.buffer {
+ // CHECK-NEXT: [[RODATA:%.+]] = vm.rodata.inline : !vm.ref<!iree.byte_buffer> = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<3xf32
+ // CHECK-NEXT: = vm.call @vmla.buffer.const([[RODATA]]) : (!vm.ref<!iree.byte_buffer>) -> !vm.ref<!vmla.buffer>
+ %0 = vmla.constant dense<[1.0, 2.0, 3.0]> : tensor<3xf32> -> !vmla.buffer
+ return %0 : !vmla.buffer
+}
+
+// -----
+
+// CHECK-LABEL: @splatConstant
+func @splatConstant() -> !vmla.buffer {
+ // CHECK-NEXT: [[RODATA:%.+]] = vm.rodata.inline : !vm.ref<!iree.byte_buffer> = dense<0.176776692> : tensor<1xf32>
+ // CHECK-NEXT: [[SPLATTED:%.+]] = vm.call @vmla.buffer.const([[RODATA]]) : (!vm.ref<!iree.byte_buffer>) -> !vm.ref<!vmla.buffer>
+ // CHECK-NEXT: [[LENGTH:%.+]] = vm.const.i32 2359296 : i32
+ // CHECK-NEXT: [[RESULT:%.+]] = vm.call @vmla.buffer.alloc([[LENGTH]]) : (i32) -> !vm.ref<!vmla.buffer>
+ // CHECK-NEXT: vm.call @vmla.buffer.fill([[SPLATTED]], [[RESULT]]) : (!vm.ref<!vmla.buffer>, !vm.ref<!vmla.buffer>) -> ()
+ %0 = vmla.constant dense<0.176776692> : tensor<1x4x384x384xf32> -> !vmla.buffer
+ // CHECK-NEXT: vm.return [[RESULT]] : !vm.ref<!vmla.buffer>
+ return %0 : !vmla.buffer
+}
diff --git a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/test/conversion.mlir b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/test/conversion.mlir
index 51b005d..fa3aae2 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/test/conversion.mlir
+++ b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/test/conversion.mlir
@@ -1,17 +1,5 @@
// RUN: iree-opt -split-input-file -iree-convert-vmla-to-vm -cse %s | IreeFileCheck %s
-// CHECK: vm.rodata @[[CONST_SYM:.+]] dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<3xf32>
-// CHECK-NEXT: vm.func @constValues
-func @constValues() -> !vmla.buffer {
- // CHECK-NEXT: %[[BYTES_REF:.+]] = vm.const.ref.rodata @[[CONST_SYM]] : !vm.ref<!iree.byte_buffer>
- // CHECK-NEXT: %[[BUFFER:.+]] = vm.call @vmla.buffer.const(%[[BYTES_REF]]) : (!vm.ref<!iree.byte_buffer>) -> !vm.ref<!vmla.buffer>
- %0 = vmla.constant dense<[1.0, 2.0, 3.0]> : tensor<3xf32> -> !vmla.buffer
- // CHECK-NEXT: vm.return %[[BUFFER]] : !vm.ref<!vmla.buffer>
- return %0 : !vmla.buffer
-}
-
-// -----
-
// CHECK-LABEL: vm.func @bufferImport
func @bufferImport() -> !vmla.buffer {
%c0 = std.constant 1 : index
diff --git a/iree/compiler/Dialect/VMLA/Transforms/Conversion.cpp b/iree/compiler/Dialect/VMLA/Transforms/Conversion.cpp
index 090c7bd..fe87d80 100644
--- a/iree/compiler/Dialect/VMLA/Transforms/Conversion.cpp
+++ b/iree/compiler/Dialect/VMLA/Transforms/Conversion.cpp
@@ -21,6 +21,7 @@
#include "iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.h"
#include "iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/ConvertStandardToVMLA.h"
#include "iree/compiler/Dialect/VMLA/Conversion/TypeConverter.h"
+#include "iree/compiler/Dialect/VMLA/IR/VMLADialect.h"
#include "iree/compiler/Dialect/VMLA/IR/VMLATypes.h"
#include "iree/compiler/Dialect/VMLA/Transforms/Passes.h"
#include "llvm/ADT/STLExtras.h"
diff --git a/iree/compiler/Translation/test/do_not_optimize.mlir b/iree/compiler/Translation/test/do_not_optimize.mlir
index 4a5d367..00becd0 100644
--- a/iree/compiler/Translation/test/do_not_optimize.mlir
+++ b/iree/compiler/Translation/test/do_not_optimize.mlir
@@ -45,7 +45,7 @@
// -----
-// CHECK-LABEL: vm.rodata @dynamic_constant_const_0 dense<3.000000e+00> : tensor<2x3xf32>
+// CHECK-LABEL: vm.rodata @dynamic_constant_const dense<3.000000e+00> : tensor<2x3xf32>
// CHECK: vm.func @dynamic_constant
func @dynamic_constant() -> tensor<?x?xf32> {
// CHECK: vm.call @hal.buffer_view.dim
diff --git a/iree/hal/vmla/vmla_executable.cc b/iree/hal/vmla/vmla_executable.cc
index 34c4b18..336c610 100644
--- a/iree/hal/vmla/vmla_executable.cc
+++ b/iree/hal/vmla/vmla_executable.cc
@@ -141,7 +141,8 @@
void* data = static_cast<HostBuffer*>(binding.buffer->allocated_buffer())
->mutable_data();
data = reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(data) +
- binding.buffer->byte_offset());
+ binding.buffer->byte_offset() +
+ binding.offset);
IREE_ASSIGN_OR_RETURN(
auto buffer, Buffer::WrapMutable(data, binding.buffer->byte_length(),
iree_allocator_null()));
diff --git a/iree/modules/hal/hal_module.cc b/iree/modules/hal/hal_module.cc
index 3a6524b..d64d643 100644
--- a/iree/modules/hal/hal_module.cc
+++ b/iree/modules/hal/hal_module.cc
@@ -168,42 +168,6 @@
// iree::hal::Allocator
//===--------------------------------------------------------------------===//
- StatusOr<int32_t> AllocatorComputeSize(
- const vm::ref<iree_hal_allocator_t>& allocator,
- absl::Span<const int32_t> shape, iree_hal_element_type_t element_type) {
- iree_device_size_t allocation_size = 0;
- IREE_RETURN_IF_ERROR(iree_hal_allocator_compute_size(
- allocator.get(), shape.data(), shape.size(), element_type,
- &allocation_size));
- return static_cast<int32_t>(allocation_size);
- }
-
- StatusOr<int32_t> AllocatorComputeOffset(
- const vm::ref<iree_hal_allocator_t>& allocator,
- absl::Span<const int32_t> shape, iree_hal_element_type_t element_type,
- absl::Span<const int32_t> indices) {
- iree_device_size_t offset = 0;
- IREE_RETURN_IF_ERROR(iree_hal_allocator_compute_offset(
- allocator.get(), shape.data(), shape.size(), element_type,
- indices.data(), indices.size(), &offset));
- return static_cast<int32_t>(offset);
- }
-
- StatusOr<std::tuple<int32_t, int32_t>> AllocatorComputeRange(
- const vm::ref<iree_hal_allocator_t>& allocator,
- absl::Span<const int32_t> shape, iree_hal_element_type_t element_type,
- absl::Span<const int32_t> start_indices,
- absl::Span<const int32_t> lengths) {
- iree_device_size_t offset = 0;
- iree_device_size_t length = 0;
- IREE_RETURN_IF_ERROR(iree_hal_allocator_compute_range(
- allocator.get(), shape.data(), shape.size(), element_type,
- start_indices.data(), start_indices.size(), lengths.data(),
- lengths.size(), &offset, &length));
- return std::make_tuple(static_cast<int32_t>(offset),
- static_cast<int32_t>(length));
- }
-
StatusOr<vm::ref<iree_hal_buffer_t>> AllocatorAllocate(
const vm::ref<iree_hal_allocator_t>& allocator,
iree_hal_memory_type_t memory_types, iree_hal_buffer_usage_t buffer_usage,
@@ -215,29 +179,34 @@
return std::move(buffer);
}
- StatusOr<vm::ref<iree_hal_buffer_t>> AllocatorAllocateConst(
+ StatusOr<vm::ref<iree_hal_buffer_t>> AllocatorWrapByteBuffer(
const vm::ref<iree_hal_allocator_t>& allocator,
iree_hal_memory_type_t memory_types, iree_hal_buffer_usage_t buffer_usage,
- absl::Span<const int32_t> shape, iree_hal_element_type_t element_type,
- const vm::ref<iree_vm_ro_byte_buffer_t>& value) {
- IREE_TRACE_SCOPE0("HALModuleState::AllocatorAllocateConst");
+ const vm::ref<iree_vm_ro_byte_buffer_t>& source, int32_t offset,
+ int32_t length) {
+ IREE_TRACE_SCOPE0("HALModuleState::AllocatorWrapByteBuffer");
- iree_device_size_t allocation_size = 0;
- IREE_RETURN_IF_ERROR(iree_hal_allocator_compute_size(
- allocator.get(), shape.data(), shape.size(), element_type,
- &allocation_size));
- if (allocation_size < value->data.data_length) {
+ // TODO(benvanik): wrap when supported.
+
+ size_t buffer_length = source->data.data_length;
+ if (length == -1) {
+ length = buffer_length;
+ }
+ if (length < 0 || offset < 0 || offset > buffer_length ||
+ offset + length > buffer_length) {
return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Constant data is too large for the minimum allocation size";
+ << "Byte range out of bounds (requested " << offset << "-"
+ << (offset + length - 1) << " of available " << buffer_length
+ << ")";
}
vm::ref<iree_hal_buffer_t> buffer;
IREE_RETURN_IF_ERROR(iree_hal_allocator_allocate_buffer(
- allocator.get(), memory_types, buffer_usage, allocation_size, &buffer))
+ allocator.get(), memory_types, buffer_usage, length, &buffer))
<< "Failed to allocate buffer";
IREE_RETURN_IF_ERROR(iree_hal_buffer_write_data(
- buffer.get(), 0, value->data.data, value->data.data_length))
+ buffer.get(), 0, source->data.data + offset, length))
<< "Writing constant data";
return buffer;
@@ -256,13 +225,22 @@
const vm::ref<iree_hal_buffer_t>& source_buffer, int32_t source_offset,
int32_t length) {
IREE_TRACE_SCOPE0("HALModuleState::BufferSubspan");
- return UnimplementedErrorBuilder(IREE_LOC) << "BufferSubspan";
+ vm::ref<iree_hal_buffer_t> target_buffer;
+ IREE_RETURN_IF_ERROR(iree_hal_buffer_subspan(
+ source_buffer.get(), source_offset, length, allocator_, &target_buffer))
+ << "Subspan of an existing buffer (source_offset=" << source_offset
+ << ", length=" << length << ")";
+ return target_buffer;
}
Status BufferFill(const vm::ref<iree_hal_buffer_t>& target_buffer,
int32_t target_offset, int32_t length, int32_t pattern) {
IREE_TRACE_SCOPE0("HALModuleState::BufferFill");
- return UnimplementedErrorBuilder(IREE_LOC) << "BufferFill";
+ IREE_RETURN_IF_ERROR(iree_hal_buffer_fill(
+ target_buffer.get(), target_offset, length, &pattern, sizeof(pattern)))
+ << "Fill range failed (target_offset=" << target_offset
+ << ", length=" << length << ")";
+ return OkStatus();
}
Status BufferReadData(const vm::ref<iree_hal_buffer_t>& source_buffer,
@@ -770,16 +748,10 @@
vm::MakeNativeFunction("ex.submit_and_wait",
&HALModuleState::ExSubmitAndWait),
- vm::MakeNativeFunction("allocator.compute_size",
- &HALModuleState::AllocatorComputeSize),
- vm::MakeNativeFunction("allocator.compute_offset",
- &HALModuleState::AllocatorComputeOffset),
- vm::MakeNativeFunction("allocator.compute_range",
- &HALModuleState::AllocatorComputeRange),
vm::MakeNativeFunction("allocator.allocate",
&HALModuleState::AllocatorAllocate),
- vm::MakeNativeFunction("allocator.allocate.const",
- &HALModuleState::AllocatorAllocateConst),
+ vm::MakeNativeFunction("allocator.wrap.byte_buffer",
+ &HALModuleState::AllocatorWrapByteBuffer),
vm::MakeNativeFunction("buffer.allocator",
&HALModuleState::BufferAllocator),
diff --git a/iree/samples/custom_modules/dialect/test/conversion.mlir b/iree/samples/custom_modules/dialect/test/conversion.mlir
index eb2c84f..7759906 100644
--- a/iree/samples/custom_modules/dialect/test/conversion.mlir
+++ b/iree/samples/custom_modules/dialect/test/conversion.mlir
@@ -16,7 +16,7 @@
// Depending on whether any manual conversion is performed this may get complex,
// such as when versioning imports or performing optimizations.
-// RUN: custom-opt %s -iree-convert-flow-to-hal -iree-shape-expand-function-ranked-shape-dims -iree-vm-conversion -split-input-file | IreeFileCheck %s
+// RUN: custom-opt %s -iree-convert-to-hal -iree-shape-expand-function-ranked-shape-dims -iree-vm-conversion -split-input-file | IreeFileCheck %s
// CHECK-LABEL: @tensorToMessage
func @tensorToMessage(%tensor : tensor<2x4xf32>) {
diff --git a/iree/vm/ref_test.cc b/iree/vm/ref_test.cc
index 496b30a..9f07faf 100644
--- a/iree/vm/ref_test.cc
+++ b/iree/vm/ref_test.cc
@@ -70,8 +70,9 @@
}
static int32_t ReadCounter(iree_vm_ref_t* ref) {
- return *((iree_atomic_ref_count_t*)(((uintptr_t)ref->ptr) +
- ref->offsetof_counter));
+ return iree_atomic_load_int32(
+ (iree_atomic_ref_count_t*)(((uintptr_t)ref->ptr) + ref->offsetof_counter),
+ iree_memory_order_seq_cst);
}
static iree_vm_ref_type_t kCTypeID = IREE_VM_REF_TYPE_NULL;