Transition vec3 flow workload to index invocation count. * This was part of a larger cl to also enable dynamic dimension propagation, so it includes code to perform those calculations. * Disables the special casing for the hand coded vulkan conv shader and the test. Will re-enable the test once either a) vmla conv op is implemented or b) linalg based patch lands. I'd rather spend my time getting the new/right path working than triaging this old part. PiperOrigin-RevId: 300664492
diff --git a/integrations/tensorflow/e2e/vulkan_conv_test.py b/integrations/tensorflow/e2e/vulkan_conv_test.py index cf7c064..d68f18e 100644 --- a/integrations/tensorflow/e2e/vulkan_conv_test.py +++ b/integrations/tensorflow/e2e/vulkan_conv_test.py
@@ -100,9 +100,12 @@ @tf_test_utils.compile_modules( backends=[ - "iree_vulkan", + # TODO(laurenzo): Enable for all backends once vmla reference + # and Linalg vulkan impl lands. + # "iree_vulkan", "tf", - ], conv2d=Conv2dModule) + ], + conv2d=Conv2dModule) class ConvTest(tf_test_utils.SavedModelTestCase): def test_id_batch_size_1(self):
diff --git a/iree/compiler/Dialect/Flow/Conversion/TypeConverter.cpp b/iree/compiler/Dialect/Flow/Conversion/TypeConverter.cpp index 1bc796d..5ca37eb 100644 --- a/iree/compiler/Dialect/Flow/Conversion/TypeConverter.cpp +++ b/iree/compiler/Dialect/Flow/Conversion/TypeConverter.cpp
@@ -23,10 +23,10 @@ FlowTypeConverter::FlowTypeConverter() { // Allow types through by default. addConversion([](Type type) { return type; }); - addConversion([](IndexType type) { - // Always treat as 32-bit. - return IntegerType::get(32, type.getContext()); - }); + // addConversion([](IndexType type) { + // // Always treat as 32-bit. + // return IntegerType::get(32, type.getContext()); + // }); addConversion([](IntegerType integerType) -> Optional<Type> { if (integerType.isSignlessInteger() && integerType.getWidth() > 32) { // Don't support 64-bit types in general. Rewrite to i32 (if desired).
diff --git a/iree/compiler/Dialect/Flow/IR/FlowBase.td b/iree/compiler/Dialect/Flow/IR/FlowBase.td index 32495ab..725675a 100644 --- a/iree/compiler/Dialect/Flow/IR/FlowBase.td +++ b/iree/compiler/Dialect/Flow/IR/FlowBase.td
@@ -131,7 +131,7 @@ def FLOW_VariablePtr : AnyPtrOf<[FLOW_Tensor, FLOW_PrimitiveType]>; // TODO(benvanik): use index types instead of i32. -def FLOW_Workload : VectorOfLengthAndType<[3], [I32]> { +def FLOW_Workload : AnyTypeOf<[Index]> { let typeDescription = [{ Describes the total untiled invocations along one or more dimensions. Tiling may later divide this value for workgroup sizes.
diff --git a/iree/compiler/Dialect/Flow/IR/test/dispatch_ops.mlir b/iree/compiler/Dialect/Flow/IR/test/dispatch_ops.mlir index 3665478..d201e53 100644 --- a/iree/compiler/Dialect/Flow/IR/test/dispatch_ops.mlir +++ b/iree/compiler/Dialect/Flow/IR/test/dispatch_ops.mlir
@@ -13,8 +13,9 @@ // CHECK-LABEL: @dispatch func @dispatch(%arg0 : tensor<4xf32>) -> tensor<4xf32> { - %cst = constant dense<1> : vector<3xi32> - // CHECK: %0 = flow.dispatch @ex0::@dispatch_fn[%cst : vector<3xi32>](%arg0) : (tensor<4xf32>) -> tensor<4xf32> - %0 = flow.dispatch @ex0::@dispatch_fn[%cst : vector<3xi32>](%arg0) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: %[[CST:.+]] = constant + %cst = constant 4 : index + // CHECK: %0 = flow.dispatch @ex0::@dispatch_fn[%[[CST]] : index](%arg0) : (tensor<4xf32>) -> tensor<4xf32> + %0 = flow.dispatch @ex0::@dispatch_fn[%cst : index](%arg0) : (tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> }
diff --git a/iree/compiler/Dialect/Flow/IR/test/dispatch_regions.mlir b/iree/compiler/Dialect/Flow/IR/test/dispatch_regions.mlir index 4e93bad..77749a5 100644 --- a/iree/compiler/Dialect/Flow/IR/test/dispatch_regions.mlir +++ b/iree/compiler/Dialect/Flow/IR/test/dispatch_regions.mlir
@@ -4,12 +4,12 @@ // CHECK-LABEL: @singleArg func @singleArg(%arg0 : tensor<?xf32>) { - // CHECK-NEXT: %0 = "some.shape" - // CHECK-NEXT: flow.dispatch.region[%0 : vector<3xi32>](%arg1 = %arg0 : tensor<?xf32>) { + // CHECK-NEXT: %[[WORKLOAD:.+]] = "some.shape" + // CHECK-NEXT: flow.dispatch.region[%[[WORKLOAD]] : index](%arg1 = %arg0 : tensor<?xf32>) { // CHECK-NEXT: flow.return // CHECK-NEXT: } - %workload = "some.shape"(%arg0) : (tensor<?xf32>) -> vector<3xi32> - flow.dispatch.region[%workload : vector<3xi32>](%i0 = %arg0 : tensor<?xf32>) { + %workload = "some.shape"(%arg0) : (tensor<?xf32>) -> index + flow.dispatch.region[%workload : index](%i0 = %arg0 : tensor<?xf32>) { flow.return } // CHECK-NEXT: return @@ -20,12 +20,12 @@ // CHECK-LABEL: @multipleArgs func @multipleArgs(%arg0 : tensor<?xf32>, %arg1 : tensor<?xf32>) { - // CHECK-NEXT: %0 = "some.shape" - // CHECK-NEXT: flow.dispatch.region[%0 : vector<3xi32>](%arg2 = %arg0 : tensor<?xf32>, %arg3 = %arg1 : tensor<?xf32>) { + // CHECK-NEXT: %[[WORKLOAD:.+]] = "some.shape" + // CHECK-NEXT: flow.dispatch.region[%[[WORKLOAD]] : index](%arg2 = %arg0 : tensor<?xf32>, %arg3 = %arg1 : tensor<?xf32>) { // CHECK-NEXT: flow.return // CHECK-NEXT: } - %workload = "some.shape"(%arg0) : (tensor<?xf32>) -> vector<3xi32> - flow.dispatch.region[%workload : vector<3xi32>](%i0 = %arg0 : tensor<?xf32>, %i1 = %arg1 : tensor<?xf32>) { + %workload = "some.shape"(%arg0) : (tensor<?xf32>) -> index + flow.dispatch.region[%workload : index](%i0 = %arg0 : tensor<?xf32>, %i1 = %arg1 : tensor<?xf32>) { flow.return } // CHECK-NEXT: return @@ -36,12 +36,12 @@ // CHECK-LABEL: @singleResult func @singleResult(%arg0 : tensor<?xf32>) -> tensor<?xf32> { - // CHECK-NEXT: %0 = "some.shape" - // CHECK-NEXT: %1 = flow.dispatch.region[%0 : vector<3xi32>](%arg1 = %arg0 : tensor<?xf32>) -> tensor<?xf32> { + // CHECK-NEXT: %[[WORKLOAD:.+]] = "some.shape" + // CHECK-NEXT: %1 = flow.dispatch.region[%[[WORKLOAD]] : index](%arg1 = %arg0 : tensor<?xf32>) -> tensor<?xf32> { // CHECK-NEXT: flow.return %arg1 : tensor<?xf32> // CHECK-NEXT: } - %workload = "some.shape"(%arg0) : (tensor<?xf32>) -> vector<3xi32> - %ret0 = flow.dispatch.region[%workload : vector<3xi32>](%i0 = %arg0 : tensor<?xf32>) -> tensor<?xf32> { + %workload = "some.shape"(%arg0) : (tensor<?xf32>) -> index + %ret0 = flow.dispatch.region[%workload : index](%i0 = %arg0 : tensor<?xf32>) -> tensor<?xf32> { flow.return %i0 : tensor<?xf32> } // CHECK-NEXT: return %1 : tensor<?xf32> @@ -52,12 +52,12 @@ // CHECK-LABEL: @multipleResults func @multipleResults(%arg0 : tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) { - // CHECK-NEXT: %0 = "some.shape" - // CHECK-NEXT: %1:2 = flow.dispatch.region[%0 : vector<3xi32>](%arg1 = %arg0 : tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) { + // CHECK-NEXT: %[[WORKLOAD:.+]] = "some.shape" + // CHECK-NEXT: %1:2 = flow.dispatch.region[%[[WORKLOAD]] : index](%arg1 = %arg0 : tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) { // CHECK-NEXT: flow.return %arg1, %arg1 : tensor<?xf32>, tensor<?xf32> // CHECK-NEXT: } - %workload = "some.shape"(%arg0) : (tensor<?xf32>) -> vector<3xi32> - %ret0, %ret1 = flow.dispatch.region[%workload : vector<3xi32>](%i0 = %arg0 : tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) { + %workload = "some.shape"(%arg0) : (tensor<?xf32>) -> index + %ret0, %ret1 = flow.dispatch.region[%workload : index](%i0 = %arg0 : tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) { flow.return %i0, %i0 : tensor<?xf32>, tensor<?xf32> } // CHECK-NEXT: return %1#0, %1#1 : tensor<?xf32>, tensor<?xf32>
diff --git a/iree/compiler/Dialect/Flow/IR/test/stream_ops.mlir b/iree/compiler/Dialect/Flow/IR/test/stream_ops.mlir index ae61a9a..1d13a98 100644 --- a/iree/compiler/Dialect/Flow/IR/test/stream_ops.mlir +++ b/iree/compiler/Dialect/Flow/IR/test/stream_ops.mlir
@@ -14,11 +14,12 @@ // CHECK-LABEL: func @fragment func @fragment(%arg0 : tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { - %cst = constant dense<[4, 1, 1]> : vector<3xi32> - // CHECK: %0:2 = flow.ex.stream.fragment(%arg1 = %cst : vector<3xi32>, %arg2 = %arg0 : tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { - %0:2 = flow.ex.stream.fragment(%arg1 = %cst : vector<3xi32>, %arg2 = %arg0 : tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { + // CHECK: %[[WORKLOAD:.+]] = constant + %cst = constant 4 : index + // CHECK: %0:2 = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD]] : index, %arg2 = %arg0 : tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { + %0:2 = flow.ex.stream.fragment(%arg1 = %cst : index, %arg2 = %arg0 : tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { // CHECK-NEXT: flow.dispatch - %1 = flow.dispatch @dispatch_0::@rgn_dispatch_0[%arg1 : vector<3xi32>](%arg2) : (tensor<4xf32>) -> tensor<4xf32> + %1 = flow.dispatch @dispatch_0::@rgn_dispatch_0[%arg1 : index](%arg2) : (tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: flow.return flow.return %1, %1 : tensor<4xf32>, tensor<4xf32> // CHECK-NEXT: }
diff --git a/iree/compiler/Dialect/Flow/Transforms/AssignExecutableWorkloads.cpp b/iree/compiler/Dialect/Flow/Transforms/AssignExecutableWorkloads.cpp deleted file mode 100644 index 0b65592..0000000 --- a/iree/compiler/Dialect/Flow/Transforms/AssignExecutableWorkloads.cpp +++ /dev/null
@@ -1,122 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/compiler/Dialect/Flow/IR/FlowOps.h" -#include "llvm/ADT/StringMap.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassRegistry.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/Utils.h" - -namespace mlir { -namespace iree_compiler { -namespace IREE { -namespace Flow { - -namespace { - -struct WorkloadInfo { - SmallVector<ElementsAttr, 4> staticWorkloads; - SmallVector<Value, 4> dynamicWorkloads; -}; - -// Finds all dispatches and records their workload attributes mapped by -// (executable ordinal, entry point ordinal). -llvm::StringMap<llvm::StringMap<WorkloadInfo>> gatherExecutableWorkloadInfos( - ModuleOp moduleOp) { - llvm::StringMap<llvm::StringMap<WorkloadInfo>> workloadInfos; - for (auto funcOp : moduleOp.getOps<FuncOp>()) { - funcOp.walk([&](DispatchOp op) { - auto &workloadInfo = workloadInfos[op.executable()][op.entry_point()]; - if (auto constantOp = - dyn_cast<ConstantOp>(op.workload().getDefiningOp())) { - for (auto existingWorkloadAttr : workloadInfo.staticWorkloads) { - if (existingWorkloadAttr == constantOp.value()) { - return; // Already present, ignore. - } - } - workloadInfo.staticWorkloads.push_back( - constantOp.value().cast<ElementsAttr>()); - } else { - workloadInfo.dynamicWorkloads.push_back(op.workload()); - } - }); - } - return workloadInfos; -} - -// Adds attributes to the given executable entry point describing the workload -// info to the backends that will be processing them. -LogicalResult attributeExecutableEntryPointWorkload( - Operation *entryPointOp, const WorkloadInfo &workloadInfo) { - if (!workloadInfo.dynamicWorkloads.empty()) { - return entryPointOp->emitError() << "dynamic workloads not yet supported"; - } - if (workloadInfo.staticWorkloads.size() != 1) { - return entryPointOp->emitError() << "static workload sizes differ in shape"; - } - - // Easy because we just support static workloads now. - // When this code is adapted to support dynamic workloads we'll want to put - // a pair of attrs describing which dimensions may be static and which args - // have the dynamic values to reference. - entryPointOp->setAttr("workload", workloadInfo.staticWorkloads.front()); - - return success(); -} - -} // namespace - -class AssignExecutableWorkloadsPass - : public ModulePass<AssignExecutableWorkloadsPass> { - public: - void runOnModule() override { - Builder builder(getModule()); - - // Find all dispatches and capture their workload information. - // We store this information by executable and then entry point ordinal. - auto executableWorkloadInfos = gatherExecutableWorkloadInfos(getModule()); - - // Process each executable with the workload information. - SymbolTable symbolTable(getModule()); - for (auto &executableIt : executableWorkloadInfos) { - auto executableOp = - symbolTable.lookup<ExecutableOp>(executableIt.first()); - for (auto &entryPointIt : executableIt.second) { - auto entryPointOp = executableOp.lookupSymbol(entryPointIt.first()); - if (failed(attributeExecutableEntryPointWorkload( - entryPointOp, entryPointIt.second))) { - return signalPassFailure(); - } - } - } - } -}; - -std::unique_ptr<OpPassBase<ModuleOp>> createAssignExecutableWorkloadsPass() { - return std::make_unique<AssignExecutableWorkloadsPass>(); -} - -static PassRegistration<AssignExecutableWorkloadsPass> pass( - "iree-flow-assign-executable-workloads", - "Assigns executable entrypoint workload attributes"); - -} // namespace Flow -} // namespace IREE -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/Dialect/Flow/Transforms/BUILD b/iree/compiler/Dialect/Flow/Transforms/BUILD index 76c106a..a282d82 100644 --- a/iree/compiler/Dialect/Flow/Transforms/BUILD +++ b/iree/compiler/Dialect/Flow/Transforms/BUILD
@@ -20,7 +20,6 @@ cc_library( name = "Transforms", srcs = [ - "AssignExecutableWorkloads.cpp", "DispatchabilityAnalysis.cpp", "FlattenTuplesInCFG.cpp", "FoldCompatibleDispatchRegions.cpp", @@ -45,6 +44,7 @@ "//iree/compiler/Dialect/Flow/Conversion/StandardToFlow", "//iree/compiler/Dialect/Flow/IR", "//iree/compiler/Dialect/Flow/Utils", + "//iree/compiler/Dialect/Shape/IR", "//iree/compiler/Utils", "@llvm-project//llvm:support", "@llvm-project//mlir:Analysis",
diff --git a/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt b/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt index 6ffa54b..59c7a41 100644 --- a/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt +++ b/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
@@ -20,7 +20,6 @@ HDRS "Passes.h" SRCS - "AssignExecutableWorkloads.cpp" "DispatchabilityAnalysis.cpp" "FlattenTuplesInCFG.cpp" "FoldCompatibleDispatchRegions.cpp"
diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp index c7331ad..3db5627 100644 --- a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp +++ b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
@@ -78,7 +78,7 @@ passManager.addNestedPass<FuncOp>(createPostPartitioningConversionPass()); // Assign attributes and negotiate each executable's ABI signature. - passManager.addPass(IREE::Flow::createAssignExecutableWorkloadsPass()); + // passManager.addPass(IREE::Flow::createAssignExecutableWorkloadsPass()); // Form streams. passManager.addPass(IREE::Flow::createFormStreamsPass());
diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.h b/iree/compiler/Dialect/Flow/Transforms/Passes.h index 1ac81b7..8667b2b 100644 --- a/iree/compiler/Dialect/Flow/Transforms/Passes.h +++ b/iree/compiler/Dialect/Flow/Transforms/Passes.h
@@ -125,9 +125,6 @@ // Module Analysis and Finalization //===----------------------------------------------------------------------===// -// Assigns workload attributes to executable entry points based on dispatches. -std::unique_ptr<OpPassBase<ModuleOp>> createAssignExecutableWorkloadsPass(); - } // namespace Flow } // namespace IREE } // namespace iree_compiler
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/assign_executable_workloads.mlir b/iree/compiler/Dialect/Flow/Transforms/test/assign_executable_workloads.mlir deleted file mode 100644 index 9cd1007..0000000 --- a/iree/compiler/Dialect/Flow/Transforms/test/assign_executable_workloads.mlir +++ /dev/null
@@ -1,20 +0,0 @@ -// RUN: iree-opt -split-input-file -iree-flow-assign-executable-workloads %s | IreeFileCheck %s - -flow.executable @singleStaticWorkload_ex_dispatch_0 { - // CHECK-LABEL: flow.dispatch.entry @singleStaticWorkload_rgn_dispatch_0 - // CHECK-SAME: workload = dense<[4, 1, 1]> : vector<3xi32> - flow.dispatch.entry @singleStaticWorkload_rgn_dispatch_0 - module { - func @singleStaticWorkload_rgn_dispatch_0(%arg0: tensor<4xf32>) -> tensor<4xf32> { - %0 = addf %arg0, %arg0 : tensor<4xf32> - %1 = subf %0, %arg0 : tensor<4xf32> - %2 = mulf %1, %arg0 : tensor<4xf32> - return %2 : tensor<4xf32> - } - } -} -func @singleStaticWorkload(%arg0: tensor<4xf32>) -> tensor<4xf32> { - %cst = constant dense<[4, 1, 1]> : vector<3xi32> - %0 = flow.dispatch @singleStaticWorkload_ex_dispatch_0::@singleStaticWorkload_rgn_dispatch_0[%cst : vector<3xi32>](%arg0) : (tensor<4xf32>) -> tensor<4xf32> - return %0 : tensor<4xf32> -}
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/fold_compatible_dispatch_regions.mlir b/iree/compiler/Dialect/Flow/Transforms/test/fold_compatible_dispatch_regions.mlir index ce154c2..2d6174b 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/fold_compatible_dispatch_regions.mlir +++ b/iree/compiler/Dialect/Flow/Transforms/test/fold_compatible_dispatch_regions.mlir
@@ -1,8 +1,8 @@ // RUN: iree-opt -split-input-file -iree-flow-fold-compatible-dispatch-regions %s | IreeFileCheck %s func @noFolding(%arg0 : tensor<4xf32>) -> tensor<4xf32> { - %cst = constant dense<[4, 1, 1]> : vector<3xi32> - %0 = flow.dispatch.region[%cst : vector<3xi32>](%arg1 = %arg0 : tensor<4xf32>) -> tensor<4xf32> { + %cst = constant 4 : index + %0 = flow.dispatch.region[%cst : index](%arg1 = %arg0 : tensor<4xf32>) -> tensor<4xf32> { %1 = xla_hlo.add %arg1, %arg1 : tensor<4xf32> flow.return %1 : tensor<4xf32> } @@ -10,8 +10,8 @@ } // CHECK-LABEL: func @noFolding -// CHECK-NEXT: %cst = constant dense<[4, 1, 1]> : vector<3xi32> -// CHECK-NEXT: %0 = flow.dispatch.region[%cst : vector<3xi32>](%arg1 = %arg0 : tensor<4xf32>) -> tensor<4xf32> { +// CHECK-NEXT: %[[WORKLOAD0:.+]] = constant 4 : index +// CHECK-NEXT: %0 = flow.dispatch.region[%[[WORKLOAD0]] : index](%arg1 = %arg0 : tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: %1 = xla_hlo.add %arg1, %arg1 : tensor<4xf32> // CHECK-NEXT: flow.return %1 : tensor<4xf32> // CHECK-NEXT: } @@ -20,16 +20,16 @@ // ----- func @elementwiseOps(%arg0 : tensor<4xf32>) -> tensor<4xf32> { - %cst = constant dense<[4, 1, 1]> : vector<3xi32> - %0 = flow.dispatch.region[%cst : vector<3xi32>](%arg1 = %arg0 : tensor<4xf32>) -> tensor<4xf32> { + %cst = constant 4 : index + %0 = flow.dispatch.region[%cst : index](%arg1 = %arg0 : tensor<4xf32>) -> tensor<4xf32> { %1 = xla_hlo.add %arg1, %arg1 : tensor<4xf32> flow.return %1 : tensor<4xf32> } - %2 = flow.dispatch.region[%cst : vector<3xi32>](%arg2 = %arg0 : tensor<4xf32>, %arg3 = %0 : tensor<4xf32>) -> tensor<4xf32> { + %2 = flow.dispatch.region[%cst : index](%arg2 = %arg0 : tensor<4xf32>, %arg3 = %0 : tensor<4xf32>) -> tensor<4xf32> { %3 = xla_hlo.sub %arg3, %arg2 : tensor<4xf32> flow.return %3 : tensor<4xf32> } - %4 = flow.dispatch.region[%cst : vector<3xi32>](%arg4 = %arg0 : tensor<4xf32>, %arg5 = %2 : tensor<4xf32>) -> tensor<4xf32> { + %4 = flow.dispatch.region[%cst : index](%arg4 = %arg0 : tensor<4xf32>, %arg5 = %2 : tensor<4xf32>) -> tensor<4xf32> { %5 = xla_hlo.mul %arg4, %arg5 : tensor<4xf32> flow.return %5 : tensor<4xf32> } @@ -37,30 +37,30 @@ } // CHECK-LABEL: func @elementwiseOps -// CHECK-NEXT: %cst = constant dense<[4, 1, 1]> -// CHECK-NEXT: %0 = flow.dispatch.region[%cst : vector<3xi32>](%arg1 = %arg0 : tensor<4xf32>) -> tensor<4xf32> { +// CHECK: %[[WORKLOAD0:.+]] = constant 4 +// CHECK: %[[R0:.+]] = flow.dispatch.region[%[[WORKLOAD0]] : index](%arg1 = %arg0 : tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: %1 = xla_hlo.add %arg1, %arg1 : tensor<4xf32> // CHECK-NEXT: %2 = xla_hlo.sub %1, %arg1 : tensor<4xf32> // CHECK-NEXT: %3 = xla_hlo.mul %arg1, %2 : tensor<4xf32> // CHECK-NEXT: flow.return %3 : tensor<4xf32> // CHECK-NEXT: } -// CHECK-NEXT: return %0 : tensor<4xf32> +// CHECK: return %[[R0]] : tensor<4xf32> // ----- func @interleavedDot(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { - %cst = constant dense<[4, 4, 1]> : vector<3xi32> - %0 = flow.dispatch.region[%cst : vector<3xi32>](%arg1 = %arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { + %cst = constant 16 : index + %0 = flow.dispatch.region[%cst : index](%arg1 = %arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { %3 = xla_hlo.add %arg1, %arg1 : tensor<4x4xf32> flow.return %3 : tensor<4x4xf32> } - %cst_0 = constant dense<[4, 4, 1]> : vector<3xi32> - %1 = flow.dispatch.region[%cst_0 : vector<3xi32>](%arg1 = %0 : tensor<4x4xf32>, %arg2 = %arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { + %cst_0 = constant 16 : index + %1 = flow.dispatch.region[%cst_0 : index](%arg1 = %0 : tensor<4x4xf32>, %arg2 = %arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { %3 = "xla_hlo.dot"(%arg1, %arg2) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> flow.return %3 : tensor<4x4xf32> } - %cst_1 = constant dense<[4, 4, 1]> : vector<3xi32> - %2 = flow.dispatch.region[%cst_1 : vector<3xi32>](%arg1 = %1 : tensor<4x4xf32>, %arg2 = %arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { + %cst_1 = constant 16 : index + %2 = flow.dispatch.region[%cst_1 : index](%arg1 = %1 : tensor<4x4xf32>, %arg2 = %arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { %3 = xla_hlo.mul %arg1, %arg2 : tensor<4x4xf32> flow.return %3 : tensor<4x4xf32> } @@ -68,19 +68,19 @@ } // CHECK-LABEL: func @interleavedDot -// CHECK-NEXT: %cst = constant dense<[4, 4, 1]> : vector<3xi32> -// CHECK-NEXT: %0 = flow.dispatch.region[%cst : vector<3xi32>](%arg1 = %arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { +// CHECK-NEXT: %[[WORKLOAD0:.+]] = constant 16 : index +// CHECK-NEXT: %[[R0:.+]] = flow.dispatch.region[%[[WORKLOAD0]] : index](%arg1 = %arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { // CHECK-NEXT: %3 = xla_hlo.add %arg1, %arg1 : tensor<4x4xf32> // CHECK-NEXT: flow.return %3 : tensor<4x4xf32> // CHECK-NEXT: } -// CHECK-NEXT: %cst_0 = constant dense<[4, 4, 1]> : vector<3xi32> -// CHECK-NEXT: %1 = flow.dispatch.region[%cst_0 : vector<3xi32>](%arg1 = %0 : tensor<4x4xf32>, %arg2 = %arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { +// CHECK-NEXT: %[[WORKLOAD1:.+]] = constant 16 : index +// CHECK-NEXT: %[[R1:.+]] = flow.dispatch.region[%[[WORKLOAD1]] : index](%arg1 = %[[R0]] : tensor<4x4xf32>, %arg2 = %arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { // CHECK-NEXT: %3 = "xla_hlo.dot"(%arg1, %arg2) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> // CHECK-NEXT: flow.return %3 : tensor<4x4xf32> // CHECK-NEXT: } -// CHECK-NEXT: %cst_1 = constant dense<[4, 4, 1]> : vector<3xi32> -// CHECK-NEXT: %2 = flow.dispatch.region[%cst_1 : vector<3xi32>](%arg1 = %1 : tensor<4x4xf32>, %arg2 = %arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { +// CHECK-NEXT: %[[WORKLOAD2:.+]] = constant 16 : index +// CHECK-NEXT: %[[R2:.+]] = flow.dispatch.region[%[[WORKLOAD2]] : index](%arg1 = %[[R1]] : tensor<4x4xf32>, %arg2 = %arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { // CHECK-NEXT: %3 = xla_hlo.mul %arg1, %arg2 : tensor<4x4xf32> // CHECK-NEXT: flow.return %3 : tensor<4x4xf32> // CHECK-NEXT: } -// CHECK-NEXT: return %2 : tensor<4x4xf32> +// CHECK-NEXT: return %[[R2]] : tensor<4x4xf32>
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/form_streams.mlir b/iree/compiler/Dialect/Flow/Transforms/test/form_streams.mlir index 7a9c1fe..791c7c2 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/form_streams.mlir +++ b/iree/compiler/Dialect/Flow/Transforms/test/form_streams.mlir
@@ -13,15 +13,15 @@ } // CHECK-LABEL: func @outerOps func @outerOps(%arg0: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: %cst = constant dense<[4, 1, 1]> : vector<3xi32> - %cst = constant dense<[4, 1, 1]> : vector<3xi32> + // CHECK: %[[WORKLOAD0:.+]] = constant 4 : index + %cst = constant 4 : index // CHECK-NEXT: %0 = addf %arg0, %arg0 : tensor<4xf32> %0 = addf %arg0, %arg0 : tensor<4xf32> - // CHECK-NEXT: %1 = flow.ex.stream.fragment(%arg1 = %cst : vector<3xi32>, %arg2 = %0 : tensor<4xf32>) -> tensor<4xf32> { - // CHECK-NEXT: %3 = flow.dispatch @outerOps_ex_dispatch_0::@outerOps_rgn_dispatch_0[%arg1 : vector<3xi32>](%arg2) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: %1 = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD0]] : index, %arg2 = %0 : tensor<4xf32>) -> tensor<4xf32> { + // CHECK-NEXT: %3 = flow.dispatch @outerOps_ex_dispatch_0::@outerOps_rgn_dispatch_0[%arg1 : index](%arg2) : (tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: flow.return %3 : tensor<4xf32> // CHECK-NEXT: } - %1 = flow.dispatch @outerOps_ex_dispatch_0::@outerOps_rgn_dispatch_0[%cst : vector<3xi32>](%0) : (tensor<4xf32>) -> tensor<4xf32> + %1 = flow.dispatch @outerOps_ex_dispatch_0::@outerOps_rgn_dispatch_0[%cst : index](%0) : (tensor<4xf32>) -> tensor<4xf32> // CHECK: %2 = addf %1, %1 : tensor<4xf32> %2 = addf %1, %1 : tensor<4xf32> // CHECK-NEXT: return %2 : tensor<4xf32> @@ -43,17 +43,17 @@ } // CHECK-LABEL: func @nondependentOuterOps( func @nondependentOuterOps(%arg0: tensor<4xf32>) -> tensor<4xf32> { - // CHECK-NEXT: %cst = constant dense<[4, 1, 1]> : vector<3xi32> - %cst = constant dense<[4, 1, 1]> : vector<3xi32> - %0 = flow.dispatch @nondependentOuterOps_ex_dispatch_0::@nondependentOuterOps_rgn_dispatch_0[%cst : vector<3xi32>](%arg0, %arg0) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: %[[WORKLOAD:.+]] = constant 4 : index + %cst = constant 4 : index + %0 = flow.dispatch @nondependentOuterOps_ex_dispatch_0::@nondependentOuterOps_rgn_dispatch_0[%cst : index](%arg0, %arg0) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: %0 = addf %arg0, %arg0 : tensor<4xf32> %1 = addf %arg0, %arg0 : tensor<4xf32> - // CHECK-NEXT: %1 = flow.ex.stream.fragment(%arg1 = %cst : vector<3xi32>, %arg2 = %arg0 : tensor<4xf32>, %arg3 = %0 : tensor<4xf32>) -> tensor<4xf32> { - // CHECK-NEXT: %3 = flow.dispatch @nondependentOuterOps_ex_dispatch_0::@nondependentOuterOps_rgn_dispatch_0[%arg1 : vector<3xi32>](%arg2, %arg2) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - // CHECK-NEXT: %4 = flow.dispatch @nondependentOuterOps_ex_dispatch_0::@nondependentOuterOps_rgn_dispatch_0[%arg1 : vector<3xi32>](%3, %arg3) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: %1 = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD]] : index, %arg2 = %arg0 : tensor<4xf32>, %arg3 = %0 : tensor<4xf32>) -> tensor<4xf32> { + // CHECK-NEXT: %3 = flow.dispatch @nondependentOuterOps_ex_dispatch_0::@nondependentOuterOps_rgn_dispatch_0[%arg1 : index](%arg2, %arg2) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: %4 = flow.dispatch @nondependentOuterOps_ex_dispatch_0::@nondependentOuterOps_rgn_dispatch_0[%arg1 : index](%3, %arg3) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: flow.return %4 : tensor<4xf32> // CHECK-NEXT: } - %2 = flow.dispatch @nondependentOuterOps_ex_dispatch_0::@nondependentOuterOps_rgn_dispatch_0[%cst : vector<3xi32>](%0, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %2 = flow.dispatch @nondependentOuterOps_ex_dispatch_0::@nondependentOuterOps_rgn_dispatch_0[%cst : index](%0, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: %2 = addf %1, %arg0 : tensor<4xf32> %3 = addf %2, %arg0 : tensor<4xf32> // CHECK-NEXT: return %2 : tensor<4xf32> @@ -75,20 +75,20 @@ } // CHECK-LABEL: func @interleavedOuterOps( func @interleavedOuterOps(%arg0: tensor<4xf32>) -> tensor<4xf32> { - // CHECK-NEXT: %cst = constant dense<[4, 1, 1]> : vector<3xi32> - %cst = constant dense<[4, 1, 1]> : vector<3xi32> - // CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg1 = %cst : vector<3xi32>, %arg2 = %arg0 : tensor<4xf32>) -> tensor<4xf32> { - // CHECK-NEXT: %3 = flow.dispatch @interleavedOuterOps_ex_dispatch_0::@interleavedOuterOps_rgn_dispatch_0[%arg1 : vector<3xi32>](%arg2) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: %[[WORKLOAD:.+]] = constant 4 : index + %cst = constant 4 : index + // CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD]] : index, %arg2 = %arg0 : tensor<4xf32>) -> tensor<4xf32> { + // CHECK-NEXT: %3 = flow.dispatch @interleavedOuterOps_ex_dispatch_0::@interleavedOuterOps_rgn_dispatch_0[%arg1 : index](%arg2) : (tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: flow.return %3 : tensor<4xf32> // CHECK-NEXT: } - %0 = flow.dispatch @interleavedOuterOps_ex_dispatch_0::@interleavedOuterOps_rgn_dispatch_0[%cst : vector<3xi32>](%arg0) : (tensor<4xf32>) -> tensor<4xf32> + %0 = flow.dispatch @interleavedOuterOps_ex_dispatch_0::@interleavedOuterOps_rgn_dispatch_0[%cst : index](%arg0) : (tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: %1 = addf %0, %0 : tensor<4xf32> %1 = addf %0, %0 : tensor<4xf32> - // CHECK-NEXT: %2 = flow.ex.stream.fragment(%arg1 = %cst : vector<3xi32>, %arg2 = %1 : tensor<4xf32>) -> tensor<4xf32> { - // CHECK-NEXT: %3 = flow.dispatch @interleavedOuterOps_ex_dispatch_0::@interleavedOuterOps_rgn_dispatch_0[%arg1 : vector<3xi32>](%arg2) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: %2 = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD]] : index, %arg2 = %1 : tensor<4xf32>) -> tensor<4xf32> { + // CHECK-NEXT: %3 = flow.dispatch @interleavedOuterOps_ex_dispatch_0::@interleavedOuterOps_rgn_dispatch_0[%arg1 : index](%arg2) : (tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: flow.return %3 : tensor<4xf32> // CHECK-NEXT: } - %2 = flow.dispatch @interleavedOuterOps_ex_dispatch_0::@interleavedOuterOps_rgn_dispatch_0[%cst : vector<3xi32>](%1) : (tensor<4xf32>) -> tensor<4xf32> + %2 = flow.dispatch @interleavedOuterOps_ex_dispatch_0::@interleavedOuterOps_rgn_dispatch_0[%cst : index](%1) : (tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: return %2 : tensor<4xf32> return %2 : tensor<4xf32> } @@ -105,13 +105,13 @@ } // CHECK-LABEL: func @independentOps( func @independentOps(%arg0: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { - // CHECK-NEXT: %cst = constant dense<[4, 1, 1]> : vector<3xi32> - %cst = constant dense<[4, 1, 1]> : vector<3xi32> - // CHECK-NEXT: %0:2 = flow.ex.stream.fragment(%arg1 = %cst : vector<3xi32>, %arg2 = %arg0 : tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { - // CHECK-DAG: = flow.dispatch @independentOps_ex_dispatch_0::@independentOps_rgn_dispatch_0[%arg1 : vector<3xi32>](%arg2) - %0 = flow.dispatch @independentOps_ex_dispatch_0::@independentOps_rgn_dispatch_0[%cst : vector<3xi32>](%arg0) : (tensor<4xf32>) -> tensor<4xf32> - // CHECK-DAG: = flow.dispatch @independentOps_ex_dispatch_0::@independentOps_rgn_dispatch_1[%arg1 : vector<3xi32>](%arg2) - %1 = flow.dispatch @independentOps_ex_dispatch_0::@independentOps_rgn_dispatch_1[%cst : vector<3xi32>](%arg0) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: %[[WORKLOAD:.+]] = constant 4 : index + %cst = constant 4 : index + // CHECK-NEXT: %0:2 = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD]] : index, %arg2 = %arg0 : tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { + // CHECK-DAG: = flow.dispatch @independentOps_ex_dispatch_0::@independentOps_rgn_dispatch_0[%arg1 : index](%arg2) + %0 = flow.dispatch @independentOps_ex_dispatch_0::@independentOps_rgn_dispatch_0[%cst : index](%arg0) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK-DAG: = flow.dispatch @independentOps_ex_dispatch_0::@independentOps_rgn_dispatch_1[%arg1 : index](%arg2) + %1 = flow.dispatch @independentOps_ex_dispatch_0::@independentOps_rgn_dispatch_1[%cst : index](%arg0) : (tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: flow.return {{%.+}}, {{%.+}} // CHECK-NEXT: } // CHECK-NEXT: return {{%.+}}, {{%.+}} @@ -155,17 +155,17 @@ } // CHECK-LABEL: func @interleavedDot( func @interleavedDot(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> { - // CHECK-NEXT: %cst = constant dense<[4, 4, 1]> : vector<3xi32> - %cst = constant dense<[4, 4, 1]> : vector<3xi32> - // CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg1 = %cst : vector<3xi32>, %arg2 = %arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { - // CHECK-NEXT: %1 = flow.dispatch @interleavedDot_ex_dispatch_0::@interleavedDot_rgn_dispatch_0[%arg1 : vector<3xi32>](%arg2) : (tensor<4x4xf32>) -> tensor<4x4xf32> - // CHECK-NEXT: %2 = flow.dispatch @interleavedDot_ex_dispatch_1::@interleavedDot_rgn_dispatch_1[%arg1 : vector<3xi32>](%1, %arg2) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> - // CHECK-NEXT: %3 = flow.dispatch @interleavedDot_ex_dispatch_2::@interleavedDot_rgn_dispatch_2[%arg1 : vector<3xi32>](%2, %arg2) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + // CHECK-NEXT: %[[WORKLOAD:.+]] = constant 16 : index + %cst = constant 16 : index + // CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD]] : index, %arg2 = %arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { + // CHECK-NEXT: %1 = flow.dispatch @interleavedDot_ex_dispatch_0::@interleavedDot_rgn_dispatch_0[%arg1 : index](%arg2) : (tensor<4x4xf32>) -> tensor<4x4xf32> + // CHECK-NEXT: %2 = flow.dispatch @interleavedDot_ex_dispatch_1::@interleavedDot_rgn_dispatch_1[%arg1 : index](%1, %arg2) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + // CHECK-NEXT: %3 = flow.dispatch @interleavedDot_ex_dispatch_2::@interleavedDot_rgn_dispatch_2[%arg1 : index](%2, %arg2) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> // CHECK-NEXT: flow.return %3 : tensor<4x4xf32> // CHECK-NEXT: } - %0 = flow.dispatch @interleavedDot_ex_dispatch_0::@interleavedDot_rgn_dispatch_0[%cst : vector<3xi32>](%arg0) : (tensor<4x4xf32>) -> tensor<4x4xf32> - %1 = flow.dispatch @interleavedDot_ex_dispatch_1::@interleavedDot_rgn_dispatch_1[%cst : vector<3xi32>](%0, %arg0) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> - %2 = flow.dispatch @interleavedDot_ex_dispatch_2::@interleavedDot_rgn_dispatch_2[%cst : vector<3xi32>](%1, %arg0) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + %0 = flow.dispatch @interleavedDot_ex_dispatch_0::@interleavedDot_rgn_dispatch_0[%cst : index](%arg0) : (tensor<4x4xf32>) -> tensor<4x4xf32> + %1 = flow.dispatch @interleavedDot_ex_dispatch_1::@interleavedDot_rgn_dispatch_1[%cst : index](%0, %arg0) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + %2 = flow.dispatch @interleavedDot_ex_dispatch_2::@interleavedDot_rgn_dispatch_2[%cst : index](%1, %arg0) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> // CHECK-NEXT: return %0 : tensor<4x4xf32> return %2 : tensor<4x4xf32> } @@ -196,20 +196,20 @@ } // CHECK-LABEL: func @caller( func @caller(%arg0: tensor<4xf32>) -> tensor<4xf32> { - // CHECK-NEXT: %cst = constant dense<[4, 1, 1]> : vector<3xi32> - %cst = constant dense<[4, 1, 1]> : vector<3xi32> - // CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg1 = %cst : vector<3xi32>, %arg2 = %arg0 : tensor<4xf32>) -> tensor<4xf32> { - // CHECK-NEXT: %3 = flow.dispatch @caller_ex_dispatch_0::@caller_rgn_dispatch_0[%arg1 : vector<3xi32>](%arg2) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: %[[WORKLOAD:.+]] = constant 4 : index + %cst = constant 4 : index + // CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD]] : index, %arg2 = %arg0 : tensor<4xf32>) -> tensor<4xf32> { + // CHECK-NEXT: %3 = flow.dispatch @caller_ex_dispatch_0::@caller_rgn_dispatch_0[%arg1 : index](%arg2) : (tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: flow.return %3 : tensor<4xf32> // CHECK-NEXT: } - %0 = flow.dispatch @caller_ex_dispatch_0::@caller_rgn_dispatch_0[%cst : vector<3xi32>](%arg0) : (tensor<4xf32>) -> tensor<4xf32> + %0 = flow.dispatch @caller_ex_dispatch_0::@caller_rgn_dispatch_0[%cst : index](%arg0) : (tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: %1 = call @callee(%0) : (tensor<4xf32>) -> tensor<4xf32> %1 = call @callee(%0) : (tensor<4xf32>) -> tensor<4xf32> - // CHECK-NEXT: %2 = flow.ex.stream.fragment(%arg1 = %cst : vector<3xi32>, %arg2 = %arg0 : tensor<4xf32>, %arg3 = %1 : tensor<4xf32>) -> tensor<4xf32> { - // CHECK-NEXT: %3 = flow.dispatch @caller_ex_dispatch_1::@caller_rgn_dispatch_1[%arg1 : vector<3xi32>](%arg2, %arg3) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: %2 = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD]] : index, %arg2 = %arg0 : tensor<4xf32>, %arg3 = %1 : tensor<4xf32>) -> tensor<4xf32> { + // CHECK-NEXT: %3 = flow.dispatch @caller_ex_dispatch_1::@caller_rgn_dispatch_1[%arg1 : index](%arg2, %arg3) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: flow.return %3 : tensor<4xf32> // CHECK-NEXT: } - %2 = flow.dispatch @caller_ex_dispatch_1::@caller_rgn_dispatch_1[%cst : vector<3xi32>](%arg0, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %2 = flow.dispatch @caller_ex_dispatch_1::@caller_rgn_dispatch_1[%cst : index](%arg0, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: return %2 : tensor<4xf32> return %2 : tensor<4xf32> } @@ -224,13 +224,13 @@ } // CHECK-LABEL: func @callee( func @callee(%arg0: tensor<4xf32>) -> tensor<4xf32> { - // CHECK-NEXT: %cst = constant dense<[4, 1, 1]> : vector<3xi32> - %cst = constant dense<[4, 1, 1]> : vector<3xi32> - // CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg1 = %cst : vector<3xi32>, %arg2 = %arg0 : tensor<4xf32>) -> tensor<4xf32> { - // CHECK-NEXT: %1 = flow.dispatch @callee_ex_dispatch_0::@callee_rgn_dispatch_0[%arg1 : vector<3xi32>](%arg2) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: %[[WORKLOAD:.+]] = constant 4 : index + %cst = constant 4 : index + // CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD]] : index, %arg2 = %arg0 : tensor<4xf32>) -> tensor<4xf32> { + // CHECK-NEXT: %1 = flow.dispatch @callee_ex_dispatch_0::@callee_rgn_dispatch_0[%arg1 : index](%arg2) : (tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: flow.return %1 : tensor<4xf32> // CHECK-NEXT: } - %0 = flow.dispatch @callee_ex_dispatch_0::@callee_rgn_dispatch_0[%cst : vector<3xi32>](%arg0) : (tensor<4xf32>) -> tensor<4xf32> + %0 = flow.dispatch @callee_ex_dispatch_0::@callee_rgn_dispatch_0[%cst : index](%arg0) : (tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: return %0 : tensor<4xf32> return %0 : tensor<4xf32> }
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions.mlir b/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions.mlir index 5c654f2..c2cebbf 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions.mlir +++ b/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions.mlir
@@ -10,15 +10,15 @@ // CHECK-LABEL: @simpleMath func @simpleMath(%arg0 : tensor<4xf32>) -> tensor<4xf32> { - // CHECK-NEXT: constant dense<[4, 1, 1]> - // CHECK-NEXT: %0 = flow.dispatch.region - // CHECK-SAME: [%cst : vector<3xi32>] + // CHECK-NEXT: %[[WORKLOAD:.+]] = constant 4 + // CHECK-NEXT: %[[R1:.+]] = flow.dispatch.region + // CHECK-SAME: [%[[WORKLOAD]] : index] // CHECK-SAME: (%arg1 = %arg0 : tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: %1 = xla_hlo.add %arg1, %arg1 : tensor<4xf32> %0 = xla_hlo.add %arg0, %arg0 : tensor<4xf32> // CHECK-NEXT: flow.return %1 : tensor<4xf32> // CHECK-NEXT: } - // CHECK-NEXT: return %0 : tensor<4xf32> + // CHECK-NEXT: return %[[R1]] : tensor<4xf32> return %0 : tensor<4xf32> } @@ -26,9 +26,9 @@ // CHECK-LABEL: @stdElementwiseOps func @stdElementwiseOps(%arg0 : tensor<4xf32>) -> tensor<4xf32> { - // CHECK-NEXT: constant dense<[4, 1, 1]> - // CHECK-NEXT: %0 = flow.dispatch.region - // CHECK-SAME: [%cst : vector<3xi32>] + // CHECK-NEXT: %[[WORKLOAD:.+]] = constant 4 + // CHECK-NEXT: %[[R1:.+]] = flow.dispatch.region + // CHECK-SAME: [%[[WORKLOAD]] : index] // CHECK-SAME: (%arg1 = %arg0 : tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: %1 = addf %arg1, %arg1 : tensor<4xf32> %0 = addf %arg0, %arg0 : tensor<4xf32> @@ -38,7 +38,7 @@ %2 = mulf %1, %arg0 : tensor<4xf32> // CHECK-NEXT: flow.return %3 : tensor<4xf32> // CHECK-NEXT: } - // CHECK-NEXT: return %0 : tensor<4xf32> + // CHECK-NEXT: return %[[R1]] : tensor<4xf32> return %2 : tensor<4xf32> } @@ -46,9 +46,9 @@ // CHECK-LABEL: @hloElementwiseOps func @hloElementwiseOps(%arg0 : tensor<4xf32>) -> tensor<4xf32> { - // CHECK-NEXT: constant dense<[4, 1, 1]> + // CHECK-NEXT: %[[WORKLOAD:.+]] = constant 4 // CHECK-NEXT: %0 = flow.dispatch.region - // CHECK-SAME: [%cst : vector<3xi32>] + // CHECK-SAME: [%[[WORKLOAD]] : index] // CHECK-SAME: (%arg1 = %arg0 : tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: %1 = xla_hlo.add %arg1, %arg1 : tensor<4xf32> %0 = xla_hlo.add %arg0, %arg0 : tensor<4xf32> @@ -66,31 +66,33 @@ // CHECK-LABEL: @interleavedDot func @interleavedDot(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { - // CHECK-NEXT: %cst = constant dense<[4, 4, 1]> - // CHECK-NEXT: %0 = flow.dispatch.region - // CHECK-SAME: [%cst : vector<3xi32>] + // NOTE: Fragile ordering. Workload constants are emitted in order a the + // top of the block. + // CHECK: %[[WORKLOAD0:.+]] = constant 16 : index + // CHECK: %[[WORKLOAD1:.+]] = constant 16 : index + // CHECK: %[[WORKLOAD2:.+]] = constant 16 : index + // CHECK: %[[R0:.+]] = flow.dispatch.region + // CHECK-SAME: [%[[WORKLOAD0]] : index] // CHECK-SAME: (%arg1 = %arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { // CHECK-NEXT: %3 = xla_hlo.add %arg1, %arg1 : tensor<4x4xf32> %0 = xla_hlo.add %arg0, %arg0 : tensor<4x4xf32> // CHECK-NEXT: flow.return %3 : tensor<4x4xf32> // CHECK-NEXT: } - // CHECK-NEXT: %cst_0 = constant dense<[4, 4, 1]> : vector<3xi32> - // CHECK-NEXT: %1 = flow.dispatch.region - // CHECK-SAME: [%cst_0 : vector<3xi32>] - // CHECK-SAME: (%arg1 = %0 : tensor<4x4xf32>, %arg2 = %arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { + // CHECK: %[[R1:.+]] = flow.dispatch.region + // CHECK-SAME: [%[[WORKLOAD1]] : index] + // CHECK-SAME: (%arg1 = %[[R0]] : tensor<4x4xf32>, %arg2 = %arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { // CHECK-NEXT: %3 = "xla_hlo.dot"(%arg1, %arg2) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> %1 = "xla_hlo.dot"(%0, %arg0) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> // CHECK-NEXT: flow.return %3 : tensor<4x4xf32> // CHECK-NEXT: } - // CHECK-NEXT: %cst_1 = constant dense<[4, 4, 1]> : vector<3xi32> - // CHECK-NEXT: %2 = flow.dispatch.region - // CHECK-SAME: [%cst_1 : vector<3xi32>] - // CHECK-SAME: (%arg1 = %1 : tensor<4x4xf32>, %arg2 = %arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { + // CHECK: %[[R2:.+]] = flow.dispatch.region + // CHECK-SAME: [%[[WORKLOAD2]] : index] + // CHECK-SAME: (%arg1 = %[[R1]] : tensor<4x4xf32>, %arg2 = %arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { // CHECK-NEXT: %3 = xla_hlo.mul %arg1, %arg2 : tensor<4x4xf32> %2 = xla_hlo.mul %1, %arg0 : tensor<4x4xf32> // CHECK-NEXT: flow.return %3 : tensor<4x4xf32> // CHECK-NEXT: } - // CHECK-NEXT: return %2 : tensor<4x4xf32> + // CHECK-NEXT: return %[[R2]] : tensor<4x4xf32> return %2 : tensor<4x4xf32> } @@ -98,9 +100,9 @@ // CHECK-LABEL: func @caller func @caller(%arg0 : tensor<4xf32>) -> tensor<4xf32> { - // CHECK-NEXT: constant dense<[4, 1, 1]> - // CHECK-NEXT: %0 = flow.dispatch.region - // CHECK-SAME: [%cst : vector<3xi32>] + // CHECK-NEXT: %[[WORKLOAD0:.+]] = constant 4 : index + // CHECK-NEXT: %[[R0:.+]] = flow.dispatch.region + // CHECK-SAME: [%[[WORKLOAD0]] : index] // CHECK-SAME: (%arg1 = %arg0 : tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: %1 = xla_hlo.add %arg1, %arg1 : tensor<4xf32> %0 = xla_hlo.add %arg0, %arg0 : tensor<4xf32> @@ -110,20 +112,20 @@ %2 = xla_hlo.mul %1, %arg0 : tensor<4xf32> // CHECK-NEXT: flow.return %3 : tensor<4xf32> // CHECK-NEXT: } - // CHECK-NEXT: return %0 : tensor<4xf32> + // CHECK-NEXT: return %[[R0]] : tensor<4xf32> return %2 : tensor<4xf32> } // CHECK-LABEL: func @callee func @callee(%arg0 : tensor<4xf32>) -> tensor<4xf32> { - // CHECK-NEXT: constant dense<[4, 1, 1]> - // CHECK-NEXT: %0 = flow.dispatch.region - // CHECK-SAME: [%cst : vector<3xi32>] + // CHECK: %[[WORKLOAD0:.+]] = constant 4 : index + // CHECK: %[[R0:.+]] = flow.dispatch.region + // CHECK-SAME: [%[[WORKLOAD0]] : index] // CHECK-SAME: (%arg1 = %arg0 : tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: %1 = xla_hlo.mul %arg1, %arg1 : tensor<4xf32> %0 = xla_hlo.mul %arg0, %arg0 : tensor<4xf32> // CHECK-NEXT: flow.return %1 : tensor<4xf32> // CHECK-NEXT: } - // CHECK-NEXT: return %0 : tensor<4xf32> + // CHECK: return %[[R0]] : tensor<4xf32> return %0 : tensor<4xf32> } @@ -131,12 +133,12 @@ // CHECK-LABEL: @single_reduction func @single_reduction(%arg0 : tensor<4x8xf32>) -> tensor<4xf32> { - // CHECK-DAG: [[INITIAL:%.+]] = constant dense<0.000000e+00> + // CHECK-DAG: %[[INITIAL:.+]] = constant dense<0.000000e+00> %0 = constant dense<0.000000e+00> : tensor<f32> - // CHECK-DAG: constant dense<[4, 1, 1]> - // CHECK-NEXT: [[RESULT:%.+]] = flow.dispatch.region - // CHECK-SAME: [%cst_0 : vector<3xi32>] - // CHECK-SAME: (%arg1 = %arg0 : tensor<4x8xf32>, %arg2 = [[INITIAL]] : tensor<f32>) -> tensor<4xf32> + // CHECK-DAG: %[[WORKLOAD0:.+]] = constant 4 : index + // CHECK: %[[RESULT:.+]] = flow.dispatch.region + // CHECK-SAME: [%[[WORKLOAD0]] : index] + // CHECK-SAME: (%arg1 = %arg0 : tensor<4x8xf32>, %arg2 = %[[INITIAL]] : tensor<f32>) -> tensor<4xf32> // CHECK-NEXT: = "xla_hlo.reduce"(%arg1, %arg2) %1 = "xla_hlo.reduce"(%arg0, %0) ( { ^bb0(%arg1 : tensor<f32>, %arg2 : tensor<f32>): @@ -144,7 +146,7 @@ "xla_hlo.return"(%2) : (tensor<f32>) -> () }) {dimensions = dense<[1]> : tensor<1xi64>} : (tensor<4x8xf32>, tensor<f32>) -> tensor<4xf32> // CHECK: flow.return - // CHECK: return [[RESULT]] : tensor<4xf32> + // CHECK: return %[[RESULT]] : tensor<4xf32> return %1 : tensor<4xf32> } @@ -152,14 +154,14 @@ // CHECK-LABEL: @multi_reduction func @multi_reduction(%arg0 : tensor<4x8xf32>, %arg1 : tensor<4x8xf32>) -> (tensor<4xf32>, tensor<4xf32>) { - // CHECK-DAG: [[INITIALA:%.+]] = constant dense<0.000000e+00> + // CHECK-DAG: %[[INITIALA:.+]] = constant dense<0.000000e+00> %0 = constant dense<0.000000e+00> : tensor<f32> - // CHECK-DAG: [[INITIALB:%.+]] = constant dense<1.000000e+00> + // CHECK-DAG: %[[INITIALB:.+]] = constant dense<1.000000e+00> %1 = constant dense<1.000000e+00> : tensor<f32> - // CHECK: constant dense<[4, 1, 1]> - // CHECK-NEXT: [[RESULT:%.+]]:2 = flow.dispatch.region - // CHECK-SAME: [%cst_1 : vector<3xi32>] - // CHECK-SAME: (%arg2 = %arg0 : tensor<4x8xf32>, %arg3 = %arg1 : tensor<4x8xf32>, %arg4 = [[INITIALA]] : tensor<f32>, %arg5 = [[INITIALB]] : tensor<f32>) -> (tensor<4xf32>, tensor<4xf32>) + // CHECK-DAG: %[[WORKLOAD0:.+]] = constant 4 : index + // CHECK: %[[RESULT:.+]]:2 = flow.dispatch.region + // CHECK-SAME: [%[[WORKLOAD0]] : index] + // CHECK-SAME: (%arg2 = %arg0 : tensor<4x8xf32>, %arg3 = %arg1 : tensor<4x8xf32>, %arg4 = %[[INITIALA]] : tensor<f32>, %arg5 = %[[INITIALB]] : tensor<f32>) -> (tensor<4xf32>, tensor<4xf32>) // CHECK-NEXT: = "xla_hlo.reduce"(%arg2, %arg3, %arg4, %arg5) %2, %3 = "xla_hlo.reduce"(%arg0, %arg1, %0, %1) ( { ^bb0(%arg0_lhs : tensor<f32>, %arg1_lhs : tensor<f32>, %arg0_rhs : tensor<f32>, %arg1_rhs : tensor<f32>): @@ -168,7 +170,7 @@ "xla_hlo.return"(%4, %5) : (tensor<f32>, tensor<f32>) -> () }) {dimensions = dense<[1]> : tensor<1xi64>} : (tensor<4x8xf32>, tensor<4x8xf32>, tensor<f32>, tensor<f32>) -> (tensor<4xf32>, tensor<4xf32>) // CHECK: flow.return - // CHECK: return [[RESULT]]#0, [[RESULT]]#1 : tensor<4xf32>, tensor<4xf32> + // CHECK: return %[[RESULT]]#0, %[[RESULT]]#1 : tensor<4xf32>, tensor<4xf32> return %2, %3 : tensor<4xf32>, tensor<4xf32> }
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/legalize_input_types.mlir b/iree/compiler/Dialect/Flow/Transforms/test/legalize_input_types.mlir index c14e262..192eb03 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/legalize_input_types.mlir +++ b/iree/compiler/Dialect/Flow/Transforms/test/legalize_input_types.mlir
@@ -40,15 +40,6 @@ // ----- -// CHECK-LABEL: func @typesIndex -// CHECK-SAME: (%arg0: i32) -> i32 -func @typesIndex(%arg0 : index) -> index { - // CHECK-NEXT: return %arg0 : i32 - return %arg0 : index -} - -// ----- - // CHECK-LABEL: func @typesI64 // CHECK-SAME: (%arg0: i32) -> i32 func @typesI64(%arg0 : i64) -> i64 {
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/rematerialize_dispatch_constants.mlir b/iree/compiler/Dialect/Flow/Transforms/test/rematerialize_dispatch_constants.mlir index d56cf01..bf2716a 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/rematerialize_dispatch_constants.mlir +++ b/iree/compiler/Dialect/Flow/Transforms/test/rematerialize_dispatch_constants.mlir
@@ -2,12 +2,13 @@ // CHECK-LABEL: func @rematerializeSmall func @rematerializeSmall(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { - %cst = constant dense<[4, 4, 1]> : vector<3xi32> + // CHECK: %[[WORKLOAD0:.+]] = constant 16 : index + %cst = constant 16 : index %small = constant dense<1.23> : tensor<4x4xf32> - // CHECK: %0 = flow.dispatch.region[%cst : vector<3xi32>](%arg1 = %arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { - %0 = flow.dispatch.region[%cst : vector<3xi32>](%arg1 = %arg0 : tensor<4x4xf32>, %arg2 = %small : tensor<4x4xf32>) -> tensor<4x4xf32> { - // CHECK-NEXT: %cst_0 = constant dense<1.230000e+00> : tensor<4x4xf32> - // CHECK-NEXT: %1 = xla_hlo.add %arg1, %cst_0 : tensor<4x4xf32> + // CHECK: %[[R0:.+]] = flow.dispatch.region[%[[WORKLOAD0]] : index](%arg1 = %arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { + %0 = flow.dispatch.region[%cst : index](%arg1 = %arg0 : tensor<4x4xf32>, %arg2 = %small : tensor<4x4xf32>) -> tensor<4x4xf32> { + // CHECK-NEXT: %[[REMAT_SMALL:.+]] = constant dense<1.230000e+00> : tensor<4x4xf32> + // CHECK-NEXT: %1 = xla_hlo.add %arg1, %[[REMAT_SMALL]] : tensor<4x4xf32> %3 = xla_hlo.add %arg1, %arg2 : tensor<4x4xf32> flow.return %3 : tensor<4x4xf32> } @@ -18,11 +19,12 @@ // CHECK-LABEL: func @noRematerializeLarge func @noRematerializeLarge(%arg0 : tensor<4096x4xf32>) -> tensor<4096x4xf32> { - %cst = constant dense<[4, 4, 1]> : vector<3xi32> - // CHECK: %cst_0 = constant dense<1.230000e+00> : tensor<4096x4xf32> + // CHECK-DAG: %[[WORKLOAD0:.+]] = constant 16 : index + // CHECK-DAG: %[[CST:.+]] = constant dense<1.230000e+00> : tensor<4096x4xf32> + %cst = constant 16 : index %large = constant dense<1.23> : tensor<4096x4xf32> - // CHECK-NEXT: %0 = flow.dispatch.region[%cst : vector<3xi32>](%arg1 = %arg0 : tensor<4096x4xf32>, %arg2 = %cst_0 : tensor<4096x4xf32>) -> tensor<4096x4xf32> { - %0 = flow.dispatch.region[%cst : vector<3xi32>](%arg1 = %arg0 : tensor<4096x4xf32>, %arg2 = %large : tensor<4096x4xf32>) -> tensor<4096x4xf32> { + // CHECK-NEXT: %[[R0:.+]] = flow.dispatch.region[%[[WORKLOAD0]] : index](%arg1 = %arg0 : tensor<4096x4xf32>, %arg2 = %[[CST]] : tensor<4096x4xf32>) -> tensor<4096x4xf32> { + %0 = flow.dispatch.region[%cst : index](%arg1 = %arg0 : tensor<4096x4xf32>, %arg2 = %large : tensor<4096x4xf32>) -> tensor<4096x4xf32> { // CHECK-NEXT: %1 = xla_hlo.add %arg1, %arg2 : tensor<4096x4xf32> %3 = xla_hlo.add %arg1, %arg2 : tensor<4096x4xf32> flow.return %3 : tensor<4096x4xf32> @@ -34,11 +36,12 @@ // CHECK-LABEL: func @noRematerializeIntoDot func @noRematerializeIntoDot(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { - %cst = constant dense<[4, 4, 1]> : vector<3xi32> - // CHECK: %cst_0 = constant dense<1.230000e+00> : tensor<4x4xf32> + // CHECK-DAG: %[[WORKLOAD0:.+]] = constant 16 : index + // CHECK-DAG: %[[SMALL:.+]] = constant dense<1.230000e+00> : tensor<4x4xf32> + %cst = constant 16 : index %small = constant dense<1.23> : tensor<4x4xf32> - // CHECK-NEXT: %0 = flow.dispatch.region[%cst : vector<3xi32>](%arg1 = %arg0 : tensor<4x4xf32>, %arg2 = %cst_0 : tensor<4x4xf32>) -> tensor<4x4xf32> { - %0 = flow.dispatch.region[%cst : vector<3xi32>](%arg1 = %arg0 : tensor<4x4xf32>, %arg2 = %small : tensor<4x4xf32>) -> tensor<4x4xf32> { + // CHECK-NEXT: %[[R0:.+]] = flow.dispatch.region[%[[WORKLOAD0]] : index](%arg1 = %arg0 : tensor<4x4xf32>, %arg2 = %[[SMALL]] : tensor<4x4xf32>) -> tensor<4x4xf32> { + %0 = flow.dispatch.region[%cst : index](%arg1 = %arg0 : tensor<4x4xf32>, %arg2 = %small : tensor<4x4xf32>) -> tensor<4x4xf32> { // CHECK-NEXT: %1 = "xla_hlo.dot"(%arg1, %arg2) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> %3 = "xla_hlo.dot"(%arg1, %arg2) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> flow.return %3 : tensor<4x4xf32>
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/transformation.mlir b/iree/compiler/Dialect/Flow/Transforms/test/transformation.mlir index 4b1dee9..75fd8c3 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/transformation.mlir +++ b/iree/compiler/Dialect/Flow/Transforms/test/transformation.mlir
@@ -14,9 +14,7 @@ } // CHECK-LABEL: flow.executable @simpleMath_ex_dispatch_0 { -// CHECK-NEXT: flow.dispatch.entry @simpleMath_ex_dispatch_0 attributes { -// CHECK-SAME: workload = dense<[4, 1, 1]> : vector<3xi32> -// CHECK-SAME: } +// CHECK-NEXT: flow.dispatch.entry @simpleMath_ex_dispatch_0 // CHECK-NEXT: module { // CHECK-NEXT: func @simpleMath_ex_dispatch_0(%arg0: tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: %0 = xla_hlo.add %arg0, %arg0 : tensor<4xf32> @@ -25,9 +23,9 @@ // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: func @simpleMath(%arg0: tensor<4xf32>) -> tensor<4xf32> { -// CHECK-NEXT: %cst = constant dense<[4, 1, 1]> : vector<3xi32> -// CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg1 = %cst : vector<3xi32>, %arg2 = %arg0 : tensor<4xf32>) -> tensor<4xf32> { -// CHECK-NEXT: %1 = flow.dispatch @simpleMath_ex_dispatch_0::@simpleMath_ex_dispatch_0[%arg1 : vector<3xi32>](%arg2) : (tensor<4xf32>) -> tensor<4xf32> +// CHECK-NEXT: %[[WORKLOAD0:.+]] = constant 4 : index +// CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD0]] : index, %arg2 = %arg0 : tensor<4xf32>) -> tensor<4xf32> { +// CHECK-NEXT: %1 = flow.dispatch @simpleMath_ex_dispatch_0::@simpleMath_ex_dispatch_0[%arg1 : index](%arg2) : (tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: flow.return %1 : tensor<4xf32> // CHECK-NEXT: } // CHECK-NEXT: return %0 : tensor<4xf32> @@ -43,9 +41,7 @@ } // CHECK-LABEL: flow.executable @stdElementwiseOps_ex_dispatch_0 { -// CHECK-NEXT: flow.dispatch.entry @stdElementwiseOps_ex_dispatch_0 attributes { -// CHECK-SAME: workload = dense<[4, 1, 1]> : vector<3xi32> -// CHECK-SAME: } +// CHECK-NEXT: flow.dispatch.entry @stdElementwiseOps_ex_dispatch_0 // CHECK-NEXT: module { // CHECK-NEXT: func @stdElementwiseOps_ex_dispatch_0(%arg0: tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: %0 = addf %arg0, %arg0 : tensor<4xf32> @@ -56,9 +52,9 @@ // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: func @stdElementwiseOps(%arg0: tensor<4xf32>) -> tensor<4xf32> { -// CHECK-NEXT: %cst = constant dense<[4, 1, 1]> : vector<3xi32> -// CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg1 = %cst : vector<3xi32>, %arg2 = %arg0 : tensor<4xf32>) -> tensor<4xf32> { -// CHECK-NEXT: %1 = flow.dispatch @stdElementwiseOps_ex_dispatch_0::@stdElementwiseOps_ex_dispatch_0[%arg1 : vector<3xi32>](%arg2) : (tensor<4xf32>) -> tensor<4xf32> +// CHECK-NEXT: %[[WORKLOAD0:.+]] = constant 4 : index +// CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD0]] : index, %arg2 = %arg0 : tensor<4xf32>) -> tensor<4xf32> { +// CHECK-NEXT: %1 = flow.dispatch @stdElementwiseOps_ex_dispatch_0::@stdElementwiseOps_ex_dispatch_0[%arg1 : index](%arg2) : (tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: flow.return %1 : tensor<4xf32> // CHECK-NEXT: } // CHECK-NEXT: return %0 : tensor<4xf32> @@ -74,9 +70,7 @@ } // CHECK-LABEL: flow.executable @hloElementwiseOps_ex_dispatch_0 { -// CHECK-NEXT: flow.dispatch.entry @hloElementwiseOps_ex_dispatch_0 attributes { -// CHECK-SAME: workload = dense<[4, 1, 1]> : vector<3xi32> -// CHECK-SAME: } +// CHECK-NEXT: flow.dispatch.entry @hloElementwiseOps_ex_dispatch_0 // CHECK-NEXT: module { // CHECK-NEXT: func @hloElementwiseOps_ex_dispatch_0(%arg0: tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: %0 = xla_hlo.add %arg0, %arg0 : tensor<4xf32> @@ -87,9 +81,9 @@ // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: func @hloElementwiseOps(%arg0: tensor<4xf32>) -> tensor<4xf32> { -// CHECK-NEXT: %cst = constant dense<[4, 1, 1]> : vector<3xi32> -// CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg1 = %cst : vector<3xi32>, %arg2 = %arg0 : tensor<4xf32>) -> tensor<4xf32> { -// CHECK-NEXT: %1 = flow.dispatch @hloElementwiseOps_ex_dispatch_0::@hloElementwiseOps_ex_dispatch_0[%arg1 : vector<3xi32>](%arg2) : (tensor<4xf32>) -> tensor<4xf32> +// CHECK-NEXT: %[[WORKLOAD0:.+]] = constant 4 : index +// CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD0]] : index, %arg2 = %arg0 : tensor<4xf32>) -> tensor<4xf32> { +// CHECK-NEXT: %1 = flow.dispatch @hloElementwiseOps_ex_dispatch_0::@hloElementwiseOps_ex_dispatch_0[%arg1 : index](%arg2) : (tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: flow.return %1 : tensor<4xf32> // CHECK-NEXT: } // CHECK-NEXT: return %0 : tensor<4xf32> @@ -105,9 +99,7 @@ } // CHECK-LABEL: flow.executable @interleavedDot_ex_dispatch_0 { -// CHECK-NEXT: flow.dispatch.entry @interleavedDot_ex_dispatch_0 attributes { -// CHECK-SAME: workload = dense<[4, 4, 1]> : vector<3xi32> -// CHECK-SAME: } +// CHECK-NEXT: flow.dispatch.entry @interleavedDot_ex_dispatch_0 // CHECK-NEXT: module { // CHECK-NEXT: func @interleavedDot_ex_dispatch_0(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> { // CHECK-NEXT: %0 = xla_hlo.add %arg0, %arg0 : tensor<4x4xf32> @@ -116,9 +108,7 @@ // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: flow.executable @interleavedDot_ex_dispatch_1 { -// CHECK-NEXT: flow.dispatch.entry @interleavedDot_ex_dispatch_1 attributes { -// CHECK-SAME: workload = dense<[4, 4, 1]> : vector<3xi32> -// CHECK-SAME: } +// CHECK-NEXT: flow.dispatch.entry @interleavedDot_ex_dispatch_1 // CHECK-NEXT: module { // CHECK-NEXT: func @interleavedDot_ex_dispatch_1(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>) -> tensor<4x4xf32> { // CHECK-NEXT: %0 = "xla_hlo.dot"(%arg0, %arg1) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> @@ -127,9 +117,7 @@ // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: flow.executable @interleavedDot_ex_dispatch_2 { -// CHECK-NEXT: flow.dispatch.entry @interleavedDot_ex_dispatch_2 attributes { -// CHECK-SAME: workload = dense<[4, 4, 1]> : vector<3xi32> -// CHECK-SAME: } +// CHECK-NEXT: flow.dispatch.entry @interleavedDot_ex_dispatch_2 // CHECK-NEXT: module { // CHECK-NEXT: func @interleavedDot_ex_dispatch_2(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>) -> tensor<4x4xf32> { // CHECK-NEXT: %0 = xla_hlo.mul %arg0, %arg1 : tensor<4x4xf32> @@ -138,11 +126,11 @@ // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: func @interleavedDot(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> { -// CHECK-NEXT: %cst = constant dense<[4, 4, 1]> : vector<3xi32> -// CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg1 = %cst : vector<3xi32>, %arg2 = %arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { -// CHECK-NEXT: %1 = flow.dispatch @interleavedDot_ex_dispatch_0::@interleavedDot_ex_dispatch_0[%arg1 : vector<3xi32>](%arg2) : (tensor<4x4xf32>) -> tensor<4x4xf32> -// CHECK-NEXT: %2 = flow.dispatch @interleavedDot_ex_dispatch_1::@interleavedDot_ex_dispatch_1[%arg1 : vector<3xi32>](%1, %arg2) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> -// CHECK-NEXT: %3 = flow.dispatch @interleavedDot_ex_dispatch_2::@interleavedDot_ex_dispatch_2[%arg1 : vector<3xi32>](%2, %arg2) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> +// CHECK-NEXT: %[[WORKLOAD0:.+]] = constant 16 : index +// CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD0]] : index, %arg2 = %arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { +// CHECK-NEXT: %1 = flow.dispatch @interleavedDot_ex_dispatch_0::@interleavedDot_ex_dispatch_0[%arg1 : index](%arg2) : (tensor<4x4xf32>) -> tensor<4x4xf32> +// CHECK-NEXT: %2 = flow.dispatch @interleavedDot_ex_dispatch_1::@interleavedDot_ex_dispatch_1[%arg1 : index](%1, %arg2) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> +// CHECK-NEXT: %3 = flow.dispatch @interleavedDot_ex_dispatch_2::@interleavedDot_ex_dispatch_2[%arg1 : index](%2, %arg2) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> // CHECK-NEXT: flow.return %3 : tensor<4x4xf32> // CHECK-NEXT: } // CHECK-NEXT: return %0 : tensor<4x4xf32> @@ -161,7 +149,7 @@ } // CHECK-LABEL: flow.executable @reduction_ex_dispatch_0 { -// CHECK-NEXT: flow.dispatch.entry @reduction_ex_dispatch_0 attributes {workload = dense<[4, 1, 1]> : vector<3xi32>} +// CHECK-NEXT: flow.dispatch.entry @reduction_ex_dispatch_0 // CHECK-NEXT: module { // CHECK-NEXT: func @reduction_ex_dispatch_0(%arg0: tensor<4x8xf32>) -> tensor<4xf32> { // CHECK-NEXT: %cst = constant dense<0.000000e+00> : tensor<f32> @@ -175,9 +163,9 @@ // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: func @reduction(%arg0: tensor<4x8xf32>) -> tensor<4xf32> { -// CHECK-NEXT: %cst = constant dense<[4, 1, 1]> : vector<3xi32> -// CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg1 = %cst : vector<3xi32>, %arg2 = %arg0 : tensor<4x8xf32>) -> tensor<4xf32> { -// CHECK-NEXT: %1 = flow.dispatch @reduction_ex_dispatch_0::@reduction_ex_dispatch_0[%arg1 : vector<3xi32>](%arg2) : (tensor<4x8xf32>) -> tensor<4xf32> +// CHECK-NEXT: %[[WORKLOAD0:.+]] = constant 4 : index +// CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD0]] : index, %arg2 = %arg0 : tensor<4x8xf32>) -> tensor<4xf32> { +// CHECK-NEXT: %1 = flow.dispatch @reduction_ex_dispatch_0::@reduction_ex_dispatch_0[%arg1 : index](%arg2) : (tensor<4x8xf32>) -> tensor<4xf32> // CHECK-NEXT: flow.return %1 : tensor<4xf32> // CHECK-NEXT: } // CHECK-NEXT: return %0 : tensor<4xf32> @@ -192,7 +180,7 @@ } // CHECK-LABEL: flow.executable @dynamicUpdateSlice_ex_dispatch_0 { -// CHECK-NEXT: flow.dispatch.entry @dynamicUpdateSlice_ex_dispatch_0 attributes {workload = dense<[4, 2, 1]> : vector<3xi32>} +// CHECK-NEXT: flow.dispatch.entry @dynamicUpdateSlice_ex_dispatch_0 // CHECK-NEXT: module { // CHECK-NEXT: func @dynamicUpdateSlice_ex_dispatch_0(%arg0: tensor<2x4xi32>, %arg1: tensor<2x4xi32>) -> tensor<2x4xi32> { // CHECK-NEXT: %0 = xla_hlo.add %arg0, %arg1 : tensor<2x4xi32> @@ -201,12 +189,12 @@ // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: func @dynamicUpdateSlice(%arg0: tensor<2x4xi32>, %arg1: tensor<1x1xi32>, %arg2: tensor<i32>, %arg3: tensor<i32>) -> tensor<2x4xi32> { -// CHECK-NEXT: %cst = constant dense<[4, 2, 1]> : vector<3xi32> +// CHECK-NEXT: %[[WORKLOAD0:.+]] = constant 8 : index // CHECK-NEXT: %0 = flow.tensor.load %arg2 : tensor<i32> // CHECK-NEXT: %1 = flow.tensor.load %arg3 : tensor<i32> -// CHECK-NEXT: %2 = flow.ex.stream.fragment(%arg4 = %arg1 : tensor<1x1xi32>, %arg5 = %arg0 : tensor<2x4xi32>, %arg6 = %0 : i32, %arg7 = %1 : i32, %arg8 = %cst : vector<3xi32>) -> tensor<2x4xi32> { +// CHECK-NEXT: %2 = flow.ex.stream.fragment(%arg4 = %arg1 : tensor<1x1xi32>, %arg5 = %arg0 : tensor<2x4xi32>, %arg6 = %0 : i32, %arg7 = %1 : i32, %arg8 = %[[WORKLOAD0]] : index) -> tensor<2x4xi32> { // CHECK-NEXT: %3 = flow.tensor.update %arg4, %arg5[%arg6, %arg7] : tensor<1x1xi32> -> tensor<2x4xi32> -// CHECK-NEXT: %4 = flow.dispatch @dynamicUpdateSlice_ex_dispatch_0::@dynamicUpdateSlice_ex_dispatch_0[%arg8 : vector<3xi32>](%arg5, %3) : (tensor<2x4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32> +// CHECK-NEXT: %4 = flow.dispatch @dynamicUpdateSlice_ex_dispatch_0::@dynamicUpdateSlice_ex_dispatch_0[%arg8 : index](%arg5, %3) : (tensor<2x4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32> // CHECK-NEXT: flow.return %4 : tensor<2x4xi32> // CHECK-NEXT: } // CHECK-NEXT: return %2 : tensor<2x4xi32>
diff --git a/iree/compiler/Dialect/Flow/Utils/BUILD b/iree/compiler/Dialect/Flow/Utils/BUILD index 2c3ba09..1fb34ad 100644 --- a/iree/compiler/Dialect/Flow/Utils/BUILD +++ b/iree/compiler/Dialect/Flow/Utils/BUILD
@@ -29,6 +29,7 @@ ], deps = [ "//iree/compiler/Dialect/Flow/IR", + "//iree/compiler/Dialect/Shape/IR", "@llvm-project//llvm:support", "@llvm-project//mlir:IR", "@llvm-project//mlir:StandardOps",
diff --git a/iree/compiler/Dialect/Flow/Utils/WorkloadUtils.cpp b/iree/compiler/Dialect/Flow/Utils/WorkloadUtils.cpp index dfb0fb0..0ec15b6 100644 --- a/iree/compiler/Dialect/Flow/Utils/WorkloadUtils.cpp +++ b/iree/compiler/Dialect/Flow/Utils/WorkloadUtils.cpp
@@ -15,7 +15,11 @@ #include "iree/compiler/Dialect/Flow/Utils/WorkloadUtils.h" #include <array> +#include <limits> +#include "iree/compiler/Dialect/Shape/IR/Builders.h" +#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h" +#include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" @@ -25,59 +29,55 @@ namespace mlir { namespace iree_compiler { + +using Shape::buildOrFindRankedShapeForValue; +using Shape::RankedDimOp; + namespace IREE { namespace Flow { Value calculateWorkload(Operation *op, Value baseOperand) { - OpBuilder builder(op); - - std::array<int32_t, 3> workload = {1, 1, 1}; - - // TODO(b/139353314): lookup/calculate based on type/etc. + OpBuilder builder(op->getContext()); auto baseOperandType = baseOperand.getType().cast<ShapedType>(); - if (!baseOperandType.hasStaticShape()) { - op->emitOpError() << "Dynamic shapes not yet supported"; + if (baseOperandType.hasRank() && baseOperandType.hasStaticShape()) { + // Just a constant (note this also covers rank0). + int64_t numElements = baseOperandType.getNumElements(); + if (numElements > std::numeric_limits<int32_t>::max()) { + return (op->emitOpError() + << "total element count > 32bit integer capacity"), + nullptr; + } + builder.setInsertionPointToStart(op->getBlock()); + return builder.create<ConstantOp>( + op->getLoc(), builder.getIndexType(), + builder.getIntegerAttr(builder.getIndexType(), numElements)); + } else if (baseOperandType.hasRank()) { + // Materialize a ranked shape and compute. + auto rankedShape = buildOrFindRankedShapeForValue( + op->getLoc(), baseOperand, builder.getIndexType(), builder); + if (!rankedShape) return nullptr; + // Ensure to emit with proper dominance. + // TODO(laurenzo): Need to overhaul insertion points generally in + // dispatch region formation. + if (rankedShape.getDefiningOp()) { + builder.setInsertionPointAfter(rankedShape.getDefiningOp()); + } + Value numElements; + for (int64_t i = 0, e = baseOperandType.getRank(); i < e; ++i) { + auto dim = builder.create<RankedDimOp>(op->getLoc(), rankedShape, i); + if (!numElements) { + numElements = dim; + continue; + } + numElements = builder.create<MulIOp>(op->getLoc(), numElements, dim); + } + op->getParentOp()->dump(); + return numElements; + } else { + op->emitOpError() + << "unranked shapes not supported for workload calculation"; return nullptr; } - auto shape = baseOperandType.getShape(); - if (auto conv = dyn_cast_or_null<xla_hlo::ConvOp>(op)) { - workload[2] = - shape[conv.dimension_numbers().output_batch_dimension().getInt()]; - int i = 0; - for (const auto &dim : - conv.dimension_numbers().output_spatial_dimensions().getIntValues()) { - if (i > 1) { - break; - } - workload[1 - i++] = shape[dim.getSExtValue()]; - } - } else { - // Drop the trailing ones from the shape. - while (shape.size() > 1 && shape.back() == 1) { - shape = shape.drop_back(); - } - if (shape.size() <= 3) { - // Maps to XYZ (possibly with 1's for unused dimensions). - for (auto dim : enumerate(shape)) { - workload[shape.size() - 1 - dim.index()] = dim.value(); - } - } else { - // Need to flatten the shape to fit XYZ. For now we just squash from LHS. - workload[2] = 1; - for (int i = 0; i < shape.size() - 2; ++i) { - workload[2] *= shape[i]; - } - workload[1] = shape[shape.size() - 2]; - workload[0] = shape.back(); - } - } - - // TODO(b/139353314): optimize workload layout. - - auto constantType = VectorType::get({3}, builder.getIntegerType(32)); - return builder.create<ConstantOp>( - op->getLoc(), constantType, - DenseIntElementsAttr::get(constantType, workload)); } } // namespace Flow
diff --git a/iree/compiler/Dialect/Flow/Utils/WorkloadUtils.h b/iree/compiler/Dialect/Flow/Utils/WorkloadUtils.h index 70147c2..5edf468 100644 --- a/iree/compiler/Dialect/Flow/Utils/WorkloadUtils.h +++ b/iree/compiler/Dialect/Flow/Utils/WorkloadUtils.h
@@ -24,7 +24,7 @@ namespace Flow { // Calculates the workload for |op| based on the given operation operand. -// Returns a vector<3xi32> containing the X, Y, Z workload parameters. +// Returns an index representing the total number of invocations required. // // The |baseOperand| is usually one of the results of a dispatch that signifies // how many invocations are ideal for writing the result. Later on in the
diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertStreamOps.cpp b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertStreamOps.cpp index 159b442..e2f60ea 100644 --- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertStreamOps.cpp +++ b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertStreamOps.cpp
@@ -19,6 +19,7 @@ #include "iree/compiler/Dialect/HAL/Utils/TypeUtils.h" #include "iree/compiler/Dialect/IREE/IR/IREETypes.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/Support/Debug.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" @@ -27,6 +28,8 @@ #include "mlir/IR/Module.h" #include "mlir/Transforms/DialectConversion.h" +#define DEBUG_TYPE "iree-hal" + namespace mlir { namespace iree_compiler { namespace { @@ -163,32 +166,39 @@ } // Returns a the (x, y, z) workgroup counts calculated from the given |workload| -// and the workgroup size of the dispatch |entryPointOp|. +// (invocation count) and the workgroup size of the dispatch |entryPointOp|. static std::array<Value, 3> getDispatchWorkgroupCounts( IREE::HAL::ExecutableEntryPointOp entryPointOp, Value workload, ConversionPatternRewriter &rewriter) { std::array<Value, 3> result; auto loc = entryPointOp.getLoc(); + auto i32Type = rewriter.getIntegerType(32); + auto constantOne = rewriter.createOrFold<mlir::ConstantOp>( + loc, rewriter.getI32IntegerAttr(1)); + workload = rewriter.createOrFold<mlir::IndexCastOp>(loc, i32Type, workload); for (int i = 0; i < 3; ++i) { // Round up: (workload + workgroup_size - 1) / workgroup_size; - auto workloadI = rewriter.createOrFold<ExtractElementOp>( - loc, workload, - rewriter.createOrFold<mlir::ConstantOp>( - loc, IntegerAttr::get(rewriter.getIndexType(), i))); auto workgroupSizeI = rewriter.createOrFold<mlir::ConstantOp>( loc, rewriter.getI32IntegerAttr( entryPointOp.workgroup_size().getValue<int32_t>( {static_cast<uint64_t>(i)}))); + auto rounded = rewriter.createOrFold<mlir::SubIOp>( + loc, rewriter.createOrFold<mlir::AddIOp>(loc, workload, workgroupSizeI), + constantOne); auto workgroupCountI = rewriter.createOrFold<mlir::UnsignedDivIOp>( - loc, - rewriter.createOrFold<mlir::SubIOp>( - loc, - rewriter.createOrFold<mlir::AddIOp>(loc, workloadI, workgroupSizeI), - rewriter.createOrFold<mlir::ConstantOp>( - loc, rewriter.getI32IntegerAttr(1))), - workgroupSizeI); - + loc, rounded, workgroupSizeI); result[i] = workgroupCountI; + + // Multiply back out and subtract from invocations. + workload = rewriter.createOrFold<SubIOp>( + loc, workload, + rewriter.createOrFold<MulIOp>(loc, workgroupCountI, rounded)); + + // Ensure > 0. + auto workloadGreaterZero = + rewriter.create<CmpIOp>(loc, CmpIPredicate::sge, workload, constantOne); + workload = rewriter.create<SelectOp>(loc, workloadGreaterZero, workload, + constantOne); } return result; }
diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/stream_ops.mlir b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/stream_ops.mlir index 39ae8cb..86c92e4 100644 --- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/stream_ops.mlir +++ b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/stream_ops.mlir
@@ -18,14 +18,14 @@ // CHECK-DAG: [[C1:%.+]] = constant 1 // CHECK-DAG: [[C4:%.+]] = constant 4 // CHECK-DAG: [[C128:%.+]] = constant 128 - %cst = constant dense<[128, 1, 1]> : vector<3xi32> + %cst = constant 128 : index // CHECK: [[RET_BUF:%.+]] = hal.allocator.allocate {{.+}}, "HostVisible|DeviceVisible|DeviceLocal", "Constant|Transfer|Mapping|Dispatch" // CHECK-NEXT: hal.ex.defer_release [[RET_BUF]] // CHECK: [[TMP_BUF:%.+]] = hal.allocator.allocate {{.+}}, "DeviceVisible|DeviceLocal", "Transfer|Dispatch" // CHECK-NEXT: hal.ex.defer_release [[TMP_BUF]] // CHECK: [[CMD:%.+]] = hal.command_buffer.create {{.+}}, "OneShot", "Transfer|Dispatch" // CHECK-NEXT: hal.command_buffer.begin [[CMD]] - %0 = flow.ex.stream.fragment(%arg1 = %cst : vector<3xi32>, %arg2 = %arg0 : tensor<128xf32>) -> tensor<128xf32> { + %0 = flow.ex.stream.fragment(%arg1 = %cst : index, %arg2 = %arg0 : tensor<128xf32>) -> tensor<128xf32> { // CHECK: [[EXE:%.+]] = hal.ex.cache_executable {{.+}}, @ex0 : !hal.executable // CHECK-NEXT: hal.ex.push_binding [[CMD]], 0, %arg0, shape = [ // CHECK-SAME: [[C128]] @@ -39,7 +39,7 @@ // CHECK-SAME: [[C4]], [[C1]], [[C1]] // CHECK-SAME: ] // CHECK: hal.command_buffer.execution_barrier - %1 = flow.dispatch @ex0::@entry0[%arg1 : vector<3xi32>](%arg2) : (tensor<128xf32>) -> tensor<128xf32> + %1 = flow.dispatch @ex0::@entry0[%arg1 : index](%arg2) : (tensor<128xf32>) -> tensor<128xf32> // CHECK: hal.ex.push_binding [[CMD]], 0, [[TMP_BUF]], shape = [ // CHECK-SAME: [[C128]] // CHECK-SAME: ], element_type = 50331680 @@ -52,7 +52,7 @@ // CHECK-SAME: [[C4]], [[C1]], [[C1]] // CHECK-SAME: ] // CHECK: hal.command_buffer.execution_barrier - %2 = flow.dispatch @ex0::@entry0[%arg1 : vector<3xi32>](%1) : (tensor<128xf32>) -> tensor<128xf32> + %2 = flow.dispatch @ex0::@entry0[%arg1 : index](%1) : (tensor<128xf32>) -> tensor<128xf32> flow.return %2 : tensor<128xf32> } // CHECK: hal.command_buffer.end [[CMD]]
diff --git a/iree/compiler/Dialect/Shape/IR/ShapeOps.cpp b/iree/compiler/Dialect/Shape/IR/ShapeOps.cpp index 388ad6f..523bb37 100644 --- a/iree/compiler/Dialect/Shape/IR/ShapeOps.cpp +++ b/iree/compiler/Dialect/Shape/IR/ShapeOps.cpp
@@ -349,6 +349,7 @@ void ConstRankedShapeOp::build(Builder *builder, OperationState &result, Type type) { + assert(type.cast<RankedShapeType>().isFullyStatic()); result.types.push_back(type); }
diff --git a/iree/compiler/Dialect/VM/Conversion/BUILD b/iree/compiler/Dialect/VM/Conversion/BUILD index c629880..ebd1918 100644 --- a/iree/compiler/Dialect/VM/Conversion/BUILD +++ b/iree/compiler/Dialect/VM/Conversion/BUILD
@@ -34,6 +34,7 @@ "//iree/compiler/Dialect/IREE/IR", "//iree/compiler/Dialect/Shape/IR", "//iree/compiler/Dialect/VM/IR", + "@llvm-project//llvm:support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", "@llvm-project//mlir:StandardOps",
diff --git a/iree/compiler/Dialect/VM/Conversion/StandardToVM/BUILD b/iree/compiler/Dialect/VM/Conversion/StandardToVM/BUILD index 17cc5ac..ec10c21 100644 --- a/iree/compiler/Dialect/VM/Conversion/StandardToVM/BUILD +++ b/iree/compiler/Dialect/VM/Conversion/StandardToVM/BUILD
@@ -16,11 +16,9 @@ "//iree/compiler/Dialect/IREE/IR", "//iree/compiler/Dialect/VM/Conversion", "//iree/compiler/Dialect/VM/IR", - "@llvm-project//llvm:support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:StandardOps", - "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", ],
diff --git a/iree/compiler/Dialect/VM/Conversion/TypeConverter.cpp b/iree/compiler/Dialect/VM/Conversion/TypeConverter.cpp index 462f74d..e402560 100644 --- a/iree/compiler/Dialect/VM/Conversion/TypeConverter.cpp +++ b/iree/compiler/Dialect/VM/Conversion/TypeConverter.cpp
@@ -15,10 +15,14 @@ #include "iree/compiler/Dialect/VM/Conversion/TypeConverter.h" #include "iree/compiler/Dialect/IREE/IR/IREETypes.h" +#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h" #include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h" #include "iree/compiler/Dialect/VM/IR/VMOps.h" +#include "llvm/Support/Debug.h" #include "mlir/IR/StandardTypes.h" +#define DEBUG_TYPE "iree-vm" + namespace mlir { namespace iree_compiler { @@ -56,7 +60,7 @@ } return IREE::PtrType::get(targetType); }); - // Convert ranked shape types. + // Convert ranked shape types (expanding all dims). addConversion( [](Shape::RankedShapeType rankedShape, SmallVectorImpl<Type> &results) { for (int i = 0; i < rankedShape.getRank(); ++i) { @@ -72,6 +76,11 @@ Type resultType, ArrayRef<Value> inputs, Location loc) { + LLVM_DEBUG(llvm::dbgs() << "MATERIALIZE CONVERSION: " << resultType << "\n"); + if (auto rsType = resultType.dyn_cast<Shape::RankedShapeType>()) { + return rewriter.create<Shape::MakeRankedShapeOp>(loc, rsType, inputs); + } + // TODO(b/145876978): materialize conversion when this is called. llvm_unreachable("unhandled materialization"); return nullptr;
diff --git a/iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.cpp b/iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.cpp index d21827e..49404f1 100644 --- a/iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.cpp +++ b/iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.cpp
@@ -15,6 +15,7 @@ #include "iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.h" #include "iree/compiler/Dialect/IREE/IR/IREETypes.h" +#include "iree/compiler/Dialect/Shape/IR/Builders.h" #include "iree/compiler/Dialect/Shape/IR/ShapeOps.h" #include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h" #include "iree/compiler/Dialect/VMLA/Conversion/TypeConverter.h" @@ -29,6 +30,8 @@ namespace mlir { namespace iree_compiler { +using Shape::buildOrFindRankedShapeForValue; + VMLAConversionTarget::VMLAConversionTarget(MLIRContext *context, TypeConverter &typeConverter) : ConversionTarget(*context), @@ -188,12 +191,8 @@ Value VMLAConversionTarget::getTensorShape( Location loc, Value originalValue, TypeConverter &typeConverter, ConversionPatternRewriter &rewriter) { - // TODO(benvanik): use tie_shape to find the ranked shape to use for the - // originalValue tensor. - auto originalType = originalValue.getType().cast<ShapedType>(); - return rewriter.createOrFold<Shape::ConstRankedShapeOp>( - loc, Shape::RankedShapeType::get(originalType.getShape(), - rewriter.getIntegerType(32))); + return buildOrFindRankedShapeForValue(loc, originalValue, + rewriter.getIntegerType(32), rewriter); } // static @@ -224,6 +223,9 @@ VMLATypeConverter::getRoundedElementByteWidth(elementType))); auto shape = getTensorShape(loc, tensorValue, typeConverter, rewriter); + if (!shape) { + return nullptr; + } Value offset = rewriter.createOrFold<mlir::ConstantOp>( loc, rewriter.getIntegerType(32), rewriter.getI32IntegerAttr(0)); for (int i = 0; i < tensorType.getRank(); ++i) { @@ -250,6 +252,7 @@ VMLATypeConverter::getRoundedElementByteWidth(elementType))); auto shape = getTensorShape(loc, tensorValue, typeConverter, rewriter); + if (!shape) return nullptr; auto dims = rewriter.create<Shape::RankedDimsOp>(loc, shape); Value length = elementSize; for (auto dim : dims.getResults()) {
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HALToVMLA/ConvertHALToVMLA.cpp b/iree/compiler/Dialect/VMLA/Conversion/HALToVMLA/ConvertHALToVMLA.cpp index c4f4568..8634f12 100644 --- a/iree/compiler/Dialect/VMLA/Conversion/HALToVMLA/ConvertHALToVMLA.cpp +++ b/iree/compiler/Dialect/VMLA/Conversion/HALToVMLA/ConvertHALToVMLA.cpp
@@ -70,6 +70,7 @@ interfaceArg, bindingOp.set(), bindingOp.binding()); auto byteLengthValue = VMLAConversionTarget::getBufferLength( loadOp.getLoc(), loadOp.result(), typeConverter, rewriter); + if (!byteLengthValue) return matchFailure(); rewriter.replaceOpWithNewOp<IREE::VMLA::BufferViewOp>( loadOp, IREE::VMLA::BufferType::get(loadOp.getContext()), bufferOp.result(), newOperands.offset(), byteLengthValue);