Merge pull request #3467 from google/benvanik-vm-fold-arithmetic
diff --git a/SUBMODULE_VERSIONS b/SUBMODULE_VERSIONS
index b61cdfe..48180ae 100644
--- a/SUBMODULE_VERSIONS
+++ b/SUBMODULE_VERSIONS
@@ -5,7 +5,7 @@
a5d9d0f7d368054fd1691aedf1db4116efcc233e third_party/flatbuffers
4fb0ff7069bd88ee85902f4d0bb62794e5f6d021 third_party/flatcc
f2fb48c3b3d79a75a88a99fba6576b25d42ec528 third_party/googletest
-93377888ae89560ba6d3976e2762d3d4724c4dfd third_party/llvm-project
+9b3c2a72e4cb3b0ae27f87064c11f728452b2af9 third_party/llvm-project
17b12a4481daa150e2d1ea3ada086b551b856707 third_party/marl
d2cdb70e038370b5e28f353fe98ccd70af1cbc25 third_party/mlir-emitc
d8c7ee00a687ac369e62e2032514a93a9b413502 third_party/pybind11
@@ -14,7 +14,7 @@
685f86471e9d26b3eb7676695a2e2cefb4551ae9 third_party/spirv_cross
f8bf11a0253a32375c32cad92c841237b96696c0 third_party/spirv_headers
57eb48aed36160c4876bc8310d9ca84d42ee9e2a third_party/swiftshader
-2e56481abcef1dd1625fba465a5d02ee6b347842 third_party/tensorflow
+090b691fbf7b7823c41345004d12eddaa6c86118 third_party/tensorflow
a9a09ab0940408898fccfdcfe2bb8dc19b50f13c third_party/tracy
9bd3f561bcee3f01d22912de10bb07ce4e23d378 third_party/vulkan_headers
909f36b714c9239ee0b112a321220213a474ba53 third_party/vulkan_memory_allocator
diff --git a/build_tools/bazel/third_party_import/llvm-project/overlay/mlir/BUILD.bazel b/build_tools/bazel/third_party_import/llvm-project/overlay/mlir/BUILD.bazel
index 31a839d..4aeb2b8 100644
--- a/build_tools/bazel/third_party_import/llvm-project/overlay/mlir/BUILD.bazel
+++ b/build_tools/bazel/third_party_import/llvm-project/overlay/mlir/BUILD.bazel
@@ -3918,6 +3918,7 @@
"include/mlir/Dialect/Linalg/EDSC/Builders.h",
"include/mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h",
"include/mlir/Dialect/Linalg/Passes.h",
+ "include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h",
"include/mlir/Dialect/Linalg/Transforms/Hoisting.h",
"include/mlir/Dialect/Linalg/Transforms/Transforms.h",
"include/mlir/Dialect/Linalg/Utils/Utils.h",
@@ -3945,6 +3946,7 @@
":Transforms",
":TransformsPassIncGen",
":VectorOps",
+ ":VectorToSCF",
"@llvm-project//llvm:Core",
"@llvm-project//llvm:Support",
],
@@ -4033,7 +4035,6 @@
":EDSC",
":IR",
":LLVMDialect",
- ":LinalgTransforms",
":Pass",
":SCFDialect",
":StandardOps",
diff --git a/build_tools/bazel/third_party_import/llvm-project/overlay/mlir/test/BUILD.bazel b/build_tools/bazel/third_party_import/llvm-project/overlay/mlir/test/BUILD.bazel
index d88190c..c467880 100644
--- a/build_tools/bazel/third_party_import/llvm-project/overlay/mlir/test/BUILD.bazel
+++ b/build_tools/bazel/third_party_import/llvm-project/overlay/mlir/test/BUILD.bazel
@@ -17,6 +17,21 @@
includes = ["."],
)
+filegroup(
+ name = "TestOpTdFiles",
+ srcs = [
+ "lib/Dialect/Test/TestOps.td",
+ "@llvm-project//mlir:OpBaseTdFiles",
+ "@llvm-project//mlir:include/mlir/IR/OpAsmInterface.td",
+ "@llvm-project//mlir:include/mlir/IR/RegionKindInterface.td",
+ "@llvm-project//mlir:include/mlir/IR/SymbolInterfaces.td",
+ "@llvm-project//mlir:include/mlir/Interfaces/CallInterfaces.td",
+ "@llvm-project//mlir:include/mlir/Interfaces/ControlFlowInterfaces.td",
+ "@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
+ "@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td",
+ ],
+)
+
gentbl(
name = "TestOpsIncGen",
strip_include_prefix = "lib/Dialect/Test",
@@ -57,14 +72,7 @@
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "lib/Dialect/Test/TestOps.td",
td_srcs = [
- "@llvm-project//mlir:OpBaseTdFiles",
- "@llvm-project//mlir:include/mlir/IR/OpAsmInterface.td",
- "@llvm-project//mlir:include/mlir/IR/RegionKindInterface.td",
- "@llvm-project//mlir:include/mlir/IR/SymbolInterfaces.td",
- "@llvm-project//mlir:include/mlir/Interfaces/CallInterfaces.td",
- "@llvm-project//mlir:include/mlir/Interfaces/ControlFlowInterfaces.td",
- "@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
- "@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td",
+ ":TestOpTdFiles",
],
test = True,
)
@@ -90,11 +98,34 @@
test = True,
)
+gentbl(
+ name = "TestTypeDefsIncGen",
+ strip_include_prefix = "lib/Dialect/Test",
+ tbl_outs = [
+ (
+ "-gen-typedef-decls",
+ "lib/Dialect/Test/TestTypeDefs.h.inc",
+ ),
+ (
+ "-gen-typedef-defs",
+ "lib/Dialect/Test/TestTypeDefs.cpp.inc",
+ ),
+ ],
+ tblgen = "@llvm-project//mlir:mlir-tblgen",
+ td_file = "lib/Dialect/Test/TestTypeDefs.td",
+ td_srcs = [
+ ":TestOpTdFiles",
+ ],
+ test = True,
+)
+
cc_library(
name = "TestDialect",
srcs = [
"lib/Dialect/Test/TestDialect.cpp",
"lib/Dialect/Test/TestPatterns.cpp",
+ "lib/Dialect/Test/TestTraits.cpp",
+ "lib/Dialect/Test/TestTypes.cpp",
],
hdrs = [
"lib/Dialect/Test/TestDialect.h",
@@ -106,6 +137,7 @@
deps = [
":TestInterfacesIncGen",
":TestOpsIncGen",
+ ":TestTypeDefsIncGen",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ControlFlowInterfaces",
"@llvm-project//mlir:DerivedAttributeOpInterface",
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
index 00ea2c2..dcd2341 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
@@ -313,8 +313,10 @@
yield call
@staticmethod
- def compare_traces(ref_trace: "Trace", tar_trace: "Trace") -> bool:
+ def compare_traces(ref_trace: "Trace",
+ tar_trace: "Trace") -> Tuple[bool, Sequence[str]]:
traces_match = True
+ error_messages = []
# Check that all method invocations match.
ref_methods = [(call.method, call.rtol, call.atol) for call in ref_trace]
@@ -330,13 +332,17 @@
logging.info("Comparing calls to '%s'", ref_call.method)
rtol, atol = ref_call.get_tolerances()
- inputs_match = Trace._check_same(ref_call.inputs, tar_call.inputs, rtol,
- atol)
+ inputs_match, error_message = Trace._check_same(ref_call.inputs,
+ tar_call.inputs, rtol,
+ atol)
if not inputs_match:
+ error_messages.append(error_message)
logging.error("Inputs did not match.")
- outputs_match = Trace._check_same(ref_call.outputs, tar_call.outputs,
- rtol, atol)
+ outputs_match, error_message = Trace._check_same(ref_call.outputs,
+ tar_call.outputs, rtol,
+ atol)
if not outputs_match:
+ error_messages.append(error_message)
logging.error("Outputs did not match.")
calls_match = inputs_match and outputs_match
@@ -349,83 +355,96 @@
logging.error("Target call '%s':\n%s", tar_trace.backend_id, tar_call)
traces_match = traces_match and calls_match
- return traces_match
+ return traces_match, error_messages
@staticmethod
- def _check_same(ref: Any, tar: Any, rtol: float, atol: float) -> bool:
+ def _check_same(ref: Any, tar: Any, rtol: float,
+ atol: float) -> Tuple[bool, Union[str, None]]:
"""Checks that ref and tar have identical datastructures and values."""
# Check for matching types.
if not isinstance(tar, type(ref)):
- logging.error(
- "Expected ref and tar to have the same type but got '%s' and '%s'",
- type(ref), type(tar))
- return False
+ error = ("Expected ref and tar to have the same type but got "
+ f"'{type(ref)}' and '{type(tar)}'")
+ logging.error(error)
+ return False, error
if ref is None:
# Nothing to compare (e.g. the called method had no outputs).
- return True
+ return True, None
# Recursive check for dicts.
if isinstance(ref, dict):
if ref.keys() != tar.keys():
- logging.error(
- "Expected ref and tar to have the same keys, but got '%s' and '%s'",
- ref.keys(), tar.keys())
- return False
+ error = ("Expected ref and tar to have the same keys, but got "
+ f"'{ref.keys()}' and '{tar.keys()}'")
+ logging.error(error)
+ return False, error
# Check that all of the dictionaries' values are the same.
for key in ref:
- if not Trace._check_same(ref[key], tar[key], rtol, atol):
- return False
+ same, error = Trace._check_same(ref[key], tar[key], rtol, atol)
+ if not same:
+ return same, error
# Recursive check for iterables.
elif isinstance(ref, list) or isinstance(ref, tuple):
if len(ref) != len(tar):
- logging.error(
- "Expected ref and tar to have the same length, but got %s and %s",
- len(ref), len(tar))
- return False
+ error = ("Expected ref and tar to have the same length, but got "
+ f"{len(ref)} and {len(tar)}")
+ logging.error(error)
+ return False, error
# Check that all of the iterables' values are the same.
for i in range(len(ref)):
- if not Trace._check_same(ref[i], tar[i], rtol, atol):
- return False
+ same, error = Trace._check_same(ref[i], tar[i], rtol, atol)
+ if not same:
+ return same, error
# Base check for numpy arrays.
elif isinstance(ref, np.ndarray):
if ref.dtype != tar.dtype:
- logging.error(
- "Expected ref and tar to have the same dtype, but got %s and %s",
- ref.dtype, tar.dtype)
- return False
+ error = ("Expected ref and tar to have the same dtype, but got "
+ f"'{ref.dtype}' and '{tar.dtype}'")
+ logging.error(error)
+ return False, error
if ref.size == tar.size == 0:
- return True
+ return True, None
if np.issubdtype(ref.dtype, np.floating):
same = np.allclose(ref, tar, rtol=rtol, atol=atol)
abs_diff = np.max(np.abs(ref - tar))
rel_diff = np.max(np.abs(ref - tar) / np.max(np.abs(tar)))
+ diff_string = (f"Max abs diff: {abs_diff:.2e}, atol: {atol:.2e}, "
+ f"max relative diff: {rel_diff:.2e}, rtol: {rtol:.2e}")
if not same:
- logging.error(
- "Floating point difference between ref and tar was too large. "
- "Max abs diff: %s, atol: %s, max relative diff: %s, rtol: %s",
- abs_diff, atol, rel_diff, rtol)
+ error = ("Floating point difference between ref and tar was too "
+ f"large. {diff_string}")
+ logging.error(error)
else:
+ error = None
logging.info(
"Floating point difference between ref and tar was within "
- "tolerance. "
- "Max abs diff: %s, atol: %s, max relative diff: %s, rtol: %s",
- abs_diff, atol, rel_diff, rtol)
- return same
+ "tolerance. %s", diff_string)
+ return same, error
+ elif np.issubdtype(ref.dtype, np.integer):
+ same = np.array_equal(ref, tar)
+ if not same:
+ abs_diff = np.max(np.abs(ref - tar))
+ error = ("Expected array equality between ref and tar, but got "
+ f"a max elementwise difference of {abs_diff}")
+ logging.error(error)
+ else:
+ error = None
+ return same, error
else:
- return np.array_equal(ref, tar)
+ return np.array_equal(ref, tar), None
# Base check for native number types.
elif isinstance(ref, (int, float)):
- return ref == tar
+ return ref == tar, None
# If outputs end up here then an extra branch for that type should be added.
else:
raise TypeError(f"Encountered results with unexpected type {type(ref)}")
- return True
+ return True, None
def save_plaintext(self, trace_dir: str, summarize: bool = True) -> None:
"""Saves a human-readable string representation of this trace to disk.
@@ -718,12 +737,14 @@
# Compare each target trace of trace_function with the reference trace.
failed_backend_indices = []
+ error_messages = []
for i, tar_trace in enumerate(tar_traces):
logging.info("Comparing the reference backend '%s' with '%s'",
ref_trace.backend_id, tar_trace.backend_id)
- traces_match = Trace.compare_traces(ref_trace, tar_trace)
+ traces_match, errors = Trace.compare_traces(ref_trace, tar_trace)
if not traces_match:
failed_backend_indices.append(i)
+ error_messages.extend(errors)
# Save the results to disk before validating.
ref_trace_dir = _get_trace_dir(modules.artifacts_dir, ref_trace)
@@ -740,10 +761,11 @@
failed_backends = [
tar_traces[i].backend_id for i in failed_backend_indices
]
+ error_list = ''.join([f'\n - {message}' for message in error_messages])
self.fail(
- "Comparision between the reference backend and the following targets "
- f"failed: {failed_backends}. The errors above show the inputs and "
- "outputs of the non-matching calls.")
+ "Comparison between the reference backend and the following targets "
+ f"failed: {failed_backends}. Errors: {error_list}\n"
+ "See the logs above for more details about the non-matching calls.")
@classmethod
def tearDownClass(cls) -> None:
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
index e1fb6e3..efd09ff 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
@@ -104,7 +104,7 @@
],
}
# yapf: enable
- same = tf_test_utils.Trace._check_same(ref, tar, rtol=1e-6, atol=1e-6)
+ same, _ = tf_test_utils.Trace._check_same(ref, tar, rtol=1e-6, atol=1e-6)
self.assertEqual(tar_same, same)
def test_trace_inputs_and_outputs(self):
@@ -171,7 +171,9 @@
vmla_trace = tf_test_utils.Trace(vmla_module, vmla_function)
vmla_function(tf_test_utils.TracedModule(vmla_module, vmla_trace))
- self.assertFalse(tf_test_utils.Trace.compare_traces(tf_trace, vmla_trace))
+ same, error_messages = tf_test_utils.Trace.compare_traces(
+ tf_trace, vmla_trace)
+ self.assertFalse(same)
def test_trace_serialize_and_load(self):
diff --git a/integrations/tensorflow/compiler/dialect/tf_strings/conversion/convert_tf_to_tf_strings.cc b/integrations/tensorflow/compiler/dialect/tf_strings/conversion/convert_tf_to_tf_strings.cc
index e5c1a3e..d5c3961 100644
--- a/integrations/tensorflow/compiler/dialect/tf_strings/conversion/convert_tf_to_tf_strings.cc
+++ b/integrations/tensorflow/compiler/dialect/tf_strings/conversion/convert_tf_to_tf_strings.cc
@@ -156,7 +156,7 @@
void populateTFToTFStringsPatterns(MLIRContext *ctx,
OwningRewritePatternList &patterns) {
- populateWithGenerated(ctx, &patterns);
+ populateWithGenerated(ctx, patterns);
patterns.insert<StringFormatOpLowering>(ctx);
}
diff --git a/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/convert_tf_to_tf_tensorlist.cc b/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/convert_tf_to_tf_tensorlist.cc
index 472b890..17d2230 100644
--- a/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/convert_tf_to_tf_tensorlist.cc
+++ b/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/convert_tf_to_tf_tensorlist.cc
@@ -83,7 +83,7 @@
// It only knows how to handle blindly convert one type to another type.
OwningRewritePatternList patterns;
- populateWithGenerated(&getContext(), &patterns);
+ populateWithGenerated(&getContext(), patterns);
patterns.insert<ConvertTfTensorlistConcatV2>(&getContext());
ConversionTarget target(getContext());
diff --git a/integrations/tensorflow/e2e/keras/BUILD b/integrations/tensorflow/e2e/keras/BUILD
index 29e0acc..7753475 100644
--- a/integrations/tensorflow/e2e/keras/BUILD
+++ b/integrations/tensorflow/e2e/keras/BUILD
@@ -200,7 +200,7 @@
datasets = ["cifar10"],
failing_configurations = [
{
- # Failing all but tf and vmla:
+ # Failing on llvm and vulkan:
"models": [
"NASNetLarge",
"NASNetMobile",
@@ -210,7 +210,6 @@
],
"datasets": ["cifar10"],
"backends": [
- "tflite",
"iree_llvmjit",
"iree_vulkan",
],
@@ -284,10 +283,19 @@
datasets = ["imagenet"],
failing_configurations = [
{
- # Failing all but tf and vmla:
+ # Failing vulkan:
"models": [
"InceptionResNetV2",
"InceptionV3",
+ ],
+ "datasets": ["imagenet"],
+ "backends": [
+ "iree_vulkan",
+ ],
+ },
+ {
+ # Failing llvm and vulkan:
+ "models": [
"NASNetLarge",
"NASNetMobile",
"ResNet50V2",
@@ -297,7 +305,6 @@
],
"datasets": ["imagenet"],
"backends": [
- "tflite",
"iree_llvmjit",
"iree_vulkan",
],
diff --git a/integrations/tensorflow/e2e/keras/train/model_train_test.py b/integrations/tensorflow/e2e/keras/train/model_train_test.py
index 49a3949..a77ff95 100644
--- a/integrations/tensorflow/e2e/keras/train/model_train_test.py
+++ b/integrations/tensorflow/e2e/keras/train/model_train_test.py
@@ -14,6 +14,7 @@
# limitations under the License.
"""Test keras Model training."""
+from absl import app
from absl import flags
import numpy as np
from pyiree.tf.support import tf_test_utils
@@ -77,6 +78,11 @@
class ModelTrainTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, *args, **kwargs):
+ super(ModelTrainTest, self).__init__(*args, **kwargs)
+ self._modules = tf_test_utils.compile_tf_module(
+ ModelTrain.CreateModule, exported_names=["train_step"])
+
def generate_regression_data(self, size=8):
x = np.arange(size) - size // 2
y = 1.0 * x**3 + 1.0 * x**2 + 1.0 * x + np.random.randn(size) * size
@@ -105,9 +111,13 @@
self.compare_backends(train_step, self._modules)
-if __name__ == "__main__":
+def main(argv):
+ del argv # Unused
if hasattr(tf, "enable_v2_behavior"):
tf.enable_v2_behavior()
- tf_test_utils.compile_tf_module(ModelTrain.CreateModule,
- exported_names=["train_step"])
+
tf.test.main()
+
+
+if __name__ == "__main__":
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/keras/vision_model_test.py b/integrations/tensorflow/e2e/keras/vision_model_test.py
index 0e7a263..f533915 100644
--- a/integrations/tensorflow/e2e/keras/vision_model_test.py
+++ b/integrations/tensorflow/e2e/keras/vision_model_test.py
@@ -150,10 +150,10 @@
self._modules = tf_test_utils.compile_tf_module(VisionModule,
exported_names=['predict'])
- def test_application(self):
+ def test_predict(self):
def predict(module):
- module.predict(tf_utils.uniform(get_input_shape()))
+ module.predict(tf_utils.uniform(get_input_shape()), atol=1e-5, rtol=1e-5)
self.compare_backends(predict, self._modules)
diff --git a/integrations/tensorflow/e2e/slim_vision_models/BUILD b/integrations/tensorflow/e2e/slim_vision_models/BUILD
index 42ce76d..fe87fb7 100644
--- a/integrations/tensorflow/e2e/slim_vision_models/BUILD
+++ b/integrations/tensorflow/e2e/slim_vision_models/BUILD
@@ -72,56 +72,36 @@
],
},
{
- # Failing all but tf and vmla:
+ # Failing llvmjit and vulkan:
"models": [
- "inception_resnet_v2",
- # tflite: RuntimeError: tensorflow/lite/kernels/reshape.cc:66 num_input_elements != num_output_elements (38400 != -1571481807)Node number 333 (RESHAPE) failed to prepare.
- # llvmjit: *** Received signal 6 ***
- # vulkan: Floating point difference was too large. Max abs diff: 1.3961234, atol: 5e-05, max relative diff: 0.27304956, rtol: 1e-06
+ "nasnet_mobile",
+ "nasnet_large",
+ "pnasnet_large",
+ "resnet_v2_50",
"resnet_v2_101",
- # tflite: RuntimeError: tensorflow/lite/core/subgraph.cc BytesRequired number of elements overflowed.
- # llvmjit: Floating point difference was too large. Max abs diff: 11.668068, atol: 5e-05, max relative diff: 0.93950737, rtol: 1e-06
- # vulkan: Floating point difference was too large. Max abs diff: 11.668067, atol: 5e-05, max relative diff: 0.93950737, rtol: 1e-06
"resnet_v2_152",
- # tflite: RuntimeError: tensorflow/lite/core/subgraph.cc BytesRequired number of elements overflowed.
- # llvmjit: Floating point difference was too large. Max abs diff: 7.080696, atol: 5e-05, max relative diff: 0.97750616, rtol: 1e-06
- # vulkan: Floating point difference was too large. Max abs diff: 7.08069, atol: 5e-05, max relative diff: 0.97750485, rtol: 1e-06
],
"backends": [
- "tflite",
"iree_llvmjit",
"iree_vulkan",
],
},
{
- # Failing llvmjit and vulkan:
+ # Failing vulkan:
"models": [
+ # [ERROR]: cannot separate Linalg/Parallel ops into multiple kernels
+ "inception_v1",
"inception_v2",
- # llvmjit: double free or corruption (!prev); *** Received signal 6 ***
- # vulkan: Floating point difference was too large. Max abs diff: 1.0769763, atol: 5e-05, max relative diff: 0.19576924, rtol: 1e-06
"inception_v3",
- # llvmjit: double free or corruption (!prev); *** Received signal 6 ***
- # vulkan: Floating point difference was too large. Max abs diff: 2.5201874, atol: 5e-05, max relative diff: 0.53700095, rtol: 1e-06
- "nasnet_mobile",
- # llvmjit: corrupted size vs. prev_size; *** Received signal 6 ***
- # vulkan: *** Received signal 11 ***
- "nasnet_large",
- # llvmjit: *** Received signal 6 ***
- # vulkan: *** Received signal 11 ***
- "pnasnet_large",
- # llvmjit: Floating point difference was too large. Max abs diff: 1.0411791, atol: 5e-05, max relative diff: 0.20533353, rtol: 1e-06
- # vulkan: *** Received signal 11 ***
- "resnet_v2_50",
- # llvmjit: Floating point difference was too large. Max abs diff: 5.8187943, atol: 5e-05, max relative diff: 0.7946711, rtol: 1e-06
- # vulkan: Floating point difference was too large. Max abs diff: 5.8187933, atol: 5e-05, max relative diff: 0.79467094, rtol: 1e-06
+ "inception_resnet_v2",
],
"backends": [
- "iree_llvmjit",
"iree_vulkan",
],
},
],
models = [
+ "amoebanet_a_n18_f448",
"inception_resnet_v2",
"inception_v1",
"inception_v2",
diff --git a/iree/base/atomics.h b/iree/base/atomics.h
index a2d0e71..b97d803 100644
--- a/iree/base/atomics.h
+++ b/iree/base/atomics.h
@@ -79,9 +79,9 @@
typedef struct {
int64_t __val;
} iree_atomic_int64_t;
-typedef __declspec(align(16)) struct {
- uint64_t __val[2];
-} iree_atomic_int128_t;
+// typedef __declspec(align(16)) struct {
+// uint64_t __val[2];
+// } iree_atomic_int128_t;
#define iree_atomic_load_int32(object, order) \
InterlockedExchangeAdd((volatile LONG*)object, 0)
@@ -153,7 +153,8 @@
typedef _Atomic int32_t iree_atomic_int32_t;
typedef _Atomic int64_t iree_atomic_int64_t;
-typedef _Atomic __int128 iree_atomic_int128_t;
+// TODO(#3453): check for __int128 support before using
+// typedef _Atomic __int128 iree_atomic_int128_t;
#define iree_atomic_load_auto(object, order) \
__c11_atomic_load((object), (order))
@@ -192,7 +193,7 @@
typedef int32_t iree_atomic_int32_t;
typedef int64_t iree_atomic_int64_t;
-typedef __int128 iree_atomic_int128_t;
+// typedef __int128 iree_atomic_int128_t;
#ifdef __cplusplus
// Equiv to C++ auto keyword in C++ mode.
diff --git a/iree/base/ref_ptr_test.cc b/iree/base/ref_ptr_test.cc
index 5087791..efe0bd6 100644
--- a/iree/base/ref_ptr_test.cc
+++ b/iree/base/ref_ptr_test.cc
@@ -218,6 +218,8 @@
struct MyBaseType : public RefObject<MyBaseType> {
int x = 5;
using RefObject<MyBaseType>::counter_; // Expose for testing.
+
+ virtual ~MyBaseType() = default;
};
struct MyTypeA : public MyBaseType {
int a = 6;
diff --git a/iree/base/threading_pthreads.c b/iree/base/threading_pthreads.c
index a5f99b6..5482414 100644
--- a/iree/base/threading_pthreads.c
+++ b/iree/base/threading_pthreads.c
@@ -241,11 +241,15 @@
iree_thread_t* thread, iree_thread_priority_class_t priority_class) {
IREE_TRACE_ZONE_BEGIN(z0);
+#if defined(IREE_PLATFORM_ANDROID)
+ // TODO(benvanik): Some sort of solution on Android, if possible (see above)
+#else
int policy = 0;
struct sched_param param;
pthread_getschedparam(thread->handle, &policy, ¶m);
param = iree_thread_sched_param_for_priority_class(policy, priority_class);
pthread_setschedparam(thread->handle, policy, ¶m);
+#endif // IREE_PLATFORM_ANDROID
IREE_TRACE_ZONE_END(z0);
}
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp
index 0b20c43..5421cc9 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp
@@ -605,8 +605,9 @@
// If the `linalgOp` writes to workgroup memory insert barrier after the
// op.
if (llvm::any_of(linalgOp.getOperands(), [](Value output) {
- return output.getType().cast<MemRefType>().getMemorySpace() ==
- getWorkgroupMemorySpace();
+ MemRefType outputType = output.getType().dyn_cast<MemRefType>();
+ return outputType &&
+ outputType.getMemorySpace() == getWorkgroupMemorySpace();
})) {
rewriter.create<spirv::ControlBarrierOp>(
linalgOp.getLoc(), spirv::Scope::Workgroup, spirv::Scope::Workgroup,
@@ -751,6 +752,7 @@
MapLinalgOpToGlobalInvocationId<linalg::IndexedGenericOp>,
MapLinalgOpToLocalInvocationId<linalg::ConvOp>,
MapLinalgOpToLocalInvocationId<linalg::CopyOp>,
+ MapLinalgOpToLocalInvocationId<linalg::FillOp>,
MapLinalgOpToLocalInvocationId<linalg::MatmulOp>,
MapLinalgOpToLocalInvocationId<linalg::BatchMatmulOp>,
MapLinalgOpToLocalInvocationId<linalg::PoolingMaxOp>,
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
index fdae78b..0021c77 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
@@ -63,23 +63,22 @@
entryPointFn.emitError("unable to find num workgroups fn ") << attr;
return nullptr;
}
- if (!numWorkgroupsFn.empty()) {
- entryPointFn.emitError("num workgroups fn expected to be empty");
- return nullptr;
- }
return numWorkgroupsFn;
}
/// Computes the bounds of the parallel loops partitioned across workgroups.
static Optional<SmallVector<Value, 2>> getParallelLoopRange(
- PatternRewriter &rewriter, Location loc, linalg::LinalgOp linalgOp) {
- FuncOp numWorkgroupsFn =
- getNumWorkgroupsFn(linalgOp.getParentOfType<FuncOp>());
- if (!numWorkgroupsFn) return {};
+ PatternRewriter &rewriter, FuncOp numWorkgroupsFn, Location loc,
+ linalg::LinalgOp linalgOp) {
+ if (!numWorkgroupsFn.empty()) {
+ numWorkgroupsFn.emitError("num workgroups fn expected to be empty");
+ return {};
+ }
LLVM_DEBUG({
llvm::dbgs() << "Found num workgroups function : "
<< numWorkgroupsFn.getName();
});
+
rewriter.setInsertionPointToEnd(numWorkgroupsFn.addEntryBlock());
llvm::SetVector<Operation *> slice;
getBackwardSlice(linalgOp, &slice);
@@ -127,10 +126,14 @@
linalg::LinalgOp linalgOp,
FuncOp entryPointFn,
ArrayRef<int64_t> tileSizes) {
+ FuncOp numWorkgroupsFn =
+ getNumWorkgroupsFn(linalgOp.getParentOfType<FuncOp>());
+ if (!numWorkgroupsFn) return failure();
+
Location loc = linalgOp.getLoc();
OpBuilder::InsertionGuard gaurd(rewriter);
Optional<SmallVector<Value, 2>> parallelLoopRange =
- getParallelLoopRange(rewriter, loc, linalgOp);
+ getParallelLoopRange(rewriter, numWorkgroupsFn, loc, linalgOp);
if (!parallelLoopRange) return failure();
Value one = rewriter.create<ConstantIndexOp>(loc, 1);
SmallVector<Value, 3> returnValues(3, one);
@@ -148,10 +151,23 @@
LogicalResult createNumWorkgroupsFromLinearizedResultShape(
PatternRewriter &rewriter, linalg::LinalgOp linalgOp, FuncOp entryPointFn,
int64_t workgroupSizeX) {
+ FuncOp numWorkgroupsFn =
+ getNumWorkgroupsFn(linalgOp.getParentOfType<FuncOp>());
+ if (!numWorkgroupsFn) return failure();
+ if (!numWorkgroupsFn.empty()) {
+ // TODO(ravishankarm): We can end up with multiple linalg operations
+ // (typically linalg.generic operations) that have the same workload in a
+ // dispatch region. In that case, the first linalg.generic creates the body
+ // of number of workgroups. For now, just returning if the body is not empty
+ // assuming that it is correct for all the ops in the dispatch region. This
+ // needs to be enforced somehow.
+ return success();
+ }
+
Location loc = linalgOp.getLoc();
OpBuilder::InsertionGuard gaurd(rewriter);
Optional<SmallVector<Value, 2>> parallelLoopRange =
- getParallelLoopRange(rewriter, loc, linalgOp);
+ getParallelLoopRange(rewriter, numWorkgroupsFn, loc, linalgOp);
if (!parallelLoopRange) return failure();
Value one = rewriter.create<ConstantIndexOp>(loc, 1);
SmallVector<Value, 3> returnValues(3, one);
@@ -168,6 +184,10 @@
// Launch config calculation.
//===----------------------------------------------------------------------===//
+/// Name of the StrAttr that can be used to get the key to access the tile size
+/// information.
+static const char kLaunchInfoKey[] = "launch_info_key";
+
/// Given `nprocs` try to distribute it evenly across 2 logical x and y.
static std::tuple<int64_t, int64_t> distributeProcs2D(int64_t nprocs) {
int64_t nprocs_x = std::max<int64_t>(
@@ -387,12 +407,27 @@
#undef DEFINE_POOLINGOP_CONFIG
-LogicalResult LaunchConfig::init(const SPIRVCodegenOptions &options,
- ArrayRef<linalg::LinalgOp> linalgOps) {
+Optional<StringRef> LaunchConfig::getKey(Operation *op) const {
+ StringAttr attr = op->getAttrOfType<StringAttr>(kLaunchInfoKey);
+ if (!attr) return {};
+ return attr.getValue();
+}
+
+LogicalResult LaunchConfig::init(MLIRContext *context,
+ const SPIRVCodegenOptions &options,
+ ArrayRef<Operation *> linalgOps) {
+ unsigned numTiledOps = 0;
+ auto setKey = [&](Operation *op) -> std::string {
+ std::string key = llvm::formatv("__op_num_{0}__", numTiledOps++).str();
+ op->setAttr(Identifier::get(kLaunchInfoKey, context),
+ StringAttr::get(key, context));
+ return key;
+ };
+
if (!options.workgroupSize.empty()) {
- for (linalg::LinalgOp op : linalgOps)
- tileSizes[op.getOperation()->getName().getStringRef()].emplace_back(
- options.tileSizes.begin(), options.tileSizes.end());
+ for (Operation *linalgOp : linalgOps)
+ tileSizes[setKey(linalgOp)].emplace_back(options.tileSizes.begin(),
+ options.tileSizes.end());
workgroupSize = {1, 1, 1};
for (unsigned i = 0,
e = std::min<unsigned>(3, options.workgroupSize.size());
@@ -406,17 +441,18 @@
spirv::ResourceLimitsAttr resourceLimits =
spirv::lookupTargetEnv(*linalgOps.begin()).getResourceLimits();
- for (linalg::LinalgOp op : linalgOps) {
- StringRef key = op.getOperation()->getName().getStringRef();
- if (tileSizes.count(key)) {
- return op.emitError("unexpected multiple ")
- << key << " operations within dispatch region";
- }
+ Optional<linalg::LinalgOp> rootOperation = {};
- TileSizesListType &tileSizesInfo = tileSizes[key];
-
+ for (Operation *op : linalgOps) {
+ linalg::LinalgOp linalgOp = cast<linalg::LinalgOp>(op);
#define DISPATCH(opName) \
- if (auto lOp = dyn_cast<opName>(op.getOperation())) { \
+ if (auto lOp = dyn_cast<opName>(linalgOp.getOperation())) { \
+ if (rootOperation) { \
+ return lOp.emitError( \
+ "unhandled multiple root operations in dispatch region"); \
+ } \
+ rootOperation = cast<linalg::LinalgOp>(lOp.getOperation()); \
+ TileSizesListType &tileSizesInfo = tileSizes[setKey(*rootOperation)]; \
if (failed(getOpLaunchConfig(lOp, options, resourceLimits, tileSizesInfo, \
workgroupSize, numSubgroups))) { \
return failure(); \
@@ -439,5 +475,12 @@
return success();
}
+void LaunchConfig::finalize(FuncOp funcOp) {
+ funcOp.walk([&](linalg::LinalgOp linalgOp) {
+ linalgOp.removeAttr(Identifier::get(kLaunchInfoKey, funcOp.getContext()));
+ ;
+ });
+}
+
} // namespace iree_compiler
} // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h
index d690a0a..919464e 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h
+++ b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h
@@ -26,6 +26,7 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringMap.h"
+#include "llvm/Support/FormatVariadic.h"
#include "mlir/IR/Operation.h"
#include "mlir/Support/LLVM.h"
@@ -39,7 +40,7 @@
namespace linalg {
class LinalgOp;
-}
+} // namespace linalg
namespace iree_compiler {
struct SPIRVCodegenOptions;
}
@@ -94,17 +95,26 @@
/// - tile sizes for each level,
/// - the workgroup size, and
/// - number of subgroups to use.
- LogicalResult init(const SPIRVCodegenOptions &options,
- ArrayRef<linalg::LinalgOp> linalgOps);
+ LogicalResult init(MLIRContext *context, const SPIRVCodegenOptions &options,
+ ArrayRef<Operation *> linalgOps);
+
+ /// Remove attributed added to operations for retrieving tile size
+ /// information.
+ void finalize(FuncOp funcOp);
/// Gets the tile size computed for an operation at all levels.
TileSizesListType getTileSizes(Operation *op) const {
- return tileSizes.lookup(op->getName().getStringRef());
+ auto key = getKey(op);
+ if (!key) return {};
+ auto it = tileSizes.find(*key);
+ return it->second;
}
/// Gets the tile size computed for an operation for an level.
ArrayRef<int64_t> getTileSizes(Operation *op, size_t level) const {
- auto it = tileSizes.find(op->getName().getStringRef());
+ auto key = getKey(op);
+ if (!key) return {};
+ auto it = tileSizes.find(*key);
if (it == tileSizes.end() || level >= it->second.size()) return {};
return it->second[level];
}
@@ -115,14 +125,19 @@
/// Returns the number of subgroups to use.
ArrayRef<int64_t> getNumSubgroups() const { return numSubgroups; }
- protected:
- /// Current tile size configuration per operation.
+ /// Returns true if tile sizes have been computed for the operation. If tile
+ /// sizes arent set, it implies operation is not to be tiled.
+ bool hasTileSizes(Operation *op, size_t level = 0) const {
+ return !getTileSizes(op, level).empty();
+ }
- // TODO: For now just use the operation name for the mapping. The tile sizes
- // will be selected only for operations like matmul, conv, pool, etc. and
- // assume that there is only one such operation per dispatch
- // region. Eventually this might need to be relaxed, and some name-marker
- // based mechanism might be needed.
+ protected:
+ /// Current tile size configuration per operation. They key used here to
+ /// retrieve the tile size information per operation is the value of a StrAttr
+ /// added to operations during `init`. When tiled this attribute is copied
+ /// over to the tiled operation, thereby the same key can be used to retrieve
+ /// the tile sizes for the next level of tiling. The `finalize` method removes
+ /// these attributes.
llvm::StringMap<TileSizesListType> tileSizes;
/// Workgroup size to use.
@@ -130,6 +145,11 @@
/// Number of subgroups that are logically distributed along x, y & z.
std::array<int64_t, 3> numSubgroups;
+
+ private:
+ /// Retrieves the key to use to get the `tileSizes` for a given
+ /// `operation`. Returns llvm::None on failure.
+ Optional<StringRef> getKey(Operation *op) const;
};
} // namespace iree_compiler
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
index 5c5abbd..c2ec755 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
@@ -27,6 +27,7 @@
#include "iree/compiler/Conversion/LinalgToSPIRV/Utils.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
@@ -123,16 +124,19 @@
namespace {
/// Pattern for tiling operations. Updates the workgroup size in the surrounding
-/// function operation if tiling succeeds.
+/// function operation if tiling succeeds, and generates the function that
+/// computes the number of workgroups for the launch.
template <typename LinalgOpTy>
struct TileToWorkgroupsPattern : public linalg::LinalgBaseTilingPattern {
using Base = linalg::LinalgBaseTilingPattern;
TileToWorkgroupsPattern(MLIRContext *context,
+ const linalg::LinalgDependenceGraph &dependenceGraph,
linalg::LinalgTilingOptions options,
linalg::LinalgMarker marker,
const LaunchConfig &launchConfig,
PatternBenefit benefit = 1)
: Base(LinalgOpTy::getOperationName(), context, options, marker, benefit),
+ dependenceGraph(dependenceGraph),
launchConfig(launchConfig) {}
LogicalResult matchAndRewrite(Operation *op,
@@ -151,18 +155,59 @@
launchConfig.getTileSizes(op, 0))))) {
return failure();
}
- rewriter.eraseOp(op);
+ setMarker(op, getDeleteMarker());
return success();
}
+ const linalg::LinalgDependenceGraph &dependenceGraph;
+ const LaunchConfig &launchConfig;
+};
+
+/// Pattern for tile + fuse of operations. Updates the workgroup size in the
+/// surrounding function operation if tiling succeeds, and generates the
+/// function that computes the number of workgroups for the launch..
+template <typename LinalgOpTy>
+struct TileAndFuseToWorkgroupsPattern
+ : public linalg::LinalgTileAndFusePattern<LinalgOpTy> {
+ using Base = linalg::LinalgTileAndFusePattern<LinalgOpTy>;
+ TileAndFuseToWorkgroupsPattern(
+ MLIRContext *context,
+ const linalg::LinalgDependenceGraph &dependenceGraph,
+ linalg::LinalgTilingOptions tilingOptions, linalg::LinalgMarker marker,
+ const LaunchConfig &launchConfig, PatternBenefit benefit = 1)
+ : Base(context, dependenceGraph, tilingOptions,
+ linalg::LinalgFusionOptions().setIndicesToFuse({2}), marker,
+ marker,
+ linalg::LinalgMarker(ArrayRef<Identifier>(),
+ Identifier::get(getDeleteMarker(), context)),
+ benefit),
+ dependenceGraph(dependenceGraph),
+ launchConfig(launchConfig) {}
+
+ virtual LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const {
+ FuncOp funcOp = op->getParentOfType<FuncOp>();
+ linalg::LinalgOp linalgOp = cast<linalg::LinalgOp>(op);
+ if (!funcOp || !dependenceGraph.hasDependentOperations(linalgOp) ||
+ failed(Base::matchAndRewrite(op, rewriter)) ||
+ failed(updateWorkGroupSize(funcOp, launchConfig.getWorkgroupSize())) ||
+ (funcOp.getAttr(getNumWorkgroupsFnAttrName()) &&
+ failed(createNumWorkgroupsFromResultShape(
+ rewriter, linalgOp, funcOp, launchConfig.getTileSizes(op, 0))))) {
+ return failure();
+ }
+ return success();
+ }
+
+ const linalg::LinalgDependenceGraph &dependenceGraph;
const LaunchConfig &launchConfig;
};
} // namespace
/// Populate patterns for first-level tiling.
static void populateTilingToWorkgroupPatterns(
- MLIRContext *context, const LaunchConfig &launchConfig,
- OwningRewritePatternList &patterns) {
+ MLIRContext *context, const linalg::LinalgDependenceGraph &dependenceGraph,
+ const LaunchConfig &launchConfig, OwningRewritePatternList &patterns) {
// Function to compute first level tiling values.
std::function<SmallVector<Value, 4>(OpBuilder &, Operation *)>
getOuterTileSizeFn =
@@ -178,13 +223,19 @@
}
return tileSizesVal;
};
- patterns.insert<TileToWorkgroupsPattern<linalg::ConvOp>,
- TileToWorkgroupsPattern<linalg::MatmulOp>,
+ patterns.insert<TileAndFuseToWorkgroupsPattern<linalg::BatchMatmulOp>,
+ TileAndFuseToWorkgroupsPattern<linalg::ConvOp>,
+ TileAndFuseToWorkgroupsPattern<linalg::MatmulOp>,
+ TileAndFuseToWorkgroupsPattern<linalg::PoolingMaxOp>,
+ TileAndFuseToWorkgroupsPattern<linalg::PoolingMinOp>,
+ TileAndFuseToWorkgroupsPattern<linalg::PoolingSumOp>,
TileToWorkgroupsPattern<linalg::BatchMatmulOp>,
+ TileToWorkgroupsPattern<linalg::ConvOp>,
+ TileToWorkgroupsPattern<linalg::MatmulOp>,
TileToWorkgroupsPattern<linalg::PoolingMaxOp>,
TileToWorkgroupsPattern<linalg::PoolingMinOp>,
TileToWorkgroupsPattern<linalg::PoolingSumOp>>(
- context,
+ context, dependenceGraph,
linalg::LinalgTilingOptions()
.setDistributionOptions(workgroupDistributionOptions)
.setTileSizeComputationFunction(getOuterTileSizeFn)
@@ -380,9 +431,9 @@
if (linalgOps.empty()) continue;
LaunchConfig launchConfig;
- SmallVector<linalg::LinalgOp, 4> linalgOpsVec(linalgOps.begin(),
- linalgOps.end());
- if (failed(launchConfig.init(options, linalgOpsVec))) {
+ SmallVector<Operation *, 4> linalgOpsVec(linalgOps.begin(),
+ linalgOps.end());
+ if (failed(launchConfig.init(context, options, linalgOpsVec))) {
funcOp.emitError("unable to find launch configuration");
return signalPassFailure();
}
@@ -406,11 +457,24 @@
}
});
- OwningRewritePatternList firstLevelTilingPatterns;
- populateTilingToWorkgroupPatterns(context, launchConfig,
- firstLevelTilingPatterns);
- applyPatternsAndFoldGreedily(funcOp, firstLevelTilingPatterns);
- applyCanonicalizationPatterns(context, funcOp);
+ {
+ // Compute the Linalg Dependence Graph.
+ linalg::Aliases aliases;
+ linalg::LinalgDependenceGraph dependenceGraph =
+ linalg::LinalgDependenceGraph::buildDependenceGraph(aliases, funcOp);
+
+ OwningRewritePatternList firstLevelTilingPatterns;
+ populateTilingToWorkgroupPatterns(context, dependenceGraph, launchConfig,
+ firstLevelTilingPatterns);
+ applyPatternsAndFoldGreedily(funcOp, firstLevelTilingPatterns);
+ applyCanonicalizationPatterns(context, funcOp);
+
+ // Delete the ops that are marked for deletion.
+ funcOp.walk([](linalg::LinalgOp linalgOp) {
+ if (hasMarker(linalgOp.getOperation(), getDeleteMarker()))
+ linalgOp.getOperation()->erase();
+ });
+ }
if (options.useWorkgroupMemory) {
// The promotion patterns are put separate from the tiling patterns to
@@ -434,6 +498,13 @@
vectorizationPatterns);
applyPatternsAndFoldGreedily(funcOp, vectorizationPatterns);
}
+
+ launchConfig.finalize(funcOp);
+ SmallVector<linalg::LinalgOp, 1> toDelete;
+ funcOp.walk([&](linalg::LinalgOp linalgOp) {
+ if (hasMarker(linalgOp.getOperation(), getDeleteMarker()))
+ linalgOp.erase();
+ });
}
}
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.cpp b/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.cpp
index 51d3585..dff5292 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.cpp
@@ -27,6 +27,8 @@
const StringLiteral VectorTransforms::kVectorTransformMarker =
"__internal_vector_transform__";
+StringRef getFusedMarker() { return "fused_numprocs_ge_numiters"; }
+
StringRef getWorkgroupMarker() { return "workgroup"; }
StringRef getWorkgroupMemoryMarker() { return "workgroup_memory"; }
@@ -45,6 +47,8 @@
StringRef getVectorizeMarker() { return "vectorize"; }
+StringRef getDeleteMarker() { return "delete"; }
+
bool hasMarker(Operation *op, ArrayRef<StringRef> marker) {
StringAttr attr = op->getAttrOfType<StringAttr>(
linalg::LinalgTransforms::kLinalgTransformMarker);
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h b/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h
index 72d5385..78d4304 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h
+++ b/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h
@@ -44,6 +44,12 @@
/// Marker for operations that are going to be vectorized.
StringRef getVectorizeMarker();
+/// Marker for tagging an operation for deletion. Tile and fuse pattern does not
+/// delete the original operation to not invalidate the
+/// `linalg::LinalgDependenceGraph` data structure. Instead it is marked with a
+/// marker that can be used later to delete these operations.
+StringRef getDeleteMarker();
+
/// Returns true if an operation has the specified `marker`. When `marker` is
/// empty, returns true if the operation has any marker.
bool hasMarker(Operation *, ArrayRef<StringRef> markers = {});
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp b/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
index 04044a7..d751065 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
@@ -85,21 +85,17 @@
// - All Linalg ops have buffer semantics.
//
// Post-conditions:
+ // - The operations that cannot be fused at buffer levels are split into
+ // separate entry points.
// - If the input Linalg ops are tilable:
// - loop.parallel ops are generated for mapping to workgroups.
// - Linalg ops are nested inside loop.parallel ops and ready for mapping
// to workitems.
+ // - If multiple linalg operations are present they get tiled and fused to
+ // get outer loop.parallel ops which can be mapped to workitems.
// - Otherwise:
// - The Linalg op is kept untouched.
- // - Dispatch functions might be split into multiple ones.
//
- // Note:
- // We first try to tile and fuse the dispatch function as a whole. If there
- // are multiple Linalg ops inside, they may not share any number of common
- // outer parallel iterators. Then the first tile and fuse pass will do
- // nothing. We split all the Linalg ops into their own dispatch functions
- // afterwards. This gives each Linalg op a second chance to be tiled,
- // with the second tile and fuse pass.
//===--------------------------------------------------------------------===//
pm.addPass(createSplitDispatchFunctionPass());
pm.addPass(createLinalgTileAndFusePass(options));
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp
index 94436c6..7e8d1aa 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp
@@ -35,6 +35,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/FormatVariadic.h"
+#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
@@ -55,37 +56,94 @@
// Utility functions
//===----------------------------------------------------------------------===//
-namespace {
+/// Returns true if an op can be fused with the list of ops that are to be put
+/// in the same entry point function. This should be consistent with whatthe
+/// downstream passes can handle.
+static bool isFusableWithCurrentOpsList(
+ Operation *nextOp, ArrayRef<Operation *> currOpsList,
+ const linalg::LinalgDependenceGraph &dependenceGraph) {
+ if (currOpsList.empty()) return true;
-/// Returns true if the Linalg ops can be separated to multiple kernels.
-bool canSeparateOps(ArrayRef<Operation *> ops) {
- if (llvm::any_of(ops, [](Operation *op) {
- if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op))
- return !linalgOp.hasBufferSemantics();
- return false;
- }))
- return false;
+ linalg::LinalgOp dstOp = dyn_cast<linalg::LinalgOp>(nextOp);
+ linalg::LinalgOp srcOp = dyn_cast<linalg::LinalgOp>(currOpsList.back());
+ if (dstOp && srcOp) {
+ // TODO(#2963): This splits independent linalg opreations into its own
+ // dispatch, but in reality if the iteration domain of the ops are the same,
+ // and they have all iterator types parallel, they could be put in the same
+ // dispatch region.
+ if (!dependenceGraph.hasDependenceFrom(srcOp, dstOp)) return false;
- // Require no other non-metadata ops interleave with Linalg structured ops for
- // now. This is the common case and it simplifies further analysis.
+#define ADD_FUSABLE_PAIR(SrcOpTy, DstOpTy, DependenceTy) \
+ if (isa<SrcOpTy>(srcOp.getOperation()) && \
+ isa<DstOpTy>(dstOp.getOperation()) && \
+ dependenceGraph.hasDependenceFrom(srcOp, dstOp, DependenceTy)) \
+ return true;
+
+ ADD_FUSABLE_PAIR(linalg::FillOp, linalg::ConvOp,
+ linalg::LinalgDependenceGraph::DependenceType::WAW)
+ ADD_FUSABLE_PAIR(linalg::FillOp, linalg::MatmulOp,
+ linalg::LinalgDependenceGraph::DependenceType::WAW)
+ ADD_FUSABLE_PAIR(linalg::FillOp, linalg::PoolingMaxOp,
+ linalg::LinalgDependenceGraph::DependenceType::WAW)
+ ADD_FUSABLE_PAIR(linalg::FillOp, linalg::PoolingMinOp,
+ linalg::LinalgDependenceGraph::DependenceType::WAW)
+ ADD_FUSABLE_PAIR(linalg::FillOp, linalg::PoolingSumOp,
+ linalg::LinalgDependenceGraph::DependenceType::WAW)
+
+#undef ADD_FUSABLE_PAIR
+ }
+ return false;
+}
+
+/// For the list of operations in `ops` returns a list of lists where each list
+/// contains the operations that need to be put in a separate dispatch function.
+static LogicalResult separateOps(
+ ArrayRef<Operation *> ops,
+ const linalg::LinalgDependenceGraph &dependenceGraph,
+ SmallVectorImpl<SmallVector<Operation *, 1>> &fusedOpList) {
+ assert(!ops.empty() &&
+ "expected at least one separable op for splitting dispatch function");
+ SmallVector<Operation *, 1> currList;
for (auto currOp = ops.begin(), nextOp = std::next(ops.begin());
nextOp != ops.end(); ++currOp, ++nextOp) {
+ // Check that the operation has buffer semantics.
+ if (auto linalgOp = dyn_cast<linalg::LinalgOp>(*currOp)) {
+ if (!linalgOp.hasBufferSemantics()) return failure();
+ }
+
+ // Require no other non-metadata ops interleave with Linalg structured ops
+ // for now. This is the common case and it simplifies further analysis.
Operation *iter = (*currOp)->getNextNode();
while (iter != *nextOp && (MemoryEffectOpInterface::hasNoEffect(iter) ||
isa<IREE::PlaceholderOp>(iter)))
iter = iter->getNextNode();
- if (iter != *nextOp) return false;
- }
+ if (iter != *nextOp) return failure();
- return true;
+ currList.push_back(*currOp);
+
+ // If the nextOp is not fusible with the currOp, then record the list of ops
+ // so far, and start a new list.
+ if (isFusableWithCurrentOpsList(*nextOp, currList, dependenceGraph)) {
+ continue;
+ }
+
+ // Push the current list of ops into the list of lists `currList` and
+ // start a new list.
+ fusedOpList.emplace_back();
+ std::swap(fusedOpList.back(), currList);
+ }
+ currList.push_back(ops.back());
+ fusedOpList.emplace_back(std::move(currList));
+ return success();
}
/// Recursively collects all the operations that are referenced by given
/// `rootOp` into `closure`.
-void collectAllReferencedOps(Operation *rootOp,
- llvm::SmallPtrSetImpl<Operation *> &closure) {
+static void collectAllReferencedOps(
+ ArrayRef<Operation *> rootOps,
+ llvm::SmallPtrSetImpl<Operation *> &closure) {
llvm::SmallVector<Operation *, 8> workList;
- workList.push_back(rootOp);
+ workList.assign(rootOps.begin(), rootOps.end());
while (!workList.empty()) {
Operation *curOp = workList.pop_back_val();
@@ -104,8 +162,6 @@
}
}
-} // namespace
-
//===----------------------------------------------------------------------===//
// Pass and patterns
//===----------------------------------------------------------------------===//
@@ -161,10 +217,16 @@
separableOps.push_back(&op);
if (separableOps.size() <= 1) return success();
- if (!canSeparateOps(separableOps)) {
+
+ linalg::Aliases aliases;
+ linalg::LinalgDependenceGraph dependenceGraph =
+ linalg::LinalgDependenceGraph::buildDependenceGraph(aliases, oldFn);
+ SmallVector<SmallVector<Operation *, 1>, 1> fusedOpsList;
+ if (failed(separateOps(separableOps, dependenceGraph, fusedOpsList))) {
return oldFn.emitError(
"cannot separate Linalg/Parallel ops into multiple kernels");
}
+ if (fusedOpsList.size() <= 1) return success();
ModuleOp moduleOp = cast<ModuleOp>(oldFn.getParentOp());
Block &oldFnBlock = oldFn.getBlocks().front();
@@ -174,10 +236,11 @@
splitKernels.reserve(separableOps.size());
llvm::SmallPtrSet<Operation *, 16> closure;
- for (const auto &separableOp : llvm::enumerate(separableOps)) {
+ for (const auto &fusedOps : llvm::enumerate(fusedOpsList)) {
+ if (fusedOps.value().empty()) continue;
// Create a new function for hosting this op.
- splitKernels.emplace_back(llvm::formatv("{0}_dispatch_{1}", oldFn.getName(),
- separableOp.index()));
+ splitKernels.emplace_back(
+ llvm::formatv("{0}_dispatch_{1}", oldFn.getName(), fusedOps.index()));
StringRef newFnName = splitKernels.back();
builder.setInsertionPointToStart(moduleOp.getBody());
auto newFn = builder.create<FuncOp>(loc, newFnName, oldFn.getType());
@@ -210,7 +273,7 @@
// Collect the closure for the current Linalg op.
closure.clear();
- collectAllReferencedOps(separableOp.value(), closure);
+ collectAllReferencedOps(fusedOps.value(), closure);
// Clone all ops in the closure to the new function.
Block *newFnBlock = newFn.addEntryBlock();
@@ -219,7 +282,7 @@
for (Operation &op : oldFnBlock) {
if (closure.count(&op) == 0) continue;
builder.insert(op.clone(remapper));
- if (&op == separableOp.value()) break;
+ if (&op == fusedOps.value().back()) break;
}
builder.insert(oldFnBlock.getTerminator()->clone(remapper));
}
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_spirv.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_spirv.mlir
index 9737bb6..52aeff7 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_spirv.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_spirv.mlir
@@ -1,13 +1,13 @@
// RUN: iree-opt -split-input-file -iree-codegen-convert-to-spirv %s | IreeFileCheck %s
module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
- // CHECK: spv.globalVariable @__push_constant_var__ : !spv.ptr<!spv.struct<!spv.array<5 x i32, stride=4> [0]>, PushConstant>
+ // CHECK: spv.globalVariable @__push_constant_var__ : !spv.ptr<!spv.struct<(!spv.array<5 x i32, stride=4> [0])>, PushConstant>
// CHECK: spv.func @push_constant()
func @push_constant() {
// CHECK: %[[INDEX_0:.+]] = spv.constant 0 : i32
// CHECK: %[[INDEX_1:.+]] = spv.constant 2 : i32
- // CHECK: %[[ADDR:.+]] = spv._address_of @__push_constant_var__ : !spv.ptr<!spv.struct<!spv.array<5 x i32, stride=4> [0]>, PushConstant>
- // CHECK: %[[AC:.+]] = spv.AccessChain %[[ADDR]][%[[INDEX_0]], %[[INDEX_1]]] : !spv.ptr<!spv.struct<!spv.array<5 x i32, stride=4> [0]>, PushConstant>
+ // CHECK: %[[ADDR:.+]] = spv._address_of @__push_constant_var__ : !spv.ptr<!spv.struct<(!spv.array<5 x i32, stride=4> [0])>, PushConstant>
+ // CHECK: %[[AC:.+]] = spv.AccessChain %[[ADDR]][%[[INDEX_0]], %[[INDEX_1]]] : !spv.ptr<!spv.struct<(!spv.array<5 x i32, stride=4> [0])>, PushConstant>
// CHECK: spv.Load "PushConstant" %[[AC]] : i32
%0 = hal.interface.load.constant offset = 2 : index
return
@@ -23,12 +23,12 @@
module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
- // CHECK: spv.globalVariable @__resource_var_3_4__ bind(3, 4) : !spv.ptr<!spv.struct<!spv.array<16 x f32, stride=4> [0]>, StorageBuffer>
- // CHECK: spv.globalVariable @__resource_var_1_2__ bind(1, 2) : !spv.ptr<!spv.struct<!spv.array<16 x f32, stride=4> [0]>, StorageBuffer>
+ // CHECK: spv.globalVariable @__resource_var_3_4__ bind(3, 4) : !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, StorageBuffer>
+ // CHECK: spv.globalVariable @__resource_var_1_2__ bind(1, 2) : !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, StorageBuffer>
// CHECK: spv.func @resource_variable()
func @resource_variable() {
- // CHECK: spv._address_of @__resource_var_1_2__ : !spv.ptr<!spv.struct<!spv.array<16 x f32, stride=4> [0]>, StorageBuffer>
- // CHECK: spv._address_of @__resource_var_3_4__ : !spv.ptr<!spv.struct<!spv.array<16 x f32, stride=4> [0]>, StorageBuffer>
+ // CHECK: spv._address_of @__resource_var_1_2__ : !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, StorageBuffer>
+ // CHECK: spv._address_of @__resource_var_3_4__ : !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, StorageBuffer>
%0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<4x4xf32>
%1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<4x4xf32>
return
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir
index 0d5adc6..580d32a 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir
@@ -302,3 +302,114 @@
// CHECK: %[[T1:.+]] = addi %[[DIM0]], %[[C3]]
// CHECK-DAG: %[[NBY:.+]] = divi_signed %[[T1]], %[[C4]]
// CHECK: return %[[NBX]], %[[NBY]], %[[C1]]
+
+// -----
+
+module attributes {
+ spv.target_env =
+ #spv.target_env<#spv.vce<v1.3,
+ [Shader], [SPV_KHR_storage_buffer_storage_class]>,
+ {max_compute_workgroup_invocations = 128 : i32,
+ max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
+ func @matmul_fusion() attributes {vkspv.num_workgroups_fn = @matmul_fusion__num_workgroups__} {
+ %0 = iree.placeholder for "interace buffer"
+ {binding = @legacy_io::@arg0, operand_result_index = 0 : i32} : memref<?x?xf32>
+ %1 = iree.placeholder for "interace buffer"
+ {binding = @legacy_io::@arg1, operand_result_index = 1 : i32} : memref<?x?xf32>
+ %2 = iree.placeholder for "interace buffer"
+ {binding = @legacy_io::@ret0, operand_result_index = 2 : i32} : memref<?x?xf32>
+ %cst = constant 0.000000e+00 : f32
+ linalg.fill(%2, %cst) : memref<?x?xf32>, f32
+ linalg.matmul ins(%0, %1 : memref<?x?xf32>, memref<?x?xf32>)
+ outs(%2 : memref<?x?xf32>)
+ return
+ }
+ func @matmul_fusion__num_workgroups__
+ (!shapex.ranked_shape<[?,?]>, !shapex.ranked_shape<[?,?]>,
+ !shapex.ranked_shape<[?,?]>) -> (index, index, index)
+ attributes {sym_visibility = "private"}
+ hal.interface @legacy_io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
+ }
+}
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 8)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<()[s0] -> (s0 * 16)>
+// CHECK: func @matmul_fusion()
+// CHECK-SAME: local_size = dense<[16, 8, 1]>
+// CHECK-DAG: %[[ARG0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg0
+// CHECK-DAG: %[[ARG1:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg1
+// CHECK-DAG: %[[RET0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@ret0
+// CHECK-DAG: %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
+// CHECK-DAG: %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
+// CHECK-NOT: scf.parallel
+// CHECK-NOT: scf.for
+// CHECK: %[[LBY:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
+// CHECK: %[[VIEW0:.+]] = subview %[[ARG0]][%[[LBY]], 0]
+// CHECK: %[[LBX:.+]] = affine.apply #[[MAP3]]()[%[[BIDX]]]
+// CHECK: %[[VIEW1:.+]] = subview %[[ARG1]][0, %[[LBX]]]
+// CHECK: %[[LBY_2:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
+// CHECK: %[[LBX_2:.+]] = affine.apply #[[MAP3]]()[%[[BIDX]]]
+// CHECK: %[[VIEW2:.+]] = subview %[[RET0]][%[[LBY_2]], %[[LBX_2]]]
+// CHECK: %[[VIEW3:.+]] = subview %[[RET0]][%[[LBY_2]], %[[LBX_2]]]
+// CHECK: linalg.fill(%[[VIEW3]], %{{.+}})
+// CHECK-SAME: "workgroup"
+// CHECK: linalg.matmul
+// CHECK-SAME: "workgroup"
+// CHECK-SAME: ins(%[[VIEW0]], %[[VIEW1]]
+// CHECK-SAME: outs(%[[VIEW2]]
+
+// -----
+
+module attributes {
+ spv.target_env =
+ #spv.target_env<#spv.vce<v1.3,
+ [Shader], [SPV_KHR_storage_buffer_storage_class]>,
+ {max_compute_workgroup_invocations = 128 : i32,
+ max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
+ func @conv_no_padding_fusion() {
+ %0 = iree.placeholder for "interace buffer"
+ {binding = @legacy_io::@arg0, operand_result_index = 0 : i32} : memref<?x?x?x?xf32>
+ %1 = iree.placeholder for "interace buffer"
+ {binding = @legacy_io::@arg1, operand_result_index = 1 : i32} : memref<?x?x?x?xf32>
+ %2 = iree.placeholder for "interace buffer"
+ {binding = @legacy_io::@ret0, operand_result_index = 2 : i32} : memref<?x?x?x?xf32>
+ %cst = constant 0.000000e+00 : f32
+ linalg.fill(%2, %cst) : memref<?x?x?x?xf32>, f32
+ linalg.conv(%0, %1, %2) {dilations = [1, 1], strides = [1, 1]} :
+ memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>
+ return
+ }
+ hal.interface @legacy_io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
+ }
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 4)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 32)>
+// CHECK: func @conv_no_padding_fusion()
+// CHECK-SAME: local_size = dense<[32, 4, 1]>
+// CHECK-DAG: %[[ARG0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg0
+// CHECK-DAG: %[[ARG1:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg1
+// CHECK-DAG: %[[RET0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@ret0
+// CHECK-DAG: %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
+// CHECK-DAG: %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
+// CHECK-DAG: %[[BIDZ:.+]] = "gpu.block_id"() {dimension = "z"}
+// CHECK: %[[LBY:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
+// CHECK: %[[LBX:.+]] = affine.apply #[[MAP1]]()[%[[BIDX]]]
+// CHECK: %[[VIEW1:.+]] = subview %[[ARG1]]
+// CHECK-SAME: [%[[BIDZ]], %[[LBY]], %[[LBX]], 0]
+// CHECK: %[[LBY_2:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
+// CHECK: %[[LBX_2:.+]] = affine.apply #[[MAP1]]()[%[[BIDX]]]
+// CHECK: %[[VIEW2:.+]] = subview %[[RET0]]
+// CHECK-SAME: [%[[BIDZ]], %[[LBY_2]], %[[LBX_2]], 0]
+// CHECK: %[[VIEW3:.+]] = subview %[[RET0]]
+// CHECK-SAME: [%[[BIDZ]], %[[LBY_2]], %[[LBX_2]], 0]
+// CHECK: linalg.fill(%[[VIEW3]], %{{.*}})
+// CHECK-SAME: "workgroup"
+// CHECK: linalg.conv
+// CHECK-SAME: %[[ARG0]], %[[VIEW1]], %[[VIEW2]]
+// CHECK-SAME: "workgroup"
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir
index 5d3ff91..e2e8d7e 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir
@@ -1,8 +1,117 @@
// RUN: iree-opt -split-input-file -iree-codegen-split-dispatch-function -verify-diagnostics %s | IreeFileCheck %s
+module {
+ // CHECK: func @kernel_fusable_fill_conv_ops
+ // CHECK: linalg.fill
+ // CHECK: linalg.conv
+
+ func @kernel_fusable_fill_conv_ops()
+ attributes {vkspv.num_workgroups_fn = @kernel_fusable_fill_conv_ops_num_workgroups__} {
+ %cst = constant 0.000000e+00 : f32
+ %dim = hal.interface.load.constant offset = 0 : index
+ %shape1 = shapex.make_ranked_shape %dim : (index) -> !shapex.ranked_shape<[?,2,2,512]>
+ %shape2 = shapex.make_ranked_shape %dim : (index) -> !shapex.ranked_shape<[?,1,1,512]>
+ %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x2x2x512xf32>
+ %ts1 = shapex.tie_shape %0, %shape1 : memref<?x2x2x512xf32>, !shapex.ranked_shape<[?,2,2,512]>
+ %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<3x3x512x1xf32>
+ %2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<?x1x1x512xf32>
+ %ts2 = shapex.tie_shape %2, %shape2 : memref<?x1x1x512xf32>, !shapex.ranked_shape<[?,1,1,512]>
+ linalg.fill(%ts2, %cst) : memref<?x1x1x512xf32>, f32
+ linalg.conv(%1, %ts1, %ts2) {dilations = [1, 1], padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, strides = [2, 2]} : memref<3x3x512x1xf32>, memref<?x2x2x512xf32>, memref<?x1x1x512xf32>
+ return
+ }
+ func @kernel_fill_conv_ops_num_workgroups__(!shapex.ranked_shape<[?,2,2,512]>,
+ !shapex.ranked_shape<[3,3,512,1]>,
+ !shapex.ranked_shape<[?,1,1,512]>)
+ -> (index, index, index)
+ attributes {sym_visibility = "private"}
+ hal.interface @legacy_io attributes {push_constants = 1 : i32, sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+ }
+}
+
+// -----
+
+module {
+ // CHECK: func @kernel_fusable_fill_matmul_ops
+ // CHECK: linalg.fill
+ // CHECK: linalg.matmul
+
+ func @kernel_fusable_fill_matmul_ops()
+ attributes {vkspv.num_workgroups_fn = @kernel_fusable_fill_matmul_ops_num_workgroups__} {
+ %cst = constant 0.000000e+00 : f32
+ %dimM = hal.interface.load.constant offset = 0 : index
+ %dimN = hal.interface.load.constant offset = 1 : index
+ %shape1 = shapex.make_ranked_shape %dimM : (index) -> !shapex.ranked_shape<[?,512]>
+ %shape2 = shapex.make_ranked_shape %dimN : (index) -> !shapex.ranked_shape<[512,?]>
+ %shape3 = shapex.make_ranked_shape %dimM, %dimN : (index, index) -> !shapex.ranked_shape<[?,?]>
+ %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x512xf32>
+ %ts1 = shapex.tie_shape %0, %shape1 : memref<?x512xf32>, !shapex.ranked_shape<[?,512]>
+ %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<512x?xf32>
+ %ts2 = shapex.tie_shape %1, %shape2 : memref<512x?xf32>, !shapex.ranked_shape<[512, ?]>
+ %2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<?x?xf32>
+ %ts3 = shapex.tie_shape %2, %shape3 : memref<?x?xf32>, !shapex.ranked_shape<[?,?]>
+ linalg.fill(%ts3, %cst) : memref<?x?xf32>, f32
+ linalg.matmul ins(%ts1, %ts2 : memref<?x512xf32>, memref<512x?xf32>)
+ outs(%ts3 : memref<?x?xf32>)
+ return
+ }
+ func @kernel_fusable_matmul_ops_num_workgroups__(!shapex.ranked_shape<[?,512]>,
+ !shapex.ranked_shape<[512,?]>,
+ !shapex.ranked_shape<[?,?]>)
+ -> (index, index, index)
+ attributes {sym_visibility = "private"}
+ hal.interface @legacy_io attributes {push_constants = 1 : i32, sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+ }
+}
+
+// -----
+
+module {
+ // CHECK: func @kernel_fusable_pooling()
+ // CHECK: linalg.fill
+ // CHECK: linalg.pooling
+ func @kernel_fusable_pooling() attributes {vkspv.num_workgroups_fn = @kernel_fusable_pooling__num_workgroups__} {
+ %cst = constant 0.000000e+00 : f32
+ %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x?xf32>
+ %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<?x?xf32>
+ %2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<?x?xf32>
+ linalg.fill(%2, %cst) : memref<?x?xf32>, f32
+ linalg.pooling_sum(%1, %0, %2) {dilations = [1, 1], strides = [1, 1]} :
+ memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
+ return
+ }
+ func @kernel_fusable_pooling__num_workgroups__(!shapex.ranked_shape<[?,?]>,
+ !shapex.ranked_shape<[?,?]>,
+ !shapex.ranked_shape<[?,?]>)
+ -> (index, index, index)
+ attributes {sym_visibility = "private"}
+ hal.interface @legacy_io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+ }
+}
+
+// -----
+
// CHECK: module attributes {vkspv.entry_point_schedule = ["kernel_dispatch_0", "kernel_dispatch_1"]}
module {
// CHECK: func @kernel_dispatch_1()
+ // CHECK: %[[ZERO:.+]] = constant
+ // CHECK: %[[DIM:.+]] = hal.interface.load.constant
+ // CHECK: %[[SHAPE:.+]] = shapex.make_ranked_shape %[[DIM]]
+ // CHECK: %[[OUT:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<?x1x1x512xf32>
+ // CHECK: %[[TS:.+]] = shapex.tie_shape %[[OUT]], %[[SHAPE]]
+ // CHECK: linalg.fill(%[[TS]], %[[ZERO]])
+ // CHECK: return
+
+ // CHECK: func @kernel_dispatch_0()
// CHECK: %[[DIM:.+]] = hal.interface.load.constant
// CHECK: %[[SHAPE1:.+]] = shapex.make_ranked_shape %[[DIM]]
// CHECK: %[[SHAPE2:.+]] = shapex.make_ranked_shape %[[DIM]]
@@ -14,15 +123,6 @@
// CHECK: linalg.conv(%[[IN2]], %[[TS1]], %[[TS2]])
// CHECK: return
- // CHECK: func @kernel_dispatch_0()
- // CHECK: %[[ZERO:.+]] = constant
- // CHECK: %[[DIM:.+]] = hal.interface.load.constant
- // CHECK: %[[SHAPE:.+]] = shapex.make_ranked_shape %[[DIM]]
- // CHECK: %[[OUT:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<?x1x1x512xf32>
- // CHECK: %[[TS:.+]] = shapex.tie_shape %[[OUT]], %[[SHAPE]]
- // CHECK: linalg.fill(%[[TS]], %[[ZERO]])
- // CHECK: return
-
func @kernel() attributes {vkspv.num_workgroups_fn = @kernel__num_workgroups__} {
%cst = constant 0.000000e+00 : f32
%dim = hal.interface.load.constant offset = 0 : index
@@ -33,8 +133,8 @@
%1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<3x3x512x1xf32>
%2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<?x1x1x512xf32>
%ts2 = shapex.tie_shape %2, %shape2 : memref<?x1x1x512xf32>, !shapex.ranked_shape<[?,1,1,512]>
- linalg.fill(%ts2, %cst) : memref<?x1x1x512xf32>, f32
linalg.conv(%1, %ts1, %ts2) {dilations = [1, 1], padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, strides = [2, 2]} : memref<3x3x512x1xf32>, memref<?x2x2x512xf32>, memref<?x1x1x512xf32>
+ linalg.fill(%ts2, %cst) : memref<?x1x1x512xf32>, f32
return
}
func @kernel__num_workgroups__(!shapex.ranked_shape<[?,2,2,512]>,
diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertStreamOps.cpp b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertStreamOps.cpp
index 62d9433..f2b0b27 100644
--- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertStreamOps.cpp
+++ b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertStreamOps.cpp
@@ -349,8 +349,9 @@
dispatchOp, dispatchOp.executable()));
// TODO(benvanik): support multiple interfaces. We'd probably want to
- // store each executable+interface as a variable.
- auto interfaceOp = executableOp.getInterfaceOp();
+ // store each executable+interface as a variable, or follow interface
+ // references stored on entry points.
+ auto interfaceOp = executableOp.getFirstInterfaceOp();
auto executableLayout =
rewriter.createOrFold<IREE::HAL::ExecutableLayoutLookupOp>(
dispatchOp.getLoc(),
@@ -415,9 +416,9 @@
dispatchState.results = resultAdaptors;
// Ask each target backend to record their dispatch logic.
- IREE::HAL::DeviceSwitchBuilder switchBuilder(dispatchOp.getLoc(),
- /*resultTypes=*/TypeRange{},
- device, rewriter);
+ IREE::HAL::DeviceSwitchRewriter switchRewriter(dispatchOp.getLoc(),
+ /*resultTypes=*/TypeRange{},
+ device, rewriter);
for (auto targetOp :
executableOp.getBlock().getOps<IREE::HAL::ExecutableTargetOp>()) {
for (auto &targetBackend : IREE::HAL::matchTargetBackends(
@@ -432,15 +433,15 @@
// sequence them together during the call to |recordDispatch| below.
dispatchState.entryPointOp = *entryPointOps.begin();
- if (failed(targetBackend->recordDispatch(dispatchOp.getLoc(),
- dispatchState, switchBuilder))) {
+ if (failed(targetBackend->recordDispatch(
+ dispatchOp.getLoc(), dispatchState, switchRewriter))) {
return dispatchOp.emitError()
<< "unable to record dispatch for target backend "
<< targetBackend->name();
}
}
}
- switchBuilder.build();
+ switchRewriter.build();
// Full barriers for now as we aren't scheduling things.
// TODO(benvanik): don't add at the end of the command buffer (we could
diff --git a/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/iree/compiler/Dialect/HAL/IR/HALOps.cpp
index 7550964..bc5e21f 100644
--- a/iree/compiler/Dialect/HAL/IR/HALOps.cpp
+++ b/iree/compiler/Dialect/HAL/IR/HALOps.cpp
@@ -1134,9 +1134,10 @@
// hal.executable
//===----------------------------------------------------------------------===//
-InterfaceOp ExecutableOp::getInterfaceOp() {
+InterfaceOp ExecutableOp::getFirstInterfaceOp() {
auto interfaceOps = llvm::to_vector<1>(getBlock().getOps<InterfaceOp>());
- assert(interfaceOps.size() == 1 && "executable must have one interface");
+ assert(!interfaceOps.empty() &&
+ "executable must have at least one interface");
return interfaceOps.front();
}
diff --git a/iree/compiler/Dialect/HAL/IR/HALOps.td b/iree/compiler/Dialect/HAL/IR/HALOps.td
index 6234e49..1541a46 100644
--- a/iree/compiler/Dialect/HAL/IR/HALOps.td
+++ b/iree/compiler/Dialect/HAL/IR/HALOps.td
@@ -1601,7 +1601,7 @@
let extraClassDeclaration = [{
Block& getBlock() { return body().front(); }
- IREE::HAL::InterfaceOp getInterfaceOp();
+ IREE::HAL::InterfaceOp getFirstInterfaceOp();
}];
let verifier = [{ return verifyExecutableOp(*this); }];
diff --git a/iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.cpp b/iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.cpp
index 2b18457..25cea52 100644
--- a/iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.cpp
@@ -125,7 +125,7 @@
LogicalResult SPIRVTargetBackend::recordDispatch(
Location loc, DispatchState dispatchState,
- DeviceSwitchBuilder &switchBuilder) {
+ DeviceSwitchRewriter &switchRewriter) {
// Multiple entry points might be generated for a single dispatch function.
// Under such circumstances, we will have a special attribute indicating the
// schedule of the split entry points. Try to see if we can find such
@@ -177,7 +177,7 @@
}
}
- auto *region = switchBuilder.addConditionRegion(
+ auto *region = switchRewriter.addConditionRegion(
IREE::HAL::DeviceMatchIDAttr::get(filter_pattern(), loc.getContext()),
{
dispatchState.workload,
@@ -185,7 +185,7 @@
});
auto &entryBlock = region->front();
- ConversionPatternRewriter &rewriter = switchBuilder.getRewriter();
+ ConversionPatternRewriter &rewriter = switchRewriter.getRewriter();
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToEnd(&entryBlock);
auto commandBuffer = entryBlock.getArgument(1);
@@ -214,7 +214,7 @@
<< " that computes the number of workgroups to use";
}
workgroupCount = calculateWorkgroupCountFromNumWorkgroupsFn(
- loc, numWorkgroupsFn, dispatchState.executableOp.getInterfaceOp(),
+ loc, numWorkgroupsFn, dispatchState.executableOp.getFirstInterfaceOp(),
dispatchState.operands, dispatchState.results, rewriter);
if (llvm::any_of(workgroupCount,
diff --git a/iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.h b/iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.h
index d24409b..83950b0 100644
--- a/iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.h
+++ b/iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.h
@@ -40,7 +40,7 @@
OpPassManager &passManager) override;
LogicalResult recordDispatch(Location loc, DispatchState dispatchState,
- DeviceSwitchBuilder &switchBuilder) override;
+ DeviceSwitchRewriter &switchRewriter) override;
// Finds the spv.ExecutionMode operation to get the workgroup size from.
std::array<Value, 3> calculateDispatchWorkgroupSize(
diff --git a/iree/compiler/Dialect/HAL/Target/TargetBackend.cpp b/iree/compiler/Dialect/HAL/Target/TargetBackend.cpp
index b126991..882c7aa 100644
--- a/iree/compiler/Dialect/HAL/Target/TargetBackend.cpp
+++ b/iree/compiler/Dialect/HAL/Target/TargetBackend.cpp
@@ -133,8 +133,8 @@
LogicalResult TargetBackend::recordDispatch(
Location loc, DispatchState dispatchState,
- DeviceSwitchBuilder &switchBuilder) {
- auto *region = switchBuilder.addConditionRegion(
+ DeviceSwitchRewriter &switchRewriter) {
+ auto *region = switchRewriter.addConditionRegion(
IREE::HAL::DeviceMatchIDAttr::get(filter_pattern(), loc.getContext()),
{
dispatchState.workload,
diff --git a/iree/compiler/Dialect/HAL/Target/TargetBackend.h b/iree/compiler/Dialect/HAL/Target/TargetBackend.h
index 2f0a8c5..1cd4b4f 100644
--- a/iree/compiler/Dialect/HAL/Target/TargetBackend.h
+++ b/iree/compiler/Dialect/HAL/Target/TargetBackend.h
@@ -237,7 +237,7 @@
// such as by inserting an `hal.command_buffer.execution_barrier`.
virtual LogicalResult recordDispatch(Location loc,
DispatchState dispatchState,
- DeviceSwitchBuilder &switchBuilder);
+ DeviceSwitchRewriter &switchRewriter);
// Inserts passes used to translate the `hal.executable.target` op contents.
// The pass manager will be nested on `hal.executable` such that the pipeline
diff --git a/iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.cpp b/iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.cpp
index ae429f5..752a6fd 100644
--- a/iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.cpp
@@ -26,9 +26,11 @@
#include "iree/schemas/vmla_executable_def_generated.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Module.h"
+#include "mlir/IR/OperationSupport.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LogicalResult.h"
@@ -37,6 +39,29 @@
namespace IREE {
namespace HAL {
+namespace {
+
+bool AreInterfacesEquivalent(IREE::HAL::InterfaceOp lhs,
+ IREE::HAL::InterfaceOp rhs) {
+ auto lhsBindings = lhs.getBlock().getOps<IREE::HAL::InterfaceBindingOp>();
+ auto rhsBindings = rhs.getBlock().getOps<IREE::HAL::InterfaceBindingOp>();
+ auto lhsIt = lhsBindings.begin(), lhsEnd = lhsBindings.end();
+ auto rhsIt = rhsBindings.begin(), rhsEnd = rhsBindings.end();
+ for (; lhsIt != lhsEnd && rhsIt != rhsEnd; ++lhsIt, ++rhsIt) {
+ // Assume bindings are in order, check equivalence of each pairing.
+ if (!OperationEquivalence::isEquivalentTo(*lhsIt, *rhsIt)) return false;
+ }
+
+ if (lhsIt != lhsEnd || rhsIt != rhsEnd) {
+ // Not finished iterating through one, number of interface bindings differ.
+ return false;
+ }
+
+ return true;
+}
+
+} // namespace
+
VMLATargetOptions getVMLATargetOptionsFromFlags() {
VMLATargetOptions targetOptions;
// TODO(benvanik): flags.
@@ -66,6 +91,169 @@
passManager, IREE::VM::getTargetOptionsFromFlags());
}
+ LogicalResult linkExecutables(mlir::ModuleOp moduleOp) override {
+ // --- Linking overview ---
+ //
+ // We start with a `module` containing multiple `hal.executable`s, each with
+ // potentially multiple `hal.executable.target`s. We want to move all
+ // compatible VMLA functions into a new "linked" executable, de-duping
+ // symbols, and updating references as we go.
+ //
+ // Sample IR after:
+ // hal.executable @linked_vmla {
+ // hal.interface @legacy_io_0 { ... }
+ // hal.interface @legacy_io_1 { ... }
+ // hal.executable.target @vmla, filter="vmla" {
+ // hal.executable.entry_point @main_dispatch_0 attributes { ... }
+ // hal.executable.entry_point @main_dispatch_1 attributes { ... }
+ // hal.executable.entry_point @main_dispatch_2 attributes { ... }
+ // module {
+ // vm.module @module {
+ // vm.func @main_0(...) { ... }
+ // vm.func @main_1(...) { ... }
+ // vm.func @main_2(...) { ... }
+ // }
+ // }
+ // }
+ // }
+ // hal.executable @main_dispatch_0 {
+ // hal.interface @legacy_io { ... }
+ // hal.executable.target @other, filter="other" {
+ // hal.executable.entry_point @main_dispatch_0 attributes { ... }
+ // module { ... }
+ // }
+ // }
+
+ OpBuilder builder = OpBuilder::atBlockBegin(moduleOp.getBody());
+ auto executableOps = moduleOp.getOps<IREE::HAL::ExecutableOp>();
+
+ // Create our new "linked" hal.executable.
+ auto linkedExecutableOp = builder.create<IREE::HAL::ExecutableOp>(
+ moduleOp.getLoc(), "linked_vmla");
+ SymbolTable::setSymbolVisibility(linkedExecutableOp,
+ SymbolTable::Visibility::Private);
+ // Add our VMLA hal.executable.target with an empty module.
+ builder.setInsertionPointToStart(linkedExecutableOp.getBody());
+ auto linkedTargetOp = builder.create<IREE::HAL::ExecutableTargetOp>(
+ moduleOp.getLoc(), name(), filter_pattern());
+ builder.setInsertionPoint(&linkedTargetOp.getBlock().back());
+ auto linkedModuleOp = builder.create<ModuleOp>(moduleOp.getLoc());
+ // Add an empty vm.module to that module.
+ builder.setInsertionPointToStart(linkedModuleOp.getBody());
+ auto linkedVmModuleOp =
+ builder.create<IREE::VM::ModuleOp>(moduleOp.getLoc(), "linked_module");
+
+ int executablesLinked = 0;
+ llvm::SmallVector<IREE::HAL::InterfaceOp, 4> interfaceOps;
+ int nextEntryPointOrdinal = 0;
+ for (auto executableOp : executableOps) {
+ auto targetOps = llvm::to_vector<4>(
+ executableOp.getOps<IREE::HAL::ExecutableTargetOp>());
+ for (auto targetOp : targetOps) {
+ // Only process targets matching our pattern.
+ if (!matchPattern(targetOp.target_backend_filter(), filter_pattern())) {
+ continue;
+ }
+
+ IREE::HAL::InterfaceOp interfaceOpForExecutable;
+ for (auto interfaceOp : interfaceOps) {
+ if (AreInterfacesEquivalent(interfaceOp,
+ executableOp.getFirstInterfaceOp())) {
+ interfaceOpForExecutable = interfaceOp;
+ }
+ }
+ if (!interfaceOpForExecutable) {
+ builder.setInsertionPoint(linkedTargetOp);
+ interfaceOpForExecutable = dyn_cast<IREE::HAL::InterfaceOp>(
+ builder.clone(*executableOp.getFirstInterfaceOp()));
+ interfaceOpForExecutable.setName(
+ llvm::formatv("legacy_io_{0}", interfaceOps.size()).str());
+ interfaceOps.push_back(interfaceOpForExecutable);
+ }
+
+ // Clone entry point ops, remapping ordinals and updating symbol refs.
+ builder.setInsertionPoint(linkedModuleOp);
+ for (auto entryPointOp :
+ targetOp.getOps<IREE::HAL::ExecutableEntryPointOp>()) {
+ auto newEntryPointOp =
+ builder.create<IREE::HAL::ExecutableEntryPointOp>(
+ entryPointOp.getLoc(), entryPointOp.sym_nameAttr(),
+ builder.getI32IntegerAttr(nextEntryPointOrdinal++),
+ builder.getSymbolRefAttr(interfaceOpForExecutable.getName()),
+ entryPointOp.signatureAttr());
+
+ // Update references to @executable::@target::@entry symbols.
+ // SymbolTable::replaceAllSymbolUses only looks at root symbols,
+ // which we can't blindly replace (other targets will map to other
+ // linked executables).
+ auto executableUses =
+ SymbolTable::getSymbolUses(executableOp, moduleOp);
+ if (!executableUses.hasValue()) continue;
+ for (auto executableUse : executableUses.getValue()) {
+ auto executableUser = executableUse.getUser();
+ // Only process symbols for this @target::@entry.
+ auto nestedRefs =
+ executableUse.getSymbolRef().getNestedReferences();
+ if (nestedRefs.size() != 2 ||
+ nestedRefs[0].getValue() != targetOp.sym_name() ||
+ nestedRefs[1].getValue() != entryPointOp.sym_name()) {
+ continue;
+ }
+ if (auto dispatchOp =
+ dyn_cast<IREE::HAL::CommandBufferDispatchSymbolOp>(
+ executableUser)) {
+ // New nested reference to the linked exe/target/entry.
+ StringRef newExecutableOpSymName =
+ linkedExecutableOp
+ .getAttrOfType<StringAttr>(
+ SymbolTable::getSymbolAttrName())
+ .getValue();
+ auto newSymbolRefAttr = builder.getSymbolRefAttr(
+ newExecutableOpSymName,
+ {builder.getSymbolRefAttr(linkedTargetOp),
+ builder.getSymbolRefAttr(newEntryPointOp)});
+ dispatchOp.setAttr("entry_point", newSymbolRefAttr);
+ }
+ }
+ }
+
+ // Clone vm.module ops, including their contents.
+ auto vmModuleOps =
+ targetOp.getInnerModule().getOps<IREE::VM::ModuleOp>();
+ if (vmModuleOps.empty()) {
+ return targetOp.getInnerModule().emitError()
+ << "target's outer module does not contain a vm.module op";
+ }
+ auto vmModuleOp = *vmModuleOps.begin();
+ builder.setInsertionPoint(&linkedVmModuleOp.getBlock().back());
+ // Use a SymbolTable to guard against inserting duplicate symbols.
+ SymbolTable symbolTable(linkedVmModuleOp.getOperation());
+
+ for (auto &op : vmModuleOp.getBody()->getOperations()) {
+ if (auto terminatorOp = dyn_cast<IREE::VM::ModuleTerminatorOp>(op)) {
+ continue;
+ }
+ if (op.hasTrait<SymbolOpInterface::Trait>() &&
+ symbolTable.lookup(dyn_cast<SymbolOpInterface>(op).getName())) {
+ continue;
+ }
+ builder.clone(op);
+ }
+
+ // Now that we're done cloning its ops, delete the original target op.
+ targetOp.erase();
+
+ executablesLinked++;
+ }
+ }
+
+ if (executablesLinked == 0) {
+ linkedExecutableOp.erase();
+ }
+
+ return success();
+ }
+
LogicalResult serializeExecutable(IREE::HAL::ExecutableTargetOp targetOp,
OpBuilder &executableBuilder) override {
// Serialize the VM module to bytes.
diff --git a/iree/compiler/Dialect/HAL/Target/VMLA/test/i1_types.mlir b/iree/compiler/Dialect/HAL/Target/VMLA/test/i1_types.mlir
index 4d5c661..f9be39e 100644
--- a/iree/compiler/Dialect/HAL/Target/VMLA/test/i1_types.mlir
+++ b/iree/compiler/Dialect/HAL/Target/VMLA/test/i1_types.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt -split-input-file -pass-pipeline='iree-hal-transformation-pipeline{serialize-executables=false},canonicalize' -iree-hal-target-backends=vmla %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -pass-pipeline='iree-hal-transformation-pipeline{serialize-executables=false link-executables=false},canonicalize' -iree-hal-target-backends=vmla %s | IreeFileCheck %s
// CHECK-LABEL: @i1_op_usage(%arg0: !hal.buffer) -> !hal.buffer
func @i1_op_usage(%arg0: tensor<4xi1>) -> tensor<4xi1> {
diff --git a/iree/compiler/Dialect/HAL/Target/VMLA/test/linking.mlir b/iree/compiler/Dialect/HAL/Target/VMLA/test/linking.mlir
new file mode 100644
index 0000000..052bd7e
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Target/VMLA/test/linking.mlir
@@ -0,0 +1,204 @@
+// RUN: iree-opt -split-input-file -iree-hal-link-executables -iree-hal-target-backends=vmla %s | IreeFileCheck %s
+
+module {
+ hal.executable @dispatch_0 attributes {sym_visibility = "private"} {
+ hal.interface @legacy_io {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+ }
+ hal.executable.target @vmla, filter="vmla" {
+ hal.executable.entry_point @dispatch_0 attributes {interface = @legacy_io, ordinal = 0 : i32, signature = (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32>}
+ module {
+ vm.module @module {
+ vm.func @dispatch_0(%arg0: !vm.ref<!vmla.interface>, %arg1: i32, %arg2: i32, %arg3: i32) {
+ vm.return
+ }
+ vm.export @dispatch_0
+ }
+ }
+ }
+ }
+ hal.executable @dispatch_1 attributes {sym_visibility = "private"} {
+ hal.interface @legacy_io {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+ }
+ hal.executable.target @vmla, filter="vmla" {
+ hal.executable.entry_point @dispatch_1 attributes {interface = @legacy_io, ordinal = 0 : i32, signature = (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32>}
+ module {
+ vm.module @module {
+ vm.func @dispatch_1(%arg0: !vm.ref<!vmla.interface>, %arg1: i32, %arg2: i32, %arg3: i32) {
+ vm.return
+ }
+ vm.export @dispatch_1
+ }
+ }
+ }
+ }
+ hal.executable @dispatch_2 attributes {sym_visibility = "private"} {
+ hal.interface @legacy_io {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg2, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+ }
+ hal.executable.target @vmla, filter="vmla" {
+ hal.executable.entry_point @dispatch_2 attributes {interface = @legacy_io, ordinal = 0 : i32, signature = (tensor<1x1xf32>, tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32>}
+ module {
+ vm.module @module {
+ vm.func @dispatch_2(%arg0: !vm.ref<!vmla.interface>, %arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32) {
+ vm.return
+ }
+ vm.export @dispatch_2
+ }
+ }
+ }
+ }
+ func @main() -> () {
+ %dev = hal.ex.shared_device : !hal.device
+ %cmd = hal.command_buffer.create %dev, "OneShot", "Transfer|Dispatch" : !hal.command_buffer
+ %c1 = constant 1 : index
+ hal.command_buffer.dispatch.symbol %cmd, @dispatch_0::@vmla::@dispatch_0, workgroup_xyz = [%c1, %c1, %c1]
+ hal.command_buffer.dispatch.symbol %cmd, @dispatch_1::@vmla::@dispatch_1, workgroup_xyz = [%c1, %c1, %c1]
+ hal.command_buffer.dispatch.symbol %cmd, @dispatch_2::@vmla::@dispatch_2, workgroup_xyz = [%c1, %c1, %c1]
+ return
+ }
+}
+
+// All executables (including their interfaces and entry points) should be linked together into @linked_vmla
+// CHECK-NOT: hal.executable @dispatch_0
+// CHECK-NOT: hal.executable @dispatch_1
+// CHECK-NOT: hal.executable @dispatch_2
+// CHECK: hal.executable @linked_vmla attributes {sym_visibility = "private"} {
+// CHECK-NEXT: hal.interface @legacy_io_0 {
+// CHECK-NEXT: hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+// CHECK-NEXT: hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+// CHECK-NEXT: hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+// CHECK-NEXT: }
+// CHECK-NEXT: hal.interface @legacy_io_1 {
+// CHECK-NEXT: hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+// CHECK-NEXT: hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+// CHECK-NEXT: hal.interface.binding @arg2, set=0, binding=1, type="StorageBuffer", access="Read"
+// CHECK-NEXT: hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+// CHECK-NEXT: }
+// CHECK-NEXT: hal.executable.target @vmla, filter="vmla" {
+// CHECK-NEXT: hal.executable.entry_point @dispatch_0 attributes {interface = @legacy_io_0, ordinal = 0 : i32, signature = (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32>}
+// CHECK-NEXT: hal.executable.entry_point @dispatch_1 attributes {interface = @legacy_io_0, ordinal = 1 : i32, signature = (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32>}
+// CHECK-NEXT: hal.executable.entry_point @dispatch_2 attributes {interface = @legacy_io_1, ordinal = 2 : i32, signature = (tensor<1x1xf32>, tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32>}
+// CHECK-NEXT: module {
+// CHECK-NEXT: vm.module @linked_module {
+// CHECK-NEXT: vm.func @dispatch_0(%arg0: !vm.ref<!vmla.interface>, %arg1: i32, %arg2: i32, %arg3: i32) {
+// CHECK-NEXT: vm.return
+// CHECK-NEXT: }
+// CHECK-NEXT: vm.export @dispatch_0
+// CHECK-NEXT: vm.func @dispatch_1(%arg0: !vm.ref<!vmla.interface>, %arg1: i32, %arg2: i32, %arg3: i32) {
+// CHECK-NEXT: vm.return
+// CHECK-NEXT: }
+// CHECK-NEXT: vm.export @dispatch_1
+// CHECK-NEXT: vm.func @dispatch_2(%arg0: !vm.ref<!vmla.interface>, %arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32) {
+// CHECK-NEXT: vm.return
+// CHECK-NEXT: }
+// CHECK-NEXT: vm.export @dispatch_2
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+//
+// CHECK: func @main() {
+// CHECK: hal.command_buffer.dispatch.symbol %cmd, @linked_vmla::@vmla::@dispatch_0, workgroup_xyz = [%c1, %c1, %c1]
+// CHECK-NEXT: hal.command_buffer.dispatch.symbol %cmd, @linked_vmla::@vmla::@dispatch_1, workgroup_xyz = [%c1, %c1, %c1]
+// CHECK-NEXT: hal.command_buffer.dispatch.symbol %cmd, @linked_vmla::@vmla::@dispatch_2, workgroup_xyz = [%c1, %c1, %c1]
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+
+// -----
+
+module {
+ hal.executable @dispatch_0 attributes {sym_visibility = "private"} {
+ hal.interface @legacy_io {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+ }
+ hal.executable.target @vmla, filter="vmla" {
+ hal.executable.entry_point @dispatch_0 attributes {interface = @legacy_io, ordinal = 0 : i32, signature = (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32>}
+ module {
+ vm.module @module {
+ vm.func @dispatch_0(%arg0: !vm.ref<!vmla.interface>, %arg1: i32, %arg2: i32, %arg3: i32) {
+ vm.return
+ }
+ vm.export @dispatch_0
+ }
+ }
+ }
+ hal.executable.target @othertarget, filter="othertarget" {
+ module {
+ }
+ }
+ }
+ func @main() -> () {
+ %dev = hal.ex.shared_device : !hal.device
+ %cmd = hal.command_buffer.create %dev, "OneShot", "Transfer|Dispatch" : !hal.command_buffer
+ hal.device.switch(%dev : !hal.device)
+ #hal.device.match.id<"vmla">(%arg1 = %cmd : !hal.command_buffer) {
+ %c1 = constant 1 : index
+ hal.command_buffer.dispatch.symbol %arg1, @linked_vmla::@vmla::@main_ex_dispatch_0, workgroup_xyz = [%c1, %c1, %c1]
+ hal.return
+ },
+ #hal.device.match.id<"othertarget">(%arg1 = %cmd : !hal.command_buffer) {
+ %c1 = constant 1 : index
+ hal.command_buffer.dispatch.symbol %arg1, @dispatch_0::@otherdispatch::@dispatch_0, workgroup_xyz = [%c1, %c1, %c1]
+ hal.return
+ }
+ return
+ }
+}
+
+// VMLA target should be pulled out from @dispatch_0
+// CHECK: hal.executable @linked_vmla attributes {sym_visibility = "private"} {
+// CHECK-NEXT: hal.interface @legacy_io_0 {
+// CHECK-NEXT: hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+// CHECK-NEXT: hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+// CHECK-NEXT: hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+// CHECK-NEXT: }
+// CHECK-NEXT: hal.executable.target @vmla, filter="vmla" {
+// CHECK-NEXT: hal.executable.entry_point @dispatch_0 attributes {interface = @legacy_io_0, ordinal = 0 : i32, signature = (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32>}
+// CHECK-NEXT: module {
+// CHECK-NEXT: vm.module @linked_module {
+// CHECK-NEXT: vm.func @dispatch_0(%arg0: !vm.ref<!vmla.interface>, %arg1: i32, %arg2: i32, %arg3: i32) {
+// CHECK-NEXT: vm.return
+// CHECK-NEXT: }
+// CHECK-NEXT: vm.export @dispatch_0
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+//
+// @dispatch_0 should remain, with just @othertarget
+// CHECK: hal.executable @dispatch_0 attributes {sym_visibility = "private"} {
+// CHECK-NEXT: hal.interface @legacy_io {
+// CHECK-NEXT: hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+// CHECK-NEXT: hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+// CHECK-NEXT: hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+// CHECK-NEXT: }
+// CHECK-NEXT: hal.executable.target @othertarget, filter="othertarget" {
+// CHECK-NEXT: module {
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+//
+// CHECK: func @main() {
+// CHECK: hal.device.switch(%dev : !hal.device)
+// CHECK-NEXT: #hal.device.match.id<"vmla">(%arg0 = %cmd : !hal.command_buffer) {
+// CHECK-NEXT: %c1 = constant 1 : index
+// CHECK-NEXT: hal.command_buffer.dispatch.symbol %arg0, @linked_vmla::@vmla::@main_ex_dispatch_0, workgroup_xyz = [%c1, %c1, %c1]
+// CHECK-NEXT: hal.return
+// CHECK-NEXT: },
+// CHECK-NEXT: #hal.device.match.id<"othertarget">(%arg0 = %cmd : !hal.command_buffer) {
+// CHECK-NEXT: %c1 = constant 1 : index
+// CHECK-NEXT: hal.command_buffer.dispatch.symbol %arg0, @dispatch_0::@otherdispatch::@dispatch_0, workgroup_xyz = [%c1, %c1, %c1]
+// CHECK-NEXT: hal.return
+// CHECK-NEXT: }
+// CHECK-NEXT: return
+// CHECK-NEXT: }
diff --git a/iree/compiler/Dialect/HAL/Target/VMLA/test/smoketest.mlir b/iree/compiler/Dialect/HAL/Target/VMLA/test/smoketest.mlir
index 48c889a..939f06a 100644
--- a/iree/compiler/Dialect/HAL/Target/VMLA/test/smoketest.mlir
+++ b/iree/compiler/Dialect/HAL/Target/VMLA/test/smoketest.mlir
@@ -12,15 +12,15 @@
}
}
-// CHECK-LABEL: hal.executable @simpleMath_ex_dispatch_0
-// CHECK-NEXT: hal.interface @legacy_io {
+// CHECK-LABEL: hal.executable @linked_vmla
+// CHECK-NEXT: hal.interface @legacy_io_0 {
// CHECK-NEXT: hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
// CHECK-NEXT: hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write|Discard"
// CHECK-NEXT: }
// CHECK-NEXT: hal.executable.target @vmla, filter="vmla" {
-// CHECK-NEXT: hal.executable.entry_point @simpleMath_rgn_dispatch_0 attributes {interface = @legacy_io, ordinal = 0 : i32, signature = (tensor<4xf32>) -> tensor<4xf32>}
+// CHECK-NEXT: hal.executable.entry_point @simpleMath_rgn_dispatch_0 attributes {interface = @legacy_io_0, ordinal = 0 : i32, signature = (tensor<4xf32>) -> tensor<4xf32>}
// CHECK-NEXT: module {
-// CHECK-NEXT: vm.module @module {
+// CHECK-NEXT: vm.module @linked_module {
// CHECK-NEXT: vm.func @simpleMath_rgn_dispatch_0(%arg0: !vm.ref<!vmla.interface>, %arg1: i32, %arg2: i32, %arg3: i32) {
// CHECK-DAG: %zero = vm.const.i32.zero : i32
// CHECK-DAG: %c16 = vm.const.i32 16 : i32
@@ -55,15 +55,15 @@
}
}
-// CHECK-LABEL: hal.executable @shaped_dispatch
-// CHECK-NEXT: hal.interface @legacy_io attributes {push_constants = 1 : i32} {
+// CHECK-LABEL: hal.executable @linked_vmla
+// CHECK-NEXT: hal.interface @legacy_io_0 attributes {push_constants = 1 : i32} {
// CHECK-NEXT: hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
// CHECK-NEXT: hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write|Discard"
// CHECK-NEXT: }
// CHECK-NEXT: hal.executable.target @vmla, filter="vmla" {
-// CHECK-NEXT: hal.executable.entry_point @entry attributes {interface = @legacy_io, ordinal = 0 : i32, signature = (tensor<4x?xf32>, index) -> tensor<4x?xf32>}
+// CHECK-NEXT: hal.executable.entry_point @entry attributes {interface = @legacy_io_0, ordinal = 0 : i32, signature = (tensor<4x?xf32>, index) -> tensor<4x?xf32>}
// CHECK-NEXT: module {
-// CHECK-NEXT: vm.module @module {
+// CHECK-NEXT: vm.module @linked_module {
// CHECK-NEXT: vm.func @entry(%arg0: !vm.ref<!vmla.interface>, %arg1: i32, %arg2: i32, %arg3: i32) {
// CHECK-DAG: %zero = vm.const.i32.zero : i32
// CHECK-DAG: %c16 = vm.const.i32 16 : i32
@@ -97,15 +97,15 @@
}
}
-// CHECK-LABEL: hal.executable @reduction_ex_dispatch_0
-// CHECK-NEXT: hal.interface @legacy_io {
+// CHECK-LABEL: hal.executable @linked_vmla
+// CHECK-NEXT: hal.interface @legacy_io_0 {
// CHECK-NEXT: hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
// CHECK-NEXT: hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write|Discard"
// CHECK-NEXT: }
// CHECK-NEXT: hal.executable.target @vmla, filter="vmla" {
-// CHECK-NEXT: hal.executable.entry_point @reduction_ex_dispatch_0 attributes {interface = @legacy_io, ordinal = 0 : i32, signature = (tensor<4x8xf32>) -> tensor<4xf32>}
+// CHECK-NEXT: hal.executable.entry_point @reduction_ex_dispatch_0 attributes {interface = @legacy_io_0, ordinal = 0 : i32, signature = (tensor<4x8xf32>) -> tensor<4xf32>}
// CHECK-NEXT: module {
-// CHECK-NEXT: vm.module @module {
+// CHECK-NEXT: vm.module @linked_module {
// CHECK-NEXT: vm.rodata @reduction_ex_dispatch_0_const_0 dense<0.000000e+00> : tensor<f32>
// CHECK-NEXT: vm.func @reduction_ex_dispatch_0(%arg0: !vm.ref<!vmla.interface>, %arg1: i32, %arg2: i32, %arg3: i32) {
// CHECK-NEXT: %zero = vm.const.i32.zero : i32
diff --git a/iree/compiler/Dialect/HAL/Transforms/LinkExecutables.cpp b/iree/compiler/Dialect/HAL/Transforms/LinkExecutables.cpp
index 3f2891e..e2e5a64 100644
--- a/iree/compiler/Dialect/HAL/Transforms/LinkExecutables.cpp
+++ b/iree/compiler/Dialect/HAL/Transforms/LinkExecutables.cpp
@@ -47,6 +47,17 @@
return signalPassFailure();
}
}
+
+ // Backends may move target ops from executables into linked executables.
+ // If an executable ends up with no targets, remove it.
+ auto executableOps =
+ llvm::to_vector<4>(moduleOp.getOps<IREE::HAL::ExecutableOp>());
+ for (auto executableOp : executableOps) {
+ auto targetOps = executableOp.getOps<IREE::HAL::ExecutableTargetOp>();
+ if (targetOps.empty()) {
+ executableOp.erase();
+ }
+ }
}
private:
diff --git a/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp b/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp
index 94b567d..d40113e 100644
--- a/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp
+++ b/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp
@@ -16,7 +16,9 @@
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/HAL/Target/TargetBackend.h"
+#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
#include "iree/compiler/Dialect/HAL/Transforms/Passes.h"
+#include "iree/compiler/Dialect/HAL/Utils/DeviceSwitchBuilder.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
@@ -55,10 +57,12 @@
// other nice thing is that we get ordering similar to the executable
// variables above.
for (auto executableOp : executableOps) {
- auto interfaceOp = executableOp.getInterfaceOp();
- defineExecutableLayoutOp(interfaceOp.getLoc(),
- interfaceOp.getExecutableSetLayoutsAttr(),
- interfaceOp.push_constantsAttr());
+ for (auto interfaceOp :
+ executableOp.getBlock().getOps<IREE::HAL::InterfaceOp>()) {
+ defineExecutableLayoutOp(interfaceOp.getLoc(),
+ interfaceOp.getExecutableSetLayoutsAttr(),
+ interfaceOp.push_constantsAttr());
+ }
}
// Generate cached resource singletons and replace lookup ops with direct
@@ -221,33 +225,71 @@
loc, executableCacheType, deviceValue,
blockBuilder.getStringAttr("default"));
- // TODO(benvanik): use targetOptions_ to determine these flags.
- auto cachingMode = ExecutableCachingModeBitfield::AliasProvidedData |
- ExecutableCachingModeBitfield::AllowPersistentCaching |
- ExecutableCachingModeBitfield::AllowOptimization;
- for (auto executableOp : executableOps) {
- auto executableIt = executableCache_.find(executableOp.sym_name());
- assert(executableIt != executableCache_.end() &&
- "executable must have been cached");
- auto executableVariableOp = executableIt->second;
+ // Create a switch statement with a case for each backend.
+ // Each case should then cache only executables which contain a matching
+ // ExecutableTargetOp.
+ // Afterwards, we could inline and de-dup across switch cases.
+ DeviceSwitchBuilder switchBuilder(loc, /*resultTypes=*/TypeRange{},
+ deviceValue, blockBuilder);
+ auto targetBackends = matchTargetBackends(targetOptions_.targets);
+ for (auto &targetBackend : targetBackends) {
+ auto *region = switchBuilder.addConditionRegion(
+ IREE::HAL::DeviceMatchIDAttr::get(targetBackend->filter_pattern(),
+ blockBuilder.getContext()),
+ {
+ executableCacheValue,
+ });
+ auto &entryBlock = region->front();
+ auto executableCache = entryBlock.getArgument(0);
+ auto caseBuilder = OpBuilder::atBlockBegin(&entryBlock);
- // TODO(benvanik): support multiple interfaces. We'd probably want to
- // store each executable+interface as a variable.
- auto interfaceOp = executableOp.getInterfaceOp();
+ // TODO(benvanik): use targetOptions_ to determine these flags.
+ auto cachingMode = ExecutableCachingModeBitfield::AliasProvidedData |
+ ExecutableCachingModeBitfield::AllowPersistentCaching |
+ ExecutableCachingModeBitfield::AllowOptimization;
+ for (auto executableOp : executableOps) {
+ // Skip executables with no matching target ops.
+ auto executableTargetOps =
+ executableOp.getOps<IREE::HAL::ExecutableTargetOp>();
+ bool hasMatchingTarget = false;
+ for (auto executableTargetOp : executableTargetOps) {
+ if (TargetBackend::matchPattern(
+ executableTargetOp.target_backend_filter(),
+ targetBackend->filter_pattern())) {
+ hasMatchingTarget = true;
+ }
+ }
+ if (!hasMatchingTarget) continue;
- auto executableLayoutVariableOp = defineExecutableLayoutOp(
- executableOp.getLoc(), interfaceOp.getExecutableSetLayoutsAttr(),
- interfaceOp.push_constantsAttr());
- auto executableLayoutValue = blockBuilder.createOrFold<VariableLoadOp>(
- loc, ExecutableLayoutType::get(loc.getContext()),
- executableLayoutVariableOp.sym_name());
- auto executableValue =
- blockBuilder.createOrFold<ExecutableCachePrepareOp>(
- loc, ExecutableType::get(loc.getContext()), executableCacheValue,
- executableLayoutValue, cachingMode, executableOp.sym_name());
- blockBuilder.create<VariableStoreOp>(loc, executableValue,
- executableVariableOp.sym_name());
+ auto executableIt = executableCache_.find(executableOp.sym_name());
+ assert(executableIt != executableCache_.end() &&
+ "executable must have been cached");
+ auto executableVariableOp = executableIt->second;
+
+ // TODO(benvanik): support multiple interfaces. We'd probably want to
+ // store each executable+interface as a variable.
+ //
+ // This is *only* safe now because any backends that support multiple
+ // interfaces during compilation do *not* use layouts during executable
+ // cache preparation.
+ auto interfaceOp = executableOp.getFirstInterfaceOp();
+
+ auto executableLayoutVariableOp = defineExecutableLayoutOp(
+ executableOp.getLoc(), interfaceOp.getExecutableSetLayoutsAttr(),
+ interfaceOp.push_constantsAttr());
+ auto executableLayoutValue = caseBuilder.createOrFold<VariableLoadOp>(
+ loc, ExecutableLayoutType::get(loc.getContext()),
+ executableLayoutVariableOp.sym_name());
+ auto executableValue =
+ caseBuilder.createOrFold<ExecutableCachePrepareOp>(
+ loc, ExecutableType::get(loc.getContext()), executableCache,
+ executableLayoutValue, cachingMode, executableOp.sym_name());
+ caseBuilder.create<VariableStoreOp>(loc, executableValue,
+ executableVariableOp.sym_name());
+ }
+ caseBuilder.create<IREE::HAL::ReturnOp>(loc);
}
+ switchBuilder.build();
blockBuilder.create<mlir::ReturnOp>(loc, executableCacheValue);
diff --git a/iree/compiler/Dialect/HAL/Transforms/Passes.cpp b/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
index 1a9440d..fc05535 100644
--- a/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
+++ b/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
@@ -35,6 +35,10 @@
llvm::cl::desc("Whether to serialize hal.executable.target ops to "
"hal.executable.binary ops."),
llvm::cl::init(true)};
+ Option<bool> linkExecutables{
+ *this, "link-executables",
+ llvm::cl::desc("Whether to link hal.executable ops together."),
+ llvm::cl::init(true)};
};
} // namespace
@@ -52,12 +56,6 @@
// this pass.
passManager.addPass(createTranslateExecutablesPass(targetOptions));
- // After all executables are translated we allow the backends to link them
- // together. For example, the LLVM AOT backend may combine all executable
- // targets for the same architecture into a single executable and link it as
- // a shared library.
- passManager.addPass(createLinkExecutablesPass(targetOptions));
-
passManager.addPass(createConvertFlowToHALPass());
// Phase ordering note: Before this pass, functions signatures will be based
@@ -75,6 +73,17 @@
// been expanded to primitives.
passManager.addPass(createPublicABIGenerationPass());
+ // After all executables are translated and before resolving entry point
+ // ordinals, we allow the backends to link executables together. For example,
+ // the LLVM AOT backend may combine all executable targets for the same
+ // architecture into a single executable and link it as a shared library.
+ // TODO(scotttodd): Move after createTranslateExecutablesPass
+ // * ConvertStreamOps under ConvertFlowToHALPass assumes one entry point.
+ // Adjust it to handle multiple entry points then this can move up.
+ if (transformOptions.linkExecutables) {
+ passManager.addPass(createLinkExecutablesPass(targetOptions));
+ }
+
// Resolve entry point ordinals from nested symbol references prior to
// serialization. As this pass creates lookup ops it should run before
// MaterializeResourceCachesPass.
diff --git a/iree/compiler/Dialect/HAL/Transforms/test/materialize_resource_caches.mlir b/iree/compiler/Dialect/HAL/Transforms/test/materialize_resource_caches.mlir
index d88af3e..590df15 100644
--- a/iree/compiler/Dialect/HAL/Transforms/test/materialize_resource_caches.mlir
+++ b/iree/compiler/Dialect/HAL/Transforms/test/materialize_resource_caches.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt -split-input-file -iree-hal-materialize-resource-caches %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -iree-hal-materialize-resource-caches %s -iree-hal-target-backends=vmla | IreeFileCheck %s
// CHECK: hal.variable @_descriptor_set_layout_0 init(@_descriptor_set_layout_0_initializer) : !hal.descriptor_set_layout
// CHECK-NEXT: func @_descriptor_set_layout_0_initializer() -> !hal.descriptor_set_layout attributes {sym_visibility = "private"} {
@@ -87,12 +87,14 @@
// -----
+// TODO(scotttodd): Test without depending on a specific HAL target? Or move to HAL/Target/*/test/?
+// - If there is no matching hal.executable.target then the executable will not be cached
hal.executable @exe {
hal.interface @interface {
hal.interface.binding @s0b0, set=0, binding=0, type="StorageBuffer", access="Read"
hal.interface.binding @s0b1, set=0, binding=1, type="StorageBuffer", access="Read|Write"
}
- hal.executable.target @target, filter="target" {
+ hal.executable.target @vmla, filter="vmla" {
hal.executable.entry_point @entry attributes {
interface = @interface,
ordinal = 0 : i32,
@@ -119,8 +121,12 @@
// CHECK: hal.variable @_executable_cache init(@_executable_cache_initializer) : !hal.executable_cache
// CHECK-NEXT: func @_executable_cache_initializer
// CHECK: %[[CACHE:.+]] = hal.executable_cache.create %dev, identifier = "default" : !hal.executable_cache
-// CHECK-NEXT: %[[LAYOUT:.+]] = hal.variable.load @_executable_layout_0 : !hal.executable_layout
-// CHECK-NEXT: %[[EXE:.+]] = hal.executable_cache.prepare %[[CACHE]], layout = %[[LAYOUT]], caching_mode = "AliasProvidedData|AllowPersistentCaching|AllowOptimization", @exe : !hal.executable
+// CHECK-NEXT: hal.device.switch(%dev : !hal.device)
+// CHECK-NEXT: #hal.device.match.id<"vmla">(%[[CACHE_CAPTURE:.+]] = %executable_cache_default : !hal.executable_cache) {
+// CHECK-NEXT: %[[LAYOUT:.+]] = hal.variable.load @_executable_layout_0 : !hal.executable_layout
+// CHECK-NEXT: %[[EXE:.+]] = hal.executable_cache.prepare %[[CACHE_CAPTURE]], layout = %[[LAYOUT]], caching_mode = "AliasProvidedData|AllowPersistentCaching|AllowOptimization", @exe : !hal.executable
+// CHECK-NEXT: hal.variable.store %[[EXE]], @_executable_exe : !hal.executable
+// CHECK-NEXT: hal.return
// CHECK-LABEL: @exeLookup
func @exeLookup(%arg0 : !hal.device) -> !hal.executable {
diff --git a/iree/compiler/Dialect/HAL/Utils/DeviceSwitchBuilder.h b/iree/compiler/Dialect/HAL/Utils/DeviceSwitchBuilder.h
index b1a2d23..15fc7b0 100644
--- a/iree/compiler/Dialect/HAL/Utils/DeviceSwitchBuilder.h
+++ b/iree/compiler/Dialect/HAL/Utils/DeviceSwitchBuilder.h
@@ -94,8 +94,9 @@
// Builder for hal.device.switch ops that allows for nesting of conditions.
//
// Example:
-// DeviceSwitchBuilder b0(.../*initialCondition=*/Z);
-// b0.addRegion(); // condition: Z
+// DeviceSwitchBuilder builder();
+// auto b0 = builder.nest(Z);
+// b0.addRegion(); // condition: Z
// b0.addConditionRegion(A); // condition: Z && A
// auto b1 = b0.nest(B);
// b1.addConditionRegion(C); // condition: Z && B && C
@@ -111,7 +112,61 @@
class DeviceSwitchBuilder {
public:
DeviceSwitchBuilder(Location loc, TypeRange resultTypes, Value device,
- ConversionPatternRewriter &rewriter)
+ OpBuilder builder)
+ : loc_(loc),
+ resultTypes_(resultTypes),
+ device_(device),
+ builder_(builder) {
+ // FIXME: Keep the same listener as the provided builder.
+ builder.setListener(nullptr);
+ }
+
+ // Pushes a new condition onto the stack and returns a builder that must have
+ // all previously nested conditions met in order to execute any conditions.
+ DeviceSwitchCaseBuilder nest(Attribute conditionAttr) {
+ return DeviceSwitchCaseBuilder(loc_, resultTypes_, device_, conditionAttr,
+ caseOps_, builder_);
+ }
+
+ // Adds a new condition region that must satisfy |conditionAttr| and all
+ // parent conditions. The region will have a single entry block with the
+ // given |args|.
+ Region *addConditionRegion(Attribute conditionAttr,
+ const SmallVector<Value, 4> &args) {
+ return nest(conditionAttr).addRegion(args);
+ }
+
+ // Constructs a single hal.device.switch from all added regions.
+ IREE::HAL::DeviceSwitchOp build() {
+ SmallVector<Attribute, 4> conditionAttrs;
+ SmallVector<SmallVector<Value, 4>, 4> conditionArgs;
+ llvm::SetVector<Value> capturedFromAbove;
+ for (auto caseOp : caseOps_) {
+ conditionAttrs.push_back(caseOp.conditions().getValue()[0]);
+ conditionArgs.push_back(caseOp.args());
+ }
+ auto switchOp = builder_.create<IREE::HAL::DeviceSwitchOp>(
+ loc_, resultTypes_, device_, conditionAttrs, conditionArgs);
+ for (int i = 0; i < caseOps_.size(); ++i) {
+ switchOp.getRegion(i).takeBody(caseOps_[i].getRegion(0));
+ caseOps_[i].erase();
+ }
+ return switchOp;
+ }
+
+ private:
+ Location loc_;
+ SmallVector<Type, 4> resultTypes_;
+ Value device_;
+ SmallVector<IREE::HAL::DeviceSwitchOp, 4> caseOps_;
+ OpBuilder builder_;
+};
+
+// Rewriter-compatible version of DeviceSwitchBuilder.
+class DeviceSwitchRewriter {
+ public:
+ DeviceSwitchRewriter(Location loc, TypeRange resultTypes, Value device,
+ ConversionPatternRewriter &rewriter)
: loc_(loc),
resultTypes_(resultTypes),
device_(device),
diff --git a/third_party/llvm-project b/third_party/llvm-project
index 9337788..9b3c2a7 160000
--- a/third_party/llvm-project
+++ b/third_party/llvm-project
@@ -1 +1 @@
-Subproject commit 93377888ae89560ba6d3976e2762d3d4724c4dfd
+Subproject commit 9b3c2a72e4cb3b0ae27f87064c11f728452b2af9
diff --git a/third_party/tensorflow b/third_party/tensorflow
index 2e56481..090b691 160000
--- a/third_party/tensorflow
+++ b/third_party/tensorflow
@@ -1 +1 @@
-Subproject commit 2e56481abcef1dd1625fba465a5d02ee6b347842
+Subproject commit 090b691fbf7b7823c41345004d12eddaa6c86118