Adding execution and dataflow dialect. This is mostly a refactoring of the existing IREE dialect and partitioning transforms into a standalone dialect. Improvements were made to match modern ODS style, add testability, add some tests, and clean up how users can integrate the dialect into larger transformation flows. Future work will add shape ops and possibly some of the other standard input legalization (like HLO tuple flattening). PiperOrigin-RevId: 281375678
diff --git a/iree/compiler/Dialect/Flow/Analysis/BUILD b/iree/compiler/Dialect/Flow/Analysis/BUILD new file mode 100644 index 0000000..d2209ac --- /dev/null +++ b/iree/compiler/Dialect/Flow/Analysis/BUILD
@@ -0,0 +1,40 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "Analysis", + srcs = [ + "Dispatchability.cpp", + "DispatchabilityTest.cpp", + ], + hdrs = [ + "Dispatchability.h", + ], + deps = [ + "//iree/compiler/Dialect", + "//iree/compiler/Dialect/Flow/IR", + "@llvm//:support", + "@local_config_mlir//:IR", + "@local_config_mlir//:Pass", + "@local_config_mlir//:StandardOps", + "@local_config_mlir//:Support", + "@org_tensorflow//tensorflow/compiler/mlir/xla:hlo", + ], + alwayslink = 1, +)
diff --git a/iree/compiler/Dialect/Flow/Analysis/Dispatchability.cpp b/iree/compiler/Dialect/Flow/Analysis/Dispatchability.cpp new file mode 100644 index 0000000..a42933e --- /dev/null +++ b/iree/compiler/Dialect/Flow/Analysis/Dispatchability.cpp
@@ -0,0 +1,151 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/compiler/Dialect/Flow/Analysis/Dispatchability.h" + +#include <list> + +#include "llvm/ADT/SetVector.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/SymbolTable.h" +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" + +namespace mlir { +namespace iree_compiler { + +// static +LogicalResult Dispatchability::annotateIR(ModuleOp moduleOp) { + Dispatchability dispatchability; + if (failed(dispatchability.recalculate(moduleOp))) { + moduleOp.emitOpError() + << "failed to analyze dispatchability for the module"; + return failure(); + } + + Builder builder(moduleOp.getContext()); + SymbolTable symbolTable(moduleOp); + for (auto &funcDispatchability : dispatchability.funcDispatchability_) { + auto funcOp = symbolTable.lookup<FuncOp>(funcDispatchability.first); + funcOp.setAttr("dispatchable", + builder.getBoolAttr(funcDispatchability.second)); + } + + return success(); +} + +LogicalResult Dispatchability::recalculate(ModuleOp moduleOp) { + funcDispatchability_.clear(); + funcCloneModuleOp_ = ModuleOp::create(UnknownLoc::get(moduleOp.getContext())); + funcClones_.clear(); + + // Run through all functions until we are able to compute their + // dispatchability. We do this so that we can determine if calls are allowed. + OpBuilder cloneBuilder(funcCloneModuleOp_); + std::vector<FuncOp> nextWorklist(moduleOp.getOps<FuncOp>().begin(), + moduleOp.getOps<FuncOp>().end()); + std::vector<FuncOp> worklist; + bool anyChanged; + do { + anyChanged = false; + worklist.swap(nextWorklist); + nextWorklist.clear(); + for (auto funcOp : worklist) { + auto isDispatchable = computeDispatchability(funcOp); + if (isDispatchable.hasValue()) { + funcDispatchability_[funcOp.getName()] = isDispatchable.getValue(); + if (isDispatchable.getValue()) { + auto clonedFuncOp = cast<FuncOp>(cloneBuilder.clone(*funcOp)); + funcClones_[funcOp.getName()] = clonedFuncOp; + funcCloneModuleOp_.push_back(clonedFuncOp); + } + anyChanged = true; + } else { + nextWorklist.push_back(funcOp); + } + } + } while (anyChanged); + if (!nextWorklist.empty()) { + return moduleOp.emitError() << "cycle detected in dispatchability analysis"; + } + + return success(); +} + +Optional<bool> Dispatchability::computeDispatchability(FuncOp funcOp) { + if (funcOp.isExternal()) { + // We assume all imports have side-effects right now, but that may not be + // the case. We should add an attribute and check for it. + return false; + } + + // TODO(b/144530470): replace with tablegen attributes/interfaces. + for (auto &block : funcOp.getBlocks()) { + for (auto &op : block.getOperations()) { + if (auto callOp = dyn_cast<CallOp>(op)) { + if (callOp.getCallee() == funcOp.getName()) { + // Recursion. + continue; + } + auto it = funcDispatchability_.find(callOp.callee()); + if (it == funcDispatchability_.end()) { + // Not yet calculated - yield. + return llvm::None; + } + return it->second; + } else if (isa<CallIndirectOp>(op)) { + // Indirect calls are not supported and must first be devirtualized. + return false; + } else if (isa<ReturnOp>(op)) { + // TODO(benvanik): widen to all known terminators? sometimes they may + // have side-effects. + continue; + } else if (!op.getDialect() || !op.hasNoSideEffect()) { + // Ops with side-effects cannot be dispatched as we must be able to + // exactly model I/O. + return false; + } else if (isa<xla_hlo::DotOp>(op) || isa<xla_hlo::ConvOp>(op)) { + // Some unfusable ops must remain on their own. + return false; + } else if (isa<xla_hlo::ReduceOp>(op) || + isa<xla_hlo::ReduceWindowOp>(op)) { + // Reductions always become flow ops. + return false; + } + } + } + + // All cases not handled above are (probably) dispatchable. This makes what we + // do here a blacklist, though as we move towards more frontend dialects that + // may not be the best idea. + return true; +} + +void Dispatchability::walkDispatchableOps( + function_ref<void(FuncOp funcOp)> fn) { + for (auto funcOp : funcClones_) { + fn(funcOp.second); + } +} + +bool Dispatchability::isDispatchable(StringRef funcName) { + return funcDispatchability_[funcName]; +} + +bool Dispatchability::isDispatchable(FuncOp funcOp) { + return isDispatchable(funcOp.getName()); +} + +} // namespace iree_compiler +} // namespace mlir
diff --git a/iree/compiler/Dialect/Flow/Analysis/Dispatchability.h b/iree/compiler/Dialect/Flow/Analysis/Dispatchability.h new file mode 100644 index 0000000..e7a67f2 --- /dev/null +++ b/iree/compiler/Dialect/Flow/Analysis/Dispatchability.h
@@ -0,0 +1,67 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_COMPILER_DIALECT_FLOW_ANALYSIS_DISPATCHABILITY_H_ +#define IREE_COMPILER_DIALECT_FLOW_ANALYSIS_DISPATCHABILITY_H_ + +#include "mlir/IR/Function.h" +#include "mlir/IR/Module.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" + +namespace mlir { +namespace iree_compiler { + +// Analyzes functions in a module to determine whether they can be performed as +// part of a dispatch operation. Functions must meet a set of criteria defining +// "dispatchability" such as the lack of side effects. +class Dispatchability { + public: + // Annotates the IR with the dispatchability information. This is only + // required if the dispatchability information is interesting to persist + // beyond transformation, such as in tests. + static LogicalResult annotateIR(ModuleOp moduleOp); + + Dispatchability() = default; + explicit Dispatchability(Operation *op) { recalculate(cast<ModuleOp>(op)); } + Dispatchability(Dispatchability &&) = default; + Dispatchability &operator=(Dispatchability &&) = default; + Dispatchability(const Dispatchability &) = delete; + Dispatchability &operator=(const Dispatchability &) = delete; + + // Recalculates the dispatchability information for the given module. + LogicalResult recalculate(ModuleOp moduleOp); + + // Calls |fn| for each dispatchable function. + void walkDispatchableOps(function_ref<void(FuncOp funcOp)> fn); + + // Returns true if |funcOp| is dispatchable. + bool isDispatchable(StringRef funcName); + bool isDispatchable(FuncOp funcOp); + + private: + // Returns true if the given function is dispatch compatible. + // Returns None if the dispatchability can't yet be calculated as dependent + // functions have not been processed. + Optional<bool> computeDispatchability(FuncOp funcOp); + + DenseMap<StringRef, bool> funcDispatchability_; + ModuleOp funcCloneModuleOp_; + DenseMap<StringRef, FuncOp> funcClones_; +}; + +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_DIALECT_FLOW_ANALYSIS_DISPATCHABILITY_H_
diff --git a/iree/compiler/Dialect/Flow/Analysis/DispatchabilityTest.cpp b/iree/compiler/Dialect/Flow/Analysis/DispatchabilityTest.cpp new file mode 100644 index 0000000..3ac5df0 --- /dev/null +++ b/iree/compiler/Dialect/Flow/Analysis/DispatchabilityTest.cpp
@@ -0,0 +1,37 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/compiler/Dialect/Flow/Analysis/Dispatchability.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" + +namespace mlir { +namespace iree_compiler { + +class DispatchabilityTestPass + : public OperationPass<DispatchabilityTestPass, ModuleOp> { + public: + void runOnOperation() override { + if (failed(Dispatchability::annotateIR(getOperation()))) { + signalPassFailure(); + } + } +}; + +static PassRegistration<DispatchabilityTestPass> pass( + "test-iree-flow-dispatchability", + "Test pass used for dispatchability analysis"); + +} // namespace iree_compiler +} // namespace mlir
diff --git a/iree/compiler/Dialect/Flow/Analysis/test/BUILD b/iree/compiler/Dialect/Flow/Analysis/test/BUILD new file mode 100644 index 0000000..beab62e --- /dev/null +++ b/iree/compiler/Dialect/Flow/Analysis/test/BUILD
@@ -0,0 +1,28 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//iree:build_defs.bzl", "iree_glob_lit_tests", "iree_setup_lit_package") + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +iree_setup_lit_package( + data = [ + "//iree/tools:iree-opt", + ], +) + +iree_glob_lit_tests()
diff --git a/iree/compiler/Dialect/Flow/Analysis/test/dispatchability.mlir b/iree/compiler/Dialect/Flow/Analysis/test/dispatchability.mlir new file mode 100644 index 0000000..69c8242 --- /dev/null +++ b/iree/compiler/Dialect/Flow/Analysis/test/dispatchability.mlir
@@ -0,0 +1,97 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// RUN: iree-opt -split-input-file -test-iree-flow-dispatchability %s | FileCheck %s --dump-input=fail + +// CHECK-LABEL: @empty +// CHECK-NEXT: dispatchable = true +func @empty() { + return +} + +// ----- + +// CHECK-LABEL: @simpleMath +// CHECK-NEXT: dispatchable = true +func @simpleMath(%arg0 : tensor<4xf32>) -> tensor<4xf32> { + %0 = xla_hlo.add %arg0, %arg0 : tensor<4xf32> + return %0 : tensor<4xf32> +} + +// ----- + +// CHECK-LABEL: @stdElementwiseOps +// CHECK-NEXT: dispatchable = true +func @stdElementwiseOps(%arg0 : tensor<4xf32>) -> tensor<4xf32> { + %0 = addf %arg0, %arg0 : tensor<4xf32> + %1 = subf %0, %arg0 : tensor<4xf32> + %2 = mulf %1, %arg0 : tensor<4xf32> + return %2 : tensor<4xf32> +} + +// ----- + +// CHECK-LABEL: @hloElementwiseOps +// CHECK-NEXT: dispatchable = true +func @hloElementwiseOps(%arg0 : tensor<4xf32>) -> tensor<4xf32> { + %0 = xla_hlo.add %arg0, %arg0 : tensor<4xf32> + %1 = xla_hlo.sub %0, %arg0 : tensor<4xf32> + %2 = xla_hlo.mul %1, %arg0 : tensor<4xf32> + return %2 : tensor<4xf32> +} + +// ----- + +// CHECK-LABEL: @interleavedDot +// CHECK-NEXT: dispatchable = false +func @interleavedDot(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { + %0 = xla_hlo.add %arg0, %arg0 : tensor<4x4xf32> + %1 = "xla_hlo.dot"(%0, %arg0) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + %2 = xla_hlo.mul %1, %arg0 : tensor<4x4xf32> + return %2 : tensor<4x4xf32> +} + +// ----- + +// CHECK-LABEL: @caller +// CHECK-NEXT: dispatchable = true +func @caller(%arg0 : tensor<4xf32>) -> tensor<4xf32> { + %0 = xla_hlo.add %arg0, %arg0 : tensor<4xf32> + %1 = call @callee(%0) : (tensor<4xf32>) -> tensor<4xf32> + %2 = xla_hlo.mul %1, %arg0 : tensor<4xf32> + return %2 : tensor<4xf32> +} +// CHECK-LABEL: func @callee +// CHECK-NEXT: dispatchable = true +func @callee(%arg0 : tensor<4xf32>) -> tensor<4xf32> { + %0 = xla_hlo.mul %arg0, %arg0 : tensor<4xf32> + return %0 : tensor<4xf32> +} + +// ----- + +// CHECK-LABEL: @dotCaller +// CHECK-NEXT: dispatchable = false +func @dotCaller(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { + %0 = xla_hlo.add %arg0, %arg0 : tensor<4x4xf32> + %1 = call @dotCallee(%0) : (tensor<4x4xf32>) -> tensor<4x4xf32> + %2 = xla_hlo.mul %1, %arg0 : tensor<4x4xf32> + return %2 : tensor<4x4xf32> +} +// CHECK-LABEL: func @dotCallee +// CHECK-NEXT: dispatchable = false +func @dotCallee(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { + %0 = "xla_hlo.dot"(%arg0, %arg0) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + return %0 : tensor<4x4xf32> +}
diff --git a/iree/compiler/Dialect/Flow/Conversion/HLOToFlow/BUILD b/iree/compiler/Dialect/Flow/Conversion/HLOToFlow/BUILD new file mode 100644 index 0000000..e131ae1 --- /dev/null +++ b/iree/compiler/Dialect/Flow/Conversion/HLOToFlow/BUILD
@@ -0,0 +1,39 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "HLOToFlow", + srcs = [ + "ConvertHLOToFlow.cpp", + "ConvertHLOToFlowPass.cpp", + ], + hdrs = [ + "ConvertHLOToFlow.h", + ], + deps = [ + "//iree/compiler/Dialect", + "//iree/compiler/Dialect/Flow/IR", + "@local_config_mlir//:IR", + "@local_config_mlir//:Pass", + "@local_config_mlir//:StandardOps", + "@local_config_mlir//:Transforms", + "@org_tensorflow//tensorflow/compiler/mlir/xla:hlo", + ], + alwayslink = 1, +)
diff --git a/iree/compiler/Dialect/Flow/Conversion/HLOToFlow/ConvertHLOToFlow.cpp b/iree/compiler/Dialect/Flow/Conversion/HLOToFlow/ConvertHLOToFlow.cpp new file mode 100644 index 0000000..f5dfe4b --- /dev/null +++ b/iree/compiler/Dialect/Flow/Conversion/HLOToFlow/ConvertHLOToFlow.cpp
@@ -0,0 +1,70 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/compiler/Dialect/Flow/Conversion/HLOToFlow/ConvertHLOToFlow.h" + +#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/Module.h" +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" +#include "tensorflow/compiler/mlir/xla/transforms/rewriters.h" + +namespace mlir { +namespace iree_compiler { + +namespace { + +struct ConstOpLowering : public OpRewritePattern<xla_hlo::ConstOp> { + using OpRewritePattern::OpRewritePattern; + PatternMatchResult matchAndRewrite(xla_hlo::ConstOp op, + PatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp<ConstantOp>(op, op.value()); + return matchSuccess(); + } +}; + +// TODO(benvanik): dynamic update slice. + +} // namespace + +void setupHLOToFlowConversion(MLIRContext *context, + ConversionTarget &conversionTarget, + OwningRewritePatternList &patterns) { + conversionTarget.addLegalDialect<IREE::Flow::FlowDialect>(); + + // Standard ops always pass through as import code may have produced some + // and control flow should have been legalized from HLO to std. + // The flow dialect uses std.module and std.func for its structure and they + // must be allowed. + conversionTarget.addLegalDialect<StandardOpsDialect>(); + conversionTarget.addLegalOp<ModuleOp, ModuleTerminatorOp, FuncOp>(); + + // Allow all non-blacklisted HLO ops by default. Partitioning will move most + // of them into executables. + conversionTarget.addLegalDialect<xla_hlo::XlaHloDialect>(); + + // Control flow must be converted to standard form via + // xla_hlo::createLegalizeControlFlowPass() prior to conversion. + conversionTarget.addIllegalOp<xla_hlo::ConditionalOp, xla_hlo::WhileOp>(); + + conversionTarget.addIllegalOp<xla_hlo::DotGeneralOp>(); + xla_hlo::PopulateGeneralDotOpLoweringPatterns(&patterns, context); + + conversionTarget.addIllegalOp<xla_hlo::ConstOp>(); + patterns.insert<ConstOpLowering>(context); +} + +} // namespace iree_compiler +} // namespace mlir
diff --git a/iree/compiler/Dialect/Flow/Conversion/HLOToFlow/ConvertHLOToFlow.h b/iree/compiler/Dialect/Flow/Conversion/HLOToFlow/ConvertHLOToFlow.h new file mode 100644 index 0000000..d9e5dac --- /dev/null +++ b/iree/compiler/Dialect/Flow/Conversion/HLOToFlow/ConvertHLOToFlow.h
@@ -0,0 +1,36 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_COMPILER_DIALECT_FLOW_CONVERSION_HLOTOFLOW_CONVERTHLOTOFLOW_H_ +#define IREE_COMPILER_DIALECT_FLOW_CONVERSION_HLOTOFLOW_CONVERTHLOTOFLOW_H_ + +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +namespace iree_compiler { + +// Sets up a |conversionTarget| for the flow dialect and adds some direct HLO +// conversion patterns that should happen prior to partitioning. +// Callers can add additional patterns and define their own legality rules on +// the conversion target in addition to the ones provided here if there are more +// direct conversions from the source dialects to the flow dialect. +void setupHLOToFlowConversion(MLIRContext *context, + ConversionTarget &conversionTarget, + OwningRewritePatternList &patterns); + +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_DIALECT_FLOW_CONVERSION_HLOTOFLOW_CONVERTHLOTOFLOW_H_
diff --git a/iree/compiler/Dialect/Flow/Conversion/HLOToFlow/ConvertHLOToFlowPass.cpp b/iree/compiler/Dialect/Flow/Conversion/HLOToFlow/ConvertHLOToFlowPass.cpp new file mode 100644 index 0000000..9f313ea --- /dev/null +++ b/iree/compiler/Dialect/Flow/Conversion/HLOToFlow/ConvertHLOToFlowPass.cpp
@@ -0,0 +1,51 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/compiler/Dialect/Flow/Conversion/HLOToFlow/ConvertHLOToFlow.h" +#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" + +namespace mlir { +namespace iree_compiler { + +namespace { + +// A pass converting XLA HLO operations into the IREE flow dialect. +// Used only for testing as in the common case we only rely on rewrite patterns. +class ConvertHLOToFlowPass : public ModulePass<ConvertHLOToFlowPass> { + void runOnModule() override { + OwningRewritePatternList patterns; + auto module = getModule(); + + ConversionTarget target(*module.getContext()); + setupHLOToFlowConversion(module.getContext(), target, patterns); + + // NOTE: we are only looking for specific HLO ops and allow others to + // remain. + if (failed(applyFullConversion(module, target, patterns))) { + return signalPassFailure(); + } + } +}; + +} // namespace + +static PassRegistration<ConvertHLOToFlowPass> pass( + "iree-convert-hlo-to-flow", "Convert XLA HLO ops to the IREE flow dialect"); + +} // namespace iree_compiler +} // namespace mlir
diff --git a/iree/compiler/Dialect/Flow/Conversion/HLOToFlow/test/BUILD b/iree/compiler/Dialect/Flow/Conversion/HLOToFlow/test/BUILD new file mode 100644 index 0000000..beab62e --- /dev/null +++ b/iree/compiler/Dialect/Flow/Conversion/HLOToFlow/test/BUILD
@@ -0,0 +1,28 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//iree:build_defs.bzl", "iree_glob_lit_tests", "iree_setup_lit_package") + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +iree_setup_lit_package( + data = [ + "//iree/tools:iree-opt", + ], +) + +iree_glob_lit_tests()
diff --git a/iree/compiler/Dialect/Flow/IR/BUILD b/iree/compiler/Dialect/Flow/IR/BUILD new file mode 100644 index 0000000..ccbfd21 --- /dev/null +++ b/iree/compiler/Dialect/Flow/IR/BUILD
@@ -0,0 +1,85 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@local_config_mlir//:tblgen.bzl", "gentbl") + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +filegroup( + name = "td_files", + srcs = glob(["*.td"]), +) + +cc_library( + name = "IR", + srcs = [ + "FlowDialect.cpp", + "FlowEnums.cpp.inc", + "FlowOpFolders.cpp", + "FlowOps.cpp", + "FlowOps.cpp.inc", + "FlowTypes.cpp", + ], + hdrs = [ + "FlowDialect.h", + "FlowEnums.h.inc", + "FlowOps.h", + "FlowOps.h.inc", + "FlowTypes.h", + ], + deps = [ + ":FlowEnumsGen", + ":FlowOpsGen", + "//iree/compiler/Dialect", + "@llvm//:support", + "@local_config_mlir//:IR", + "@local_config_mlir//:StandardOps", + "@local_config_mlir//:Support", + "@local_config_mlir//:TransformUtils", + ], + alwayslink = 1, +) + +gentbl( + name = "FlowEnumsGen", + tbl_outs = [ + ("-gen-enum-decls", "FlowEnums.h.inc"), + ("-gen-enum-defs", "FlowEnums.cpp.inc"), + ], + tblgen = "@local_config_mlir//:mlir-tblgen", + td_file = "FlowBase.td", + td_srcs = [ + ":td_files", + "//iree/compiler/Dialect:td_files", + "@local_config_mlir//:OpBaseTdFiles", + ], +) + +gentbl( + name = "FlowOpsGen", + tbl_outs = [ + ("-gen-op-decls", "FlowOps.h.inc"), + ("-gen-op-defs", "FlowOps.cpp.inc"), + ], + tblgen = "@local_config_mlir//:mlir-tblgen", + td_file = "FlowOps.td", + td_srcs = [ + ":td_files", + "//iree/compiler/Dialect:td_files", + "@local_config_mlir//:OpBaseTdFiles", + ], +)
diff --git a/iree/compiler/Dialect/Flow/IR/FlowBase.td b/iree/compiler/Dialect/Flow/IR/FlowBase.td new file mode 100644 index 0000000..621d714 --- /dev/null +++ b/iree/compiler/Dialect/Flow/IR/FlowBase.td
@@ -0,0 +1,144 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_DIALECT_FLOW_BASE +#define IREE_DIALECT_FLOW_BASE + +#ifndef IREE_DIALECT_COMMON_BASE +include "iree/compiler/Dialect/CommonBase.td" +#endif // IREE_DIALECT_COMMON_BASE + +//===----------------------------------------------------------------------===// +// IREE execution flow dialect +//===----------------------------------------------------------------------===// + +def FLOW_Dialect : Dialect { + let name = "flow"; + let cppNamespace = "IREE::Flow"; + + let summary = [{ + A dialect designed to model execution data flow and partitioning. + }]; + let description = [{ + The flow dialect is used to model regions of dense computation and the data + flow between them. MLIR value-semantic tensors are used as the primary data + type to allow SSA use-def to provide a bulk of the infrastructure required + to perform the computation partitioning and outlining. + + The dialect is designed to ingest relatively high-level linear algebra via + XLA HLO ops (that also operate on the value-semantic tensor types) and + optionally MLIR standard ops for control flow and other actions. After + conversion of any higher-level ops that have special semantics in the flow + dialect, such as global variables, the rest are partitioned into regions + containing simple and compatible computations. Finally, outlining moves the + computations into executables and leaves only the execution flow encoded via + dispatch operations. + + The primary unit of interest is a "dispatch region" containing compatible + computations that can be scheduled together efficiently (and safely). + "Compatible" here is specified as similarly shaped workloads that indicate + how many invocations a computation can be parallelized across when running + in a SPMD execution model. Though it depends on the particular runtime + backends this more concretely means things like the untiled workload + (or tiled workgroups) used in GPU dispatches or similar thread pool + executors. + + After identification of the dispatchable regions a set of transformations + performs folding and simplification to reduce the total number of + dispatches. Heuristics are used in certain cases to more efficiently + schedule special ops (such as GEMM) and the design is amenable to profile- + guided analysis that can be added in the future. + + The resulting outlined executable modules containing the dispatchable code + can be translated to one or more backends (such as SPIR-V for Vulkan, or + LLVM IR for running on the CPU, etc). The IR that is outlined is untouched + and in the input format (such as XLA HLO ops) allowing conversion using any + MLIR target that supports ingesting such input. A few special ops are used + to communicate statically available information such as the expected + workload size, shapes of inputs and outputs, etc. + }]; +} + +//===----------------------------------------------------------------------===// +// Base flow dialect op classes +//===----------------------------------------------------------------------===// + +class FLOW_Op<string mnemonic, list<OpTrait> traits = []> : + Op<FLOW_Dialect, mnemonic, traits> { + let parser = [{ return parse$cppClass(parser, &result); }]; + let printer = [{ return print$cppClass(p, *this); }]; +} + +class FLOW_PureOp<string mnemonic, list<OpTrait> traits = []> : + FLOW_Op<mnemonic, !listconcat(traits, [NoSideEffect])>; + +//===----------------------------------------------------------------------===// +// Flow dialect types +//===----------------------------------------------------------------------===// + +def FLOW_ExecutableRefAttr : AliasedSymbolRefAttr; +def FLOW_VariableRefAttr : AliasedSymbolRefAttr; + +// TODO(benvanik): use index types instead of i32. +def FLOW_Workload : 1DTensorOf<[I32]> { + let typeDescription = [{ + Describes the total untiled invocations along one or more dimensions. + Tiling may later divide this value for workgroup sizes. + }]; +} + +// TODO(cl/277443143): moving into OpBase.td. +// A `width`-bit integer elements attribute. The attribute should be +// ranked and has a shape as specified in `dims`. +class FLOW_RankedIntElementsAttr<int width, list<int> dims> : ElementsAttrBase< + CPred<"$_self.isa<DenseIntElementsAttr>() &&" + "$_self.cast<DenseIntElementsAttr>().getType()." + "getElementType().isInteger(" # width # ") && " + // Check that this is ranked and has the specified shape. + "$_self.cast<DenseIntElementsAttr>().getType().hasRank() && " + "$_self.cast<DenseIntElementsAttr>().getType().getShape() == " + "llvm::ArrayRef<int64_t>({" # StrJoinInt<dims>.result # "})">, + width # "-bit integer elements attribute of shape [" # + StrJoinInt<dims>.result # "]"> { + + let storageType = [{ DenseIntElementsAttr }]; + let returnType = [{ DenseIntElementsAttr }]; + + let constBuilderCall = "DenseElementsAttr::get(" + "RankedTensorType::get({" # StrJoinInt<dims>.result # + "}, $_builder.getIntegerType(" # width # ")), " + "llvm::makeArrayRef($0)).cast<DenseIntElementsAttr>()"; + let convertFromStorage = "$_self"; +} + +def FLOW_WorkloadAttr : FLOW_RankedIntElementsAttr<32, [3]>; + +// Use no padding and clamp the window to the valid area, possibly stopping +// early prior to having covered all data. +def FLOW_PM_ClampWindowToFit : I32EnumAttrCase<"ClampWindowToFit", 0>; +// Use initial values for padding when windows cross dimension boundaries. +def FLOW_PM_PadBorder : I32EnumAttrCase<"PadBorder", 1>; +// Describes the padding applied for a windowed operation like convolution, +// where a window is placed inside a base area. +def FLOW_PaddingModeAttr : + I32EnumAttr<"PaddingMode", "Padding mode", [ + FLOW_PM_ClampWindowToFit, + FLOW_PM_PadBorder, + ]> { + let returnType = "::mlir::iree_compiler::IREE::Flow::PaddingMode"; + let convertFromStorage = "static_cast<::mlir::iree_compiler::IREE::Flow::PaddingMode>($_self.getInt())"; + let cppNamespace = "::mlir::iree_compiler::IREE::Flow"; +} + +#endif // IREE_DIALECT_FLOW_BASE
diff --git a/iree/compiler/Dialect/Flow/IR/FlowDialect.cpp b/iree/compiler/Dialect/Flow/IR/FlowDialect.cpp new file mode 100644 index 0000000..b6a7fe3 --- /dev/null +++ b/iree/compiler/Dialect/Flow/IR/FlowDialect.cpp
@@ -0,0 +1,58 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h" + +#include "iree/compiler/Dialect/Flow/IR/FlowOps.h" +#include "iree/compiler/Dialect/Flow/IR/FlowTypes.h" +#include "llvm/Support/SourceMgr.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/Transforms/FoldUtils.h" +#include "mlir/Transforms/InliningUtils.h" + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace Flow { + +static DialectRegistration<FlowDialect> flow_dialect; + +namespace { + +struct FlowFolderInterface : public OpFolderDialectInterface { + using OpFolderDialectInterface::OpFolderDialectInterface; + + bool shouldMaterializeInto(Region *region) const override { + // TODO(benvanik): redirect constants to the region scope when small. + return false; + } +}; + +} // namespace + +FlowDialect::FlowDialect(MLIRContext *context) + : Dialect(getDialectNamespace(), context) { + addInterfaces<FlowFolderInterface>(); + +#define GET_OP_LIST + addOperations< +#include "iree/compiler/Dialect/Flow/IR/FlowOps.cpp.inc" + >(); +} + +} // namespace Flow +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir
diff --git a/iree/compiler/Dialect/Flow/IR/FlowDialect.h b/iree/compiler/Dialect/Flow/IR/FlowDialect.h new file mode 100644 index 0000000..d484adb --- /dev/null +++ b/iree/compiler/Dialect/Flow/IR/FlowDialect.h
@@ -0,0 +1,38 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_COMPILER_DIALECT_FLOW_IR_FLOWDIALECT_H_ +#define IREE_COMPILER_DIALECT_FLOW_IR_FLOWDIALECT_H_ + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/SymbolTable.h" + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace Flow { + +class FlowDialect : public Dialect { + public: + explicit FlowDialect(MLIRContext *context); + static StringRef getDialectNamespace() { return "flow"; } +}; + +} // namespace Flow +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_DIALECT_FLOW_IR_FLOWDIALECT_H_
diff --git a/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp b/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp new file mode 100644 index 0000000..748f0f6 --- /dev/null +++ b/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
@@ -0,0 +1,132 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <algorithm> + +#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h" +#include "iree/compiler/Dialect/Flow/IR/FlowOps.h" +#include "llvm/ADT/StringExtras.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Support/LogicalResult.h" + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace Flow { + +//===----------------------------------------------------------------------===// +// Variables +//===----------------------------------------------------------------------===// + +namespace { + +/// Converts variable initializer functions that evaluate to a constant to a +/// specified initial value. +struct InlineConstVariableOpInitializer : public OpRewritePattern<VariableOp> { + using OpRewritePattern<VariableOp>::OpRewritePattern; + + PatternMatchResult matchAndRewrite(VariableOp op, + PatternRewriter &rewriter) const override { + if (!op.initializer()) return matchFailure(); + auto *symbolOp = + SymbolTable::lookupNearestSymbolFrom(op, op.initializer().getValue()); + auto initializer = cast<FuncOp>(symbolOp); + if (initializer.getBlocks().size() == 1 && + initializer.getBlocks().front().getOperations().size() == 2 && + isa<mlir::ReturnOp>( + initializer.getBlocks().front().getOperations().back())) { + auto &primaryOp = initializer.getBlocks().front().getOperations().front(); + Attribute constResult; + if (matchPattern(primaryOp.getResult(0), m_Constant(&constResult))) { + rewriter.replaceOpWithNewOp<VariableOp>( + op, op.sym_name(), op.is_mutable(), op.type(), constResult); + return matchSuccess(); + } + } + return matchFailure(); + } +}; + +} // namespace + +void VariableOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert<InlineConstVariableOpInitializer>(context); +} + +namespace { + +/// Erases flow.variable.load ops whose values are unused. +/// We have to do this manually as the load op cannot be marked pure and have it +/// done automatically. +struct EraseUnusedVariableLoadOp : public OpRewritePattern<VariableLoadOp> { + using OpRewritePattern<VariableLoadOp>::OpRewritePattern; + + PatternMatchResult matchAndRewrite(VariableLoadOp op, + PatternRewriter &rewriter) const override { + if (op.result()->use_empty()) { + rewriter.eraseOp(op); + return matchSuccess(); + } + return matchFailure(); + } +}; + +} // namespace + +void VariableLoadOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert<EraseUnusedVariableLoadOp>(context); +} + +namespace { + +/// Erases flow.variable.store ops that are no-ops. +/// This can happen if there was a variable load, some DCE'd usage, and a +/// store back to the same variable: we want to be able to elide the entire load +/// and store. +struct EraseUnusedVariableStoreOp : public OpRewritePattern<VariableStoreOp> { + using OpRewritePattern<VariableStoreOp>::OpRewritePattern; + + PatternMatchResult matchAndRewrite(VariableStoreOp op, + PatternRewriter &rewriter) const override { + if (auto loadOp = + dyn_cast_or_null<VariableLoadOp>(op.value()->getDefiningOp())) { + if (loadOp.variable() == op.variable()) { + rewriter.eraseOp(op); + return matchSuccess(); + } + } + return matchFailure(); + } +}; + +} // namespace + +void VariableStoreOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert<EraseUnusedVariableStoreOp>(context); +} + +} // namespace Flow +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir
diff --git a/iree/compiler/Dialect/Flow/IR/FlowOps.cpp b/iree/compiler/Dialect/Flow/IR/FlowOps.cpp new file mode 100644 index 0000000..9760076 --- /dev/null +++ b/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
@@ -0,0 +1,880 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/compiler/Dialect/Flow/IR/FlowOps.h" + +#include "iree/compiler/Dialect/Types.h" +#include "llvm/ADT/StringExtras.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Support/LogicalResult.h" + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace Flow { + +// Returns true if the given |accessType| is compatible with the |variableType|. +// For example, this will return true if the variable type is a tensor<?xf32> +// and the access is tensor<4xf32>. +static bool isVariableTypeCompatible(Type variableType, Type accessType) { + return succeeded(mlir::verifyCompatibleShape(variableType, accessType)); +} + +//===----------------------------------------------------------------------===// +// flow.variable +//===----------------------------------------------------------------------===// + +static ParseResult parseVariableOp(OpAsmParser &parser, + OperationState *result) { + StringAttr nameAttr; + if (failed(parser.parseSymbolName(nameAttr, + mlir::SymbolTable::getSymbolAttrName(), + result->attributes))) { + return failure(); + } + + if (succeeded(parser.parseOptionalKeyword("mutable"))) { + result->addAttribute("is_mutable", UnitAttr::get(result->getContext())); + } + + if (succeeded(parser.parseOptionalKeyword("init"))) { + FlatSymbolRefAttr initializerAttr; + if (failed(parser.parseLParen()) || + failed(parser.parseAttribute(initializerAttr, "initializer", + result->attributes)) || + failed(parser.parseRParen())) { + return failure(); + } + } + + if (failed(parser.parseOptionalColon())) { + Attribute initialValueAttr; + if (failed(parser.parseAttribute(initialValueAttr, "initial_value", + result->attributes))) { + return failure(); + } + result->addAttribute("type", TypeAttr::get(initialValueAttr.getType())); + } else { + Type type; + if (failed(parser.parseType(type))) { + return failure(); + } + result->addAttribute("type", TypeAttr::get(type)); + } + + return success(); +} + +static void printVariableOp(OpAsmPrinter &p, VariableOp op) { + p << op.getOperationName() << ' '; + p.printSymbolName(op.sym_name()); + if (op.is_mutable()) { + p << " mutable"; + } + if (op.initializer().hasValue()) { + p << " init("; + p.printSymbolName(op.initializer().getValue()); + p << ')'; + } + if (op.initial_value().hasValue()) { + p << ' '; + p.printAttribute(op.initial_value().getValue()); + } else { + p << " : "; + p.printType(op.type()); + } +} + +static LogicalResult verifyVariableOp(VariableOp op) { + if (op.initializer().hasValue() && op.initial_value().hasValue()) { + return op.emitOpError() + << "variables can have either an initializer or an initial value"; + } else if (op.initializer().hasValue()) { + // Ensure initializer returns the same type as the variable. + auto *symbolOp = + SymbolTable::lookupNearestSymbolFrom(op, op.initializer().getValue()); + if (!symbolOp) { + return op.emitOpError() << "initializer function " + << op.initializer().getValue() << " not found"; + } + auto initializerOp = dyn_cast<FuncOp>(symbolOp); + if (initializerOp.getNumArguments() != 0 || + initializerOp.getNumResults() != 1 || + initializerOp.getType().getResult(0) != op.type()) { + return op.emitOpError() + << "initializer type mismatch; variable " << op.sym_name() + << " is " << op.type() << " but initializer function " + << initializerOp.getName() << " is " << initializerOp.getType(); + } + } else if (op.initial_value().hasValue()) { + // Ensure the value is something we can store in the variable + if (!isVariableTypeCompatible(op.type(), op.initial_value()->getType())) { + return op.emitOpError() + << "initial value type mismatch; variable " << op.sym_name() + << " is " << op.type() << " but initial value provided is " + << op.initial_value()->getType(); + } + } + return success(); +} + +void VariableOp::build(Builder *builder, OperationState &state, StringRef name, + bool isMutable, FuncOp initializer, + ArrayRef<NamedAttribute> attrs) { + state.addAttribute(SymbolTable::getSymbolAttrName(), + builder->getStringAttr(name)); + if (isMutable) { + state.addAttribute("is_mutable", builder->getUnitAttr()); + } + state.addAttribute("initializer", builder->getSymbolRefAttr(initializer)); + state.addAttribute("type", TypeAttr::get(initializer.getType().getResult(0))); + state.attributes.append(attrs.begin(), attrs.end()); +} + +void VariableOp::build(Builder *builder, OperationState &result, StringRef name, + bool isMutable, Type type, Attribute initialValue, + ArrayRef<NamedAttribute> attrs) { + result.addAttribute(SymbolTable::getSymbolAttrName(), + builder->getStringAttr(name)); + if (isMutable) { + result.addAttribute("is_mutable", builder->getUnitAttr()); + } + result.addAttribute("initial_value", initialValue); + result.addAttribute("type", TypeAttr::get(type)); + result.attributes.append(attrs.begin(), attrs.end()); +} + +void VariableOp::build(Builder *builder, OperationState &result, StringRef name, + bool isMutable, Type type, + ArrayRef<NamedAttribute> attrs) { + result.addAttribute(SymbolTable::getSymbolAttrName(), + builder->getStringAttr(name)); + if (isMutable) { + result.addAttribute("is_mutable", builder->getUnitAttr()); + } + result.addAttribute("type", TypeAttr::get(type)); + result.attributes.append(attrs.begin(), attrs.end()); +} + +//===----------------------------------------------------------------------===// +// flow.variable.load +//===----------------------------------------------------------------------===// + +static ParseResult parseVariableLoadOp(OpAsmParser &parser, + OperationState *result) { + FlatSymbolRefAttr variableAttr; + Type valueType; + if (failed(parser.parseAttribute(variableAttr, "variable", + result->attributes)) || + failed(parser.parseOptionalAttrDict(result->attributes)) || + failed(parser.parseColonType(valueType))) { + return failure(); + } + result->addTypes({valueType}); + return success(); +} + +static void printVariableLoadOp(OpAsmPrinter &p, VariableLoadOp &op) { + p << op.getOperationName() << ' '; + p.printSymbolName(op.variable()); + p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"variable"}); + p << " : "; + p.printType(op.result()->getType()); +} + +static LogicalResult verifyVariableLoadOp(VariableLoadOp &op) { + auto *symbolOp = SymbolTable::lookupNearestSymbolFrom(op, op.variable()); + if (!symbolOp) { + return op.emitOpError() << "undefined variable: " << op.variable(); + } + auto variableOp = dyn_cast<VariableOp>(symbolOp); + auto loadType = op.result()->getType(); + if (!isVariableTypeCompatible(variableOp.type(), loadType)) { + return op.emitOpError() + << "variable type mismatch; variable " << op.variable() << " is " + << variableOp.type() << " but load is " << loadType; + } + return success(); +} + +//===----------------------------------------------------------------------===// +// flow.variable.store +//===----------------------------------------------------------------------===// + +static ParseResult parseVariableStoreOp(OpAsmParser &parser, + OperationState *result) { + FlatSymbolRefAttr variableAttr; + OpAsmParser::OperandType value; + Type valueType; + if (failed(parser.parseAttribute(variableAttr, "variable", + result->attributes)) || + failed(parser.parseComma()) || failed(parser.parseOperand(value)) || + failed(parser.parseOptionalAttrDict(result->attributes)) || + failed(parser.parseColonType(valueType)) || + failed(parser.resolveOperand(value, valueType, result->operands))) { + return failure(); + } + return success(); +} + +static void printVariableStoreOp(OpAsmPrinter &p, VariableStoreOp &op) { + p << op.getOperationName() << ' '; + p.printSymbolName(op.variable()); + p << ", "; + p.printOperand(op.value()); + p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"variable"}); + p << " : "; + p.printType(op.value()->getType()); +} + +static LogicalResult verifyVariableStoreOp(VariableStoreOp &op) { + auto *symbolOp = SymbolTable::lookupNearestSymbolFrom(op, op.variable()); + if (!symbolOp) { + return op.emitOpError() << "undefined variable: " << op.variable(); + } + auto variableOp = dyn_cast<VariableOp>(symbolOp); + auto storeType = op.value()->getType(); + if (!isVariableTypeCompatible(variableOp.type(), storeType)) { + return op.emitOpError() + << "variable type mismatch; variable " << op.variable() << " is " + << variableOp.type() << " but store is " << storeType; + } + if (!variableOp.is_mutable()) { + return op.emitOpError() << "variable " << op.variable() + << " is not mutable and cannot be stored to"; + } + return success(); +} + +//===----------------------------------------------------------------------===// +// flow.dispatch.region +//===----------------------------------------------------------------------===// + +void DispatchRegionOp::build(Builder *builder, OperationState &state, + ArrayRef<Type> resultTypes, Value *workload, + ArrayRef<Value *> operands, + ArrayRef<NamedAttribute> attributes) { + state.addTypes(resultTypes); + state.addOperands({workload}); + state.addOperands(operands); + state.addAttributes(attributes); + state.addRegion(); + state.setOperandListToResizable(); +} + +ParseResult parseDispatchRegionOp(OpAsmParser &parser, OperationState *result) { + // Parse required workload. + OpAsmParser::OperandType workloadArg; + Type workloadArgType; + if (failed(parser.parseLSquare()) || + failed(parser.parseOperand(workloadArg)) || + failed(parser.parseColonType(workloadArgType)) || + failed(parser.parseRSquare()) || + failed(parser.resolveOperand(workloadArg, workloadArgType, + result->operands))) { + return failure(); + } + + // Parse (optional) args. + SmallVector<OpAsmParser::OperandType, 16> regionArgs; + SmallVector<Type, 16> regionArgTypes; + if (failed(parser.parseLParen())) { + return failure(); + } + if (failed(parser.parseOptionalRParen())) { + SmallVector<OpAsmParser::OperandType, 16> regionOperands; + auto argsLoc = parser.getCurrentLocation(); + do { + // Reserve entries in the lists. + regionArgs.emplace_back(); + regionOperands.emplace_back(); + regionArgTypes.emplace_back(); + if (failed(parser.parseRegionArgument(regionArgs.back())) || + failed(parser.parseEqual()) || + failed(parser.parseOperand(regionOperands.back())) || + failed(parser.parseColonType(regionArgTypes.back()))) { + return failure(); + } + } while (succeeded(parser.parseOptionalComma())); + if (failed(parser.parseRParen()) || + failed(parser.resolveOperands(regionOperands, regionArgTypes, argsLoc, + result->operands))) { + return failure(); + } + } + result->setOperandListToResizable(); + + // Parse (optional) results. + if (failed(parser.parseOptionalColonTypeList(result->types))) { + return failure(); + } + + // Parse region body. + Region *body = result->addRegion(); + if (failed(parser.parseRegion(*body, regionArgs, regionArgTypes)) || + failed(parser.parseOptionalAttrDict(result->attributes))) { + return failure(); + } + return success(); +} + +void printDispatchRegionOp(OpAsmPrinter &p, DispatchRegionOp op) { + p << op.getOperationName(); + + // Print the workload argument. + p << "["; + p.printOperand(op.workload()); + p << " : "; + p.printType(op.workload()->getType()); + p << "]"; + + // Print the data argument remapping. + p << "("; + interleaveComma(llvm::zip(op.body().front().getArguments(), op.args()), p, + [&](std::tuple<BlockArgument *, Value *> it) { + p << *std::get<0>(it) << " = " << *std::get<1>(it); + p << " : "; + p << std::get<1>(it)->getType(); + }); + p << ")"; + + // Print the result types, if any. + if (op.getNumResults() > 0) { + p << " : "; + interleaveComma(op.getResultTypes(), p); + } + + p.printRegion(op.body(), /*printEntryBlockArgs=*/false); + p.printOptionalAttrDict(op.getAttrs(), + /*elidedAttrs=*/{}); +} + +//===----------------------------------------------------------------------===// +// flow.reduction.region +//===----------------------------------------------------------------------===// + +void ReductionRegionOp::build(Builder *builder, OperationState &state, + ArrayRef<Type> resultTypes, Value *workload, + ArrayRef<Value *> operands, + ArrayRef<Value *> initialValues, + ArrayRef<int32_t> dimensions, + ArrayRef<NamedAttribute> attributes) { + state.addTypes(resultTypes); + state.addOperands({workload}); + state.addOperands(operands); + state.addOperands(initialValues); + state.addAttribute( + "dimensions", + DenseIntElementsAttr::get( + RankedTensorType::get({static_cast<int32_t>(dimensions.size())}, + builder->getIntegerType(32)), + dimensions)); + state.addAttributes(attributes); + state.addRegion(); + state.setOperandListToResizable(); +} + +ParseResult parseReductionRegionOp(OpAsmParser &parser, + OperationState *result) { + OpAsmParser::OperandType workloadArg; + Type workloadArgType; + if (failed(parser.parseLSquare()) || + failed(parser.parseOperand(workloadArg)) || + failed(parser.parseColonType(workloadArgType)) || + failed(parser.parseRSquare()) || + failed(parser.resolveOperand(workloadArg, workloadArgType, + result->operands))) { + return failure(); + } + + SmallVector<OpAsmParser::OperandType, 8> reductionOperands; + Type reductionType; + auto operandsLoc = parser.getCurrentLocation(); + if (failed(parser.parseLParen()) || + failed(parser.parseOperandList(reductionOperands)) || + failed(parser.parseRParen()) || + failed(parser.parseColonType(reductionType)) || + failed(parser.resolveOperands( + reductionOperands, reductionType.cast<FunctionType>().getInputs(), + operandsLoc, result->operands))) { + return failure(); + } + for (auto type : reductionType.cast<FunctionType>().getResults()) { + result->types.push_back(type); + } + result->setOperandListToResizable(); + + SmallVector<OpAsmParser::OperandType, 8> regionArgs; + SmallVector<Type, 8> regionArgTypes; + if (failed(parser.parseKeyword("invocation")) || + failed(parser.parseLParen())) { + return failure(); + } + do { + Type argType; + SmallVector<OpAsmParser::OperandType, 2> reductionRegionArgs; + OpAsmParser::OperandType initialValue; + if (failed(parser.parseLParen()) || + failed(parser.parseOperandList(reductionRegionArgs, 2)) || + failed(parser.parseRParen()) || failed(parser.parseEqual()) || + failed(parser.parseOperand(initialValue)) || + failed(parser.parseColonType(argType)) || + failed( + parser.resolveOperand(initialValue, argType, result->operands))) { + return failure(); + } + regionArgs.push_back(reductionRegionArgs[0]); + regionArgTypes.push_back(argType); + regionArgs.push_back(reductionRegionArgs[1]); + regionArgTypes.push_back(argType); + } while (succeeded(parser.parseOptionalComma())); + if (failed(parser.parseRParen())) { + return failure(); + } + + // Parse region body. + Region *body = result->addRegion(); + if (failed(parser.parseRegion(*body, regionArgs, regionArgTypes)) || + failed(parser.parseOptionalAttrDict(result->attributes))) { + return failure(); + } + + return success(); +} + +void printReductionRegionOp(OpAsmPrinter &p, ReductionRegionOp op) { + p << op.getOperationName(); + + // Print the workload argument. + p << "["; + p.printOperand(op.workload()); + p << " : "; + p.printType(op.workload()->getType()); + p << "]"; + + p << "("; + p.printOperands(op.operands()); + p << ")"; + if (op.getNumResults() > 0) { + p << " : ("; + interleaveComma(op.operands(), p, + [&](Value *operand) { p.printType(operand->getType()); }); + p << ")"; + p << " -> ("; + interleaveComma(op.getResultTypes(), p); + p << ")"; + } + p << "\n"; + + p << " invocation("; + auto &entryBlock = op.body().getBlocks().front(); + int regionArgIndex = 0; + interleaveComma(op.initial_values(), p, [&](Value *operand) { + p << "("; + p.printOperand(entryBlock.getArgument(regionArgIndex++)); + p << ", "; + p.printOperand(entryBlock.getArgument(regionArgIndex++)); + p << ") = "; + p.printOperand(operand); + p << " : "; + p.printType(operand->getType()); + }); + p << ") "; + + p.printRegion(op.body(), /*printEntryBlockArgs=*/false); + p.printOptionalAttrDict(op.getAttrs(), + /*elidedAttrs=*/{}); +} + +//===----------------------------------------------------------------------===// +// flow.windowed_reduction.region +//===----------------------------------------------------------------------===// + +void WindowedReductionRegionOp::build( + Builder *builder, OperationState &state, ArrayRef<Type> resultTypes, + Value *workload, ArrayRef<Value *> operands, + ArrayRef<Value *> initialValues, ArrayRef<int32_t> windowDimensions, + ArrayRef<int32_t> windowStrides, ArrayRef<int32_t> baseDilations, + ArrayRef<int32_t> windowDilations, PaddingMode paddingMode, + ArrayRef<NamedAttribute> attributes) { + state.addTypes(resultTypes); + state.addOperands({workload}); + state.addOperands(operands); + state.addOperands(initialValues); + state.addAttribute( + "window_dimensions", + DenseIntElementsAttr::get( + RankedTensorType::get({static_cast<int32_t>(windowDimensions.size())}, + builder->getIntegerType(32)), + windowDimensions)); + state.addAttribute( + "window_strides", + DenseIntElementsAttr::get( + RankedTensorType::get({static_cast<int32_t>(windowStrides.size())}, + builder->getIntegerType(32)), + windowStrides)); + state.addAttribute( + "base_dilations", + DenseIntElementsAttr::get( + RankedTensorType::get({static_cast<int32_t>(baseDilations.size())}, + builder->getIntegerType(32)), + baseDilations)); + state.addAttribute( + "window_dilations", + DenseIntElementsAttr::get( + RankedTensorType::get({static_cast<int32_t>(windowDilations.size())}, + builder->getIntegerType(32)), + windowDilations)); + state.addAttribute("padding_mode", builder->getI32IntegerAttr( + static_cast<int32_t>(paddingMode))); + state.addAttributes(attributes); + state.addRegion(); + state.setOperandListToResizable(); +} + +ParseResult parseWindowedReductionRegionOp(OpAsmParser &parser, + OperationState *result) { + return parseReductionRegionOp(parser, result); +} + +void printWindowedReductionRegionOp(OpAsmPrinter &p, + WindowedReductionRegionOp op) { + p << op.getOperationName(); + + // Print the workload argument. + p << "["; + p.printOperand(op.workload()); + p << " : "; + p.printType(op.workload()->getType()); + p << "]"; + + p << "("; + p.printOperands(op.operands()); + p << ")"; + if (op.getNumResults() > 0) { + p << " : ("; + interleaveComma(op.operands(), p, + [&](Value *operand) { p.printType(operand->getType()); }); + p << ")"; + p << " -> ("; + interleaveComma(op.getResultTypes(), p); + p << ")"; + } + p << "\n"; + + p << " invocation("; + auto &entryBlock = op.body().getBlocks().front(); + int regionArgIndex = 0; + interleaveComma(op.initial_values(), p, [&](Value *operand) { + p << "("; + p.printOperand(entryBlock.getArgument(regionArgIndex++)); + p << ", "; + p.printOperand(entryBlock.getArgument(regionArgIndex++)); + p << ") = "; + p.printOperand(operand); + p << " : "; + p.printType(operand->getType()); + }); + p << ") "; + + p.printRegion(op.body(), /*printEntryBlockArgs=*/false); + p.printOptionalAttrDict(op.getAttrs(), + /*elidedAttrs=*/{}); +} + +//===----------------------------------------------------------------------===// +// flow.return +//===----------------------------------------------------------------------===// + +static ParseResult parseReturnOp(OpAsmParser &parser, OperationState *result) { + SmallVector<OpAsmParser::OperandType, 2> opInfo; + SmallVector<Type, 2> types; + llvm::SMLoc loc = parser.getCurrentLocation(); + return failure(parser.parseOperandList(opInfo) || + (!opInfo.empty() && parser.parseColonTypeList(types)) || + parser.resolveOperands(opInfo, types, loc, result->operands)); +} + +static void printReturnOp(OpAsmPrinter &p, ReturnOp op) { + p << op.getOperationName(); + if (op.getNumOperands() > 0) { + p << ' '; + p.printOperands(op.operand_begin(), op.operand_end()); + p << " : "; + interleaveComma(op.getOperandTypes(), p); + } +} + +//===----------------------------------------------------------------------===// +// flow.dispatch +//===----------------------------------------------------------------------===// + +static ParseResult parseDispatchOp(OpAsmParser &parser, + OperationState *result) { + auto executableLoc = parser.getNameLoc(); + + // TODO(benvanik): replace with SymbolRefAttr. + StringAttr executableAttr; + StringAttr entryPointAttr; + if (failed(parser.parseSymbolName(executableAttr, "executable", + result->attributes)) || + failed(parser.parseColon()) || failed(parser.parseColon()) || + failed(parser.parseSymbolName(entryPointAttr, "entry_point", + result->attributes))) { + return failure(); + } + result->attributes[0].second = + parser.getBuilder().getSymbolRefAttr(executableAttr.getValue()); + result->attributes[1].second = + parser.getBuilder().getSymbolRefAttr(entryPointAttr.getValue()); + + OpAsmParser::OperandType workloadArg; + Type workloadArgType; + if (failed(parser.parseLSquare()) || + failed(parser.parseOperand(workloadArg)) || + failed(parser.parseColonType(workloadArgType)) || + failed(parser.parseRSquare()) || + failed(parser.resolveOperand(workloadArg, workloadArgType, + result->operands))) { + return failure(); + } + + SmallVector<OpAsmParser::OperandType, 4> operands; + FunctionType entryPointType; + if (failed( + parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren)) || + failed(parser.parseOptionalAttrDict(result->attributes)) || + failed(parser.parseColonType(entryPointType)) || + failed( + parser.addTypesToList(entryPointType.getResults(), result->types)) || + failed(parser.resolveOperands(operands, entryPointType.getInputs(), + executableLoc, result->operands))) { + return failure(); + } + return success(); +} + +static void printDispatchOp(OpAsmPrinter &p, DispatchOp op) { + p << op.getOperationName() << ' '; + // TODO(benvanik): replace with SymbolRefAttr. + p.printSymbolName(op.executable()); + p << "::"; + p.printSymbolName(op.entry_point()); + p << "["; + p.printOperand(op.workload()); + p << " : "; + p.printType(op.workload()->getType()); + p << "]("; + p.printOperands(op.operands()); + p << ')'; + p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{ + "executable", + "entry_point", + }); + p << " : "; + p.printType(op.getEntryPointType()); +} + +FunctionType DispatchOp::getEntryPointType() { + SmallVector<Type, 4> resultTypes(getResultTypes()); + SmallVector<Type, 8> argTypes(operand_type_range{operands()}); + return FunctionType::get(argTypes, resultTypes, getContext()); +} + +//===----------------------------------------------------------------------===// +// flow.executable +//===----------------------------------------------------------------------===// + +void ExecutableOp::build(Builder *builder, OperationState &state, + StringRef name) { + ensureTerminator(*state.addRegion(), *builder, state.location); + state.addAttribute(mlir::SymbolTable::getSymbolAttrName(), + builder->getStringAttr(name)); +} + +static ParseResult parseExecutableOp(OpAsmParser &parser, + OperationState *result) { + StringAttr nameAttr; + if (failed(parser.parseSymbolName(nameAttr, + mlir::SymbolTable::getSymbolAttrName(), + result->attributes)) || + failed(parser.parseOptionalAttrDictWithKeyword(result->attributes))) { + return failure(); + } + + // Parse the module body. + auto *body = result->addRegion(); + if (failed(parser.parseRegion(*body, llvm::None, llvm::None))) { + return failure(); + } + + // Ensure that this module has a valid terminator. + ExecutableOp::ensureTerminator(*body, parser.getBuilder(), result->location); + return success(); +} + +static void printExecutableOp(OpAsmPrinter &p, ExecutableOp op) { + p << op.getOperationName() << ' '; + p.printSymbolName(op.sym_name()); + p.printOptionalAttrDictWithKeyword( + op.getAttrs(), + /*elidedAttrs=*/{mlir::SymbolTable::getSymbolAttrName()}); + p.printRegion(op.body(), /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/false); +} + +static LogicalResult verifyExecutableOp(ExecutableOp op) { + // TODO(benvanik): check export name conflicts. + return success(); +} + +static ParseResult parseRegionEndOp(OpAsmParser &parser, + OperationState *result) { + return parser.parseOptionalAttrDict(result->attributes); +} + +static void printRegionEndOp(OpAsmPrinter &p, Operation *op) { + p << op->getName(); + p.printOptionalAttrDict(op->getAttrs()); +} + +//===----------------------------------------------------------------------===// +// flow.dispatch.entry +//===----------------------------------------------------------------------===// + +static ParseResult parseDispatchEntryOp(OpAsmParser &parser, + OperationState *result) { + FlatSymbolRefAttr functionRefAttr; + if (failed(parser.parseAttribute(functionRefAttr, "function_ref", + result->attributes))) { + return failure(); + } + + if (succeeded(parser.parseOptionalKeyword("as"))) { + StringAttr exportNameAttr; + if (failed(parser.parseLParen()) || + failed(parser.parseAttribute(exportNameAttr, "sym_name", + result->attributes)) || + failed(parser.parseRParen())) { + return failure(); + } + } else { + result->addAttribute("sym_name", parser.getBuilder().getStringAttr( + functionRefAttr.getValue())); + } + + if (failed(parser.parseOptionalAttrDictWithKeyword(result->attributes))) { + return failure(); + } + + return success(); +} + +static void printDispatchEntryOp(OpAsmPrinter &p, DispatchEntryOp op) { + p << op.getOperationName() << ' '; + p.printSymbolName(op.function_ref()); + if (op.sym_name() != op.function_ref()) { + p << " as(\"" << op.sym_name() << "\")"; + } + p.printOptionalAttrDictWithKeyword( + op.getAttrs(), /*elidedAttrs=*/{"function_ref", "sym_name"}); +} + +//===----------------------------------------------------------------------===// +// flow.reduction.entry / flow.windowed_reduction.entry +//===----------------------------------------------------------------------===// + +static ParseResult parseReductionEntryOp(OpAsmParser &parser, + OperationState *result) { + FlatSymbolRefAttr functionRefAttr; + FlatSymbolRefAttr applyRefAttr; + if (failed(parser.parseAttribute(functionRefAttr, "function_ref", + result->attributes)) || + failed(parser.parseKeyword("apply")) || failed(parser.parseLParen()) || + failed(parser.parseAttribute(applyRefAttr, "apply_ref", + result->attributes)) || + failed(parser.parseRParen())) { + return failure(); + } + + if (succeeded(parser.parseOptionalKeyword("as"))) { + StringAttr exportNameAttr; + if (failed(parser.parseLParen()) || + failed(parser.parseAttribute(exportNameAttr, "sym_name", + result->attributes)) || + failed(parser.parseRParen())) { + return failure(); + } + } else { + result->addAttribute("sym_name", parser.getBuilder().getStringAttr( + functionRefAttr.getValue())); + } + + if (failed(parser.parseOptionalAttrDictWithKeyword(result->attributes))) { + return failure(); + } + + return success(); +} + +static void printReductionEntryOp(OpAsmPrinter &p, ReductionEntryOp op) { + p << op.getOperationName() << ' '; + p.printSymbolName(op.function_ref()); + p << " apply("; + p.printSymbolName(op.apply_ref()); + p << ")"; + if (op.sym_name() != op.function_ref()) { + p << " as(\"" << op.sym_name() << "\")"; + } + p.printOptionalAttrDictWithKeyword( + op.getAttrs(), /*elidedAttrs=*/{"apply_ref", "function_ref", "sym_name"}); +} + +static ParseResult parseWindowedReductionEntryOp(OpAsmParser &parser, + OperationState *result) { + return parseReductionEntryOp(parser, result); +} + +static void printWindowedReductionEntryOp(OpAsmPrinter &p, + WindowedReductionEntryOp op) { + p << op.getOperationName() << ' '; + p.printSymbolName(op.function_ref()); + p << " apply("; + p.printSymbolName(op.apply_ref()); + p << ")"; + if (op.sym_name() != op.function_ref()) { + p << " as(\"" << op.sym_name() << "\")"; + } + p.printOptionalAttrDictWithKeyword( + op.getAttrs(), /*elidedAttrs=*/{"apply_ref", "function_ref", "sym_name"}); +} + +//===----------------------------------------------------------------------===// +// TableGen definitions (intentionally last) +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "iree/compiler/Dialect/Flow/IR/FlowOps.cpp.inc" + +} // namespace Flow +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir
diff --git a/iree/compiler/Dialect/Flow/IR/FlowOps.h b/iree/compiler/Dialect/Flow/IR/FlowOps.h new file mode 100644 index 0000000..21abd2c --- /dev/null +++ b/iree/compiler/Dialect/Flow/IR/FlowOps.h
@@ -0,0 +1,45 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_COMPILER_DIALECT_FLOW_IR_FLOWOPS_H_ +#define IREE_COMPILER_DIALECT_FLOW_IR_FLOWOPS_H_ + +#include <cstdint> + +#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h" +#include "iree/compiler/Dialect/Flow/IR/FlowTypes.h" +#include "iree/compiler/Dialect/Traits.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/FunctionSupport.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/SymbolTable.h" + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace Flow { + +#define GET_OP_CLASSES +#include "iree/compiler/Dialect/Flow/IR/FlowOps.h.inc" + +} // namespace Flow +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_DIALECT_FLOW_IR_FLOWOPS_H_
diff --git a/iree/compiler/Dialect/Flow/IR/FlowOps.td b/iree/compiler/Dialect/Flow/IR/FlowOps.td new file mode 100644 index 0000000..8170ee2 --- /dev/null +++ b/iree/compiler/Dialect/Flow/IR/FlowOps.td
@@ -0,0 +1,441 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_DIALECT_FLOW_OPS +#define IREE_DIALECT_FLOW_OPS + +#ifndef IREE_DIALECT_FLOW_BASE +include "iree/compiler/Dialect/Flow/IR/FlowBase.td" +#endif // IREE_DIALECT_FLOW_BASE + +//===----------------------------------------------------------------------===// +// Variables +//===----------------------------------------------------------------------===// + +def FLOW_VariableOp : FLOW_Op<"variable", [ + Symbol, + ]> { + let summary = [{stateful variable declaration}]; + let description = [{ + Declares a persistent variable that maintains its value. + }]; + + let arguments = (ins + StrAttr:$sym_name, + // TODO(benvanik): verify AnyRankedTensor. + TypeAttr:$type, + UnitAttr:$is_mutable, + // TODO(benvanik): verify matches $type. + OptionalAttr<FlatSymbolRefAttr>:$initializer, + // TODO(benvanik): verify matches $type. + OptionalAttr<AnyAttr>:$initial_value + ); + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder<[{ + Builder *builder, OperationState &result, StringRef name, bool isMutable, + FuncOp initializer, ArrayRef<NamedAttribute> attrs = {} + }]>, + OpBuilder<[{ + Builder *builder, OperationState &result, StringRef name, bool isMutable, + Type type, Attribute initialValue, ArrayRef<NamedAttribute> attrs = {} + }]>, + OpBuilder<[{ + Builder *builder, OperationState &result, StringRef name, bool isMutable, + Type type, ArrayRef<NamedAttribute> attrs = {} + }]>, + ]; + + let verifier = [{ return verifyVariableOp(*this); }]; + + let hasCanonicalizer = 1; +} + +def FLOW_VariableLoadOp : FLOW_Op<"variable.load"> { + let summary = [{loads a value from a global variable}]; + let description = [{ + Returns a copy of the variable value. + }]; + + let arguments = (ins + FLOW_VariableRefAttr:$variable + ); + let results = (outs + AnyRankedTensor:$result + ); + + let verifier = [{ return verifyVariableLoadOp(*this); }]; + + let hasCanonicalizer = 1; +} + +def FLOW_VariableStoreOp : FLOW_Op<"variable.store"> { + let summary = [{stores a value into a global variable}]; + let description = [{ + Stores a copy of the value into a variable. + }]; + + let arguments = (ins + FLOW_VariableRefAttr:$variable, + AnyRankedTensor:$value + ); + + let verifier = [{ return verifyVariableStoreOp(*this); }]; + + let hasCanonicalizer = 1; +} + +// TODO(benvanik): additional resource variable ops (like scatter/gather). + +//===----------------------------------------------------------------------===// +// Partitioned regions +//===----------------------------------------------------------------------===// + +def FLOW_DispatchRegionOp : FLOW_PureOp<"dispatch.region"> { + let summary = [{partitioned region representing a dispatched workload}]; + let description = [{ + A closure that represents a functional dispatch unit. These perform + computations in a way that can be lowered to target executable formats such + as SPIR-V for execution. + + Ops that are identified as "dispatchable" are grouped into dispatch regions + and compatible dispatch regions are folded together. What remains outside of + the dispatch regions is the glue required to schedule the work (commonly + referred to as "host" code, even if it doesn't run on an AP). + + Dispatch regions are modeled using value semantics: it is assumed that all + arguments are read-only and that the dispatch regions themselves have no + side-effects. + }]; + + let arguments = (ins + FLOW_Workload:$workload, + Variadic<AnyType>:$args + ); + let results = (outs + Variadic<AnyType>:$results + ); + + let regions = (region AnyRegion:$body); + + let extraClassDeclaration = [{ + /// Returns the index of the args() operand in the Operation operands list. + unsigned mapArgOperandToOpOperand(unsigned i) { return i + 1; } + }]; + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder<[{ + Builder *builder, OperationState &state, ArrayRef<Type> resultTypes, + Value *workload, ArrayRef<Value *> args, + ArrayRef<NamedAttribute> attributes = {} + }]>, + ]; +} + +def FLOW_ReductionRegionOp : FLOW_PureOp<"reduction.region", [ + SameVariadicOperandSize, + // TODO(benvanik): verify operands and initial values have the same element + // types (but NOT the same shapes). + ]> { + let summary = [{partitioned reduction region}]; + let description = [{ + A closure that defines a reduction operation over one or more inputs. + Reductions are dispatches with very specific semantics around the indexing + of work. Parititoning first isolates reduction regions prior to dispatch + regions so that such semantics can be identified for folding. + + This operation follows the XLA Reduce semantics: + https://www.tensorflow.org/xla/operation_semantics#reduce + }]; + + let arguments = (ins + FLOW_Workload:$workload, + Variadic<AnyType>:$operands, + Variadic<AnyType>:$initial_values, + // TODO(benvanik): use index types instead of i32. + OptionalAttr<I32ElementsAttr>:$dimensions + ); + let results = (outs + Variadic<AnyType>:$results + ); + + let regions = (region AnyRegion:$body); + + let extraClassDeclaration = [{ + unsigned getNumReductionOperands() { + return std::distance(operands().begin(), operands().end()); + } + }]; + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder<[{ + Builder *builder, OperationState &state, ArrayRef<Type> resultTypes, + Value *workload, ArrayRef<Value *> operands, + ArrayRef<Value *> initialValues, ArrayRef<int32_t> dimensions, + ArrayRef<NamedAttribute> attributes = {} + }]>, + ]; +} + +def FLOW_WindowedReductionRegionOp : FLOW_PureOp<"windowed_reduction.region", [ + SameVariadicOperandSize, + // TODO(benvanik): verify operands and initial values have the same element + // types (but NOT the same shapes). + ]> { + let summary = [{partitioned reduction region}]; + let description = [{ + A closure that defines a reduction operation over one or more inputs. + Reductions are dispatches with very specific semantics around the indexing + of work. Parititoning first isolates reduction regions prior to dispatch + regions so that such semantics can be identified for folding. + + This operation follows the XLA ReduceWindow semantics: + https://www.tensorflow.org/xla/operation_semantics#reducewindow + }]; + + let arguments = (ins + FLOW_Workload:$workload, + Variadic<AnyType>:$operands, + Variadic<AnyType>:$initial_values, + // TODO(benvanik): use index types instead of i32. + I32ElementsAttr:$window_dimensions, + I32ElementsAttr:$window_strides, + I32ElementsAttr:$base_dilations, + I32ElementsAttr:$window_dilations, + FLOW_PaddingModeAttr:$padding_mode + ); + let results = (outs + Variadic<AnyType>:$results + ); + + let regions = (region AnyRegion:$body); + + let extraClassDeclaration = [{ + unsigned getNumReductionOperands() { + return std::distance(operands().begin(), operands().end()); + } + }]; + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder<[{ + Builder *builder, OperationState &state, ArrayRef<Type> resultTypes, + Value *workload, ArrayRef<Value *> operands, + ArrayRef<Value *> initialValues, ArrayRef<int32_t> windowDimensions, + ArrayRef<int32_t> windowStrides, ArrayRef<int32_t> baseDilations, + ArrayRef<int32_t> windowDilations, PaddingMode paddingMode, + ArrayRef<NamedAttribute> attributes = {} + }]>, + ]; +} + +def FLOW_ReturnOp : FLOW_Op<"return", [Terminator]> { + let summary = [{return from a flow.dispatch_region}]; + let description = [{ + Returns the given values from the region and back to the host code. + }]; + + let arguments = (ins + Variadic<AnyType>:$operands + ); + + let builders = [ + OpBuilder<[{ + Builder *builder, OperationState &result + }], [{ + build(builder, result, llvm::None); + }]>, + ]; +} + +//===----------------------------------------------------------------------===// +// Dispatch ops +//===----------------------------------------------------------------------===// + +def FLOW_DispatchOp : FLOW_PureOp<"dispatch"> { + let summary = [{a dispatch to an outlined dispatch region}]; + let description = [{ + Dispatches a workload to the specified executable function. + }]; + + let arguments = (ins + // TODO(benvanik): replace with SymbolRefAttr. + // TODO(benvanik): validate target is an executable. + FlatSymbolRefAttr:$executable, + FlatSymbolRefAttr:$entry_point, + FLOW_Workload:$workload, + Variadic<AnyType>:$operands + ); + let results = (outs + Variadic<AnyType>:$results + ); + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder<[{ + Builder *builder, OperationState &result, StringRef executable, + StringRef entryPoint, Value *workload, + ArrayRef<Type> results, ArrayRef<Value *> operands = {} + }], [{ + result.addOperands({workload}); + result.addOperands(operands); + result.addAttribute("executable", builder->getSymbolRefAttr(executable)); + result.addAttribute("entry_point", builder->getSymbolRefAttr(entryPoint)); + result.addTypes(results); + }]>, + ]; + + let extraClassDeclaration = [{ + FunctionType getEntryPointType(); + }]; +} + +//===----------------------------------------------------------------------===// +// Executables for outlined regions +//===----------------------------------------------------------------------===// + +def FLOW_ExecutableOp : FLOW_Op<"executable", [ + IsolatedFromAbove, + SingleBlockImplicitTerminator<"IREE::Flow::ExecutableEndOp">, + NativeOpTrait<"SymbolTable">, + Symbol, + ]> { + let summary = [{generic executable module}]; + let description = [{ + An executable module containing one or more public functions. The contents + of the functions are safe to dispatch and can be lowered further to + target-specific backend IR representations. + }]; + + let arguments = (ins + StrAttr:$sym_name + // TODO(benvanik): add compatibility and versioning attributes. + ); + + let regions = (region SizedRegion<1>:$body); + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder<[{ + Builder *builder, OperationState &state, StringRef name + }]>, + ]; + + let extraClassDeclaration = [{ + Block& getBlock() { return body().front(); } + + ::mlir::ModuleOp getInnerModule() { + return *getBlock().getOps<::mlir::ModuleOp>().begin(); + } + }]; + + let verifier = [{ return verifyExecutableOp(*this); }]; +} + +def FLOW_ExecutableEndOp : FLOW_Op<"executable_end", [ + HasParent<"IREE::Flow::ExecutableOp">, + Terminator, + ]> { + let summary = [{terminator pseudo-op for the executable op}]; + + let parser = [{ return parseRegionEndOp(parser, &result); }]; + let printer = [{ return printRegionEndOp(p, *this); }]; +} + +def FLOW_DispatchEntryOp : FLOW_Op<"dispatch.entry", [ + HasParent<"IREE::Flow::ExecutableOp">, + Symbol, + ]> { + let summary = [{defines an executable entry point for dispatch operations}]; + let description = [{ + Specifies an exported function with an externally-visible alias. Multiple + exports can reference the same internal function. + }]; + + // TODO(benvanik): add a list of all used workloads. + let arguments = (ins + StrAttr:$sym_name, + // TODO(benvanik): ref into child module. + FlatSymbolRefAttr:$function_ref, + OptionalAttr<FLOW_WorkloadAttr>:$workload, + OptionalAttr<FLOW_WorkloadAttr>:$workgroup_size + ); +} + +def FLOW_ReductionEntryOp : FLOW_Op<"reduction.entry", [ + HasParent<"IREE::Flow::ExecutableOp">, + Symbol, + ]> { + let summary = [{defines an executable entry point for reduction operations}]; + let description = [{ + Specifies an exported function with an externally-visible alias. Multiple + exports can reference the same internal function. The computation represents + a reduction operation that has additional backend-specific semantics that + need to be lowered. + + This operation follows the XLA Reduce semantics: + https://www.tensorflow.org/xla/operation_semantics#reduce + }]; + + // TODO(benvanik): add a list of all used workloads. + let arguments = (ins + // TODO(benvanik): ref into child module. + StrAttr:$sym_name, + FlatSymbolRefAttr:$function_ref, + FlatSymbolRefAttr:$apply_ref, + I32Attr:$dimension + ); +} + +def FLOW_WindowedReductionEntryOp : FLOW_Op<"windowed_reduction.entry", [ + HasParent<"IREE::Flow::ExecutableOp">, + Symbol, + ]> { + let summary = [{defines an executable entry point for reduction operations}]; + let description = [{ + Specifies an exported function with an externally-visible alias. Multiple + exports can reference the same internal function. The computation represents + a reduction operation that has additional backend-specific semantics that + need to be lowered. + + This operation follows the XLA ReduceWindow semantics: + https://www.tensorflow.org/xla/operation_semantics#reducewindow + }]; + + // TODO(benvanik): add a list of all used workloads. + let arguments = (ins + // TODO(benvanik): ref into child module. + StrAttr:$sym_name, + FlatSymbolRefAttr:$function_ref, + FlatSymbolRefAttr:$apply_ref, + I32Attr:$window_dimension, + I32Attr:$window_stride, + I32Attr:$base_dilation, + I32Attr:$window_dilation, + FLOW_PaddingModeAttr:$padding + ); +} + +//===----------------------------------------------------------------------===// +// Tensor ops +//===----------------------------------------------------------------------===// + +// TODO(benvanik): tensor casts for widening/narrowing? or rely on std? +// TODO(benvanik): DynamicUpdateSlice-equivalent? +// TODO(benvanik): structured control flow (if we want it here). + +#endif // IREE_DIALECT_FLOW_OPS
diff --git a/iree/compiler/Dialect/Flow/IR/FlowTypes.cpp b/iree/compiler/Dialect/Flow/IR/FlowTypes.cpp new file mode 100644 index 0000000..ba1ca3a --- /dev/null +++ b/iree/compiler/Dialect/Flow/IR/FlowTypes.cpp
@@ -0,0 +1,17 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/compiler/Dialect/Flow/IR/FlowTypes.h" + +#include "iree/compiler/Dialect/Flow/IR/FlowEnums.cpp.inc"
diff --git a/iree/compiler/Dialect/Flow/IR/FlowTypes.h b/iree/compiler/Dialect/Flow/IR/FlowTypes.h new file mode 100644 index 0000000..dc98f10 --- /dev/null +++ b/iree/compiler/Dialect/Flow/IR/FlowTypes.h
@@ -0,0 +1,29 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_COMPILER_DIALECT_FLOW_IR_FLOWTYPES_H_ +#define IREE_COMPILER_DIALECT_FLOW_IR_FLOWTYPES_H_ + +#include "iree/compiler/Dialect/Types.h" +#include "llvm/ADT/DenseMapInfo.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringSwitch.h" +#include "mlir/IR/TypeSupport.h" +#include "mlir/IR/Types.h" +#include "mlir/Support/LLVM.h" + +// Order matters. +#include "iree/compiler/Dialect/Flow/IR/FlowEnums.h.inc" + +#endif // IREE_COMPILER_DIALECT_FLOW_IR_FLOWTYPES_H_
diff --git a/iree/compiler/Dialect/Flow/IR/test/BUILD b/iree/compiler/Dialect/Flow/IR/test/BUILD new file mode 100644 index 0000000..beab62e --- /dev/null +++ b/iree/compiler/Dialect/Flow/IR/test/BUILD
@@ -0,0 +1,28 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//iree:build_defs.bzl", "iree_glob_lit_tests", "iree_setup_lit_package") + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +iree_setup_lit_package( + data = [ + "//iree/tools:iree-opt", + ], +) + +iree_glob_lit_tests()
diff --git a/iree/compiler/Dialect/Flow/IR/test/dispatch_ops.mlir b/iree/compiler/Dialect/Flow/IR/test/dispatch_ops.mlir new file mode 100644 index 0000000..de01602 --- /dev/null +++ b/iree/compiler/Dialect/Flow/IR/test/dispatch_ops.mlir
@@ -0,0 +1,34 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Tests printing and parsing of dispatch ops. + +// RUN: iree-opt -split-input-file %s | iree-opt | FileCheck %s --dump-input=fail + +flow.executable @ex0 { + module { + func @dispatch_fn(%arg0 : tensor<4xf32>) -> tensor<4xf32> { + return %arg0 : tensor<4xf32> + } + } + flow.dispatch.entry @dispatch_fn +} + +// CHECK-LABEL: @dispatch +func @dispatch(%arg0 : tensor<4xf32>) -> tensor<4xf32> { + %cst = constant dense<1> : tensor<3xi32> + // CHECK: %0 = flow.dispatch @ex0::@dispatch_fn[%cst : tensor<3xi32>](%arg0) : (tensor<4xf32>) -> tensor<4xf32> + %0 = flow.dispatch @ex0::@dispatch_fn[%cst : tensor<3xi32>](%arg0) : (tensor<4xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +}
diff --git a/iree/compiler/Dialect/Flow/IR/test/dispatch_regions.mlir b/iree/compiler/Dialect/Flow/IR/test/dispatch_regions.mlir new file mode 100644 index 0000000..74ef01b --- /dev/null +++ b/iree/compiler/Dialect/Flow/IR/test/dispatch_regions.mlir
@@ -0,0 +1,79 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Tests printing and parsing of dispatch region ops. + +// RUN: iree-opt -split-input-file %s | iree-opt | FileCheck %s --dump-input=fail + +// CHECK-LABEL: @singleArg +func @singleArg(%arg0 : tensor<?xf32>) { + // CHECK-NEXT: %0 = "some.shape" + // CHECK-NEXT: flow.dispatch.region[%0 : tensor<1xi32>](%arg1 = %arg0 : tensor<?xf32>) { + // CHECK-NEXT: flow.return + // CHECK-NEXT: } + %workload = "some.shape"(%arg0) : (tensor<?xf32>) -> tensor<1xi32> + flow.dispatch.region[%workload : tensor<1xi32>](%i0 = %arg0 : tensor<?xf32>) { + flow.return + } + // CHECK-NEXT: return + return +} + +// ----- + +// CHECK-LABEL: @multipleArgs +func @multipleArgs(%arg0 : tensor<?xf32>, %arg1 : tensor<?xf32>) { + // CHECK-NEXT: %0 = "some.shape" + // CHECK-NEXT: flow.dispatch.region[%0 : tensor<1xi32>](%arg2 = %arg0 : tensor<?xf32>, %arg3 = %arg1 : tensor<?xf32>) { + // CHECK-NEXT: flow.return + // CHECK-NEXT: } + %workload = "some.shape"(%arg0) : (tensor<?xf32>) -> tensor<1xi32> + flow.dispatch.region[%workload : tensor<1xi32>](%i0 = %arg0 : tensor<?xf32>, %i1 = %arg1 : tensor<?xf32>) { + flow.return + } + // CHECK-NEXT: return + return +} + +// ----- + +// CHECK-LABEL: @singleResult +func @singleResult(%arg0 : tensor<?xf32>) -> tensor<?xf32> { + // CHECK-NEXT: %0 = "some.shape" + // CHECK-NEXT: %1 = flow.dispatch.region[%0 : tensor<1xi32>](%arg1 = %arg0 : tensor<?xf32>) : tensor<?xf32> { + // CHECK-NEXT: flow.return %arg1 : tensor<?xf32> + // CHECK-NEXT: } + %workload = "some.shape"(%arg0) : (tensor<?xf32>) -> tensor<1xi32> + %ret0 = flow.dispatch.region[%workload : tensor<1xi32>](%i0 = %arg0 : tensor<?xf32>) : tensor<?xf32> { + flow.return %i0 : tensor<?xf32> + } + // CHECK-NEXT: return %1 : tensor<?xf32> + return %ret0 : tensor<?xf32> +} + +// ----- + +// CHECK-LABEL: @multipleResults +func @multipleResults(%arg0 : tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) { + // CHECK-NEXT: %0 = "some.shape" + // CHECK-NEXT: %1:2 = flow.dispatch.region[%0 : tensor<1xi32>](%arg1 = %arg0 : tensor<?xf32>) : tensor<?xf32>, tensor<?xf32> { + // CHECK-NEXT: flow.return %arg1, %arg1 : tensor<?xf32>, tensor<?xf32> + // CHECK-NEXT: } + %workload = "some.shape"(%arg0) : (tensor<?xf32>) -> tensor<1xi32> + %ret0, %ret1 = flow.dispatch.region[%workload : tensor<1xi32>](%i0 = %arg0 : tensor<?xf32>) : tensor<?xf32>, tensor<?xf32> { + flow.return %i0, %i0 : tensor<?xf32>, tensor<?xf32> + } + // CHECK-NEXT: return %1#0, %1#1 : tensor<?xf32>, tensor<?xf32> + return %ret0, %ret1 : tensor<?xf32>, tensor<?xf32> +}
diff --git a/iree/compiler/Dialect/Flow/IR/test/executable_ops.mlir b/iree/compiler/Dialect/Flow/IR/test/executable_ops.mlir new file mode 100644 index 0000000..4525af1 --- /dev/null +++ b/iree/compiler/Dialect/Flow/IR/test/executable_ops.mlir
@@ -0,0 +1,53 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Tests printing and parsing of executable/structural ops. + +// RUN: iree-opt -split-input-file %s | iree-opt | FileCheck %s --dump-input=fail + +// CHECK-LABEL: @dispatch_ex +flow.executable @dispatch_ex { + // CHECK: module { + module { + // CHECK: @dispatch0 + func @dispatch0() { + return + } + } + // CHECK: flow.dispatch.entry @dispatch0 + flow.dispatch.entry @dispatch0 + // CHECK: flow.dispatch.entry @dispatch0 as("dispatch0_alias") + flow.dispatch.entry @dispatch0 as("dispatch0_alias") +} + +// ----- + +// CHECK-LABEL: @reduction_ex +flow.executable @reduction_ex { + // CHECK: module { + module { + // CHECK: @entry + func @entry(tensor<4x8xf32>, tensor<f32>) -> tensor<4xf32> + // CHECK: @apply + func @apply(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> { + %0 = xla_hlo.add %arg0, %arg1 : tensor<f32> + return %0 : tensor<f32> + } + } + // CHECK: flow.reduction.entry @entry + // CHECK-SAME: apply(@apply) + // CHECK-SAME: as("entry_alias") + // CHECK-SAME: attributes {dimension = 1 : i32} + flow.reduction.entry @entry apply(@apply) as("entry_alias") attributes {dimension = 1 : i32} +}
diff --git a/iree/compiler/Dialect/Flow/IR/test/reduction_regions.mlir b/iree/compiler/Dialect/Flow/IR/test/reduction_regions.mlir new file mode 100644 index 0000000..ba850f9 --- /dev/null +++ b/iree/compiler/Dialect/Flow/IR/test/reduction_regions.mlir
@@ -0,0 +1,62 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Tests printing and parsing of reduction region ops. + +// RUN: iree-opt -split-input-file %s | iree-opt | FileCheck %s --dump-input=fail + +// CHECK-LABEL: @singleReduction +func @singleReduction(%arg0 : tensor<5x1xf32>) { + // CHECK: %0 = "some.shape"(%arg0) : (tensor<5x1xf32>) -> tensor<1xi32> + %workload = "some.shape"(%arg0) : (tensor<5x1xf32>) -> tensor<1xi32> + // CHECK: %1 = "some.constant"() : () -> tensor<f32> + %initialValueF = "some.constant"() : () -> tensor<f32> + // CHECK: %2 = flow.reduction.region[%0 : tensor<1xi32>](%arg0) : (tensor<5x1xf32>) -> (tensor<1xf32>) + // CHECK-NEXT: invocation((%arg1, %arg2) = %1 : tensor<f32>) { + // CHECK-NEXT: %3 = "my.add"(%arg1, %arg2) : (tensor<f32>, tensor<f32>) -> tensor<f32> + // CHECK-NEXT: flow.return %3 : tensor<f32> + // CHECK-NEXT: } {dimensions = dense<[1, 2]> : tensor<2xi32>} + %ret = flow.reduction.region[%workload : tensor<1xi32>](%arg0) : (tensor<5x1xf32>) -> (tensor<1xf32>) + invocation((%i0, %i1) = %initialValueF : tensor<f32>) { + %resultF = "my.add"(%i0, %i1) : (tensor<f32>, tensor<f32>) -> tensor<f32> + flow.return %resultF : tensor<f32> + } {dimensions = dense<[1, 2]> : tensor<2xi32>} + return +} + +// ----- + +// CHECK-LABEL: @fusedReduction +func @fusedReduction(%arg0 : tensor<5x1xf32>, %arg1 : tensor<5x1xi32>) { + // CHECK: %0 = "some.shape"(%arg0) : (tensor<5x1xf32>) -> tensor<1xi32> + %workload = "some.shape"(%arg0) : (tensor<5x1xf32>) -> tensor<1xi32> + // CHECK: %1 = "some.constant"() : () -> tensor<f32> + // CHECK: %2 = "some.constant"() : () -> tensor<i32> + %initialValueF = "some.constant"() : () -> tensor<f32> + %initialValueI = "some.constant"() : () -> tensor<i32> + // CHECK: %3:2 = flow.reduction.region[%0 : tensor<1xi32>](%arg0, %arg1) : (tensor<5x1xf32>, tensor<5x1xi32>) -> (tensor<1xf32>, tensor<1xi32>) + // CHECK-NEXT: invocation((%arg2, %arg3) = %1 : tensor<f32>, (%arg4, %arg5) = %2 : tensor<i32>) { + // CHECK-NEXT: %4 = "my.add"(%arg2, %arg3) : (tensor<f32>, tensor<f32>) -> tensor<f32> + // CHECK-NEXT: %5 = "my.add"(%arg4, %arg5) : (tensor<i32>, tensor<i32>) -> tensor<i32> + // CHECK-NEXT: flow.return %4, %5 : tensor<f32>, tensor<i32> + // CHECK-NEXT: } {dimensions = dense<[1, 2]> : tensor<2xi32>} + %ret:2 = flow.reduction.region[%workload : tensor<1xi32>](%arg0, %arg1) : (tensor<5x1xf32>, tensor<5x1xi32>) -> (tensor<1xf32>, tensor<1xi32>) + invocation((%i0, %i1) = %initialValueF : tensor<f32>, + (%i2, %i3) = %initialValueI : tensor<i32>) { + %resultF = "my.add"(%i0, %i1) : (tensor<f32>, tensor<f32>) -> tensor<f32> + %resultI = "my.add"(%i2, %i3) : (tensor<i32>, tensor<i32>) -> tensor<i32> + flow.return %resultF, %resultI : tensor<f32>, tensor<i32> + } {dimensions = dense<[1, 2]> : tensor<2xi32>} + return +}
diff --git a/iree/compiler/Dialect/Flow/IR/test/variable_folding.mlir b/iree/compiler/Dialect/Flow/IR/test/variable_folding.mlir new file mode 100644 index 0000000..7d7d818 --- /dev/null +++ b/iree/compiler/Dialect/Flow/IR/test/variable_folding.mlir
@@ -0,0 +1,46 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Tests folding and canonicalization of variable ops. + +// RUN: iree-opt -split-input-file -canonicalize %s | iree-opt | FileCheck %s --dump-input=fail + +// CHECK: flow.variable @v_initialized dense<4> : tensor<4xi32> +flow.variable @v_initialized init(@initializer) : tensor<4xi32> +func @initializer() -> tensor<4xi32> { + %0 = constant dense<4> : tensor<4xi32> + return %0 : tensor<4xi32> +} + +// ----- + +flow.variable @v_unused : tensor<4xi32> +// CHECK-LABEL: @unused_load +func @unused_load() { + // CHECK-NEXT: return + %0 = flow.variable.load @v_unused : tensor<4xi32> + return +} + +// ----- + +flow.variable @v_nop mutable : tensor<4xi32> +// CHECK-LABEL: @nop_load_store +func @nop_load_store() { + // CHECK-NEXT: return + %0 = flow.variable.load @v_nop : tensor<4xi32> + flow.variable.store @v_nop, %0 : tensor<4xi32> + return +} +
diff --git a/iree/compiler/Dialect/Flow/IR/test/variable_ops.mlir b/iree/compiler/Dialect/Flow/IR/test/variable_ops.mlir new file mode 100644 index 0000000..739ed47 --- /dev/null +++ b/iree/compiler/Dialect/Flow/IR/test/variable_ops.mlir
@@ -0,0 +1,55 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Tests printing and parsing of variable ops. + +// RUN: iree-opt -split-input-file %s | iree-opt | FileCheck %s --dump-input=fail + +// CHECK: flow.variable @v_immutable : tensor<i32> +flow.variable @v_immutable : tensor<i32> +// CHECK: flow.variable @v_mutable mutable : tensor<i32> +flow.variable @v_mutable mutable : tensor<i32> + +// ----- + +// CHECK: flow.variable @v_initialized_const dense<4> : tensor<4xi32> +flow.variable @v_initialized_const dense<4> : tensor<4xi32> + +// ----- + +// CHECK: flow.variable @v_initialized init(@initializer) : tensor<4xi32> +flow.variable @v_initialized init(@initializer) : tensor<4xi32> +func @initializer() -> tensor<4xi32> + +// ----- + +flow.variable @v_loaded : tensor<4xi32> +// CHECK-LABEL: @loaded +func @loaded() { + // CHECK-NEXT: %0 = flow.variable.load @v_loaded : tensor<4xi32> + %0 = flow.variable.load @v_loaded : tensor<4xi32> + return +} + +// ----- + +flow.variable @v_stored mutable : tensor<4xi32> +// CHECK-LABEL: @stored +func @stored() { + // CHECK-NEXT: = constant + %cst = constant dense<5> : tensor<4xi32> + // CHECK-NEXT: flow.variable.store @v_stored, %cst : tensor<4xi32> + flow.variable.store @v_stored, %cst : tensor<4xi32> + return +}
diff --git a/iree/compiler/Dialect/Flow/Transforms/AssignExecutableWorkloads.cpp b/iree/compiler/Dialect/Flow/Transforms/AssignExecutableWorkloads.cpp new file mode 100644 index 0000000..cf9c64e --- /dev/null +++ b/iree/compiler/Dialect/Flow/Transforms/AssignExecutableWorkloads.cpp
@@ -0,0 +1,129 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/compiler/Dialect/Flow/IR/FlowOps.h" +#include "llvm/ADT/StringMap.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/Utils.h" + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace Flow { + +namespace { + +struct WorkloadInfo { + SmallVector<ElementsAttr, 4> staticWorkloads; + SmallVector<Value *, 4> dynamicWorkloads; +}; + +// Finds all dispatches and records their workload attributes mapped by +// (executable ordinal, entry point ordinal). +llvm::StringMap<llvm::StringMap<WorkloadInfo>> gatherExecutableWorkloadInfos( + ModuleOp moduleOp) { + llvm::StringMap<llvm::StringMap<WorkloadInfo>> workloadInfos; + for (auto funcOp : moduleOp.getOps<FuncOp>()) { + funcOp.walk([&](DispatchOp op) { + auto &workloadInfo = workloadInfos[op.executable()][op.entry_point()]; + if (auto constantOp = + dyn_cast<ConstantOp>(op.workload()->getDefiningOp())) { + for (auto existingWorkloadAttr : workloadInfo.staticWorkloads) { + if (existingWorkloadAttr == constantOp.value()) { + return; // Already present, ignore. + } + } + workloadInfo.staticWorkloads.push_back( + constantOp.value().cast<ElementsAttr>()); + } else { + workloadInfo.dynamicWorkloads.push_back(op.workload()); + } + }); + } + return workloadInfos; +} + +// Adds attributes to the given executable entry point describing the workload +// info to the backends that will be processing them. +LogicalResult attributeExecutableEntryPointWorkload( + Operation *entryPointOp, const WorkloadInfo &workloadInfo) { + if (!workloadInfo.dynamicWorkloads.empty()) { + return entryPointOp->emitError() << "dynamic workloads not yet supported"; + } + if (workloadInfo.staticWorkloads.size() != 1) { + return entryPointOp->emitError() << "static workload sizes differ in shape"; + } + + // Easy because we just support static workloads now. + // When this code is adapted to support dynamic workloads we'll want to put + // a pair of attrs describing which dimensions may be static and which args + // have the dynamic values to reference. + entryPointOp->setAttr("workload", workloadInfo.staticWorkloads.front()); + + // Hardwire workgroup size to {32, 1, 1} + SmallVector<int32_t, 3> workGroupInfo = {32, 1, 1}; + auto workGroupAttr = DenseIntElementsAttr::get<int32_t>( + RankedTensorType::get(3, + IntegerType::get(32, entryPointOp->getContext())), + workGroupInfo); + entryPointOp->setAttr("workgroup_size", workGroupAttr); + return success(); +} + +} // namespace + +class AssignExecutableWorkloadsPass + : public ModulePass<AssignExecutableWorkloadsPass> { + public: + void runOnModule() override { + Builder builder(getModule()); + + // Find all dispatches and capture their workload information. + // We store this information by executable and then entry point ordinal. + auto executableWorkloadInfos = gatherExecutableWorkloadInfos(getModule()); + + // Process each executable with the workload information. + SymbolTable symbolTable(getModule()); + for (auto &executableIt : executableWorkloadInfos) { + auto executableOp = + symbolTable.lookup<ExecutableOp>(executableIt.first()); + for (auto &entryPointIt : executableIt.second) { + auto entryPointOp = executableOp.lookupSymbol(entryPointIt.first()); + if (failed(attributeExecutableEntryPointWorkload( + entryPointOp, entryPointIt.second))) { + return signalPassFailure(); + } + } + } + } +}; + +std::unique_ptr<OpPassBase<ModuleOp>> createAssignExecutableWorkloadsPass() { + return std::make_unique<AssignExecutableWorkloadsPass>(); +} + +static PassRegistration<AssignExecutableWorkloadsPass> pass( + "iree-flow-assign-executable-workloads", + "Assigns executable entrypoint workload attributes"); + +} // namespace Flow +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir
diff --git a/iree/compiler/Dialect/Flow/Transforms/BUILD b/iree/compiler/Dialect/Flow/Transforms/BUILD new file mode 100644 index 0000000..91308ba --- /dev/null +++ b/iree/compiler/Dialect/Flow/Transforms/BUILD
@@ -0,0 +1,51 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "Transforms", + srcs = [ + "AssignExecutableWorkloads.cpp", + "DispatchabilityAnalysis.cpp", + "FlattenTuplesInCFG.cpp", + "FoldCompatibleDispatchRegions.cpp", + "IdentifyDispatchRegions.cpp", + "IdentifyReductionRegions.cpp", + "OutlineDispatchRegions.cpp", + "OutlineReductionRegions.cpp", + "Passes.cpp", + "RematerializeDispatchConstants.cpp", + ], + hdrs = [ + "Passes.h", + ], + deps = [ + "//iree/compiler/Dialect/Flow/Analysis", + "//iree/compiler/Dialect/Flow/IR", + "//iree/compiler/Dialect/Flow/Utils", + "@llvm//:support", + "@local_config_mlir//:IR", + "@local_config_mlir//:Pass", + "@local_config_mlir//:StandardOps", + "@local_config_mlir//:Support", + "@local_config_mlir//:TransformUtils", + "@local_config_mlir//:Transforms", + "@org_tensorflow//tensorflow/compiler/mlir/xla:hlo", + ], + alwayslink = 1, +)
diff --git a/iree/compiler/Dialect/Flow/Transforms/DispatchabilityAnalysis.cpp b/iree/compiler/Dialect/Flow/Transforms/DispatchabilityAnalysis.cpp new file mode 100644 index 0000000..37c8cb0 --- /dev/null +++ b/iree/compiler/Dialect/Flow/Transforms/DispatchabilityAnalysis.cpp
@@ -0,0 +1,64 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <utility> + +#include "iree/compiler/Dialect/Flow/Analysis/Dispatchability.h" +#include "mlir/IR/Builders.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace Flow { + +class DispatchabilityAnalysisPass + : public ModulePass<DispatchabilityAnalysisPass> { + public: + DispatchabilityAnalysisPass() = default; + explicit DispatchabilityAnalysisPass( + std::shared_ptr<llvm::StringMap<FuncOp>> dispatchableFuncOps) + : dispatchableFuncOps_(std::move(dispatchableFuncOps)) {} + + void runOnModule() override { + // Force creation (or caching) of dispatchability information. + auto &dispatchability = getAnalysis<Dispatchability>(); + markAllAnalysesPreserved(); + + // Build the dispatchable func table. + if (dispatchableFuncOps_) { + dispatchability.walkDispatchableOps([&](FuncOp funcOp) { + (*dispatchableFuncOps_)[funcOp.getName()] = funcOp; + }); + } + } + + std::shared_ptr<llvm::StringMap<FuncOp>> dispatchableFuncOps_; +}; + +std::unique_ptr<OpPassBase<ModuleOp>> createDispatchabilityAnalysisPass( + std::shared_ptr<llvm::StringMap<FuncOp>> dispatchableFuncOps) { + return std::make_unique<DispatchabilityAnalysisPass>( + std::move(dispatchableFuncOps)); +} + +static PassRegistration<DispatchabilityAnalysisPass> pass( + "iree-flow-dispatchability-analysis", + "Analyzes functions to determine their dispatchability"); + +} // namespace Flow +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir
diff --git a/iree/compiler/Dialect/Flow/Transforms/FlattenTuplesInCFG.cpp b/iree/compiler/Dialect/Flow/Transforms/FlattenTuplesInCFG.cpp new file mode 100644 index 0000000..d9dc18f --- /dev/null +++ b/iree/compiler/Dialect/Flow/Transforms/FlattenTuplesInCFG.cpp
@@ -0,0 +1,336 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/iterator_range.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Transforms/Utils.h" +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace Flow { + +namespace { + +// Given a set of types, unpack to a list of a types, removing all tuples. +void untupleTypes(llvm::ArrayRef<Type> types, + llvm::SmallVectorImpl<Type> *newTypes) { + for (auto &type : types) { + if (type.isa<TupleType>()) { + untupleTypes(type.dyn_cast<TupleType>().getTypes(), newTypes); + } else { + newTypes->push_back(type); + } + } +} + +Value *processTuple(Type type, Location loc, Block *block, OpBuilder &builder) { + if (!type.isa<TupleType>()) { + return block->addArgument(type); + } + + auto tupleType = type.dyn_cast<TupleType>(); + llvm::SmallVector<Value *, 4> values; + values.reserve(tupleType.size()); + for (auto subtype : tupleType.getTypes()) { + values.push_back(processTuple(subtype, loc, block, builder)); + } + + return builder.create<xla_hlo::TupleOp>(loc, tupleType, values); +} + +void copyOperationAttrs(Operation *oldOp, Operation *newOp) { + for (const auto &oldAttr : oldOp->getAttrs()) { + newOp->setAttr(oldAttr.first, oldAttr.second); + } +} + +bool recursiveUntuple(Value *value, Location loc, OpBuilder &builder, + BlockAndValueMapping *mapping, + llvm::SmallVectorImpl<Value *> *newValues) { + Type type = value->getType(); + // We can return the value as is. + if (!type.isa<TupleType>()) { + newValues->push_back(value); + return false; + } + + TupleType tupleType = type.dyn_cast<TupleType>(); + for (int i = 0; i < tupleType.size(); i++) { + auto subType = tupleType.getType(i); + + auto elementOp = builder.create<xla_hlo::GetTupleElementOp>( + loc, subType, value, builder.getI32IntegerAttr(i)); + recursiveUntuple(elementOp.getResult(), loc, builder, mapping, newValues); + } + + return false; +} + +Value *recursiveRetuple( + Type oldType, llvm::iterator_range<Operation::result_iterator> *values, + OpBuilder &builder, Location loc) { + if (!oldType.isa<TupleType>()) { + Value *returnValue = *values->begin(); + *values = llvm::iterator_range<Operation::result_iterator>( + values->begin() + 1, values->end()); + return returnValue; + } + + TupleType tupleType = oldType.dyn_cast<TupleType>(); + llvm::SmallVector<Value *, 10> subValues; + for (auto subtype : tupleType.getTypes()) { + subValues.push_back(recursiveRetuple(subtype, values, builder, loc)); + } + + return builder.create<xla_hlo::TupleOp>(loc, tupleType, subValues) + .getResult(); +} + +template <typename T> +bool untupleAndLookupValues(T values, llvm::SmallVectorImpl<Value *> *newValues, + OpBuilder &builder, Location loc, + BlockAndValueMapping *mapping) { + for (auto operand : values) { + auto newValue = mapping->lookupOrNull(operand); + if (!newValue) { + return true; + } + + recursiveUntuple(newValue, loc, builder, mapping, newValues); + } + + return false; +} + +bool convertReturnOp(ReturnOp *op, OpBuilder &builder, + BlockAndValueMapping *mapping) { + llvm::SmallVector<Value *, 10> newOperands; + if (untupleAndLookupValues(op->getOperands(), &newOperands, builder, + op->getLoc(), mapping)) { + return true; + } + + builder.create<ReturnOp>(op->getLoc(), newOperands); + return false; +} + +bool convertCallOp(CallOp *oldOp, OpBuilder &builder, + BlockAndValueMapping *mapping) { + llvm::SmallVector<Value *, 4> newArgs; + if (untupleAndLookupValues(oldOp->getOperands(), &newArgs, builder, + oldOp->getLoc(), mapping)) { + return true; + } + + SmallVector<Type, 4> originalTypes(oldOp->getOperation()->getResultTypes()); + SmallVector<Type, 4> resultTypes; + untupleTypes(originalTypes, &resultTypes); + auto newOp = builder.create<CallOp>(oldOp->getLoc(), oldOp->getCallee(), + resultTypes, newArgs); + copyOperationAttrs(oldOp->getOperation(), newOp.getOperation()); + + auto newResults = newOp.getResults(); + for (auto oldResult : oldOp->getResults()) { + llvm::SmallVector<Value *, 10> subValues; + auto newResult = recursiveRetuple(oldResult->getType(), &newResults, + builder, oldOp->getLoc()); + mapping->map(oldResult, newResult); + } + + return false; +} + +bool convertIndirectCallOp(CallIndirectOp *oldOp, OpBuilder &builder, + BlockAndValueMapping *mapping) { + llvm::SmallVector<Value *, 4> newArgs; + if (untupleAndLookupValues(oldOp->getOperands(), &newArgs, builder, + oldOp->getLoc(), mapping)) { + return true; + } + + auto newOp = builder.create<CallIndirectOp>(oldOp->getLoc(), + oldOp->getCallee(), newArgs); + copyOperationAttrs(oldOp->getOperation(), newOp.getOperation()); + + for (int i = 0; i < newOp.getNumResults(); ++i) { + auto *oldResult = oldOp->getResult(i); + auto *newResult = newOp.getResult(i); + mapping->map(oldResult, newResult); + } + + return false; +} + +bool convertBranchOp(BranchOp *oldOp, OpBuilder &builder, + BlockAndValueMapping *mapping) { + llvm::SmallVector<Value *, 4> newArgs; + if (untupleAndLookupValues(oldOp->getOperands(), &newArgs, builder, + oldOp->getLoc(), mapping)) { + return true; + } + + auto newOp = builder.create<BranchOp>( + oldOp->getLoc(), mapping->lookupOrNull(oldOp->getDest()), newArgs); + + copyOperationAttrs(oldOp->getOperation(), newOp.getOperation()); + + return false; +} + +bool convertCondBranchOp(CondBranchOp *oldOp, OpBuilder &builder, + BlockAndValueMapping *mapping) { + llvm::SmallVector<Value *, 4> trueArgs; + if (untupleAndLookupValues(oldOp->getTrueOperands(), &trueArgs, builder, + oldOp->getLoc(), mapping)) { + return true; + } + + llvm::SmallVector<Value *, 4> falseArgs; + if (untupleAndLookupValues(oldOp->getFalseOperands(), &falseArgs, builder, + oldOp->getLoc(), mapping)) { + return true; + } + + auto newOp = builder.create<CondBranchOp>( + oldOp->getLoc(), mapping->lookupOrNull(oldOp->getCondition()), + mapping->lookupOrNull(oldOp->getTrueDest()), trueArgs, + mapping->lookupOrNull(oldOp->getFalseDest()), falseArgs); + + copyOperationAttrs(oldOp->getOperation(), newOp.getOperation()); + + return false; +} + +bool convertOperation(Operation *op, OpBuilder &builder, + BlockAndValueMapping *mapping) { + if (auto returnOp = dyn_cast<ReturnOp>(op)) { + return convertReturnOp(&returnOp, builder, mapping); + } else if (auto callOp = dyn_cast<CallOp>(op)) { + return convertCallOp(&callOp, builder, mapping); + } else if (auto callIndirectOp = dyn_cast<CallIndirectOp>(op)) { + return convertIndirectCallOp(&callIndirectOp, builder, mapping); + } else if (auto branchOp = dyn_cast<BranchOp>(op)) { + return convertBranchOp(&branchOp, builder, mapping); + } else if (auto condBranchOp = dyn_cast<CondBranchOp>(op)) { + return convertCondBranchOp(&condBranchOp, builder, mapping); + } + + builder.clone(*op, *mapping); + return false; +} + +bool convertFunction(FuncOp oldFunction, FuncOp newFunction) { + OpBuilder builder(newFunction.getBody()); + BlockAndValueMapping mapping; + + for (auto attr : oldFunction.getAttrs()) { + if (attr.first != oldFunction.getTypeAttrName()) { + newFunction.setAttr(attr.first, attr.second); + } + } + + newFunction.getBlocks().clear(); + for (auto &oldBlock : oldFunction.getBlocks()) { + auto *newBlock = builder.createBlock(&newFunction.getBody()); + for (auto *oldArg : oldBlock.getArguments()) { + llvm::SmallVector<Type, 4> newTypes; + untupleTypes(oldArg->getType(), &newTypes); + + Value *newTuple = processTuple(oldArg->getType(), oldFunction.getLoc(), + newBlock, builder); + if (!newTuple) { + return true; + } + + mapping.map(oldArg, newTuple); + } + mapping.map(&oldBlock, newBlock); + } + + // Convert all ops in the blocks. + for (auto &oldBlock : oldFunction.getBlocks()) { + builder.setInsertionPointToEnd(mapping.lookupOrNull(&oldBlock)); + for (auto &oldOp : oldBlock.getOperations()) { + if (convertOperation(&oldOp, builder, &mapping)) { + return true; + } + } + } + + return false; +} + +class FlattenTuplesInCFGPass : public ModulePass<FlattenTuplesInCFGPass> { + public: + void runOnModule() override { + auto module = getModule(); + Builder builder(module.getContext()); + + // Build a list of (oldFunction, newFunction) for all functions we need to + // replace. This will ensure that when we go to convert function bodies we + // have only new functions defined. + std::vector<std::pair<FuncOp, FuncOp>> convertedFunctions; + + for (auto oldFunction : module.getOps<FuncOp>()) { + auto oldFunctionType = oldFunction.getType(); + llvm::SmallVector<Type, 10> newInputTypes; + untupleTypes(oldFunctionType.getInputs(), &newInputTypes); + + llvm::SmallVector<Type, 10> newResultTypes; + untupleTypes(oldFunctionType.getResults(), &newResultTypes); + + auto newFunctionType = + builder.getFunctionType(newInputTypes, newResultTypes); + auto newFunction = + FuncOp::create(oldFunction.getLoc(), oldFunction.getName(), + newFunctionType, oldFunction.getDialectAttrs()); + convertedFunctions.push_back({oldFunction, newFunction}); + + // Perform the actual body conversion now that we have proper signatures. + if (convertFunction(oldFunction, newFunction)) { + return signalPassFailure(); + } + } + + // Replace functions in the module. + for (auto &pair : convertedFunctions) { + pair.first.erase(); + module.push_back(pair.second); + } + } +}; + +} // namespace + +std::unique_ptr<OpPassBase<ModuleOp>> createFlattenTuplesInCFGPass() { + return std::make_unique<FlattenTuplesInCFGPass>(); +} + +static PassRegistration<FlattenTuplesInCFGPass> pass( + "iree-flow-flatten-tuples-in-cfg", + "Convert functions to remove tuples from method signatures and blocks"); + +} // namespace Flow +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir
diff --git a/iree/compiler/Dialect/Flow/Transforms/FoldCompatibleDispatchRegions.cpp b/iree/compiler/Dialect/Flow/Transforms/FoldCompatibleDispatchRegions.cpp new file mode 100644 index 0000000..1ba59db --- /dev/null +++ b/iree/compiler/Dialect/Flow/Transforms/FoldCompatibleDispatchRegions.cpp
@@ -0,0 +1,382 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/compiler/Dialect/Flow/IR/FlowOps.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/Utils.h" +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace Flow { + +namespace { + +// Replaces |returnOp| with a clone including |newOperands| appended. +LogicalResult appendReturnOperands(ReturnOp returnOp, + ArrayRef<Value *> newOperands) { + // Insert prior to the original return. + OpBuilder builder(returnOp); + + // Clone with new args. + SmallVector<Value *, 8> operands; + operands.reserve(returnOp.getNumOperands() + newOperands.size()); + operands.append(returnOp.operand_begin(), returnOp.operand_end()); + operands.append(newOperands.begin(), newOperands.end()); + builder.create<ReturnOp>(returnOp.getLoc(), operands); + + // Remove original. + returnOp.erase(); + + return success(); +} + +// Replaces |regionOp| with a clone including |newArgs| and |newResults|. +DispatchRegionOp appendRegionArgsAndResults(DispatchRegionOp ®ionOp, + ArrayRef<Value *> newArgs, + ArrayRef<Value *> newResults, + Location otherLoc) { + // Insert prior to the original region. + OpBuilder builder(regionOp); + + // Location is original region + new region location (both probably fused). + SmallVector<Location, 2> fusedLocs = {regionOp.getLoc(), otherLoc}; + auto fusedLoc = FusedLoc::get(fusedLocs, regionOp.getContext()); + + // Clone with new results. + SmallVector<Value *, 8> operands; + operands.append(regionOp.args().begin(), regionOp.args().end()); + operands.append(newArgs.begin(), newArgs.end()); + SmallVector<Type, 8> resultTypes; + resultTypes.append(regionOp.result_type_begin(), regionOp.result_type_end()); + for (auto *newResult : newResults) { + resultTypes.push_back(newResult->getType()); + } + auto newRegionOp = builder.create<DispatchRegionOp>( + fusedLoc, resultTypes, regionOp.workload(), operands, + regionOp.getAttrs()); + newRegionOp.body().takeBody(regionOp.body()); + + // Replace uses of original values with the new values. + for (int i = 0; i < regionOp.getNumResults(); ++i) { + regionOp.getResult(i)->replaceAllUsesWith(newRegionOp.getResult(i)); + } + + // Erase the original region. + regionOp.erase(); + + return newRegionOp; +} + +// Removes results that are not used from the dispatch region. +// Returns the new operation. There may be unused ops in the region but DCE +// should take care of that later. +DispatchRegionOp removeUnusedResults(DispatchRegionOp regionOp) { + // Find return value within the region. + auto ®ionBlock = regionOp.body().getBlocks().front(); + auto returnOp = dyn_cast<ReturnOp>(regionBlock.getTerminator()); + if (!returnOp) { + regionBlock.getParent()->getParentOfType<FuncOp>().emitError() + << "block does not contain an flow.return op"; + } + + // Calculate new return values. + SmallVector<Type, 8> newReturnTypes; + SmallVector<Value *, 8> newReturnValues; + SmallVector<Value *, 8> newRegionResults; + for (int i = 0; i < returnOp.getNumOperands(); ++i) { + auto *resultValue = regionOp.getResult(i); + if (!resultValue->use_empty()) { + // Still has uses so we will preserve it. + newReturnTypes.push_back(resultValue->getType()); + newReturnValues.push_back(returnOp.getOperand(i)); + newRegionResults.push_back(resultValue); + } + } + + // Update return op operands. We can do this in-place as we are only shrinking + // the list. + returnOp.getOperation()->setOperands(newReturnValues); + + // Insert prior to the original region. + OpBuilder builder(regionOp); + + // Clone with new results. + auto newRegionOp = builder.create<DispatchRegionOp>( + regionOp.getLoc(), newReturnTypes, regionOp.workload(), + llvm::to_vector<8>(regionOp.args()), regionOp.getAttrs()); + newRegionOp.body().takeBody(regionOp.body()); + + // Replace uses of original values with the new values. + for (int i = 0; i < newRegionResults.size(); ++i) { + newRegionResults[i]->replaceAllUsesWith(newRegionOp.getResult(i)); + } + + // Erase the original region. + regionOp.erase(); + + return newRegionOp; +} + +// Returns true if |lhs| and |rhs| have either an identical workload or one that +// is compatible. +bool areDispatchRegionWorkloadsCompatible(DispatchRegionOp &lhs, + DispatchRegionOp &rhs) { + // TODO(benvanik): more sophisticated checking; right now it's just identical. + return lhs.workload() == rhs.workload(); +} + +// Returns true if |value| depends in any way on |op| through any path. +bool doesValueDependOnOperation(Value *value, Operation *op) { + if (!value->getDefiningOp()) { + return false; + } else if (value->getDefiningOp() == op) { + return true; + } else if (value->getDefiningOp()->getBlock() == op->getBlock() && + value->getDefiningOp()->isBeforeInBlock(op)) { + // Can't depend on |op| as it is defined prior to it. + return false; + } + for (auto *operand : value->getDefiningOp()->getOperands()) { + if (doesValueDependOnOperation(operand, op)) { + return true; + } + } + return true; +} + +// Returns true if |rhs| transitively depends on any out of |lhs|. +// |rhs| may depend directly on the results of |lhs| but no other ops in the +// parent block will use the results prior to |rhs|. +bool areDispatchRegionsTransitivelyDependent(DispatchRegionOp &lhs, + DispatchRegionOp &rhs) { + for (auto *arg : rhs.args()) { + if (arg->getDefiningOp() != lhs && doesValueDependOnOperation(arg, lhs)) { + // Transitively dependent - boo - can't merge yet. + return true; + } + } + return false; +} + +// Returns true if the dispatch region contains only a single block. +// This is because our merge isn't very smart and will not preserve the CFG +// right now. We can fix this when needed. +bool isDispatchRegionMergable(DispatchRegionOp ®ionOp) { + // Disallow merging of dispatch regions containing matmuls and other big ops. + // We do this to allow backends to lower the big op as entirely isolated such + // that substituting library calls is easier. + for (auto &block : regionOp.body().getBlocks()) { + for (auto &op : block) { + // TODO(b/144530470): replace with tablegen attributes/interfaces. + if (isa<xla_hlo::DotOp>(op) || isa<xla_hlo::ConvOp>(op)) { + return false; + } + } + } + return regionOp.body().getBlocks().size() == 1; +} + +// Merges |rhs| into |lhs| and returns the new |lhs| op. +// Precondition: !areDispatchRegionsTransitivelyDependent +DispatchRegionOp mergeDispatchRegions(DispatchRegionOp &lhs, + DispatchRegionOp &rhs) { + auto &lhsBlock = lhs.body().front(); + auto &rhsBlock = rhs.body().front(); + + // Find the values used as return values in the lhs. + // We'll need to replace the uses in rhs with these. + auto lhsReturnOp = cast<ReturnOp>(lhsBlock.getTerminator()); + SmallVector<Value *, 8> lhsReturnValues; + lhsReturnValues.reserve(lhsReturnOp.getNumOperands()); + lhsReturnValues.append(lhsReturnOp.operand_begin(), + lhsReturnOp.operand_end()); + + // Find the values used as return values in the rhs. + // We'll add these to the results of the lhs region. + auto rhsReturnOp = cast<ReturnOp>(rhsBlock.getTerminator()); + SmallVector<Value *, 8> rhsReturnValues; + rhsReturnValues.reserve(rhsReturnOp.getNumOperands()); + rhsReturnValues.append(rhsReturnOp.operand_begin(), + rhsReturnOp.operand_end()); + + // Compute new args. + BlockAndValueMapping mapping; + SmallVector<Value *, 8> newArgs; + auto lhsArgs = llvm::to_vector<8>(lhs.args()); + auto rhsArgs = llvm::to_vector<8>(rhs.args()); + for (int rhsOpIdx = 0; rhsOpIdx < rhsArgs.size(); ++rhsOpIdx) { + bool didElide = false; + // Find if the rhs arg already exists on the lhs and dedupe. + for (int lhsOpIdx = 0; lhsOpIdx < lhsArgs.size(); ++lhsOpIdx) { + if (rhsArgs[rhsOpIdx] == lhsArgs[lhsOpIdx]) { + mapping.map(rhsBlock.getArgument(rhsOpIdx), + lhsBlock.getArgument(lhsOpIdx)); + didElide = true; + break; + } + } + // Find if the arg has a direct dependency on the results of the lhs. + for (int lhsResultIdx = 0; lhsResultIdx < lhs.getNumResults(); + ++lhsResultIdx) { + if (rhsArgs[rhsOpIdx] == lhs.getResult(lhsResultIdx)) { + // Direct dependency; can elide. We'll skip adding it to the new region + // args and instead just remap it later. + mapping.map(rhsBlock.getArgument(rhsOpIdx), + lhsReturnValues[lhsResultIdx]); + didElide = true; + break; + } + } + if (!didElide) { + // Add to the lhs block. + auto *oldArg = rhs.getOperand(rhsOpIdx + 1); + auto *newArg = lhsBlock.addArgument(oldArg->getType()); + mapping.map(rhsBlock.getArgument(rhsOpIdx), newArg); + newArgs.push_back(oldArg); + } + } + + OpBuilder regionBuilder(&lhsBlock); + + // Copy ops (replacing any args as needed). + // Note that we need to insert prior to the terminator. + regionBuilder.setInsertionPoint(lhsReturnOp); + for (auto &op : rhsBlock) { + // Note that this updates the mapping with the new values (so at the end + // we have those new values). + // + // We avoid the return op here as we have already merged it above. + if (!op.isKnownTerminator()) { + regionBuilder.clone(op, mapping); + } + } + + // Compute new results and add to both region and return op. + SmallVector<Value *, 8> newResults; + for (auto *rhsResult : rhsReturnValues) { + newResults.push_back(mapping.lookupOrDefault(rhsResult)); + } + if (failed(appendReturnOperands(lhsReturnOp, newResults))) { + return nullptr; + } + auto newRegionOp = + appendRegionArgsAndResults(lhs, newArgs, newResults, rhs.getLoc()); + + // Replace uses of original values with the new values. + for (int i = 0; i < rhs.getNumResults(); ++i) { + rhs.getResult(i)->replaceAllUsesWith( + newRegionOp.getResult(lhsReturnValues.size() + i)); + } + + // Remove rhs region. + rhs.erase(); + + // Remove results from the lhs that aren't used anymore as they may have been + // elided when we merged as only the rhs was using them. + newRegionOp = removeUnusedResults(newRegionOp); + + return newRegionOp; +} + +// Merges multiple dispatch regions within a block into the same region, +// if possible. Operations may be reordered if it's possible to merge more while +// still obeying data dependencies. +LogicalResult mergeBlockDispatchRegions(FuncOp func, Block *parentBlock) { + SmallVector<DispatchRegionOp, 8> mergableRegions; + for (auto &op : *parentBlock) { + if (auto regionOp = dyn_cast<DispatchRegionOp>(op)) { + if (isDispatchRegionMergable(regionOp)) { + mergableRegions.push_back(regionOp); + } else { + regionOp.emitRemark( + "unable to merge into following dispatch region; " + "contains non-trivial control flow"); + } + } + } + for (int i = 0; i < mergableRegions.size(); ++i) { + if (!mergableRegions[i]) continue; + auto &lhs = mergableRegions[i]; + for (int j = i + 1; j < mergableRegions.size(); ++j) { + if (!mergableRegions[j]) continue; + auto &rhs = mergableRegions[j]; + if (!areDispatchRegionWorkloadsCompatible(lhs, rhs) || + areDispatchRegionsTransitivelyDependent(lhs, rhs)) { + continue; + } + if (!isDispatchRegionMergable(rhs)) { + // TODO(b/134675461): support non-trivial control flow. + rhs.emitRemark( + "unable to merge into previous dispatch region; " + "contains non-trivial control flow"); + } + mergableRegions[i] = mergeDispatchRegions(lhs, rhs); + if (!mergableRegions[i]) { + return failure(); + } + mergableRegions[j] = nullptr; + --i; // Try again to see if there are subsequent regions to merge. + break; + } + } + + return success(); +} + +} // namespace + +// Identifies dispatch regions that have compatible workloads and folds them. +// This relies on CSE having deduped workloads to simplify the logic to simply +// looking for dispatch regions using the same values. +class FoldCompatibleDispatchRegionsPass + : public FunctionPass<FoldCompatibleDispatchRegionsPass> { + public: + void runOnFunction() override { + auto func = getFunction(); + for (auto &block : func) { + if (failed(mergeBlockDispatchRegions(func, &block))) { + return signalPassFailure(); + } + } + } +}; + +std::unique_ptr<OpPassBase<FuncOp>> createFoldCompatibleDispatchRegionsPass() { + return std::make_unique<FoldCompatibleDispatchRegionsPass>(); +} + +static PassRegistration<FoldCompatibleDispatchRegionsPass> pass( + "iree-flow-fold-compatible-dispatch-regions", + "Folds dispatch regions that have compatible workloads"); + +} // namespace Flow +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir
diff --git a/iree/compiler/Dialect/Flow/Transforms/IdentifyDispatchRegions.cpp b/iree/compiler/Dialect/Flow/Transforms/IdentifyDispatchRegions.cpp new file mode 100644 index 0000000..176dd19 --- /dev/null +++ b/iree/compiler/Dialect/Flow/Transforms/IdentifyDispatchRegions.cpp
@@ -0,0 +1,276 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <algorithm> + +#include "iree/compiler/Dialect/Flow/Analysis/Dispatchability.h" +#include "iree/compiler/Dialect/Flow/IR/FlowOps.h" +#include "iree/compiler/Dialect/Flow/Utils/DispatchUtils.h" +#include "iree/compiler/Dialect/Flow/Utils/WorkloadUtils.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/Utils.h" +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace Flow { + +namespace { + +// Returns true if the given |op| can be dispatched in all cases. +// Other passes may handle special cases of these ops but this initial +// identification is conservative. +bool isDispatchableOp(Operation *op, Dispatchability &dispatchability) { + // TODO(b/144530470): replace with tablegen attributes/interfaces. + if (op->getDialect() && op->getDialect()->getNamespace().startswith("flow")) { + // Ignore things we've already produced as they should only relate to + // sequencer operations. + return false; + } else if (op->isKnownTerminator()) { + // Currently we skip all terminators as we want to leave them in the block + // to keep it valid. Future folding passes may take care of them if they are + // worth bringing into the dispatch region. + return false; + } else if (auto callOp = dyn_cast<CallOp>(op)) { + return dispatchability.isDispatchable(callOp.getCallee()); + } else if (isa<CallIndirectOp>(op)) { + // Indirect calls are not supported in dispatch code. + return false; + } else if (isa<ConstantOp>(op)) { + // Constants are handled in the RematerializeDispatchConstants pass. + // We do that independently so that we can more easily see the use of + // constants across all dispatches instead of just on an individual basis + // as we do here. + return false; + } else if (isa<xla_hlo::DynamicUpdateSliceOp>(op)) { + // TODO(benvanik): lower these to the sequencer dialect prior to ID'ing. + return false; + } + return true; +} + +// Returns true if the given |op| can have other ops fused into it. +// This is sketchy and it'd be nice to define this as an op property instead. +// +// What we are looking for in foldable ops is whether the execution of the op +// when fused has some possible benefit (or at least, a non-negative cost). +// Eventually we want to allow backends to vote on this and allow multiple +// folding strategies within the same executable. For now we just hardcode what +// we know for the ops we have. +// +// Preconditions: isDispatchableOp(op) == true. +bool isFusionRootOp(Operation *op) { + // TODO(b/144530470): replace with tablegen attributes/interfaces. + if (isa<xla_hlo::DotOp>(op) || isa<xla_hlo::ConvOp>(op)) { + // We have hand-written kernels for these right now we want to stand alone. + // When we do a bit more magic we should allow these ops to fold. + return false; + } + return true; +} + +// Returns true if the given |op| can be fused into other ops. +// +// Ops that perform narrowing on shapes (such as reduction ops) should not +// generally be fused with other downstream ops (probably...). This avoids +// potential oversampling and indexing issues and allows backends to perform +// more efficient rooted cascading reduction dispatches. +// +// Preconditions: isDispatchableOp(op) == true. +bool isFusableOp(Operation *op) { + // TODO(b/144530470): replace with tablegen attributes/interfaces. + if (isa<xla_hlo::DotOp>(op) || isa<xla_hlo::ConvOp>(op)) { + return false; + } else if (isa<xla_hlo::ReduceOp>(op)) { + // Reduction is usually a dedicated root operation - we can shove things in + // the front of it but not behind. + return false; + } + return true; +} + +// Puts all of the |unsortedOps| into |sortedOps| in an arbitrary topological +// order. +// https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search +// +// Preconditions: |unsortedOps| has no cycles within the set of ops. +std::vector<Operation *> sortOpsTopologically( + const llvm::SetVector<Operation *> &unsortedOps) { + llvm::SetVector<Operation *> unmarkedOps; + unmarkedOps.insert(unsortedOps.begin(), unsortedOps.end()); + llvm::SetVector<Operation *> markedOps; + + using VisitFn = std::function<void(Operation * op)>; + VisitFn visit = [&](Operation *op) { + if (markedOps.count(op) > 0) return; + for (auto *result : op->getResults()) { + for (auto *user : result->getUsers()) { + // Don't visit ops not in our set. + if (unsortedOps.count(user) == 0) continue; + visit(user); + } + } + markedOps.insert(op); + }; + + while (!unmarkedOps.empty()) { + auto *op = unmarkedOps.pop_back_val(); + visit(op); + } + + auto sortedOps = markedOps.takeVector(); + std::reverse(sortedOps.begin(), sortedOps.end()); + return sortedOps; +} + +// Recursively traverses the IR DAG along the operand edges to find ops we are +// able to fuse and appends them to |subgraph|. +void gatherFusionOps(Operation *op, Dispatchability &dispatchability, + llvm::SetVector<Operation *> *subgraph) { + // Skip ops that are used outside of the subgraph we are building. + for (auto *result : op->getResults()) { + if (result->use_empty() || result->hasOneUse()) continue; + for (auto *user : result->getUsers()) { + if (subgraph->count(user) == 0) { + // Op that consumes the result is not (yet) in the subgraph. + // For now we'll ignore these as it may represent a fork that we don't + // want to join too early. + return; + } + } + } + + // Walk backward up to ops providing our input operands. + for (auto *operand : op->getOperands()) { + auto *sourceOp = operand->getDefiningOp(); + if (!sourceOp) continue; + if (subgraph->count(sourceOp) == 0) { + if (isDispatchableOp(sourceOp, dispatchability) && + isFusableOp(sourceOp)) { + gatherFusionOps(sourceOp, dispatchability, subgraph); + } + } + } + + subgraph->insert(op); +} + +// Finds all ops that can be fused together with the given |rootOp| by searching +// backwards in the op order through input edges. +// Returns a topologically sorted list of all fused ops with |rootOp| at the +// end. +std::vector<Operation *> findFusionSubgraphFromRoot( + Operation *rootOp, Dispatchability &dispatchability) { + if (!isFusionRootOp(rootOp)) { + return {rootOp}; + } + llvm::SetVector<Operation *> subgraph; + subgraph.insert(rootOp); + gatherFusionOps(rootOp, dispatchability, &subgraph); + return sortOpsTopologically(subgraph); +} + +// Identifies ranges of dispatchable ops and moves them into dispatch regions. +LogicalResult identifyBlockDispatchRegions(FuncOp func, Block *block, + Dispatchability &dispatchability) { + // Fixed point iteration until we can no longer fuse anything. + bool didFindAnyNewRegions; + do { + // Iterate in reverse so we root further along in the op list. + didFindAnyNewRegions = false; + for (auto &rootOp : llvm::reverse(*block)) { + if (!isDispatchableOp(&rootOp, dispatchability)) { + // Op should remain at the sequencer level. + continue; + } + + // Attempt to find all operations, including rootOp, that can be fused. + // The ops will be sorted in topological order with rootOp as the last op. + // Worst case we may end up with a subgraph of only the rootOp. + auto fusedSubgraph = findFusionSubgraphFromRoot(&rootOp, dispatchability); + + // Compute the workload based on the output shape. + // When variadic all output shapes match so we can just take the first. + auto *workload = calculateWorkload( + &rootOp, rootOp.getResult(0)->getType().cast<ShapedType>()); + + // Try to build a dispatch region from this root. + if (failed(buildDispatchRegion(func, block, workload, fusedSubgraph))) { + return failure(); + } + + // Successfully created a dispatch region from the ops and we must now + // start over again as we've likely trashed the whole block structure. + didFindAnyNewRegions = true; + break; + } + } while (didFindAnyNewRegions); + return success(); +} + +} // namespace + +// Identifies dispatchable ops and moves them into iree.dispatch_regions. +// Some ops, such as call, will be deferred until following passes. +class IdentifyDispatchRegionsPass + : public FunctionPass<IdentifyDispatchRegionsPass> { + public: + void runOnFunction() override { + // NOTE: we require the DispatchabilityAnalysisPass to have run first. + auto dispatchability = getCachedParentAnalysis<Dispatchability>(); + if (!dispatchability.hasValue()) { + getFunction().emitError() + << "dispatchability analysis not performed " + "on module; run -iree-flow-dispatchability-analysis first"; + return signalPassFailure(); + } + + for (auto &block : getFunction()) { + if (failed(identifyBlockDispatchRegions(getFunction(), &block, + dispatchability.getValue()))) { + return signalPassFailure(); + } + } + } +}; + +std::unique_ptr<OpPassBase<FuncOp>> createIdentifyDispatchRegionsPass() { + return std::make_unique<IdentifyDispatchRegionsPass>(); +} + +static PassRegistration<IdentifyDispatchRegionsPass> pass( + "iree-flow-identify-dispatch-regions", + "Conservatively identifies dispatch regions in functions"); + +} // namespace Flow +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir
diff --git a/iree/compiler/Dialect/Flow/Transforms/IdentifyReductionRegions.cpp b/iree/compiler/Dialect/Flow/Transforms/IdentifyReductionRegions.cpp new file mode 100644 index 0000000..a1ac50e --- /dev/null +++ b/iree/compiler/Dialect/Flow/Transforms/IdentifyReductionRegions.cpp
@@ -0,0 +1,168 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <algorithm> + +#include "iree/compiler/Dialect/Flow/IR/FlowOps.h" +#include "iree/compiler/Dialect/Flow/Utils/DispatchUtils.h" +#include "iree/compiler/Dialect/Flow/Utils/WorkloadUtils.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/Utils.h" +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace Flow { + +namespace { + +// Builds a new iree.reduction_region with the given |invocationRegion|. +// The new region will be inserted after |originalOp|. +// +// All |invocationRegion| ops must be compatible with the |workload| specified +// as they will all be dispatched with the same workgroup structure. The +// |invocationRegion| will not be modified. +LogicalResult buildReductionRegion(Operation *originalOp, + ArrayRef<Value *> operands, + ArrayRef<Value *> initialValues, + ArrayRef<int32_t> dimensions, + Region &invocationRegion) { + OpBuilder parentBuilder(originalOp); + + // Compute the workload based on the output shape. + // When variadic all output shapes match so we can just take the first. + auto *workload = calculateWorkload( + originalOp, originalOp->getResult(0)->getType().cast<ShapedType>()); + + // Build the region op and add it to the parent block. + SmallVector<Type, 4> resultTypes{originalOp->getResultTypes()}; + auto reductionRegionOp = parentBuilder.create<ReductionRegionOp>( + originalOp->getLoc(), resultTypes, workload, operands, initialValues, + dimensions); + + // Create the block and setup the arg mapping for captured values. + BlockAndValueMapping mapping; + invocationRegion.cloneInto(&reductionRegionOp.body(), mapping); + + // Replace xla_hlo.return -> flow.return. + OpBuilder regionBuilder(reductionRegionOp.body()); + reductionRegionOp.walk([&](xla_hlo::ReturnOp returnOp) { + regionBuilder.setInsertionPoint(returnOp); + SmallVector<Value *, 4> returnValues(returnOp.getOperands()); + regionBuilder.create<ReturnOp>(returnOp.getLoc(), returnValues); + returnOp.erase(); + }); + + // Replace usage of values with the results of the region. + for (int i = 0; i < originalOp->getNumResults(); ++i) { + originalOp->getResult(i)->replaceAllUsesWith( + reductionRegionOp.getResult(i)); + } + + return success(); +} + +// Converts an xla_hlo::ReduceOp to a reduction region and inlines the target +// computation into the region body. +LogicalResult buildReductionRegionFromXLAReduceOp(xla_hlo::ReduceOp reduceOp) { + SmallVector<Value *, 4> operands(reduceOp.getOperands()); + OperandAdaptor<xla_hlo::ReduceOp> adaptor(operands); + + SmallVector<int32_t, 4> dimensions; + for (auto dim : reduceOp.dimensions().getIntValues()) { + dimensions.push_back(dim.getSExtValue()); + } + + // Create the iree.reduction_region. + if (failed(buildReductionRegion(reduceOp, adaptor.operands(), + adaptor.init_values(), dimensions, + reduceOp.body()))) { + return failure(); + } + + // Remove original XLA reduction op. + reduceOp.erase(); + + return success(); +} + +// Identifies reduction ops and moves them into reduction regions. +LogicalResult identifyBlockReductionRegions(FuncOp funcOp, Block *block) { + // Fixed point iteration until we can no longer fuse anything. + bool didFindAnyNewRegions; + do { + // Iterate in reverse so we root further along in the op list. + didFindAnyNewRegions = false; + for (auto &rootOp : llvm::reverse(*block)) { + if (auto reduceOp = dyn_cast<xla_hlo::ReduceOp>(rootOp)) { + if (failed(buildReductionRegionFromXLAReduceOp(reduceOp))) { + return failure(); + } + + // Successfully created a dispatch region from the ops and we must now + // start over again as we've likely trashed the whole block structure. + didFindAnyNewRegions = true; + break; + } + } + } while (didFindAnyNewRegions); + return success(); +} + +} // namespace + +// Identifies reduction ops and moves their targets into iree.reduction_regions. +class IdentifyReductionRegionsPass + : public ModulePass<IdentifyReductionRegionsPass> { + public: + void runOnModule() override { + for (auto funcOp : getModule().getOps<FuncOp>()) { + for (auto &block : funcOp) { + if (failed(identifyBlockReductionRegions(funcOp, &block))) { + return signalPassFailure(); + } + } + } + } +}; + +std::unique_ptr<OpPassBase<ModuleOp>> createIdentifyReductionRegionsPass() { + return std::make_unique<IdentifyReductionRegionsPass>(); // NOLINT +} + +static PassRegistration<IdentifyReductionRegionsPass> pass( + "iree-flow-identify-reduction-regions", + "Identifies reduction regions based on input reduction ops"); + +} // namespace Flow +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir
diff --git a/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp b/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp new file mode 100644 index 0000000..1faab2f --- /dev/null +++ b/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp
@@ -0,0 +1,130 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <utility> + +#include "iree/compiler/Dialect/Flow/IR/FlowOps.h" +#include "iree/compiler/Dialect/Flow/Utils/DispatchUtils.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace Flow { + +namespace { + +// Converts a dispatch_region into a dispatch to the outlined region function. +LogicalResult convertToDispatchOp(DispatchRegionOp regionOp, + ExecutableOp executableOp, + DispatchEntryOp entryPointOp, + FuncOp outlinedFuncOp) { + // Insert at the same place as the original region. + OpBuilder builder(regionOp); + + // Create the dispatch op to the executable function. + auto dispatchOp = builder.create<DispatchOp>( + regionOp.getLoc(), executableOp.getName(), entryPointOp.getName(), + regionOp.workload(), outlinedFuncOp.getType().getResults(), + llvm::to_vector<8>(regionOp.args())); + + // Replace uses of the existing results with the new results. + for (int i = 0; i < regionOp.getNumResults(); ++i) { + regionOp.getResult(i)->replaceAllUsesWith(dispatchOp.getResult(i)); + } + + // Erase original region. + regionOp.erase(); + + return success(); +} + +// Outlines a dispatch region into a flow.executable. +LogicalResult outlineDispatchRegion( + DispatchRegionOp regionOp, int outlinedRegionOrdinal, + llvm::StringMap<FuncOp> &dispatchableFuncOps) { + // Build function type matching 1:1 with the region signature. + SmallVector<Type, 8> operandTypes( + Operation::operand_type_range{regionOp.args()}); + SmallVector<Type, 8> resultTypes(regionOp.getResultTypes()); + auto functionType = + FunctionType::get(operandTypes, resultTypes, regionOp.getContext()); + + // Create the executable with the region cloned into it. + ExecutableOp executableOp; + FuncOp outlinedFuncOp; + std::tie(executableOp, outlinedFuncOp) = createRegionExecutable( + regionOp, functionType, + "_dispatch_" + std::to_string(outlinedRegionOrdinal), + dispatchableFuncOps); + + // Add dispatch export pointing at the function. + OpBuilder builder(executableOp.body()); + auto entryPointOp = builder.create<DispatchEntryOp>( + regionOp.getLoc(), builder.getStringAttr(outlinedFuncOp.getName()), + builder.getSymbolRefAttr(outlinedFuncOp), DenseIntElementsAttr{}, + DenseIntElementsAttr{}); + + // Finally convert the dispatch region into a dispatch to the outlined func. + return convertToDispatchOp(regionOp, executableOp, entryPointOp, + outlinedFuncOp); +} + +} // namespace + +class OutlineDispatchRegionsPass + : public ModulePass<OutlineDispatchRegionsPass> { + public: + OutlineDispatchRegionsPass() = default; + OutlineDispatchRegionsPass( + std::shared_ptr<llvm::StringMap<FuncOp>> dispatchableFuncOps) + : dispatchableFuncOps_(std::move(dispatchableFuncOps)) {} + + void runOnModule() override { + // TODO(benvanik): replace with a pattern rewriter? + auto funcOps = llvm::to_vector<32>(getModule().getOps<FuncOp>()); + for (auto funcOp : funcOps) { + // Outline all of the dispatch regions ops in this function. + SmallVector<DispatchRegionOp, 8> dispatchRegionOps; + funcOp.walk( + [&](DispatchRegionOp op) { dispatchRegionOps.push_back(op); }); + for (int i = 0; i < dispatchRegionOps.size(); ++i) { + if (failed(outlineDispatchRegion(dispatchRegionOps[i], i, + *dispatchableFuncOps_))) { + return signalPassFailure(); + } + } + } + } + + private: + std::shared_ptr<llvm::StringMap<FuncOp>> dispatchableFuncOps_; +}; + +std::unique_ptr<OpPassBase<ModuleOp>> createOutlineDispatchRegionsPass( + std::shared_ptr<llvm::StringMap<FuncOp>> dispatchableFuncOps) { + return std::make_unique<OutlineDispatchRegionsPass>( + std::move(dispatchableFuncOps)); +} + +static PassRegistration<OutlineDispatchRegionsPass> pass( + "iree-flow-outline-dispatch-regions", + "Outlines dispatch regions into standalone functions"); + +} // namespace Flow +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir
diff --git a/iree/compiler/Dialect/Flow/Transforms/OutlineReductionRegions.cpp b/iree/compiler/Dialect/Flow/Transforms/OutlineReductionRegions.cpp new file mode 100644 index 0000000..53a377c --- /dev/null +++ b/iree/compiler/Dialect/Flow/Transforms/OutlineReductionRegions.cpp
@@ -0,0 +1,392 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <utility> + +#include "iree/compiler/Dialect/Flow/IR/FlowOps.h" +#include "iree/compiler/Dialect/Flow/Utils/DispatchUtils.h" +#include "iree/compiler/Dialect/Flow/Utils/WorkloadUtils.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace Flow { + +namespace { + +// Determines the shapes involved with reducing this dimension. +SmallVector<int64_t, 4> calculateResultShape(Value *input, + int windowDimension) { + SmallVector<int64_t, 4> resultShape; + for (auto it : + llvm::enumerate(input->getType().cast<ShapedType>().getShape())) { + if (it.index() != windowDimension) { + resultShape.push_back(it.value()); + } + } + return resultShape; +} + +// Converts a reduction_region into a dispatch to the outlined region function +// for a single reduction dimension. +// Returns the results of the reduction or empty if the construction fails. +SmallVector<Value *, 4> convertToDispatchOp( + Operation *regionOp, ExecutableOp executableOp, StringRef entryPointName, + int reductionDimension, SmallVector<Value *, 4> initialValues, + SmallVector<Value *, 4> inputs, OpBuilder &dispatcherBuilder) { + SmallVector<Type, 4> resultTypes; + for (auto resultType : llvm::enumerate(regionOp->getResultTypes())) { + // Allocate output buffer in the dispatcher to pass in to the region. + auto shapedType = resultType.value().cast<ShapedType>(); + auto reducedType = RankedTensorType::get( + calculateResultShape(inputs[resultType.index()], reductionDimension), + shapedType.getElementType()); + resultTypes.push_back(reducedType); + } + + // Calculate workload from the result shape. + auto *workload = + calculateWorkload(regionOp, resultTypes.front().cast<ShapedType>()); + + // Create the reduce op to the executable function. + std::vector<Value *> allOperands; + allOperands.insert(allOperands.end(), inputs.begin(), inputs.end()); + allOperands.insert(allOperands.end(), initialValues.begin(), + initialValues.end()); + auto dispatchOp = dispatcherBuilder.create<DispatchOp>( + regionOp->getLoc(), executableOp.getName(), entryPointName, workload, + resultTypes, allOperands); + + return llvm::to_vector<4>(dispatchOp.getResults()); +} + +// Creates an executable that holds the given elemental reduction region. +// The executable will have an entry point taking the specified reduction values +// and writing the results to output arguments. +std::pair<ExecutableOp, ReductionEntryOp> createReductionExecutable( + ReductionRegionOp regionOp, int outlinedRegionOrdinal, + int separatedReductionIndex, int reductionDimension, + SmallVector<Value *, 4> initialValues, SmallVector<Value *, 4> inputs, + llvm::StringMap<FuncOp> &dispatchableFuncOps) { + // Build function type matching 1:1 with the region signature. + SmallVector<Type, 8> elementalOperandTypes; + SmallVector<Type, 8> elementalResultTypes; + for (auto *arg : regionOp.initial_values()) { + // (in0, in1) -> out0 + elementalOperandTypes.push_back(arg->getType()); + elementalOperandTypes.push_back(arg->getType()); + elementalResultTypes.push_back(arg->getType()); + } + auto elementalFunctionType = FunctionType::get( + elementalOperandTypes, elementalResultTypes, regionOp.getContext()); + + // Create the executable with the region cloned into it. + ExecutableOp executableOp; + FuncOp elementalFuncOp; + std::tie(executableOp, elementalFuncOp) = createRegionExecutable( + regionOp, elementalFunctionType, + "_reduce_" + std::to_string(outlinedRegionOrdinal) + "_dim_" + + std::to_string(separatedReductionIndex), + dispatchableFuncOps); + + // Create a new entry point that we can use with the signature for this + // dimension. + SmallVector<Type, 8> allOperandTypes; + auto inputTypes = + llvm::map_range(inputs, [](Value *value) { return value->getType(); }); + allOperandTypes.append(inputTypes.begin(), inputTypes.end()); + auto initialValueTypes = llvm::map_range( + initialValues, [](Value *value) { return value->getType(); }); + allOperandTypes.append(initialValueTypes.begin(), initialValueTypes.end()); + SmallVector<Type, 4> resultTypes; + for (auto resultType : llvm::enumerate(regionOp.getResultTypes())) { + auto shapedType = resultType.value().cast<ShapedType>(); + auto reducedType = RankedTensorType::get( + calculateResultShape(inputs[resultType.index()], reductionDimension), + shapedType.getElementType()); + resultTypes.push_back(reducedType); + } + auto entryFuncType = + FunctionType::get(allOperandTypes, resultTypes, regionOp.getContext()); + auto entryFuncOp = FuncOp::create( + regionOp.getLoc(), (elementalFuncOp.getName() + "_entry").str(), + entryFuncType); + elementalFuncOp.getOperation()->getBlock()->push_back(entryFuncOp); + entryFuncOp.getOperation()->moveBefore(elementalFuncOp); + + // Add dispatch export pointing at the function. + OpBuilder builder(executableOp.body()); + auto entryPointOp = builder.create<ReductionEntryOp>( + regionOp.getLoc(), builder.getStringAttr(entryFuncOp.getName()), + builder.getSymbolRefAttr(entryFuncOp), + builder.getSymbolRefAttr(elementalFuncOp), + builder.getI32IntegerAttr(reductionDimension)); + + return {executableOp, entryPointOp}; +} + +// Outlines a reduction region into one or more executables. +// This separates the reduction into multiple dispatches, one for each reduction +// dimension (thankfully XLA's operation semantics state this is ok). We then +// special case the first dispatch such that it takes the constant initial +// values so that we don't have to materialize a buffer for them. +LogicalResult outlineReductionRegion( + ReductionRegionOp regionOp, int outlinedRegionOrdinal, + llvm::StringMap<FuncOp> &dispatchableFuncOps) { + // Insert at the same place as the original region. + OpBuilder dispatcherBuilder(regionOp); + + SmallVector<Value *, 4> initialValues{regionOp.initial_values()}; + SmallVector<Value *, 4> temps{regionOp.operands()}; + + // Create one dispatch per dimension being reduced. + // We'll do this by chaining the original input through with the temporary + // reduction results. The results we end up with will be the originally + // requested shape and we can just substitute them. + auto dimensions = regionOp.dimensions().getValue(); + SmallVector<int32_t, 4> sortedDimensions; + for (uint32_t i = 0; i < dimensions.getNumElements(); ++i) { + sortedDimensions.push_back(dimensions.getValue<IntegerAttr>({i}).getInt()); + } + llvm::sort(sortedDimensions, [](int32_t a, int32_t b) { return a - b; }); + for (auto dimension : llvm::enumerate(sortedDimensions)) { + // Create the executable with the region cloned into it. + ExecutableOp executableOp; + ReductionEntryOp entryPointOp; + std::tie(executableOp, entryPointOp) = createReductionExecutable( + regionOp, outlinedRegionOrdinal, dimension.index(), dimension.value(), + initialValues, temps, dispatchableFuncOps); + + // Finally convert the dispatch region into a dispatch to the outlined func. + temps = convertToDispatchOp(regionOp, executableOp, entryPointOp.getName(), + dimension.value(), initialValues, + std::move(temps), dispatcherBuilder); + if (temps.empty()) { + return regionOp.emitOpError() + << "failed to construct reduction for dimension " + << dimension.value(); + } + } + + // Replace uses of the existing results with the new results. + for (int i = 0; i < regionOp.getNumResults(); ++i) { + regionOp.getResult(i)->replaceAllUsesWith(temps[i]); + } + + // Erase original region. + regionOp.erase(); + + return success(); +} + +// Creates an executable that holds the given elemental reduction region. +// The executable will have an entry point taking the specified reduction values +// and writing the results to output arguments. +std::pair<ExecutableOp, WindowedReductionEntryOp> +createWindowedReductionExecutable( + WindowedReductionRegionOp regionOp, int outlinedRegionOrdinal, + int separatedReductionIndex, int32_t windowDimension, int32_t windowStride, + int32_t baseDilation, int32_t windowDilation, + SmallVector<Value *, 4> initialValues, SmallVector<Value *, 4> inputs, + llvm::StringMap<FuncOp> &dispatchableFuncOps) { + // Build function type matching 1:1 with the region signature. + SmallVector<Type, 8> elementalOperandTypes; + SmallVector<Type, 8> elementalResultTypes; + for (auto *arg : regionOp.initial_values()) { + // (in0, in1) -> out0 + elementalOperandTypes.push_back(arg->getType()); + elementalOperandTypes.push_back(arg->getType()); + elementalResultTypes.push_back(arg->getType()); + } + auto elementalFunctionType = FunctionType::get( + elementalOperandTypes, elementalResultTypes, regionOp.getContext()); + + // Create the executable with the region cloned into it. + ExecutableOp executableOp; + FuncOp elementalFuncOp; + std::tie(executableOp, elementalFuncOp) = createRegionExecutable( + regionOp, elementalFunctionType, + "_reduce_" + std::to_string(outlinedRegionOrdinal) + "_dim_" + + std::to_string(separatedReductionIndex), + dispatchableFuncOps); + + // Create a new entry point that we can use with the signature for this + // dimension. + SmallVector<Type, 8> allOperandTypes; + auto inputTypes = + llvm::map_range(inputs, [](Value *value) { return value->getType(); }); + allOperandTypes.append(inputTypes.begin(), inputTypes.end()); + auto initialValueTypes = llvm::map_range( + initialValues, [](Value *value) { return value->getType(); }); + allOperandTypes.append(initialValueTypes.begin(), initialValueTypes.end()); + SmallVector<Type, 4> resultTypes; + for (auto resultType : llvm::enumerate(regionOp.getResultTypes())) { + auto shapedType = resultType.value().cast<ShapedType>(); + auto reducedType = RankedTensorType::get( + calculateResultShape(inputs[resultType.index()], windowDimension), + shapedType.getElementType()); + resultTypes.push_back(reducedType); + } + auto entryFuncType = + FunctionType::get(allOperandTypes, resultTypes, regionOp.getContext()); + auto entryFuncOp = FuncOp::create( + regionOp.getLoc(), (elementalFuncOp.getName() + "_entry").str(), + entryFuncType); + elementalFuncOp.getOperation()->getBlock()->push_back(entryFuncOp); + entryFuncOp.getOperation()->moveBefore(elementalFuncOp); + + // Add dispatch export pointing at the function. + OpBuilder builder(executableOp.body()); + auto entryPointOp = builder.create<WindowedReductionEntryOp>( + regionOp.getLoc(), builder.getStringAttr(entryFuncOp.getName()), + builder.getSymbolRefAttr(entryFuncOp), + builder.getSymbolRefAttr(elementalFuncOp), + builder.getI32IntegerAttr(windowDimension), + builder.getI32IntegerAttr(windowStride), + builder.getI32IntegerAttr(baseDilation), + builder.getI32IntegerAttr(windowDilation), + builder.getI32IntegerAttr( + static_cast<uint32_t>(regionOp.padding_mode()))); + + return {executableOp, entryPointOp}; +} + +// Outlines a windowed reduction region into one or more executables. +// This separates the reduction into multiple dispatches, one for each reduction +// dimension (thankfully XLA's operation semantics state this is ok). We then +// special case the first dispatch such that it takes the constant initial +// values so that we don't have to materialize a buffer for them. +LogicalResult outlineWindowedReductionRegion( + WindowedReductionRegionOp regionOp, int outlinedRegionOrdinal, + llvm::StringMap<FuncOp> &dispatchableFuncOps) { + // Insert at the same place as the original region. + OpBuilder dispatcherBuilder(regionOp); + + SmallVector<Value *, 4> initialValues{regionOp.initial_values()}; + SmallVector<Value *, 4> temps{regionOp.operands()}; + + // Create one dispatch per dimension being reduced. + // We'll do this by chaining the original input through with the temporary + // reduction results. The results we end up with will be the originally + // requested shape and we can just substitute them. + auto windowDimensions = regionOp.window_dimensions(); + auto windowStrides = regionOp.window_strides(); + auto baseDilations = regionOp.base_dilations(); + auto windowDilations = regionOp.window_dilations(); + SmallVector<std::tuple<int32_t, int32_t, int32_t, int32_t>, 4> + sortedWindowAttrs; + for (uint32_t i = 0; i < windowDimensions.getNumElements(); ++i) { + int32_t windowDimension = + windowDimensions.getValue<IntegerAttr>({i}).getInt(); + int32_t windowStride = windowStrides.getValue<IntegerAttr>({i}).getInt(); + int32_t baseDilation = baseDilations.getValue<IntegerAttr>({i}).getInt(); + int32_t windowDilation = + windowDilations.getValue<IntegerAttr>({i}).getInt(); + sortedWindowAttrs.push_back( + {windowDimension, windowStride, baseDilation, windowDilation}); + } + llvm::sort(sortedWindowAttrs, + [](std::tuple<int32_t, int32_t, int32_t, int32_t> a, + std::tuple<int32_t, int32_t, int32_t, int32_t> b) { + return std::get<0>(a) - std::get<0>(b); + }); + for (auto windowAttrs : llvm::enumerate(sortedWindowAttrs)) { + int32_t windowDimension = std::get<0>(windowAttrs.value()); + int32_t windowStride = std::get<1>(windowAttrs.value()); + int32_t baseDilation = std::get<2>(windowAttrs.value()); + int32_t windowDilation = std::get<3>(windowAttrs.value()); + ExecutableOp executableOp; + WindowedReductionEntryOp entryPointOp; + std::tie(executableOp, entryPointOp) = createWindowedReductionExecutable( + regionOp, outlinedRegionOrdinal, windowAttrs.index(), windowDimension, + windowStride, baseDilation, windowDilation, initialValues, temps, + dispatchableFuncOps); + temps = convertToDispatchOp(regionOp, executableOp, entryPointOp.getName(), + windowDimension, initialValues, + std::move(temps), dispatcherBuilder); + if (temps.empty()) { + return regionOp.emitOpError() + << "failed to construct reduction for windowed dimension " + << windowDimension; + } + } + + // Replace uses of the existing results with the new results. + for (int i = 0; i < regionOp.getNumResults(); ++i) { + regionOp.getResult(i)->replaceAllUsesWith(temps[i]); + } + + // Erase original region. + regionOp.erase(); + + return success(); +} + +} // namespace + +class OutlineReductionRegionsPass + : public ModulePass<OutlineReductionRegionsPass> { + public: + OutlineReductionRegionsPass() = default; + explicit OutlineReductionRegionsPass( + std::shared_ptr<llvm::StringMap<FuncOp>> dispatchableFuncOps) + : dispatchableFuncOps_(std::move(dispatchableFuncOps)) {} + + void runOnModule() override { + // TODO(benvanik): replace with a pattern rewriter? + auto funcOps = llvm::to_vector<32>(getModule().getOps<FuncOp>()); + for (auto funcOp : funcOps) { + SmallVector<ReductionRegionOp, 4> reductionRegionOps; + funcOp.walk( + [&](ReductionRegionOp op) { reductionRegionOps.push_back(op); }); + for (int i = 0; i < reductionRegionOps.size(); ++i) { + if (failed(outlineReductionRegion(reductionRegionOps[i], i, + *dispatchableFuncOps_))) { + return signalPassFailure(); + } + } + SmallVector<WindowedReductionRegionOp, 4> windowedReductionRegionOps; + funcOp.walk([&](WindowedReductionRegionOp op) { + windowedReductionRegionOps.push_back(op); + }); + for (int i = 0; i < windowedReductionRegionOps.size(); ++i) { + if (failed(outlineWindowedReductionRegion(windowedReductionRegionOps[i], + i, *dispatchableFuncOps_))) { + return signalPassFailure(); + } + } + } + } + + private: + std::shared_ptr<llvm::StringMap<FuncOp>> dispatchableFuncOps_; +}; + +std::unique_ptr<OpPassBase<ModuleOp>> createOutlineReductionRegionsPass( + std::shared_ptr<llvm::StringMap<FuncOp>> dispatchableFuncOps) { + return std::make_unique<OutlineReductionRegionsPass>( + std::move(dispatchableFuncOps)); // NOLINT +} + +static PassRegistration<OutlineReductionRegionsPass> pass( + "iree-flow-outline-reduction-regions", + "Outlines reduction regions into standalone functions"); + +} // namespace Flow +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir
diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp new file mode 100644 index 0000000..9d7e77d --- /dev/null +++ b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
@@ -0,0 +1,83 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/compiler/Dialect/Flow/Transforms/Passes.h" + +#include <memory> + +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Transforms/Passes.h" + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace Flow { + +void buildFlowTransformPassPipeline(OpPassManager &passManager) { + // Flatten tuples (like tuple<tensor<...>, tensor<...>>) so we can do + // fine-grained tensor tracking. + passManager.addPass(IREE::Flow::createFlattenTuplesInCFGPass()); + + // Find reduction ops and create flow.reduction.regions. We do this prior to + // performing dispatch region identification so that we can build as big of + // fused reduction regions as possible. The remaining ops will be put into + // dispatch regions. + passManager.addPass(IREE::Flow::createIdentifyReductionRegionsPass()); + passManager.addNestedPass<FuncOp>(createCSEPass()); + + // First perform module-level analysis that following passes will use to query + // per-function dispatchability information. We run this first so that it only + // needs to run once and will be cached for all of the following passes. + // TODO(b/144784188): avoid this and instead rely on AnalysisManager cache. + auto dispatchableFuncOps = std::make_shared<llvm::StringMap<FuncOp>>(); + passManager.addPass( + IREE::Flow::createDispatchabilityAnalysisPass(dispatchableFuncOps)); + + // Create all of the dispatch regions, CSE their workloads, and fold. + passManager.addPass(IREE::Flow::createIdentifyDispatchRegionsPass()); + passManager.addNestedPass<FuncOp>(createCSEPass()); + passManager.addPass(IREE::Flow::createFoldCompatibleDispatchRegionsPass()); + + // Note that as we are rematerializing things here it's critical we do not run + // the canonicalizer/CSE between now and when we outline - otherwise it'll + // undo all of our work! + passManager.addPass(IREE::Flow::createRematerializeDispatchConstantsPass()); + + // Outline the dispatch regions into their own functions. This separates the + // sequencer functions performing dispatches from the dispatchees. + passManager.addPass( + IREE::Flow::createOutlineDispatchRegionsPass(dispatchableFuncOps)); + passManager.addPass( + IREE::Flow::createOutlineReductionRegionsPass(dispatchableFuncOps)); + + // Cleanup identity ops that clutter up the IR and canonicalize. + passManager.addNestedPass<FuncOp>(createCanonicalizerPass()); + + // Assign attributes and negotiate each executable's ABI signature. + passManager.addPass(IREE::Flow::createAssignExecutableWorkloadsPass()); + + // TODO(benvanik): run symbol DCE pass. +} + +static PassPipelineRegistration<> transformPassPipeline( + "iree-flow-transformation-pipeline", + "Runs the full IREE flow dialect transformation pipeline", + [](OpPassManager &passManager) { + buildFlowTransformPassPipeline(passManager); + }); + +} // namespace Flow +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir
diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.h b/iree/compiler/Dialect/Flow/Transforms/Passes.h new file mode 100644 index 0000000..f11d2e0 --- /dev/null +++ b/iree/compiler/Dialect/Flow/Transforms/Passes.h
@@ -0,0 +1,108 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_COMPILER_DIALECT_FLOW_TRANSFORMS_PASSES_H_ +#define IREE_COMPILER_DIALECT_FLOW_TRANSFORMS_PASSES_H_ + +#include "iree/compiler/Dialect/Flow/IR/FlowOps.h" +#include "llvm/ADT/StringMap.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/Module.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace Flow { + +//===----------------------------------------------------------------------===// +// Helpers +//===----------------------------------------------------------------------===// + +// Adds a set of passes to the given pass manager that run the required flow +// transforms in the canonical order. +// +// Most translation code should prefer to use this instead of manually adding +// the passes themselves to ensure that expected pass ordering is observed. +// +// The expected usage is: +// <run conversion from TF/HLO/etc to flow> +// buildFlowTransformPassPipeline & run +// <run conversion from flow to sequencer/hal/vm/etc> +void buildFlowTransformPassPipeline(OpPassManager &passManager); + +//===----------------------------------------------------------------------===// +// Input canonicalization and legalization +//===----------------------------------------------------------------------===// + +// Flattens tuple values in function signatures and blocks. +std::unique_ptr<OpPassBase<ModuleOp>> createFlattenTuplesInCFGPass(); + +//===----------------------------------------------------------------------===// +// Dispatches (flow.dispatch.region) +//===----------------------------------------------------------------------===// + +// Analyzes a module to identify which functions are dispatchable. +// This information is cached on the module and is used by other FuncOp-scoped +// passes to quickly access the module-level dispatchability information. +std::unique_ptr<OpPassBase<ModuleOp>> createDispatchabilityAnalysisPass( + std::shared_ptr<llvm::StringMap<FuncOp>> dispatchableFuncOps); + +// Identifies dispatchable regions of functions and wraps them in +// flow.dispatch_regions. +std::unique_ptr<OpPassBase<FuncOp>> createIdentifyDispatchRegionsPass(); + +// Folds multiple dispatch regions together that have compatible workloads. +std::unique_ptr<OpPassBase<FuncOp>> createFoldCompatibleDispatchRegionsPass(); + +// Rematerializes small previously-CSE'd constants into dispatch regions. +std::unique_ptr<OpPassBase<FuncOp>> createRematerializeDispatchConstantsPass(); + +// Outlines dispatch regions into executables. +std::unique_ptr<OpPassBase<ModuleOp>> createOutlineDispatchRegionsPass( + std::shared_ptr<llvm::StringMap<FuncOp>> dispatchableFuncOps); + +//===----------------------------------------------------------------------===// +// Reductions (flow.reduction.region) +//===----------------------------------------------------------------------===// + +// Identifies reduction regions and wraps them in flow.reduction_regions. +std::unique_ptr<OpPassBase<ModuleOp>> createIdentifyReductionRegionsPass(); + +// Outlines dispatch regions into executables. +std::unique_ptr<OpPassBase<ModuleOp>> createOutlineReductionRegionsPass( + std::shared_ptr<llvm::StringMap<FuncOp>> dispatchableFuncOps); + +//===----------------------------------------------------------------------===// +// Optimizations +//===----------------------------------------------------------------------===// + +// TODO(benvanik): pass to dedupe similar executables (by making dynamically +// shaped, adjusting types, etc). + +//===----------------------------------------------------------------------===// +// Module Analysis and Finalization +//===----------------------------------------------------------------------===// + +// Assigns workload attributes to executable entry points based on dispatches. +std::unique_ptr<OpPassBase<ModuleOp>> createAssignExecutableWorkloadsPass(); + +} // namespace Flow +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_DIALECT_FLOW_TRANSFORMS_PASSES_H_
diff --git a/iree/compiler/Dialect/Flow/Transforms/RematerializeDispatchConstants.cpp b/iree/compiler/Dialect/Flow/Transforms/RematerializeDispatchConstants.cpp new file mode 100644 index 0000000..c8c956d --- /dev/null +++ b/iree/compiler/Dialect/Flow/Transforms/RematerializeDispatchConstants.cpp
@@ -0,0 +1,231 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <algorithm> + +#include "iree/compiler/Dialect/Flow/IR/FlowOps.h" +#include "llvm/Support/Debug.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/Utils.h" +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace Flow { + +namespace { + +// Chosen randomly for now. We can measure and see what makes sense. +constexpr int64_t kMaxRematerializedConstantSizeInBytes = 1 * 1024; + +// Returns true if the constant value is under a certain threshold. +// This threshold is fixed for all backends as a value that is assumed small +// enough to be worth inlining possibly several times (at the cost of binary +// bloat). +bool isConstantSmall(ConstantOp constantOp) { + if (auto shapedType = constantOp.getType().dyn_cast<ShapedType>()) { + return shapedType.getSizeInBits() / 8 <= + kMaxRematerializedConstantSizeInBytes; + } + + // Assume anything unshaped is small. This may not always be true in custom + // dialects but is in std for now. + return true; +} + +// Returns true if the dispatch region is allowed to have constants inside. +// Certain regions that may get replaced or turned into kernel imports shouldn't +// have the constants moved into them as they'll just get lost. +bool canDispatchRegionContainConstants( + DispatchRegionOp dispatchRegionOp) { + for (auto &block : dispatchRegionOp.body()) { + for (auto &op : block) { + // TODO(b/144530470): replace with tablegen attributes/interfaces. + if (isa<xla_hlo::DotOp>(&op) || isa<xla_hlo::ConvOp>(&op)) { + return false; + } + } + } + return true; +} + +// Recursively clones the given |sourceOp| and returns the newly cloned op. +Operation *recursivelyCloneOp(Operation *sourceOp, OpBuilder &builder, + BlockAndValueMapping *mapping) { + // Note that we dedupe required operands in the case of multiple arguments + // coming from the same source operation. + SmallPtrSet<Operation *, 4> operandOps; + for (auto *operand : sourceOp->getOperands()) { + operandOps.insert(operand->getDefiningOp()); + } + for (auto *operandOp : operandOps) { + recursivelyCloneOp(operandOp, builder, mapping); + } + return builder.clone(*sourceOp, *mapping); +} + +// Clones the |sourceValue| op tree into |targetBlock|. +// |mapping| is used to lookup existing values that may be present in the block +// such as block arguments or already cloned ancestor ops. |mapping| will be +// updated as the tree is cloned. +Value *cloneOpTreeIntoBlock(Value *sourceValue, Block *targetBlock, + BlockAndValueMapping *mapping) { + // If the op has already been cloned we can just reuse that. + // This happens if multiple arguments reference the same trees. + if (auto *existingValue = mapping->lookupOrNull(sourceValue)) { + return existingValue; + } + + OpBuilder builder(targetBlock); + builder.setInsertionPointToStart(targetBlock); + auto *sourceOp = sourceValue->getDefiningOp(); + auto *clonedOp = recursivelyCloneOp(sourceOp, builder, mapping); + + // Return only the result matching our source value (in the case of multiple + // results). + int resultIndex = std::distance( + sourceOp->result_begin(), + std::find(sourceOp->result_begin(), sourceOp->result_end(), sourceValue)); + return clonedOp->getResult(resultIndex); +} + +// Inlines use of the given |value| from outside of a dispatch region to inside +// of it and removes the argument. Supports multiple arguments that reference +// |value| and will clone the entire value tree. +LogicalResult inlineDispatchRegionOperandsUsingValue( + DispatchRegionOp dispatchRegionOp, Value *value) { + // Find all args that are using this value. + SmallVector<unsigned, 4> argIndices; + for (auto arg : llvm::enumerate(dispatchRegionOp.args())) { + if (arg.value() == value) { + argIndices.push_back(arg.index()); + } + } + if (argIndices.empty()) { + // Not used? Wasteful call! + return success(); + } + + // Clone the value (and the ops required to create it) into the entry block. + auto &entryBlock = dispatchRegionOp.body().getBlocks().front(); + BlockAndValueMapping mapping; + auto *clonedValue = cloneOpTreeIntoBlock(value, &entryBlock, &mapping); + + // Replace all uses of the inner operand with the new value. + for (unsigned argIndex : argIndices) { + entryBlock.getArgument(argIndex)->replaceAllUsesWith(clonedValue); + } + + // Remove the dispatch region args and the block args that have been + // replaced. + for (unsigned argIndex : llvm::reverse(argIndices)) { + dispatchRegionOp.getOperation()->eraseOperand( + dispatchRegionOp.mapArgOperandToOpOperand(argIndex)); + entryBlock.eraseArgument(argIndex); + } + + return success(); +} + +// Rematerializes a constant inside of all dispatch regions that use it. +// Afterward the constant is only removed if there are no other uses within the +// non-dispatch block (such as by sequencer ops). +LogicalResult rematerializeConstantInDispatchRegions(ConstantOp constantOp) { + Value *constantValue = constantOp.getResult(); + SmallVector<DispatchRegionOp, 4> usingRegionOps; + for (auto *user : constantValue->getUsers()) { + if (auto dispatchRegionOp = dyn_cast<DispatchRegionOp>(user)) { + // Ensure this isn't just the workload and is used as an arg. + if (std::find(dispatchRegionOp.args().begin(), + dispatchRegionOp.args().end(), + constantValue) != dispatchRegionOp.args().end()) { + if (canDispatchRegionContainConstants(dispatchRegionOp)) { + usingRegionOps.push_back(dispatchRegionOp); + } + } + } + } + for (auto &dispatchRegionOp : usingRegionOps) { + if (failed(inlineDispatchRegionOperandsUsingValue(dispatchRegionOp, + constantValue))) { + return failure(); + } + } + + // Remove if there are no other uses within the block. + if (constantOp.use_empty()) { + constantOp.erase(); + } + + return success(); +} + +} // namespace + +// Finds constant arguments to dispatch regions that are too small to be worth +// putting into constant pools. This prevents things like a CSE'd scalar +// constant of 0.0 being passed by reference to a bunch of regions. Later +// backend-specific passes running on the dispatch regions may also be able to +// improve their constant propagation chances by having the full constant value +// available. +// +// Note that this currently only operates at the block level. Constants that are +// pushed across branches are assumed to have been rematerialized within blocks +// already, but if that isn't the case then this pass can be extended to do +// that. +class RematerializeDispatchConstantsPass + : public FunctionPass<RematerializeDispatchConstantsPass> { + public: + void runOnFunction() override { + for (auto &block : getFunction()) { + SmallVector<ConstantOp, 8> smallConstantOps; + for (auto constantOp : block.getOps<ConstantOp>()) { + if (isConstantSmall(constantOp)) { + smallConstantOps.push_back(constantOp); + } + } + // Note: we iterate in reverse so that the rematerialized constants appear + // in the same order they did originally (as insertion is at the top). + for (auto constantOp : llvm::reverse(smallConstantOps)) { + if (failed(rematerializeConstantInDispatchRegions(constantOp))) { + return signalPassFailure(); + } + } + } + } +}; + +std::unique_ptr<OpPassBase<FuncOp>> createRematerializeDispatchConstantsPass() { + return std::make_unique<RematerializeDispatchConstantsPass>(); +} + +static PassRegistration<RematerializeDispatchConstantsPass> pass( + "iree-flow-rematerialize-dispatch-constants", + "Rematerializes small previously-CSE'd constants into dispatch regions"); + +} // namespace Flow +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/BUILD b/iree/compiler/Dialect/Flow/Transforms/test/BUILD new file mode 100644 index 0000000..beab62e --- /dev/null +++ b/iree/compiler/Dialect/Flow/Transforms/test/BUILD
@@ -0,0 +1,28 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//iree:build_defs.bzl", "iree_glob_lit_tests", "iree_setup_lit_package") + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +iree_setup_lit_package( + data = [ + "//iree/tools:iree-opt", + ], +) + +iree_glob_lit_tests()
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/assign_executable_workloads.mlir b/iree/compiler/Dialect/Flow/Transforms/test/assign_executable_workloads.mlir new file mode 100644 index 0000000..51ec348 --- /dev/null +++ b/iree/compiler/Dialect/Flow/Transforms/test/assign_executable_workloads.mlir
@@ -0,0 +1,57 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// RUN: iree-opt -split-input-file -iree-flow-assign-executable-workloads %s | FileCheck %s --dump-input=fail + +flow.executable @singleStaticWorkload_ex_dispatch_0 { + // CHECK-LABEL: flow.dispatch.entry @singleStaticWorkload_rgn_dispatch_0 + // CHECK-SAME: workgroup_size = dense<[32, 1, 1]> : tensor<3xi32> + // CHECK-SAME: workload = dense<[4, 1, 1]> : tensor<3xi32> + flow.dispatch.entry @singleStaticWorkload_rgn_dispatch_0 + module { + func @singleStaticWorkload_rgn_dispatch_0(%arg0: tensor<4xf32>) -> tensor<4xf32> { + %0 = addf %arg0, %arg0 : tensor<4xf32> + %1 = subf %0, %arg0 : tensor<4xf32> + %2 = mulf %1, %arg0 : tensor<4xf32> + return %2 : tensor<4xf32> + } + } +} +func @singleStaticWorkload(%arg0: tensor<4xf32>) -> tensor<4xf32> { + %cst = constant dense<[4, 1, 1]> : tensor<3xi32> + %0 = flow.dispatch @singleStaticWorkload_ex_dispatch_0::@singleStaticWorkload_rgn_dispatch_0[%cst : tensor<3xi32>](%arg0) : (tensor<4xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +} + +// ----- + +flow.executable @reduction_ex_reduce_0_dim_0 { + // CHECK-LABEL: flow.reduction.entry @reduction_rgn_reduce_0_dim_0_entry + // CHECK-SAME: workgroup_size = dense<[32, 1, 1]> : tensor<3xi32> + // CHECK-SAME: workload = dense<[4, 1, 1]> : tensor<3xi32> + flow.reduction.entry @reduction_rgn_reduce_0_dim_0_entry apply(@reduction_rgn_reduce_0_dim_0) attributes {dimension = 1 : i32} + module { + func @reduction_rgn_reduce_0_dim_0_entry(tensor<4x8xf32>, tensor<f32>) -> tensor<4xf32> + func @reduction_rgn_reduce_0_dim_0(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> { + %0 = xla_hlo.add %arg0, %arg1 : tensor<f32> + return %0 : tensor<f32> + } + } +} +func @reduction(%arg0: tensor<4x8xf32>) -> tensor<4xf32> { + %cst = constant dense<0.000000e+00> : tensor<f32> + %cst_0 = constant dense<[4, 1, 1]> : tensor<3xi32> + %0 = flow.dispatch @reduction_ex_reduce_0_dim_0::@reduction_rgn_reduce_0_dim_0_entry[%cst_0 : tensor<3xi32>](%arg0, %cst) : (tensor<4x8xf32>, tensor<f32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +}
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/fold_compatible_dispatch_regions.mlir b/iree/compiler/Dialect/Flow/Transforms/test/fold_compatible_dispatch_regions.mlir new file mode 100644 index 0000000..2c43c88 --- /dev/null +++ b/iree/compiler/Dialect/Flow/Transforms/test/fold_compatible_dispatch_regions.mlir
@@ -0,0 +1,100 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// RUN: iree-opt -split-input-file -iree-flow-fold-compatible-dispatch-regions %s | FileCheck %s --dump-input=fail + +func @noFolding(%arg0 : tensor<4xf32>) -> tensor<4xf32> { + %cst = constant dense<[4, 1, 1]> : tensor<3xi32> + %0 = flow.dispatch.region[%cst : tensor<3xi32>](%arg1 = %arg0 : tensor<4xf32>) : tensor<4xf32> { + %1 = xla_hlo.add %arg1, %arg1 : tensor<4xf32> + flow.return %1 : tensor<4xf32> + } + return %0 : tensor<4xf32> +} + +// CHECK-LABEL: func @noFolding +// CHECK-NEXT: %cst = constant dense<[4, 1, 1]> : tensor<3xi32> +// CHECK-NEXT: %0 = flow.dispatch.region[%cst : tensor<3xi32>](%arg1 = %arg0 : tensor<4xf32>) : tensor<4xf32> { +// CHECK-NEXT: %1 = xla_hlo.add %arg1, %arg1 : tensor<4xf32> +// CHECK-NEXT: flow.return %1 : tensor<4xf32> +// CHECK-NEXT: } +// CHECK-NEXT: return %0 : tensor<4xf32> + +// ----- + +func @elementwiseOps(%arg0 : tensor<4xf32>) -> tensor<4xf32> { + %cst = constant dense<[4, 1, 1]> : tensor<3xi32> + %0 = flow.dispatch.region[%cst : tensor<3xi32>](%arg1 = %arg0 : tensor<4xf32>) : tensor<4xf32> { + %1 = xla_hlo.add %arg1, %arg1 : tensor<4xf32> + flow.return %1 : tensor<4xf32> + } + %2 = flow.dispatch.region[%cst : tensor<3xi32>](%arg2 = %arg0 : tensor<4xf32>, %arg3 = %0 : tensor<4xf32>) : tensor<4xf32> { + %3 = xla_hlo.sub %arg3, %arg2 : tensor<4xf32> + flow.return %3 : tensor<4xf32> + } + %4 = flow.dispatch.region[%cst : tensor<3xi32>](%arg4 = %arg0 : tensor<4xf32>, %arg5 = %2 : tensor<4xf32>) : tensor<4xf32> { + %5 = xla_hlo.mul %arg4, %arg5 : tensor<4xf32> + flow.return %5 : tensor<4xf32> + } + return %4 : tensor<4xf32> +} + +// CHECK-LABEL: func @elementwiseOps +// CHECK-NEXT: %cst = constant dense<[4, 1, 1]> +// CHECK-NEXT: %0 = flow.dispatch.region[%cst : tensor<3xi32>](%arg1 = %arg0 : tensor<4xf32>) : tensor<4xf32> { +// CHECK-NEXT: %1 = xla_hlo.add %arg1, %arg1 : tensor<4xf32> +// CHECK-NEXT: %2 = xla_hlo.sub %1, %arg1 : tensor<4xf32> +// CHECK-NEXT: %3 = xla_hlo.mul %arg1, %2 : tensor<4xf32> +// CHECK-NEXT: flow.return %3 : tensor<4xf32> +// CHECK-NEXT: } +// CHECK-NEXT: return %0 : tensor<4xf32> + +// ----- + +func @interleavedDot(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { + %cst = constant dense<[4, 4, 1]> : tensor<3xi32> + %0 = flow.dispatch.region[%cst : tensor<3xi32>](%arg1 = %arg0 : tensor<4x4xf32>) : tensor<4x4xf32> { + %3 = xla_hlo.add %arg1, %arg1 : tensor<4x4xf32> + flow.return %3 : tensor<4x4xf32> + } + %cst_0 = constant dense<[4, 4, 1]> : tensor<3xi32> + %1 = flow.dispatch.region[%cst_0 : tensor<3xi32>](%arg1 = %0 : tensor<4x4xf32>, %arg2 = %arg0 : tensor<4x4xf32>) : tensor<4x4xf32> { + %3 = "xla_hlo.dot"(%arg1, %arg2) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + flow.return %3 : tensor<4x4xf32> + } + %cst_1 = constant dense<[4, 4, 1]> : tensor<3xi32> + %2 = flow.dispatch.region[%cst_1 : tensor<3xi32>](%arg1 = %1 : tensor<4x4xf32>, %arg2 = %arg0 : tensor<4x4xf32>) : tensor<4x4xf32> { + %3 = xla_hlo.mul %arg1, %arg2 : tensor<4x4xf32> + flow.return %3 : tensor<4x4xf32> + } + return %2 : tensor<4x4xf32> +} + +// CHECK-LABEL: func @interleavedDot +// CHECK-NEXT: %cst = constant dense<[4, 4, 1]> : tensor<3xi32> +// CHECK-NEXT: %0 = flow.dispatch.region[%cst : tensor<3xi32>](%arg1 = %arg0 : tensor<4x4xf32>) : tensor<4x4xf32> { +// CHECK-NEXT: %3 = xla_hlo.add %arg1, %arg1 : tensor<4x4xf32> +// CHECK-NEXT: flow.return %3 : tensor<4x4xf32> +// CHECK-NEXT: } +// CHECK-NEXT: %cst_0 = constant dense<[4, 4, 1]> : tensor<3xi32> +// CHECK-NEXT: %1 = flow.dispatch.region[%cst_0 : tensor<3xi32>](%arg1 = %0 : tensor<4x4xf32>, %arg2 = %arg0 : tensor<4x4xf32>) : tensor<4x4xf32> { +// CHECK-NEXT: %3 = "xla_hlo.dot"(%arg1, %arg2) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> +// CHECK-NEXT: flow.return %3 : tensor<4x4xf32> +// CHECK-NEXT: } +// CHECK-NEXT: %cst_1 = constant dense<[4, 4, 1]> : tensor<3xi32> +// CHECK-NEXT: %2 = flow.dispatch.region[%cst_1 : tensor<3xi32>](%arg1 = %1 : tensor<4x4xf32>, %arg2 = %arg0 : tensor<4x4xf32>) : tensor<4x4xf32> { +// CHECK-NEXT: %3 = xla_hlo.mul %arg1, %arg2 : tensor<4x4xf32> +// CHECK-NEXT: flow.return %3 : tensor<4x4xf32> +// CHECK-NEXT: } +// CHECK-NEXT: return %2 : tensor<4x4xf32>
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions.mlir b/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions.mlir new file mode 100644 index 0000000..f5b1afe --- /dev/null +++ b/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions.mlir
@@ -0,0 +1,142 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// RUN: iree-opt -split-input-file -iree-flow-dispatchability-analysis -iree-flow-identify-dispatch-regions %s | FileCheck %s --dump-input=always + +// CHECK-LABEL: @empty +func @empty() { + // CHECK-NEXT: return + return +} + +// ----- + +// CHECK-LABEL: @simpleMath +func @simpleMath(%arg0 : tensor<4xf32>) -> tensor<4xf32> { + // CHECK-NEXT: constant dense<[4, 1, 1]> + // CHECK-NEXT: %0 = flow.dispatch.region + // CHECK-SAME: [%cst : tensor<3xi32>] + // CHECK-SAME: (%arg1 = %arg0 : tensor<4xf32>) : tensor<4xf32> { + // CHECK-NEXT: %1 = xla_hlo.add %arg1, %arg1 : tensor<4xf32> + %0 = xla_hlo.add %arg0, %arg0 : tensor<4xf32> + // CHECK-NEXT: flow.return %1 : tensor<4xf32> + // CHECK-NEXT: } + // CHECK-NEXT: return %0 : tensor<4xf32> + return %0 : tensor<4xf32> +} + +// ----- + +// CHECK-LABEL: @stdElementwiseOps +func @stdElementwiseOps(%arg0 : tensor<4xf32>) -> tensor<4xf32> { + // CHECK-NEXT: constant dense<[4, 1, 1]> + // CHECK-NEXT: %0 = flow.dispatch.region + // CHECK-SAME: [%cst : tensor<3xi32>] + // CHECK-SAME: (%arg1 = %arg0 : tensor<4xf32>) : tensor<4xf32> { + // CHECK-NEXT: %1 = addf %arg1, %arg1 : tensor<4xf32> + %0 = addf %arg0, %arg0 : tensor<4xf32> + // CHECK-NEXT: %2 = subf %1, %arg1 : tensor<4xf32> + %1 = subf %0, %arg0 : tensor<4xf32> + // CHECK-NEXT: %3 = mulf %2, %arg1 : tensor<4xf32> + %2 = mulf %1, %arg0 : tensor<4xf32> + // CHECK-NEXT: flow.return %3 : tensor<4xf32> + // CHECK-NEXT: } + // CHECK-NEXT: return %0 : tensor<4xf32> + return %2 : tensor<4xf32> +} + +// ----- + +// CHECK-LABEL: @hloElementwiseOps +func @hloElementwiseOps(%arg0 : tensor<4xf32>) -> tensor<4xf32> { + // CHECK-NEXT: constant dense<[4, 1, 1]> + // CHECK-NEXT: %0 = flow.dispatch.region + // CHECK-SAME: [%cst : tensor<3xi32>] + // CHECK-SAME: (%arg1 = %arg0 : tensor<4xf32>) : tensor<4xf32> { + // CHECK-NEXT: %1 = xla_hlo.add %arg1, %arg1 : tensor<4xf32> + %0 = xla_hlo.add %arg0, %arg0 : tensor<4xf32> + // CHECK-NEXT: %2 = xla_hlo.sub %1, %arg1 : tensor<4xf32> + %1 = xla_hlo.sub %0, %arg0 : tensor<4xf32> + // CHECK-NEXT: %3 = xla_hlo.mul %2, %arg1 : tensor<4xf32> + %2 = xla_hlo.mul %1, %arg0 : tensor<4xf32> + // CHECK-NEXT: flow.return %3 : tensor<4xf32> + // CHECK-NEXT: } + // CHECK-NEXT: return %0 : tensor<4xf32> + return %2 : tensor<4xf32> +} + +// ----- + +// CHECK-LABEL: @interleavedDot +func @interleavedDot(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { + // CHECK-NEXT: %cst = constant dense<[4, 4, 1]> + // CHECK-NEXT: %0 = flow.dispatch.region + // CHECK-SAME: [%cst : tensor<3xi32>] + // CHECK-SAME: (%arg1 = %arg0 : tensor<4x4xf32>) : tensor<4x4xf32> { + // CHECK-NEXT: %3 = xla_hlo.add %arg1, %arg1 : tensor<4x4xf32> + %0 = xla_hlo.add %arg0, %arg0 : tensor<4x4xf32> + // CHECK-NEXT: flow.return %3 : tensor<4x4xf32> + // CHECK-NEXT: } + // CHECK-NEXT: %cst_0 = constant dense<[4, 4, 1]> : tensor<3xi32> + // CHECK-NEXT: %1 = flow.dispatch.region + // CHECK-SAME: [%cst_0 : tensor<3xi32>] + // CHECK-SAME: (%arg1 = %0 : tensor<4x4xf32>, %arg2 = %arg0 : tensor<4x4xf32>) : tensor<4x4xf32> { + // CHECK-NEXT: %3 = "xla_hlo.dot"(%arg1, %arg2) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + %1 = "xla_hlo.dot"(%0, %arg0) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + // CHECK-NEXT: flow.return %3 : tensor<4x4xf32> + // CHECK-NEXT: } + // CHECK-NEXT: %cst_1 = constant dense<[4, 4, 1]> : tensor<3xi32> + // CHECK-NEXT: %2 = flow.dispatch.region + // CHECK-SAME: [%cst_1 : tensor<3xi32>] + // CHECK-SAME: (%arg1 = %1 : tensor<4x4xf32>, %arg2 = %arg0 : tensor<4x4xf32>) : tensor<4x4xf32> { + // CHECK-NEXT: %3 = xla_hlo.mul %arg1, %arg2 : tensor<4x4xf32> + %2 = xla_hlo.mul %1, %arg0 : tensor<4x4xf32> + // CHECK-NEXT: flow.return %3 : tensor<4x4xf32> + // CHECK-NEXT: } + // CHECK-NEXT: return %2 : tensor<4x4xf32> + return %2 : tensor<4x4xf32> +} + +// ----- + +// CHECK-LABEL: func @caller +func @caller(%arg0 : tensor<4xf32>) -> tensor<4xf32> { + // CHECK-NEXT: constant dense<[4, 1, 1]> + // CHECK-NEXT: %0 = flow.dispatch.region + // CHECK-SAME: [%cst : tensor<3xi32>] + // CHECK-SAME: (%arg1 = %arg0 : tensor<4xf32>) : tensor<4xf32> { + // CHECK-NEXT: %1 = xla_hlo.add %arg1, %arg1 : tensor<4xf32> + %0 = xla_hlo.add %arg0, %arg0 : tensor<4xf32> + // CHECK-NEXT: %2 = call @callee(%1) : (tensor<4xf32>) -> tensor<4xf32> + %1 = call @callee(%0) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: %3 = xla_hlo.mul %2, %arg1 : tensor<4xf32> + %2 = xla_hlo.mul %1, %arg0 : tensor<4xf32> + // CHECK-NEXT: flow.return %3 : tensor<4xf32> + // CHECK-NEXT: } + // CHECK-NEXT: return %0 : tensor<4xf32> + return %2 : tensor<4xf32> +} +// CHECK-LABEL: func @callee +func @callee(%arg0 : tensor<4xf32>) -> tensor<4xf32> { + // CHECK-NEXT: constant dense<[4, 1, 1]> + // CHECK-NEXT: %0 = flow.dispatch.region + // CHECK-SAME: [%cst : tensor<3xi32>] + // CHECK-SAME: (%arg1 = %arg0 : tensor<4xf32>) : tensor<4xf32> { + // CHECK-NEXT: %1 = xla_hlo.mul %arg1, %arg1 : tensor<4xf32> + %0 = xla_hlo.mul %arg0, %arg0 : tensor<4xf32> + // CHECK-NEXT: flow.return %1 : tensor<4xf32> + // CHECK-NEXT: } + // CHECK-NEXT: return %0 : tensor<4xf32> + return %0 : tensor<4xf32> +}
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/identify_reduction_regions.mlir b/iree/compiler/Dialect/Flow/Transforms/test/identify_reduction_regions.mlir new file mode 100644 index 0000000..8b2bbf3 --- /dev/null +++ b/iree/compiler/Dialect/Flow/Transforms/test/identify_reduction_regions.mlir
@@ -0,0 +1,64 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// RUN: iree-opt -split-input-file -iree-flow-identify-reduction-regions %s | FileCheck %s --dump-input=fail + +// CHECK-LABEL: @single_reduction +func @single_reduction(%arg0 : tensor<4x8xf32>) -> tensor<4xf32> { + %0 = constant dense<0.0> : tensor<f32> + // CHECK: constant dense<[4, 1, 1]> + // CHECK-NEXT: %0 = flow.reduction.region + // CHECK-SAME: [%cst_0 : tensor<3xi32>] + // CHECK-SAME: (%arg0) : (tensor<4x8xf32>) -> (tensor<4xf32>) + %1 = "xla_hlo.reduce"(%arg0, %0) ( { + // CHECK-NEXT: invocation((%arg1, %arg2) = %cst : tensor<f32>) { + ^bb0(%arg1 : tensor<f32>, %arg2 : tensor<f32>): + // CHECK-NEXT: %1 = xla_hlo.add %arg1, %arg2 : tensor<f32> + %2 = xla_hlo.add %arg1, %arg2 : tensor<f32> + // CHECK-NEXT: flow.return %1 : tensor<f32> + "xla_hlo.return"(%2) : (tensor<f32>) -> () + // CHECK-NEXT: } {dimensions = dense<1> : tensor<1xi32>} + }) {dimensions = dense<[1]> : tensor<1xi64>} : (tensor<4x8xf32>, tensor<f32>) -> tensor<4xf32> + // CHECK-NEXT: return %0 : tensor<4xf32> + return %1 : tensor<4xf32> +} + +// ----- + +// CHECK-LABEL: @multi_reduction +func @multi_reduction(%arg0 : tensor<4x8xf32>, %arg1 : tensor<4x8xf32>) -> (tensor<4xf32>, tensor<4xf32>) { + %0 = constant dense<0.0> : tensor<f32> + %1 = constant dense<1.0> : tensor<f32> + // CHECK: constant dense<[4, 1, 1]> + // CHECK-NEXT: %0:2 = flow.reduction.region + // CHECK-SAME: [%cst_1 : tensor<3xi32>] + // CHECK-SAME: (%arg0, %arg1) : (tensor<4x8xf32>, tensor<4x8xf32>) -> (tensor<4xf32>, tensor<4xf32>) + %2, %3 = "xla_hlo.reduce"(%arg0, %arg1, %0, %1) ( { + // CHECK-NEXT: invocation((%arg2, %arg3) = %cst : tensor<f32>, (%arg4, %arg5) = %cst_0 : tensor<f32>) { + ^bb0(%arg0_lhs : tensor<f32>, %arg1_lhs : tensor<f32>, %arg0_rhs : tensor<f32>, %arg1_rhs : tensor<f32>): + // CHECK-NEXT: %1 = xla_hlo.add %arg2, %arg4 : tensor<f32> + %4 = xla_hlo.add %arg0_lhs, %arg0_rhs : tensor<f32> + // CHECK-NEXT: %2 = xla_hlo.add %arg3, %arg5 : tensor<f32> + %5 = xla_hlo.add %arg1_lhs, %arg1_rhs : tensor<f32> + // CHECK-NEXT: flow.return %1, %2 : tensor<f32>, tensor<f32> + "xla_hlo.return"(%4, %5) : (tensor<f32>, tensor<f32>) -> () + // CHECK-NEXT: } {dimensions = dense<1> : tensor<1xi32>} + }) {dimensions = dense<[1]> : tensor<1xi64>} : (tensor<4x8xf32>, tensor<4x8xf32>, tensor<f32>, tensor<f32>) -> (tensor<4xf32>, tensor<4xf32>) + // CHECK-NEXT: return %0#0, %0#1 : tensor<4xf32>, tensor<4xf32> + return %2, %3 : tensor<4xf32>, tensor<4xf32> +} + +// ----- + +// TODO(benvanik): windowed reduction.
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/rematerialize_dispatch_constants.mlir b/iree/compiler/Dialect/Flow/Transforms/test/rematerialize_dispatch_constants.mlir new file mode 100644 index 0000000..d399e15 --- /dev/null +++ b/iree/compiler/Dialect/Flow/Transforms/test/rematerialize_dispatch_constants.mlir
@@ -0,0 +1,61 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// RUN: iree-opt -split-input-file -iree-flow-rematerialize-dispatch-constants %s | FileCheck %s --dump-input=always + +// CHECK-LABEL: func @rematerializeSmall +func @rematerializeSmall(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { + %cst = constant dense<[4, 4, 1]> : tensor<3xi32> + %small = constant dense<1.23> : tensor<4x4xf32> + // CHECK: %0 = flow.dispatch.region[%cst : tensor<3xi32>](%arg1 = %arg0 : tensor<4x4xf32>) : tensor<4x4xf32> { + %0 = flow.dispatch.region[%cst : tensor<3xi32>](%arg1 = %arg0 : tensor<4x4xf32>, %arg2 = %small : tensor<4x4xf32>) : tensor<4x4xf32> { + // CHECK-NEXT: %cst_0 = constant dense<1.230000e+00> : tensor<4x4xf32> + // CHECK-NEXT: %1 = xla_hlo.add %arg1, %cst_0 : tensor<4x4xf32> + %3 = xla_hlo.add %arg1, %arg2 : tensor<4x4xf32> + flow.return %3 : tensor<4x4xf32> + } + return %0 : tensor<4x4xf32> +} + +// ----- + +// CHECK-LABEL: func @noRematerializeLarge +func @noRematerializeLarge(%arg0 : tensor<4096x4xf32>) -> tensor<4096x4xf32> { + %cst = constant dense<[4, 4, 1]> : tensor<3xi32> + // CHECK: %cst_0 = constant dense<1.230000e+00> : tensor<4096x4xf32> + %large = constant dense<1.23> : tensor<4096x4xf32> + // CHECK-NEXT: %0 = flow.dispatch.region[%cst : tensor<3xi32>](%arg1 = %arg0 : tensor<4096x4xf32>, %arg2 = %cst_0 : tensor<4096x4xf32>) : tensor<4096x4xf32> { + %0 = flow.dispatch.region[%cst : tensor<3xi32>](%arg1 = %arg0 : tensor<4096x4xf32>, %arg2 = %large : tensor<4096x4xf32>) : tensor<4096x4xf32> { + // CHECK-NEXT: %1 = xla_hlo.add %arg1, %arg2 : tensor<4096x4xf32> + %3 = xla_hlo.add %arg1, %arg2 : tensor<4096x4xf32> + flow.return %3 : tensor<4096x4xf32> + } + return %0 : tensor<4096x4xf32> +} + +// ----- + +// CHECK-LABEL: func @noRematerializeIntoDot +func @noRematerializeIntoDot(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { + %cst = constant dense<[4, 4, 1]> : tensor<3xi32> + // CHECK: %cst_0 = constant dense<1.230000e+00> : tensor<4x4xf32> + %small = constant dense<1.23> : tensor<4x4xf32> + // CHECK-NEXT: %0 = flow.dispatch.region[%cst : tensor<3xi32>](%arg1 = %arg0 : tensor<4x4xf32>, %arg2 = %cst_0 : tensor<4x4xf32>) : tensor<4x4xf32> { + %0 = flow.dispatch.region[%cst : tensor<3xi32>](%arg1 = %arg0 : tensor<4x4xf32>, %arg2 = %small : tensor<4x4xf32>) : tensor<4x4xf32> { + // CHECK-NEXT: %1 = "xla_hlo.dot"(%arg1, %arg2) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + %3 = "xla_hlo.dot"(%arg1, %arg2) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + flow.return %3 : tensor<4x4xf32> + } + return %0 : tensor<4x4xf32> +}
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/transformation.mlir b/iree/compiler/Dialect/Flow/Transforms/test/transformation.mlir new file mode 100644 index 0000000..bffe99a --- /dev/null +++ b/iree/compiler/Dialect/Flow/Transforms/test/transformation.mlir
@@ -0,0 +1,233 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// RUN: iree-opt -split-input-file -iree-flow-transformation-pipeline %s | FileCheck %s --dump-input=fail + +// CHECK-LABEL: @empty +func @empty() { + // CHECK-NEXT: return + return +} + +// ----- + +func @simpleMath(%arg0 : tensor<4xf32>) -> tensor<4xf32> { + %0 = xla_hlo.add %arg0, %arg0 : tensor<4xf32> + return %0 : tensor<4xf32> +} + +// CHECK-LABEL: flow.executable @simpleMath_ex_dispatch_0 { +// CHECK-NEXT: flow.dispatch.entry @simpleMath_rgn_dispatch_0 +// CHECK-SAME: workgroup_size = dense<[32, 1, 1]> : tensor<3xi32> +// CHECK-SAME: workload = dense<[4, 1, 1]> : tensor<3xi32> +// CHECK-NEXT: module { +// CHECK-NEXT: func @simpleMath_rgn_dispatch_0(%arg0: tensor<4xf32>) -> tensor<4xf32> { +// CHECK-NEXT: %0 = xla_hlo.add %arg0, %arg0 : tensor<4xf32> +// CHECK-NEXT: return %0 : tensor<4xf32> +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: func @simpleMath(%arg0: tensor<4xf32>) -> tensor<4xf32> { +// CHECK-NEXT: %cst = constant dense<[4, 1, 1]> : tensor<3xi32> +// CHECK-NEXT: %0 = flow.dispatch @simpleMath_ex_dispatch_0::@simpleMath_rgn_dispatch_0[%cst : tensor<3xi32>](%arg0) : (tensor<4xf32>) -> tensor<4xf32> +// CHECK-NEXT: return %0 : tensor<4xf32> +// CHECK-NEXT: } + +// ----- + +func @stdElementwiseOps(%arg0 : tensor<4xf32>) -> tensor<4xf32> { + %0 = addf %arg0, %arg0 : tensor<4xf32> + %1 = subf %0, %arg0 : tensor<4xf32> + %2 = mulf %1, %arg0 : tensor<4xf32> + return %2 : tensor<4xf32> +} + +// CHECK-LABEL: flow.executable @stdElementwiseOps_ex_dispatch_0 { +// CHECK-NEXT: flow.dispatch.entry @stdElementwiseOps_rgn_dispatch_0 +// CHECK-SAME: workgroup_size = dense<[32, 1, 1]> : tensor<3xi32> +// CHECK-SAME: workload = dense<[4, 1, 1]> : tensor<3xi32> +// CHECK-NEXT: module { +// CHECK-NEXT: func @stdElementwiseOps_rgn_dispatch_0(%arg0: tensor<4xf32>) -> tensor<4xf32> { +// CHECK-NEXT: %0 = addf %arg0, %arg0 : tensor<4xf32> +// CHECK-NEXT: %1 = subf %0, %arg0 : tensor<4xf32> +// CHECK-NEXT: %2 = mulf %1, %arg0 : tensor<4xf32> +// CHECK-NEXT: return %2 : tensor<4xf32> +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: func @stdElementwiseOps(%arg0: tensor<4xf32>) -> tensor<4xf32> { +// CHECK-NEXT: %cst = constant dense<[4, 1, 1]> : tensor<3xi32> +// CHECK-NEXT: %0 = flow.dispatch @stdElementwiseOps_ex_dispatch_0::@stdElementwiseOps_rgn_dispatch_0[%cst : tensor<3xi32>](%arg0) : (tensor<4xf32>) -> tensor<4xf32> +// CHECK-NEXT: return %0 : tensor<4xf32> +// CHECK-NEXT: } + +// ----- + +func @hloElementwiseOps(%arg0 : tensor<4xf32>) -> tensor<4xf32> { + %0 = xla_hlo.add %arg0, %arg0 : tensor<4xf32> + %1 = xla_hlo.sub %0, %arg0 : tensor<4xf32> + %2 = xla_hlo.mul %1, %arg0 : tensor<4xf32> + return %2 : tensor<4xf32> +} + +// CHECK-LABEL: flow.executable @hloElementwiseOps_ex_dispatch_0 { +// CHECK-NEXT: flow.dispatch.entry @hloElementwiseOps_rgn_dispatch_0 +// CHECK-SAME: workgroup_size = dense<[32, 1, 1]> : tensor<3xi32> +// CHECK-SAME: workload = dense<[4, 1, 1]> : tensor<3xi32> +// CHECK-NEXT: module { +// CHECK-NEXT: func @hloElementwiseOps_rgn_dispatch_0(%arg0: tensor<4xf32>) -> tensor<4xf32> { +// CHECK-NEXT: %0 = xla_hlo.add %arg0, %arg0 : tensor<4xf32> +// CHECK-NEXT: %1 = xla_hlo.sub %0, %arg0 : tensor<4xf32> +// CHECK-NEXT: %2 = xla_hlo.mul %1, %arg0 : tensor<4xf32> +// CHECK-NEXT: return %2 : tensor<4xf32> +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: func @hloElementwiseOps(%arg0: tensor<4xf32>) -> tensor<4xf32> { +// CHECK-NEXT: %cst = constant dense<[4, 1, 1]> : tensor<3xi32> +// CHECK-NEXT: %0 = flow.dispatch @hloElementwiseOps_ex_dispatch_0::@hloElementwiseOps_rgn_dispatch_0[%cst : tensor<3xi32>](%arg0) : (tensor<4xf32>) -> tensor<4xf32> +// CHECK-NEXT: return %0 : tensor<4xf32> +// CHECK-NEXT: } + +// ----- + +func @interleavedDot(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { + %0 = xla_hlo.add %arg0, %arg0 : tensor<4x4xf32> + %1 = "xla_hlo.dot"(%0, %arg0) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + %2 = xla_hlo.mul %1, %arg0 : tensor<4x4xf32> + return %2 : tensor<4x4xf32> +} + +// CHECK-LABEL: flow.executable @interleavedDot_ex_dispatch_0 { +// CHECK-NEXT: flow.dispatch.entry @interleavedDot_rgn_dispatch_0 +// CHECK-SAME: workgroup_size = dense<[32, 1, 1]> : tensor<3xi32> +// CHECK-SAME: workload = dense<[4, 4, 1]> : tensor<3xi32> +// CHECK-NEXT: module { +// CHECK-NEXT: func @interleavedDot_rgn_dispatch_0(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> { +// CHECK-NEXT: %0 = xla_hlo.add %arg0, %arg0 : tensor<4x4xf32> +// CHECK-NEXT: return %0 : tensor<4x4xf32> +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: flow.executable @interleavedDot_ex_dispatch_1 { +// CHECK-NEXT: flow.dispatch.entry @interleavedDot_rgn_dispatch_1 +// CHECK-SAME: workgroup_size = dense<[32, 1, 1]> : tensor<3xi32> +// CHECK-SAME: workload = dense<[4, 4, 1]> : tensor<3xi32> +// CHECK-NEXT: module { +// CHECK-NEXT: func @interleavedDot_rgn_dispatch_1(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>) -> tensor<4x4xf32> { +// CHECK-NEXT: %0 = "xla_hlo.dot"(%arg0, %arg1) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> +// CHECK-NEXT: return %0 : tensor<4x4xf32> +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: flow.executable @interleavedDot_ex_dispatch_2 { +// CHECK-NEXT: flow.dispatch.entry @interleavedDot_rgn_dispatch_2 +// CHECK-SAME: workgroup_size = dense<[32, 1, 1]> : tensor<3xi32> +// CHECK-SAME: workload = dense<[4, 4, 1]> : tensor<3xi32> +// CHECK-NEXT: module { +// CHECK-NEXT: func @interleavedDot_rgn_dispatch_2(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>) -> tensor<4x4xf32> { +// CHECK-NEXT: %0 = xla_hlo.mul %arg0, %arg1 : tensor<4x4xf32> +// CHECK-NEXT: return %0 : tensor<4x4xf32> +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: func @interleavedDot(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> { +// CHECK-NEXT: %cst = constant dense<[4, 4, 1]> : tensor<3xi32> +// CHECK-NEXT: %0 = flow.dispatch @interleavedDot_ex_dispatch_0::@interleavedDot_rgn_dispatch_0[%cst : tensor<3xi32>](%arg0) : (tensor<4x4xf32>) -> tensor<4x4xf32> +// CHECK-NEXT: %1 = flow.dispatch @interleavedDot_ex_dispatch_1::@interleavedDot_rgn_dispatch_1[%cst : tensor<3xi32>](%0, %arg0) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> +// CHECK-NEXT: %2 = flow.dispatch @interleavedDot_ex_dispatch_2::@interleavedDot_rgn_dispatch_2[%cst : tensor<3xi32>](%1, %arg0) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> +// CHECK-NEXT: return %2 : tensor<4x4xf32> +// CHECK-NEXT: } + +// ----- + +func @caller(%arg0 : tensor<4xf32>) -> tensor<4xf32> { + %0 = xla_hlo.add %arg0, %arg0 : tensor<4xf32> + %1 = call @callee(%0) : (tensor<4xf32>) -> tensor<4xf32> + %2 = xla_hlo.mul %1, %arg0 : tensor<4xf32> + return %2 : tensor<4xf32> +} +func @callee(%arg0 : tensor<4xf32>) -> tensor<4xf32> { + %0 = xla_hlo.mul %arg0, %arg0 : tensor<4xf32> + return %0 : tensor<4xf32> +} + +// CHECK-LABEL: flow.executable @caller_ex_dispatch_0 { +// CHECK-NEXT: flow.dispatch.entry @caller_rgn_dispatch_0 +// CHECK-SAME: workgroup_size = dense<[32, 1, 1]> : tensor<3xi32> +// CHECK-SAME: workload = dense<[4, 1, 1]> : tensor<3xi32> +// CHECK-NEXT: module { +// CHECK-NEXT: func @caller_rgn_dispatch_0(%arg0: tensor<4xf32>) -> tensor<4xf32> { +// CHECK-NEXT: %0 = xla_hlo.add %arg0, %arg0 : tensor<4xf32> +// CHECK-NEXT: %1 = call @callee(%0) : (tensor<4xf32>) -> tensor<4xf32> +// CHECK-NEXT: %2 = xla_hlo.mul %1, %arg0 : tensor<4xf32> +// CHECK-NEXT: return %2 : tensor<4xf32> +// CHECK-NEXT: } +// CHECK-NEXT: func @callee(%arg0: tensor<4xf32>) -> tensor<4xf32> { +// CHECK-NEXT: %0 = xla_hlo.mul %arg0, %arg0 : tensor<4xf32> +// CHECK-NEXT: return %0 : tensor<4xf32> +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: func @caller(%arg0: tensor<4xf32>) -> tensor<4xf32> { +// CHECK-NEXT: %cst = constant dense<[4, 1, 1]> : tensor<3xi32> +// CHECK-NEXT: %0 = flow.dispatch @caller_ex_dispatch_0::@caller_rgn_dispatch_0[%cst : tensor<3xi32>](%arg0) : (tensor<4xf32>) -> tensor<4xf32> +// CHECK-NEXT: return %0 : tensor<4xf32> +// CHECK-NEXT: } +// CHECK-NEXT: flow.executable @callee_ex_dispatch_0 { +// CHECK-NEXT: flow.dispatch.entry @callee_rgn_dispatch_0 +// CHECK-NEXT: module { +// CHECK-NEXT: func @callee_rgn_dispatch_0(%arg0: tensor<4xf32>) -> tensor<4xf32> { +// CHECK-NEXT: %0 = xla_hlo.mul %arg0, %arg0 : tensor<4xf32> +// CHECK-NEXT: return %0 : tensor<4xf32> +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: func @callee(%arg0: tensor<4xf32>) -> tensor<4xf32> { +// CHECK-NEXT: %cst = constant dense<[4, 1, 1]> : tensor<3xi32> +// CHECK-NEXT: %0 = flow.dispatch @callee_ex_dispatch_0::@callee_rgn_dispatch_0[%cst : tensor<3xi32>](%arg0) : (tensor<4xf32>) -> tensor<4xf32> +// CHECK-NEXT: return %0 : tensor<4xf32> +// CHECK-NEXT: } + +// ----- + +func @reduction(%arg0 : tensor<4x8xf32>) -> tensor<4xf32> { + %0 = constant dense<0.0> : tensor<f32> + %1 = "xla_hlo.reduce"(%arg0, %0) ( { + ^bb0(%arg1 : tensor<f32>, %arg2 : tensor<f32>): + %2 = xla_hlo.add %arg1, %arg2 : tensor<f32> + "xla_hlo.return"(%2) : (tensor<f32>) -> () + }) {dimensions = dense<[1]> : tensor<1xi64>} : (tensor<4x8xf32>, tensor<f32>) -> tensor<4xf32> + return %1 : tensor<4xf32> +} + +// CHECK-LABEL: flow.executable @reduction_ex_reduce_0_dim_0 { +// CHECK-NEXT: flow.reduction.entry @reduction_rgn_reduce_0_dim_0_entry apply(@reduction_rgn_reduce_0_dim_0) +// CHECK-SAME: dimension = 1 : i32 +// CHECK-SAME: workgroup_size = dense<[32, 1, 1]> : tensor<3xi32> +// CHECK-SAME: workload = dense<[4, 1, 1]> : tensor<3xi32> +// CHECK-NEXT: module { +// CHECK-NEXT: func @reduction_rgn_reduce_0_dim_0_entry(tensor<4x8xf32>, tensor<f32>) -> tensor<4xf32> +// CHECK-NEXT: func @reduction_rgn_reduce_0_dim_0(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> { +// CHECK-NEXT: %0 = xla_hlo.add %arg0, %arg1 : tensor<f32> +// CHECK-NEXT: return %0 : tensor<f32> +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: func @reduction(%arg0: tensor<4x8xf32>) -> tensor<4xf32> { +// CHECK-NEXT: %cst = constant dense<0.000000e+00> : tensor<f32> +// CHECK-NEXT: %cst_0 = constant dense<[4, 1, 1]> : tensor<3xi32> +// CHECK-NEXT: %0 = flow.dispatch @reduction_ex_reduce_0_dim_0::@reduction_rgn_reduce_0_dim_0_entry[%cst_0 : tensor<3xi32>](%arg0, %cst) : (tensor<4x8xf32>, tensor<f32>) -> tensor<4xf32> +// CHECK-NEXT: return %0 : tensor<4xf32> +// CHECK-NEXT: }
diff --git a/iree/compiler/Dialect/Flow/Utils/BUILD b/iree/compiler/Dialect/Flow/Utils/BUILD new file mode 100644 index 0000000..64e78e4 --- /dev/null +++ b/iree/compiler/Dialect/Flow/Utils/BUILD
@@ -0,0 +1,38 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "Utils", + srcs = [ + "DispatchUtils.cpp", + "WorkloadUtils.cpp", + ], + hdrs = [ + "DispatchUtils.h", + "WorkloadUtils.h", + ], + deps = [ + "//iree/compiler/Dialect/Flow/IR", + "@llvm//:support", + "@local_config_mlir//:IR", + "@local_config_mlir//:StandardOps", + "@local_config_mlir//:Support", + ], + alwayslink = 1, +)
diff --git a/iree/compiler/Dialect/Flow/Utils/DispatchUtils.cpp b/iree/compiler/Dialect/Flow/Utils/DispatchUtils.cpp new file mode 100644 index 0000000..701b006 --- /dev/null +++ b/iree/compiler/Dialect/Flow/Utils/DispatchUtils.cpp
@@ -0,0 +1,208 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/compiler/Dialect/Flow/Utils/DispatchUtils.h" + +#include "iree/compiler/Dialect/Flow/IR/FlowOps.h" +#include "llvm/ADT/SetVector.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace Flow { + +namespace { + +// Returns the set of values that must be captured for use by |ops| and the +// set of values defined by |ops| that are used outside of the set. +LogicalResult analyzeOpRangeValues( + const llvm::SmallDenseSet<Operation *> &opSet, + llvm::SetVector<Value *> *capturedValues, + llvm::SetVector<Value *> *escapingValues) { + for (auto *op : opSet) { + for (auto *value : op->getOperands()) { + if (!llvm::is_contained(opSet, value->getDefiningOp())) { + // Op is using a value not in the ops set, ensure we capture it. + capturedValues->insert(value); + } + } + for (auto *value : op->getResults()) { + for (auto &use : value->getUses()) { + if (!llvm::is_contained(opSet, use.getOwner())) { + // An op outside of the ops set is using the value, needs to escape. + escapingValues->insert(value); + } + } + } + } + return success(); +} + +} // namespace + +LogicalResult buildDispatchRegion(FuncOp func, Block *parentBlock, + Value *workload, ArrayRef<Operation *> ops) { + // Fused location with all ops. + SmallVector<Location, 16> opLocs; + for (auto *op : ops) { + opLocs.push_back(op->getLoc()); + } + auto regionLoc = FusedLoc::get(opLocs, func.getContext()); + + // Get a list of values that we need to capture and values that escape the + // region and need to be returned. + llvm::SmallDenseSet<Operation *> opSet; + opSet.reserve(ops.size()); + opSet.insert(ops.begin(), ops.end()); + llvm::SetVector<Value *> capturedValues; + llvm::SetVector<Value *> escapingValues; + if (failed(analyzeOpRangeValues(opSet, &capturedValues, &escapingValues))) { + return failure(); + } + SmallVector<Type, 8> escapingTypes; + for (auto *value : escapingValues) escapingTypes.push_back(value->getType()); + + // Build the region op and add it to the parent block. + OpBuilder parentBuilder(parentBlock); + parentBuilder.setInsertionPoint(ops.back()); + auto dispatchRegionOp = parentBuilder.create<IREE::Flow::DispatchRegionOp>( + regionLoc, escapingTypes, workload, capturedValues.getArrayRef()); + + // Create the block and setup the arg mapping for captured values. + auto *regionBlock = new Block(); + dispatchRegionOp.body().push_back(regionBlock); + OpBuilder regionBuilder(regionBlock); + BlockAndValueMapping mapping; + for (auto *capturedValue : capturedValues) { + auto *blockArg = regionBlock->addArgument(capturedValue->getType()); + mapping.map(capturedValue, blockArg); + } + + // Clone ops into the new region block. + for (auto *op : ops) { + // Note that this updates the mapping with the new values (so at the end + // we have those new values). + regionBuilder.clone(*op, mapping); + } + + // Return results (as we need a terminator in our block). + // These are all of the values that escape our region. + SmallVector<Value *, 8> resultValues; + for (auto *oldValue : escapingValues) { + resultValues.push_back(mapping.lookupOrDefault(oldValue)); + } + regionBuilder.create<IREE::Flow::ReturnOp>(opLocs.back(), resultValues); + + // Replace usage of values with the results of the region. + for (int i = 0; i < escapingValues.size(); ++i) { + escapingValues[i]->replaceAllUsesWith(dispatchRegionOp.getResult(i)); + } + + // Remove original ops from the parent region. + for (auto it = ops.rbegin(); it != ops.rend(); ++it) { + (*it)->erase(); + } + + return success(); +} + +namespace { + +// Recursively finds all reachable functions from the given |rootFunc| and adds +// them to the |reachableFuncs| set. +// +// Note that indirect calls are not supported, however we don't allow those in +// dispatch regions anyway so they should not be present here. +LogicalResult findReachableFunctions( + FuncOp rootFuncOp, llvm::SetVector<FuncOp> &reachableFuncs, + llvm::StringMap<FuncOp> &dispatchableFuncOps) { + llvm::SetVector<FuncOp> worklist; + worklist.insert(rootFuncOp); + while (!worklist.empty()) { + auto funcOp = worklist.pop_back_val(); + funcOp.walk([&](CallOp callOp) { + auto calleeOp = dispatchableFuncOps.find(callOp.callee())->second; + if (reachableFuncs.insert(calleeOp)) { + worklist.insert(calleeOp); + } + }); + } + return success(); +} + +} // namespace + +std::pair<IREE::Flow::ExecutableOp, FuncOp> createRegionExecutable( + Operation *op, FunctionType functionType, StringRef symbolSuffix, + llvm::StringMap<FuncOp> &dispatchableFuncOps) { + // Create the function and take the region body directly. + // NOTE: this will get uniquified if we have multiple in the same block. + auto parentFunc = op->getParentOfType<FuncOp>(); + std::string functionName = + (parentFunc.getName().str() + "_rgn" + symbolSuffix).str(); + auto outlinedFunc = FuncOp::create(op->getLoc(), functionName, functionType); + BlockAndValueMapping mapping; + op->getRegion(0).cloneInto(&outlinedFunc.getBody(), mapping); + + // Replace flow.return with std.return. + for (auto &block : outlinedFunc.getBlocks()) { + if (auto returnOp = dyn_cast<IREE::Flow::ReturnOp>(block.back())) { + OpBuilder builder(returnOp); + builder.create<mlir::ReturnOp>( + returnOp.getLoc(), llvm::to_vector<4>(returnOp.getOperands())); + returnOp.erase(); + } + } + + // Gather all reachable functions. + llvm::SetVector<FuncOp> reachableFuncs; + findReachableFunctions(outlinedFunc, reachableFuncs, dispatchableFuncOps); + + // Create the executable that will contain the outlined region. + // NOTE: this will get uniquified if we have multiple in the same block. + auto parentModule = parentFunc.getParentOfType<ModuleOp>(); + OpBuilder parentModuleBuilder(parentModule); + parentModuleBuilder.setInsertionPoint(parentFunc); + std::string executableName = + (parentFunc.getName().str() + "_ex" + symbolSuffix).str(); + auto executableOp = parentModuleBuilder.create<IREE::Flow::ExecutableOp>( + outlinedFunc.getLoc(), executableName); + + // Create the inner ModuleOp that contains the original functions. We need + // to provide this shim as some ops (like std.call) look for the + // containing module to provide symbol resolution. + OpBuilder executableBuilder(executableOp); + executableBuilder.setInsertionPointToStart(&executableOp.getBlock()); + auto innerModule = executableBuilder.create<ModuleOp>(outlinedFunc.getLoc()); + innerModule.push_back(outlinedFunc); + + // Copy all reachable functions into the executable. + // Linker passes may dedupe these later on. + OpBuilder innerModuleBuilder(innerModule.getBody()); + innerModuleBuilder.setInsertionPoint(innerModule.getBody(), + ++innerModule.getBody()->begin()); + for (auto reachableFunc : reachableFuncs) { + innerModuleBuilder.clone(*reachableFunc); + } + + return std::make_pair(executableOp, outlinedFunc); +} + +} // namespace Flow +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir
diff --git a/iree/compiler/Dialect/Flow/Utils/DispatchUtils.h b/iree/compiler/Dialect/Flow/Utils/DispatchUtils.h new file mode 100644 index 0000000..6db31e2 --- /dev/null +++ b/iree/compiler/Dialect/Flow/Utils/DispatchUtils.h
@@ -0,0 +1,60 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Utilities for dispatch region and function manipulation. +// These are shared between all dispatchable types such as the standard +// iree.dispatch_region as well as dispatch-related types like +// iree.reduction_region. + +#ifndef IREE_COMPILER_DIALECT_FLOW_UTILS_DISPATCHUTILS_H_ +#define IREE_COMPILER_DIALECT_FLOW_UTILS_DISPATCHUTILS_H_ + +#include "iree/compiler/Dialect/Flow/IR/FlowOps.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LogicalResult.h" + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace Flow { + +// Builds a new dispatch region with the given |ops|. +// The region will capture all required values and return all values used +// outside of the |ops| provided. The region will be inserted at the location of +// the last operation in the set. +// +// All |ops| must be compatible with the |workload| specified as they will all +// be dispatched with the same workgroup structure. +// TODO(benvanik): ensure we want to insert at end. Maybe front? +LogicalResult buildDispatchRegion(FuncOp func, Block *parentBlock, + Value *workload, ArrayRef<Operation *> ops); + +// Creates an executable containing exported function containing the body region +// of |op|. Created executables will be named for their original function +// concatenated with |symbolSuffix|. All functions reachable by the region will +// be added to the executable by looking them up in |dispatchableFuncOps|. +std::pair<IREE::Flow::ExecutableOp, FuncOp> createRegionExecutable( + Operation *op, FunctionType functionType, StringRef symbolSuffix, + llvm::StringMap<FuncOp> &dispatchableFuncOps); + +} // namespace Flow +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_DIALECT_FLOW_UTILS_DISPATCHUTILS_H_
diff --git a/iree/compiler/Dialect/Flow/Utils/WorkloadUtils.cpp b/iree/compiler/Dialect/Flow/Utils/WorkloadUtils.cpp new file mode 100644 index 0000000..543fa45 --- /dev/null +++ b/iree/compiler/Dialect/Flow/Utils/WorkloadUtils.cpp
@@ -0,0 +1,70 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/compiler/Dialect/Flow/Utils/WorkloadUtils.h" + +#include <array> + +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/StandardTypes.h" + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace Flow { + +Value *calculateWorkload(Operation *op, ShapedType baseOperandType) { + OpBuilder builder(op); + + std::array<int32_t, 3> workload = {1, 1, 1}; + + // TODO(b/139353314): lookup/calculate based on type/etc. + if (!baseOperandType.hasStaticShape()) { + op->emitOpError() << "Dynamic shapes not yet supported"; + return nullptr; + } + auto shape = baseOperandType.getShape(); + // Drop the trailing ones from the shape. + while (shape.size() > 1 && shape.back() == 1) { + shape = shape.drop_back(); + } + if (shape.size() <= 3) { + // Maps to XYZ (possibly with 1's for unused dimensions). + for (auto dim : enumerate(shape)) { + workload[shape.size() - 1 - dim.index()] = dim.value(); + } + } else { + // Need to flatten the shape to fit XYZ. For now we just squash from LHS. + workload[2] = 1; + for (int i = 0; i < shape.size(); ++i) { + workload[2] *= shape[i]; + } + workload[1] = shape[shape.size() - 2]; + workload[0] = shape.back(); + } + + // TODO(b/139353314): optimize workload layout. + + auto constantType = RankedTensorType::get({3}, builder.getIntegerType(32)); + return builder.create<ConstantOp>( + op->getLoc(), constantType, + DenseIntElementsAttr::get<int32_t>(constantType, workload)); +} + +} // namespace Flow +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir
diff --git a/iree/compiler/Dialect/Flow/Utils/WorkloadUtils.h b/iree/compiler/Dialect/Flow/Utils/WorkloadUtils.h new file mode 100644 index 0000000..de9cfad --- /dev/null +++ b/iree/compiler/Dialect/Flow/Utils/WorkloadUtils.h
@@ -0,0 +1,34 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_COMPILER_DIALECT_FLOW_UTILS_WORKLOADUTILS_H_ +#define IREE_COMPILER_DIALECT_FLOW_UTILS_WORKLOADUTILS_H_ + +#include "mlir/IR/Operation.h" +#include "mlir/IR/Value.h" + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace Flow { + +// Calculates the workload for |op| based on the op type. +Value *calculateWorkload(Operation *op, ShapedType baseOperandType); + +} // namespace Flow +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_DIALECT_FLOW_UTILS_WORKLOADUTILS_H_
diff --git a/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.h b/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.h index 14ead96..8b6cf92 100644 --- a/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.h +++ b/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.h
@@ -54,7 +54,7 @@ // See iree/schemas/bytecode_module_def.fbs for the description of the // serialized module format. // -// Exposed via the --vm-mlir-to-bytecode-module translation. +// Exposed via the --iree-vm-ir-to-bytecode-module translation. LogicalResult translateModuleToBytecode(BytecodeTargetOptions targetOptions, IREE::VM::ModuleOp moduleOp, llvm::raw_ostream &output);
diff --git a/iree/tools/BUILD b/iree/tools/BUILD index f9ae673..8b18ec3 100644 --- a/iree/tools/BUILD +++ b/iree/tools/BUILD
@@ -32,6 +32,9 @@ deps = [ "//integrations/tensorflow/compiler:tensorflow", "//iree/compiler/Dialect", + "//iree/compiler/Dialect/Flow/Analysis", + "//iree/compiler/Dialect/Flow/IR", + "//iree/compiler/Dialect/Flow/Transforms", "//iree/compiler/Dialect/VM/Analysis", "//iree/compiler/Dialect/VM/Conversion/StandardToVM", "//iree/compiler/Dialect/VM/IR",