Merge pull request #6830 from hanhanW:main-to-google PiperOrigin-RevId: 392697643
diff --git a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py index 065e603..f877e5e 100644 --- a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py +++ b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
@@ -12,6 +12,11 @@ "//build_tools:default_linkopts": [], "//build_tools:dl": ["${CMAKE_DL_LIBS}"], + # IREE llvm-external-projects + "//llvm-external-projects/iree-dialects:IREEDialect": [ + "IREEDialectsIREEDialect" + ], + # LLVM "@llvm-project//llvm:IPO": ["LLVMipo"], # MLIR
diff --git a/integrations/tensorflow/WORKSPACE b/integrations/tensorflow/WORKSPACE index b7b0b6c..b7e8507 100644 --- a/integrations/tensorflow/WORKSPACE +++ b/integrations/tensorflow/WORKSPACE
@@ -27,35 +27,6 @@ load("@bazel_skylib//lib:paths.bzl", "paths") ################################################################################ -################################## TensorFlow ################################## -maybe( - local_repository, - name = "org_tensorflow", - path = paths.join(IREE_PATH, "third_party/tensorflow"), -) - -# Import all of the tensorflow dependencies. Note that we are deliberately -# letting TensorFlow take control of all the dependencies it sets up, whereas -# ours are initialized with `maybe`. Actually tracking this with Bazel is PITA -# and for now this gets TF stuff building. This includes, for instance, -# @llvm-project and @com_google_absl. -load("@org_tensorflow//tensorflow:workspace3.bzl", "tf_workspace3") - -tf_workspace3() - -load("@org_tensorflow//tensorflow:workspace2.bzl", "tf_workspace2") - -tf_workspace2() - -load("@org_tensorflow//tensorflow:workspace1.bzl", "tf_workspace1") - -tf_workspace1() - -load("@org_tensorflow//tensorflow:workspace0.bzl", "tf_workspace0") - -tf_workspace0() -################################################################################ - ##################################### IREE ##################################### # We need a shim here to make this use the version of mlir-hlo present in # TensorFlow. This shim is just alias rules that forward to the TF rule. Note @@ -79,4 +50,43 @@ iree_path = IREE_PATH, iree_repo_alias = "@iree", ) + +new_local_repository( + name = "llvm-raw", + build_file_content = "# empty", + path = paths.join(IREE_PATH, "third_party/llvm-project"), +) + +load("@llvm-raw//utils/bazel:configure.bzl", "llvm_configure", "llvm_disable_optional_support_deps") + +llvm_configure(name = "llvm-project") + +llvm_disable_optional_support_deps() + +################################################################################ + +################################## TensorFlow ################################## +maybe( + local_repository, + name = "org_tensorflow", + path = paths.join(IREE_PATH, "third_party/tensorflow"), +) + +# Import all of the tensorflow dependencies. Note that any repository previously +# defined (e.g. from IREE submodules) will be skipped. +load("@org_tensorflow//tensorflow:workspace3.bzl", "tf_workspace3") + +tf_workspace3() + +load("@org_tensorflow//tensorflow:workspace2.bzl", "tf_workspace2") + +tf_workspace2() + +load("@org_tensorflow//tensorflow:workspace1.bzl", "tf_workspace1") + +tf_workspace1() + +load("@org_tensorflow//tensorflow:workspace0.bzl", "tf_workspace0") + +tf_workspace0() ################################################################################
diff --git a/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp b/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp index f00eca4..586b289 100644 --- a/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp +++ b/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp
@@ -41,7 +41,7 @@ } StringRef getDescription() const override { - return " Wraps all entry points in a function that is compatible with the " + return "Wraps all entry points in a function that is compatible with the " "expected invocation semantics of bindings following the native " "IREE ABI."; } @@ -51,7 +51,9 @@ SmallVector<FuncOp, 4> entryFuncOps; for (auto funcOp : moduleOp.getOps<FuncOp>()) { - if (funcOp.isPublic()) entryFuncOps.push_back(funcOp); + if (funcOp.isPublic() && !funcOp->hasAttr("iree.abi.stub")) { + entryFuncOps.push_back(funcOp); + } } // Create a wrapper function for each entry point.
diff --git a/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points.mlir b/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points.mlir index 9ab791a..78e47bb 100644 --- a/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points.mlir +++ b/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points.mlir
@@ -26,3 +26,13 @@ %1 = "mhlo.add"(%0, %arg0) : (tensor<?x8x8x3xf32>, tensor<?x8x8x3xf32>) -> tensor<?x8x8x3xf32> return %0, %1 : tensor<?x8x8x3xf32>, tensor<?x8x8x3xf32> } + +// ----- + +// CHECK-LABEL: func @wrappedAlready +// CHECK-SAME: (%arg0: !hal.buffer_view) -> !hal.buffer_view +// CHECK-SAME: attributes {iree.abi.stub} +func @wrappedAlready(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} { + return %arg0 : !hal.buffer_view +} +// CHECK-NOT: func @_wrappedAlready
diff --git a/iree/compiler/Bindings/TFLite/Transforms/WrapEntryPoints.cpp b/iree/compiler/Bindings/TFLite/Transforms/WrapEntryPoints.cpp index 1c7cdbb..981cf82 100644 --- a/iree/compiler/Bindings/TFLite/Transforms/WrapEntryPoints.cpp +++ b/iree/compiler/Bindings/TFLite/Transforms/WrapEntryPoints.cpp
@@ -44,7 +44,9 @@ SmallVector<FuncOp, 4> entryFuncOps; for (auto funcOp : moduleOp.getOps<FuncOp>()) { - if (funcOp.isPublic()) entryFuncOps.push_back(funcOp); + if (funcOp.isPublic() && !funcOp->hasAttr("iree.abi.stub")) { + entryFuncOps.push_back(funcOp); + } } if (entryFuncOps.size() == 0) { moduleOp.emitError()
diff --git a/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp index 8e9d20c..a069bed 100644 --- a/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp +++ b/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -103,7 +103,7 @@ {tileX / workgroupSize[1], tileY / workgroupSize[0]}); tileSizes.push_back(invocationLevelTs); // Thread level. return setOpConfigAndEntryPointFnTranslation( - entryPoint, op, tileSizes, /*nativeVectorSize=*/ArrayRef<int64_t>{}, + entryPoint, op, tileSizes, /*nativeVectorSizes=*/ArrayRef<int64_t>{}, IREE::HAL::DispatchLoweringPassPipeline::LLVMGPUMatmulSimt, workgroupSize); } @@ -184,7 +184,7 @@ tileSizes.push_back({}); // Subgroup level. tileSizes.emplace_back(std::move(threadTileSizes)); // Thread level return setOpConfigAndEntryPointFnTranslation( - entryPoint, op, tileSizes, /*nativeVectorSize=*/ArrayRef<int64_t>{}, + entryPoint, op, tileSizes, /*nativeVectorSizes=*/ArrayRef<int64_t>{}, IREE::HAL::DispatchLoweringPassPipeline::LLVMGPUVectorize, workgroupSize); }
diff --git a/iree/compiler/Codegen/Passes.h b/iree/compiler/Codegen/Passes.h index 84cfd52..01fd686 100644 --- a/iree/compiler/Codegen/Passes.h +++ b/iree/compiler/Codegen/Passes.h
@@ -254,53 +254,51 @@ std::unique_ptr<OperationPass<FuncOp>> createLLVMGPUPipeliningPass(); //------------------------------------------------------------------------------ -// SPIRV Passes +// SPIR-V Passes //------------------------------------------------------------------------------ -/// Pass pipeline to lower executable obtained from Linalg tile + distribute to -/// scalar + vector code. Does distribution to threads (no vectorization). -void addSPIRVDistributePassPipeline(OpPassManager &pm); +/// Pass pipeline to lower IREE HAL executables with workgroup tiled and +/// distributed Linalg ops to SPIR-V scalar code. Additionally performs +/// distribution to threads without vectorization. +void addSPIRVTileAndDistributePassPipeline(OpPassManager &pm); -/// Pass pipeline to lower executables that contain operations that are not -/// tiled + distributed. -void addSPIRVDistributeToGlobalIDPipeline(OpPassManager &pm); +/// Pass pipeline to lower IREE HAL executables that contain Linalg ops that are +/// not tiled/distributed. Performs distribution to global invocations. +void addSPIRVDistributeToGlobalIDPassPipeline(OpPassManager &pm); -/// pipeline to lower executable obtained from Linalg tile + distribute to -/// scalar + vector code. Does distribution to threads and vectorization. -void addSPIRVVectorizationPassPipeline(OpPassManager &pm); +/// Pass pipeline to lower IREE HAL executables with workgroup tiled and +/// distributed Linalg ops to SPIR-V scalar and vector code. Additionally +/// performs distribution to threads with vectorization. +void addSPIRVTileAndVectorizePassPipeline(OpPassManager &pm); /// Pass to perform the final conversion to SPIR-V dialect. +/// /// This pass converts remaining interface ops into SPIR-V global variables, /// GPU processor ID ops into SPIR-V global variables, loop/standard ops into /// corresponding SPIR-V ops. std::unique_ptr<OperationPass<ModuleOp>> createConvertToSPIRVPass(); -/// Pass to add the synchronizations and attributes needed to lower from PLoops -/// to GPU dialect. -std::unique_ptr<OperationPass<FuncOp>> createSPIRVConvertToGPUPass(); +/// Pass to distribute Linalg ops with buffer semantics to global invocations. +std::unique_ptr<OperationPass<FuncOp>> createSPIRVDistributeToGlobalIDPass(); /// Creates a pass to fold processor ID uses where possible. std::unique_ptr<OperationPass<FuncOp>> createSPIRVFoldProcessorIDUsesPass(); -/// Main pass to lower executables to scalar + vector code on SPIR-V -/// path. Invokes one of the pass pipelines that translate the executable to +/// Main pass to lower executables to scalar + vector code on SPIR-V path. +/// Invokes one of the pass pipelines that translate the executable to /// scalar + vector code. std::unique_ptr<OperationPass<IREE::HAL::ExecutableVariantOp>> createSPIRVLowerExecutableTargetPass(); -/// Pass to remove loop generated at Flow for tile + distribute when the loop is -/// known to have a single trip count. NOTE: DO NOT USE. This is a legacy pass -/// that is to be deprecated. +/// Pass to remove loop generated at flow for tiled and distributed Linalg ops +/// when the loop is known to have a single trip count. +/// WARNING: DO NOT USE. This is a legacy pass that is to be deprecated. std::unique_ptr<OperationPass<FuncOp>> createSPIRVRemoveOneTripTiledLoopPass(); -/// Pass to tile and distribute Linalg operations on buffers in a single -/// workgroup. +/// Pass to tile and distribute Linalg ops with buffer semantics to subgroups +/// and invocations. std::unique_ptr<OperationPass<FuncOp>> createSPIRVTileAndDistributePass(); -/// Pass to tile and vectorize Linalg operations on buffers in a single -/// workgroup. -std::unique_ptr<OperationPass<FuncOp>> createSPIRVTileAndVectorizePass(); - /// Pass to convert vector read/write/arithmetic operations to the corresponding /// cooperative matrix ops when possible. std::unique_ptr<OperationPass<FuncOp>> @@ -309,6 +307,9 @@ /// Pass to lower linalg.copy for copying data to workgroup memory. std::unique_ptr<OperationPass<FuncOp>> createSPIRVCopyToWorkgroupMemoryPass(); +/// Pass to vectorize Linalg ops with buffer semantics. +std::unique_ptr<OperationPass<FuncOp>> createSPIRVVectorizePass(); + /// Converts memref of scalar to memref of vector of efficent size. This will /// allow to convert memory accesses to vector load/store in SPIR-V without /// having pointer bitcast.
diff --git a/iree/compiler/Codegen/Passes.td b/iree/compiler/Codegen/Passes.td index c322788..ce1a556 100644 --- a/iree/compiler/Codegen/Passes.td +++ b/iree/compiler/Codegen/Passes.td
@@ -210,56 +210,60 @@ let summary = "Pass to do software pipelining."; let constructor = "mlir::iree_compiler::createLLVMGPUPipeliningPass()"; } + //------------------------------------------------------------------------------ -// SPIRV +// SPIR-V //------------------------------------------------------------------------------ // TODO: Rename argument to be fully qualified. -def ConvertToSPIRV : - Pass<"iree-convert-to-spirv", "ModuleOp"> { - let summary = "Perform final conversion to SPIR-V dialect"; +def ConvertToSPIRV : Pass<"iree-convert-to-spirv", "ModuleOp"> { + let summary = "Perform the final conversion to SPIR-V dialect"; let constructor = "mlir::iree_compiler::createConvertToSPIRVPass()"; } // TODO: Rename argument to be fully qualified. -def SPIRVConvertToGPU : Pass<"iree-spirv-convert-to-gpu", "FuncOp"> { - let summary = "Map tiled linalg and loop ops to GPU"; - let constructor = "mlir::iree_compiler::createSPIRVConvertToGPUPass()"; +def SPIRVDistributeToGlobalID : + Pass<"iree-spirv-distribute-to-global-id", "FuncOp"> { + let summary = "Distribute Linalg ops with buffer semantics to global " + "invocations"; + let constructor = + "mlir::iree_compiler::createSPIRVDistributeToGlobalIDPass()"; } // TODO: Rename argument to be fully qualified. -// TODO: Does not appear used? -def SPIRVFoldProcessorIDUses : Pass<"iree-spirv-fold-gpu-procid-uses", "FuncOp"> { +def SPIRVFoldProcessorIDUses : + Pass<"iree-spirv-fold-gpu-procid-uses", "FuncOp"> { let summary = "Fold GPU processor ID uses where possible"; let constructor = "mlir::iree_compiler::createSPIRVFoldProcessorIDUsesPass()"; } def SPIRVLowerExecutableTarget : - Pass<"iree-spirv-lower-executable-target-pass", "mlir::iree_compiler::IREE::HAL::ExecutableVariantOp"> { - let summary = "Perform lowering of executable target using one of the IREE::HAL::DispatchLoweringPassPipeline"; - let constructor = "mlir::iree_compiler::createSPIRVLowerExecutableTargetPass()"; + Pass<"iree-spirv-lower-executable-target-pass", + "mlir::iree_compiler::IREE::HAL::ExecutableVariantOp"> { + let summary = "Lower the executable target to SPIR-V using one of the " + "IREE::HAL::DispatchLoweringPassPipeline"; + let constructor = + "mlir::iree_compiler::createSPIRVLowerExecutableTargetPass()"; } def SPIRVRemoveOneTripTiledLoop : Pass<"iree-spirv-remove-one-trip-tiled-loop", "FuncOp"> { - let summary = "Remove one trip tiled loop. ---- Legacy Pass! Do not use ---"; - let constructor = "mlir::iree_compiler::createSPIRVRemoveOneTripTiledLoopPass()"; -} - -// TODO: Rename argument to be fully qualified. -def SPIRVTileAndVectorize : Pass<"iree-spirv-tile-and-vectorize", "FuncOp"> { - let summary = - "Tile and vectorize Linalg operations on buffers in one workgroup"; + let summary = "Remove one-trip tiled loop for convolution " + "(LEGACY PASS; DO NOT USE!)"; let constructor = - "mlir::iree_compiler::createSPIRVTileAndVectorizePass()"; + "mlir::iree_compiler::createSPIRVRemoveOneTripTiledLoopPass()"; } // TODO: Rename argument to be fully qualified. def SPIRVTileAndDistribute : Pass<"iree-spirv-tile-and-distribute", "FuncOp"> { - let summary = - "Tile and distribute Linalg operations on buffers in one workgroup"; - let constructor = - "mlir::iree_compiler::createSPIRVTileAndDistributePass()"; + let summary = "Tile and distribute Linalg ops with buffer semantics to " + "subgroups and invocations"; + let constructor = "mlir::iree_compiler::createSPIRVTileAndDistributePass()"; +} + +def SPIRVVectorize : Pass<"iree-spirv-vectorize", "FuncOp"> { + let summary = "Vectorize Linalg ops with buffer semantics"; + let constructor = "mlir::iree_compiler::createSPIRVVectorizePass()"; } // TODO: Rename argument to be fully qualified. @@ -272,7 +276,8 @@ // TODO: Rename argument to be fully qualified. def SPIRVVectorToCooperativeMatrix : Pass<"iree-spirv-vector-to-cooperative-matrix", "FuncOp"> { - let summary = "Generate cooperative matrix ops when possible"; + let summary = "Convert vector ops to SPIR-V cooperative matrix ops " + "when possible"; let constructor = "mlir::iree_compiler::createSPIRVVectorToCooperativeMatrixPass()"; } @@ -280,8 +285,9 @@ // TODO: Rename argument to be fully qualified. def SPIRVCopyToWorkgroupMemory : Pass<"iree-spirv-copy-to-workgroup-memory", "FuncOp"> { - let summary = "Convert vector dialect to gpu subgroup level GPU instructions"; - let constructor = "mlir::iree_compiler::createSPIRVCopyToWorkgroupMemoryPass()"; + let summary = "Lower linalg.copy for copying data to workgroup memory"; + let constructor = + "mlir::iree_compiler::createSPIRVCopyToWorkgroupMemoryPass()"; } //------------------------------------------------------------------------------
diff --git a/iree/compiler/Codegen/SPIRV/BUILD b/iree/compiler/Codegen/SPIRV/BUILD index 63662d0..af8ec54 100644 --- a/iree/compiler/Codegen/SPIRV/BUILD +++ b/iree/compiler/Codegen/SPIRV/BUILD
@@ -14,21 +14,21 @@ name = "SPIRV", srcs = [ "ConvertToSPIRVPass.cpp", - "KernelDispatchUtils.cpp", + "KernelConfig.cpp", "Passes.cpp", - "SPIRVConvertToGPU.cpp", "SPIRVCopyToWorkgroupMemory.cpp", + "SPIRVDistributeToGlobalID.cpp", "SPIRVFoldGPUProcessorIDUses.cpp", "SPIRVLowerExecutableTargetPass.cpp", "SPIRVRemoveOneTripTiledLoops.cpp", "SPIRVTileAndDistribute.cpp", - "SPIRVTileAndVectorize.cpp", "SPIRVVectorToCooperativeMatrix.cpp", + "SPIRVVectorize.cpp", "SPIRVVectorizeLoadStore.cpp", "Utils.cpp", ], hdrs = [ - "KernelDispatchUtils.h", + "KernelConfig.h", "MemorySpace.h", "Utils.h", ], @@ -74,6 +74,7 @@ "@llvm-project//mlir:TosaDialect", "@llvm-project//mlir:TosaToStandard", "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:VectorInterfaces", "@llvm-project//mlir:VectorOps", "@llvm-project//mlir:VectorToSPIRV", "@mlir-hlo//:hlo",
diff --git a/iree/compiler/Codegen/SPIRV/CMakeLists.txt b/iree/compiler/Codegen/SPIRV/CMakeLists.txt index f55dd2a..65f615d 100644 --- a/iree/compiler/Codegen/SPIRV/CMakeLists.txt +++ b/iree/compiler/Codegen/SPIRV/CMakeLists.txt
@@ -14,21 +14,21 @@ NAME SPIRV HDRS - "KernelDispatchUtils.h" + "KernelConfig.h" "MemorySpace.h" "Utils.h" SRCS "ConvertToSPIRVPass.cpp" - "KernelDispatchUtils.cpp" + "KernelConfig.cpp" "Passes.cpp" - "SPIRVConvertToGPU.cpp" "SPIRVCopyToWorkgroupMemory.cpp" + "SPIRVDistributeToGlobalID.cpp" "SPIRVFoldGPUProcessorIDUses.cpp" "SPIRVLowerExecutableTargetPass.cpp" "SPIRVRemoveOneTripTiledLoops.cpp" "SPIRVTileAndDistribute.cpp" - "SPIRVTileAndVectorize.cpp" "SPIRVVectorToCooperativeMatrix.cpp" + "SPIRVVectorize.cpp" "SPIRVVectorizeLoadStore.cpp" "Utils.cpp" DEPS @@ -61,6 +61,7 @@ MLIRTosaToStandard MLIRTransforms MLIRVector + MLIRVectorInterfaces MLIRVectorToSPIRV iree::compiler::Codegen::Common iree::compiler::Codegen::PassHeaders
diff --git a/iree/compiler/Codegen/SPIRV/KernelDispatchUtils.cpp b/iree/compiler/Codegen/SPIRV/KernelConfig.cpp similarity index 85% rename from iree/compiler/Codegen/SPIRV/KernelDispatchUtils.cpp rename to iree/compiler/Codegen/SPIRV/KernelConfig.cpp index 7963338..888993a 100644 --- a/iree/compiler/Codegen/SPIRV/KernelDispatchUtils.cpp +++ b/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
@@ -13,7 +13,7 @@ // //===----------------------------------------------------------------------===// -#include "iree/compiler/Codegen/SPIRV/KernelDispatchUtils.h" +#include "iree/compiler/Codegen/SPIRV/KernelConfig.h" #include "iree/compiler/Codegen/Passes.h" #include "iree/compiler/Codegen/SPIRV/Utils.h" @@ -148,7 +148,7 @@ batchTs[2] / pair.workgroupSize[0], batchTs[3]}; tileSizes.emplace_back(invocationLevelTs); return setOpConfigAndEntryPointFnTranslation( - entryPoint, op, tileSizes, /*nativeVectorSize=*/ArrayRef<int64_t>{}, + entryPoint, op, tileSizes, /*nativeVectorSizes=*/ArrayRef<int64_t>{}, IREE::HAL::DispatchLoweringPassPipeline::SPIRVVectorize, pair.workgroupSize); } @@ -185,7 +185,7 @@ tileSizes.emplace_back(); // subgroup level tileSizes.emplace_back(std::move(invocationLevel)); return setOpConfigAndEntryPointFnTranslation( - entryPoint, op, tileSizes, /*nativeVectorSize=*/ArrayRef<int64_t>{}, + entryPoint, op, tileSizes, /*nativeVectorSizes=*/ArrayRef<int64_t>{}, IREE::HAL::DispatchLoweringPassPipeline::SPIRVDistribute, workgroupSize); } @@ -269,7 +269,7 @@ numVecMatmulPerSubgroupX * (*coopMatmulSize)[1]}; tileSizes.emplace_back(std::move(subgroupTs)); return setOpConfigAndEntryPointFnTranslation( - entryPoint, op, tileSizes, /*nativeVectorSize=*/ArrayRef<int64_t>{}, + entryPoint, op, tileSizes, /*nativeVectorSizes=*/ArrayRef<int64_t>{}, IREE::HAL::DispatchLoweringPassPipeline::SPIRVVectorize, workgroupSize); } @@ -282,7 +282,7 @@ // Serialized computation. return setOpConfigAndEntryPointFnTranslation( entryPoint, op, /*tileSizes =*/TileSizesListType{{}}, - /*nativeVectorSize=*/ArrayRef<int64_t>{}, + /*nativeVectorSizes=*/ArrayRef<int64_t>{}, IREE::HAL::DispatchLoweringPassPipeline::SPIRVVectorize, {1, 1, 1}); } @@ -357,7 +357,7 @@ return setOpConfigAndEntryPointFnTranslation( entryPoint, op, tileSizes, - /*nativeVectorSize =*/ArrayRef<int64_t>{}, pipeline, workgroupSize); + /*nativeVectorSizes =*/ArrayRef<int64_t>{}, pipeline, workgroupSize); } /// Launch configuration for different known GPU configuration. @@ -400,7 +400,7 @@ tileSizes.emplace_back(invocationLevelTs); return setOpConfigAndEntryPointFnTranslation( entryPoint, op, tileSizes, - /*nativeVectorSize =*/ArrayRef<int64_t>{}, + /*nativeVectorSizes =*/ArrayRef<int64_t>{}, IREE::HAL::DispatchLoweringPassPipeline::SPIRVVectorize, pair.workgroupSize); } @@ -445,7 +445,7 @@ tileSizes.emplace_back(); // subgroup level tileSizes.emplace_back(std::move(invocationLevel)); return setOpConfigAndEntryPointFnTranslation( - entryPoint, op, tileSizes, /*nativeVectorSize =*/ArrayRef<int64_t>{}, + entryPoint, op, tileSizes, /*nativeVectorSizes =*/ArrayRef<int64_t>{}, IREE::HAL::DispatchLoweringPassPipeline::SPIRVDistribute, workgroupSize); } @@ -505,7 +505,7 @@ tileSizes.emplace_back(fourthLevel); if (failed(setOpConfigAndEntryPointFnTranslation( - entryFn, op, tileSizes, /*nativeVectorSize=*/ArrayRef<int64_t>{}, + entryFn, op, tileSizes, /*nativeVectorSizes=*/ArrayRef<int64_t>{}, IREE::HAL::DispatchLoweringPassPipeline::SPIRVVectorize, workgroupSize))) return failure(); @@ -593,7 +593,7 @@ tileSizes.emplace_back(fourthLevel); if (failed(setOpConfigAndEntryPointFnTranslation( - entryFn, op, tileSizes, /*nativeVectorSize=*/ArrayRef<int64_t>{}, + entryFn, op, tileSizes, /*nativeVectorSizes=*/ArrayRef<int64_t>{}, IREE::HAL::DispatchLoweringPassPipeline::SPIRVVectorize, workgroupSize))) return failure(); @@ -753,124 +753,5 @@ return success(); } -template <typename OpTy> -static Optional<SmallVector<int64_t, 4>> getOpNativeVectorSize(OpTy op) { - return llvm::None; -} - -template <> -Optional<SmallVector<int64_t, 4>> getOpNativeVectorSize<vector::ContractionOp>( - vector::ContractionOp op) { - auto targetEnvAttr = spirv::lookupTargetEnv(op); - auto targetEnv = spirv::TargetEnv(targetEnvAttr); - if (targetEnv.allows(spirv::Capability::CooperativeMatrixNV) && - targetEnv.allows(spirv::Extension::SPV_NV_cooperative_matrix)) { - return getCooperativeMatmulSubgroupSize( - targetEnv.getResourceLimits(), op.getLhsType().getElementType(), - op.getRhsType().getElementType(), - op.getAccType().cast<VectorType>().getElementType(), - op.getResultType().cast<VectorType>().getElementType()); - } else { - unsigned lastParalleldim = 0; - for (auto it : llvm::enumerate(op.iterator_types())) { - if (isParallelIterator(it.value())) lastParalleldim = it.index(); - } - SmallVector<int64_t, 4> nativeSize(op.iterator_types().size(), 1); - nativeSize[lastParalleldim] = 4; - // Map to vec4 fma operations. - return nativeSize; - } -} - -template <> -Optional<SmallVector<int64_t, 4>> getOpNativeVectorSize<vector::FMAOp>( - vector::FMAOp op) { - SmallVector<int64_t, 4> size(op.getType().getRank(), 1); - size.back() = 4; - return size; -} - -template <> -Optional<SmallVector<int64_t, 4>> getOpNativeVectorSize<vector::TransferReadOp>( - vector::TransferReadOp op) { - auto targetEnv = spirv::TargetEnv(spirv::lookupTargetEnv(op)); - if (targetEnv.allows(spirv::Capability::CooperativeMatrixNV) && - targetEnv.allows(spirv::Extension::SPV_NV_cooperative_matrix)) { - // Unroll cooperative martrix load based on the size of the contract. - VectorType dstVec; - for (Operation *users : op->getUsers()) { - auto extract = dyn_cast<vector::ExtractStridedSliceOp>(users); - if (!extract) return llvm::None; - auto vecType = extract.getResult().getType().cast<VectorType>(); - if (dstVec && dstVec != vecType) return llvm::None; - dstVec = vecType; - } - return SmallVector<int64_t, 4>(dstVec.getShape().begin(), - dstVec.getShape().end()); - } - - // Map to load4. - auto rank = op.getVectorType().getRank(); - SmallVector<int64_t, 4> nativeSize(rank, 1); - // Load 4 elements on the most inner dimension. - for (auto dim : llvm::enumerate(op.permutation_map().getResults())) { - if (auto dimExpr = dim.value().dyn_cast<AffineDimExpr>()) { - if (dimExpr.getPosition() == op.permutation_map().getNumDims() - 1) - nativeSize[dim.index()] = 4; - } - } - return nativeSize; -} - -template <> -Optional<SmallVector<int64_t, 4>> -getOpNativeVectorSize<vector::TransferWriteOp>(vector::TransferWriteOp op) { - auto targetEnv = spirv::TargetEnv(spirv::lookupTargetEnv(op)); - if (targetEnv.allows(spirv::Capability::CooperativeMatrixNV) && - targetEnv.allows(spirv::Extension::SPV_NV_cooperative_matrix)) { - // Unroll cooperative martrix store based on the size of the contract. - auto insert = op.vector().getDefiningOp<vector::InsertStridedSliceOp>(); - if (!insert) return llvm::None; - ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape(); - return SmallVector<int64_t, 4>(shape.begin(), shape.end()); - } - - // Map to store4. - auto rank = op.getVectorType().getRank(); - SmallVector<int64_t, 4> nativeSize(rank, 1); - // Store 4 elements on the most inner dimension. - for (auto dim : llvm::enumerate(op.permutation_map().getResults())) { - if (auto dimExpr = dim.value().dyn_cast<AffineDimExpr>()) { - if (dimExpr.getPosition() == op.permutation_map().getNumDims() - 1) - nativeSize[dim.index()] = 4; - } - } - return nativeSize; -} - -Optional<SmallVector<int64_t, 4>> getSPIRVNativeVectorSize(Operation *op) { -#define DISPATCH(opname) \ - if (isa<opname>(op)) { \ - return getOpNativeVectorSize(cast<opname>(op)); \ - } - - DISPATCH(vector::ContractionOp) - DISPATCH(vector::FMAOp) - DISPATCH(vector::TransferReadOp) - DISPATCH(vector::TransferWriteOp) - -#undef DISPATCH - - if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1) { - if (auto vecType = op->getResultTypes()[0].dyn_cast<VectorType>()) { - // Map elementwise ops to vec4. - SmallVector<int64_t, 4> nativeSize(vecType.getRank() - 1, 1); - nativeSize.push_back(4); - return nativeSize; - } - } - return llvm::None; -} - } // namespace iree_compiler } // namespace mlir
diff --git a/iree/compiler/Codegen/SPIRV/KernelConfig.h b/iree/compiler/Codegen/SPIRV/KernelConfig.h new file mode 100644 index 0000000..22a1a8e --- /dev/null +++ b/iree/compiler/Codegen/SPIRV/KernelConfig.h
@@ -0,0 +1,31 @@ +// Copyright 2020 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +//===- KernelConfig.h - Kernel Generation Configurations ------------------===// +// +// This file declares utility functions for configuring SPIR-V kernel +// generation, e.g., tiling schemes and workgroup size for important +// Linalg named ops. +// +//===----------------------------------------------------------------------===// + +#ifndef IREE_COMPILER_CODEGEN_SPIRV_KERNELCONFIG_H_ +#define IREE_COMPILER_CODEGEN_SPIRV_KERNELCONFIG_H_ + +#include "mlir/IR/BuiltinOps.h" + +namespace mlir { +namespace iree_compiler { + +/// Attaches the `translation.info` attribute to entry points in `moduleOp` and +/// `lowering.config` attributes to all root ops in `moduleOp`'s region. +/// These attributes are used to drive the CodeGen pipeline. +LogicalResult initSPIRVLaunchConfig(ModuleOp moduleOp); + +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_CODEGEN_SPIRV_KERNELCONFIG_H_
diff --git a/iree/compiler/Codegen/SPIRV/KernelDispatchUtils.h b/iree/compiler/Codegen/SPIRV/KernelDispatchUtils.h deleted file mode 100644 index 3465e7d..0000000 --- a/iree/compiler/Codegen/SPIRV/KernelDispatchUtils.h +++ /dev/null
@@ -1,46 +0,0 @@ -// Copyright 2020 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -//===- KernelDispatchUtils.h - Utilities for generating dispatch info -----===// -// -// This file declares utility functions that can be used to create information -// the dispatch on the host side needs to execute an entry point function, like -// the number of workgroups to use for launch, etc. -// -//===----------------------------------------------------------------------===// - -#ifndef IREE_COMPILER_CODEGEN_SPIRV_KERNELDISPATCHUTILS_H_ -#define IREE_COMPILER_CODEGEN_SPIRV_KERNELDISPATCHUTILS_H_ - -#include <array> - -#include "iree/compiler/Codegen/Passes.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringMap.h" -#include "llvm/Support/FormatVariadic.h" -#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" -#include "mlir/Dialect/Linalg/IR/LinalgOps.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/Types.h" -#include "mlir/IR/Value.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Support/LogicalResult.h" - -namespace mlir { -namespace iree_compiler { - -LogicalResult initSPIRVLaunchConfig(ModuleOp moduleOp); - -/// Returns the size of instruction in `vector` dialect that maps directly to -/// the hardware. -Optional<SmallVector<int64_t, 4>> getSPIRVNativeVectorSize(Operation *op); - -} // namespace iree_compiler -} // namespace mlir - -#endif // IREE_COMPILER_CODEGEN_SPIRV_DISPATCHUTILS_H_
diff --git a/iree/compiler/Codegen/SPIRV/Passes.cpp b/iree/compiler/Codegen/SPIRV/Passes.cpp index 4fa7ee2..af8b525 100644 --- a/iree/compiler/Codegen/SPIRV/Passes.cpp +++ b/iree/compiler/Codegen/SPIRV/Passes.cpp
@@ -4,9 +4,10 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -//===- Passes.cpp - Pipeline from HLO to Linalg to SPIR-V -----------------===// +//===- Passes.cpp - Pipelines from Linalg ops to SPIR-V -------------------===// // -// Implementation of conversion from XLA-HLO to Linalg to SPIR-V dialect. +// This file contains various pipelines to lower IREE HAL executables containing +// Linalg ops to SPIR-V. // //===----------------------------------------------------------------------===// @@ -57,74 +58,51 @@ return builder.create<memref::AllocOp>(loc, allocType, dynamicSizes); } -void addSPIRVVectorizationPassPipeline(OpPassManager &pm) { - //===--------------------------------------------------------------------===// - // Initial clean up. - //===--------------------------------------------------------------------===// +void addSPIRVTileAndVectorizePassPipeline(OpPassManager &pm) { pm.addPass(createCanonicalizerPass()); pm.addPass(createCSEPass()); + pm.addNestedPass<FuncOp>(createSPIRVRemoveOneTripTiledLoopPass()); // Tile and distribute to GPU subgroups/invocations and vectorize. - pm.addNestedPass<FuncOp>(createSPIRVTileAndVectorizePass()); + pm.addNestedPass<FuncOp>(createSPIRVTileAndDistributePass()); + pm.addNestedPass<FuncOp>(createSPIRVVectorizePass()); pm.addPass(createCanonicalizerPass()); - // Handle ops that cannot go through the previous tiling, distribution, and - // vectorization flow. Only perform one level of distribution to map them to - // GPU global invocation IDs for distribution. - // TODO(antiagainst): Handle all the cases uniformly and remove this pass. pm.addNestedPass<FuncOp>(createSPIRVCopyToWorkgroupMemoryPass()); + pm.addNestedPass<FuncOp>(createConvertLinalgToLoopsPass()); pm.addPass(createLowerAffinePass()); pm.addPass(createCanonicalizerPass()); pm.addPass(createCSEPass()); - //===--------------------------------------------------------------------===// - // Optimizations and cleanups - //===--------------------------------------------------------------------===// - // Perform various vector-level cross-op optimizations like load-store // forwarding, shape casting and casting op cancelling. pm.addNestedPass<FuncOp>(createOptimizeVectorTransferPass()); } -void addSPIRVDistributePassPipeline(OpPassManager &pm) { - //===--------------------------------------------------------------------===// - // Initial clean up. - //===--------------------------------------------------------------------===// +void addSPIRVTileAndDistributePassPipeline(OpPassManager &pm) { pm.addPass(createCanonicalizerPass()); pm.addPass(createCSEPass()); - // Tile and distribute to GPU subgroups/invocations and vectorize. + + // Tile and distribute to GPU subgroups/invocations. pm.addNestedPass<FuncOp>(createSPIRVTileAndDistributePass()); pm.addPass(createCanonicalizerPass()); - // Handle ops that cannot go through the previous tiling, distribution, and - // vectorization flow. Only perform one level of distribution to map them to - // GPU global invocation IDs for distribution. - // TODO(antiagainst): Handle all the cases uniformly and remove this pass. pm.addNestedPass<FuncOp>(createSPIRVCopyToWorkgroupMemoryPass()); + pm.addNestedPass<FuncOp>(createConvertLinalgToLoopsPass()); pm.addPass(createLowerAffinePass()); pm.addPass(createCanonicalizerPass()); pm.addPass(createCSEPass()); - //===--------------------------------------------------------------------===// - // Optimizations and cleanups - //===--------------------------------------------------------------------===// // Perform various vector-level cross-op optimizations like load-store // forwarding, shape casting and casting op cancelling. pm.addNestedPass<FuncOp>(createOptimizeVectorTransferPass()); } -void addSPIRVDistributeToGlobalIDPipeline(OpPassManager &pm) { - // Handle ops that cannot go through the previous tiling, distribution, and - // vectorization flow. Only perform one level of distribution to map them to - // GPU global invocation IDs for distribution. - // TODO(antiagainst): Handle all the cases uniformly and remove this pass. - pm.addNestedPass<FuncOp>(createSPIRVConvertToGPUPass()); +void addSPIRVDistributeToGlobalIDPassPipeline(OpPassManager &pm) { + pm.addNestedPass<FuncOp>(createSPIRVDistributeToGlobalIDPass()); pm.addPass(createLowerAffinePass()); pm.addPass(createCanonicalizerPass()); pm.addPass(createCSEPass()); - //===--------------------------------------------------------------------===// - // Optimizations and cleanups - //===--------------------------------------------------------------------===// // Perform various vector-level cross-op optimizations like load-store // forwarding, shape casting and casting op cancelling. @@ -159,10 +137,6 @@ pm.addPass(createCanonicalizerPass()); pm.addPass(createCSEPass()); - //===--------------------------------------------------------------------===// - // SPIR-V conversions - //===--------------------------------------------------------------------===// - // Finally convert everything to SPIR-V. pm.addPass(createConvertToSPIRVPass());
diff --git a/iree/compiler/Codegen/SPIRV/SPIRVConvertToGPU.cpp b/iree/compiler/Codegen/SPIRV/SPIRVDistributeToGlobalID.cpp similarity index 94% rename from iree/compiler/Codegen/SPIRV/SPIRVConvertToGPU.cpp rename to iree/compiler/Codegen/SPIRV/SPIRVDistributeToGlobalID.cpp index 5613f85..b8df3b4 100644 --- a/iree/compiler/Codegen/SPIRV/SPIRVConvertToGPU.cpp +++ b/iree/compiler/Codegen/SPIRV/SPIRVDistributeToGlobalID.cpp
@@ -4,9 +4,9 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -//===- SPIRVConvertToGPUPass.cpp ------------------------------------------===// +//===- SPIRVDistributeToGlobalIDPass.cpp ----------------------------------===// // -// Partition computation within dispatch function to workgroups/workitems. +// This pass distributes Linalg ops with buffer semantics to global invocations. // //===----------------------------------------------------------------------===// @@ -15,7 +15,6 @@ #include "iree/compiler/Codegen/PassDetail.h" #include "iree/compiler/Codegen/Passes.h" -#include "iree/compiler/Codegen/SPIRV/KernelDispatchUtils.h" #include "iree/compiler/Codegen/SPIRV/Utils.h" #include "iree/compiler/Codegen/Transforms/Transforms.h" #include "iree/compiler/Codegen/Utils/MarkerUtils.h" @@ -121,8 +120,8 @@ namespace { /// Pass to convert from tiled and fused linalg ops into gpu.func. -struct SPIRVConvertToGPUPass - : public SPIRVConvertToGPUBase<SPIRVConvertToGPUPass> { +struct SPIRVDistributeToGlobalIDPass + : public SPIRVDistributeToGlobalIDBase<SPIRVDistributeToGlobalIDPass> { void getDependentDialects(DialectRegistry ®istry) const override { registry.insert<AffineDialect, gpu::GPUDialect, memref::MemRefDialect, scf::SCFDialect, ShapeDialect>(); @@ -186,9 +185,10 @@ } // namespace -void SPIRVConvertToGPUPass::runOnOperation() { +void SPIRVDistributeToGlobalIDPass::runOnOperation() { FuncOp funcOp = getOperation(); if (!isEntryPoint(funcOp)) return; + MLIRContext *context = &getContext(); ConversionTarget target(*context); // After this pass Linalg and scf.parallel ops should be gone. @@ -217,8 +217,8 @@ return signalPassFailure(); } -std::unique_ptr<OperationPass<FuncOp>> createSPIRVConvertToGPUPass() { - return std::make_unique<SPIRVConvertToGPUPass>(); +std::unique_ptr<OperationPass<FuncOp>> createSPIRVDistributeToGlobalIDPass() { + return std::make_unique<SPIRVDistributeToGlobalIDPass>(); } } // namespace iree_compiler
diff --git a/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp b/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp index 91e499d..5413c48 100644 --- a/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp +++ b/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp
@@ -6,7 +6,7 @@ #include "iree/compiler/Codegen/PassDetail.h" #include "iree/compiler/Codegen/Passes.h" -#include "iree/compiler/Codegen/SPIRV/KernelDispatchUtils.h" +#include "iree/compiler/Codegen/SPIRV/KernelConfig.h" #include "iree/compiler/Codegen/Utils/Utils.h" #include "iree/compiler/Dialect/HAL/IR/HALDialect.h" #include "iree/compiler/Dialect/HAL/IR/HALOps.h" @@ -25,13 +25,16 @@ namespace iree_compiler { namespace { -/// Lowers an hal.executable.variant operation to scalar/native-vector +/// Lowers a hal.executable.variant inner module to SPIR-V scalar/native-vector /// code. Invokes different compilation pipeline to -/// - first lower to scalar/native-vector code +/// - first lower to scalar/native-vector code, /// - then convert to SPIRV dialect. class SPIRVLowerExecutableTargetPass : public SPIRVLowerExecutableTargetBase<SPIRVLowerExecutableTargetPass> { public: + SPIRVLowerExecutableTargetPass() = default; + SPIRVLowerExecutableTargetPass(const SPIRVLowerExecutableTargetPass &pass) {} + void getDependentDialects(DialectRegistry ®istry) const override { registry.insert<AffineDialect, gpu::GPUDialect, IREE::HAL::HALDialect, linalg::LinalgDialect, linalg_ext::LinalgExtDialect, @@ -39,19 +42,14 @@ spirv::SPIRVDialect, vector::VectorDialect>(); } - SPIRVLowerExecutableTargetPass() = default; - SPIRVLowerExecutableTargetPass(const SPIRVLowerExecutableTargetPass &pass){}; - void runOnOperation() override; private: Option<bool> testLoweringConfiguration{ *this, "test-lowering-configuration", - llvm::cl::desc( - "Flag used for lit-testing the default configuration set for root " - "ops in hal.executable.variants. Defaults to false and is set to " - "true " - "for lit tests. Not for general usage"), + llvm::cl::desc("Flag used for lit-testing the configuration set for root " + "ops in hal.executable.variants. Defaults to false. Set " + "to true for lit tests; not for general usage"), llvm::cl::init(false)}; }; } // namespace @@ -99,13 +97,13 @@ OpPassManager &nestedModulePM = executableLoweringPipeline.nest<ModuleOp>(); switch (*passPipeline) { case IREE::HAL::DispatchLoweringPassPipeline::SPIRVDistribute: - addSPIRVDistributePassPipeline(nestedModulePM); + addSPIRVTileAndDistributePassPipeline(nestedModulePM); break; case IREE::HAL::DispatchLoweringPassPipeline::SPIRVDistributeToGlobalID: - addSPIRVDistributeToGlobalIDPipeline(nestedModulePM); + addSPIRVDistributeToGlobalIDPassPipeline(nestedModulePM); break; case IREE::HAL::DispatchLoweringPassPipeline::SPIRVVectorize: - addSPIRVVectorizationPassPipeline(nestedModulePM); + addSPIRVTileAndVectorizePassPipeline(nestedModulePM); break; default: llvm_unreachable("Unsupported pipeline on GPU target.");
diff --git a/iree/compiler/Codegen/SPIRV/SPIRVTileAndDistribute.cpp b/iree/compiler/Codegen/SPIRV/SPIRVTileAndDistribute.cpp index 8fcbea0..0befc77 100644 --- a/iree/compiler/Codegen/SPIRV/SPIRVTileAndDistribute.cpp +++ b/iree/compiler/Codegen/SPIRV/SPIRVTileAndDistribute.cpp
@@ -4,38 +4,31 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -//===- SPIRVTileAndDistribute.cpp -//------------------------------------------===// +//===- SPIRVTileAndDistribute.cpp -----------------------------------------===// // -// This pass tiles and vectorizes Linalg ops on buffers within in a single -// workgroup. +// This pass tiles and distributes Linalg ops with buffer semantics to subgroups +// and invocations. // //===----------------------------------------------------------------------===// #include "iree/compiler/Codegen/PassDetail.h" #include "iree/compiler/Codegen/Passes.h" -#include "iree/compiler/Codegen/SPIRV/KernelDispatchUtils.h" -#include "iree/compiler/Codegen/SPIRV/MemorySpace.h" #include "iree/compiler/Codegen/SPIRV/Utils.h" #include "iree/compiler/Codegen/Transforms/Transforms.h" #include "iree/compiler/Codegen/Utils/MarkerUtils.h" #include "iree/compiler/Codegen/Utils/Utils.h" -#include "iree/compiler/Dialect/HAL/IR/HALDialect.h" -#include "iree/compiler/Dialect/HAL/IR/HALOps.h" -#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Debug.h" #include "mlir/Dialect/GPU/GPUDialect.h" -#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Vector/VectorTransforms.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Identifier.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" @@ -43,7 +36,7 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/LoopUtils.h" -#define DEBUG_TYPE "iree-spirv-tile-and-vectorize" +#define DEBUG_TYPE "iree-spirv-tile-and-distribute" namespace mlir { namespace iree_compiler { @@ -72,29 +65,7 @@ } //===----------------------------------------------------------------------===// -// Main pass -//===----------------------------------------------------------------------===// - -namespace { -/// Function pass that implements tiling and fusion in Linalg on buffers. -class SPIRVTileAndDistributePass - : public SPIRVTileAndDistributeBase<SPIRVTileAndDistributePass> { - public: - SPIRVTileAndDistributePass() = default; - SPIRVTileAndDistributePass(const SPIRVTileAndDistributePass &pass) = default; - - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert<AffineDialect, IREE::HAL::HALDialect, gpu::GPUDialect, - linalg::LinalgDialect, memref::MemRefDialect, - scf::SCFDialect, ShapeDialect, vector::VectorDialect>(); - } - - void runOnOperation() override; -}; -} // namespace - -//===----------------------------------------------------------------------===// -// Patterns to tile computation to map to subgroups +// Subgroup tiling patterns //===----------------------------------------------------------------------===// /// Computes the Value for subgroupID along each dimension given number of @@ -169,7 +140,7 @@ } //===----------------------------------------------------------------------===// -// Patterns and methods for thread tiling. +// Invocation tiling patterns //===----------------------------------------------------------------------===// /// Patterns for third level tiling to target invocations. @@ -184,8 +155,8 @@ })); }; - auto getThreadProcInfoFn = [&](OpBuilder &builder, Location loc, - ArrayRef<Range> parallelLoopRanges) { + auto getThreadProcInfoFn = [](OpBuilder &builder, Location loc, + ArrayRef<Range> parallelLoopRanges) { return getGPUProcessorIdsAndCounts<gpu::ThreadIdOp, gpu::BlockDimOp>( builder, loc, parallelLoopRanges.size()); }; @@ -205,7 +176,7 @@ linalg::LinalgTilingPattern<linalg::FillOp>, linalg::LinalgTilingPattern<linalg::BatchMatmulOp>, linalg::LinalgTilingPattern<linalg::Conv1DNwcWcfOp>, - linalg::LinalgTilingPattern<linalg::Conv2DNhwcHwcfOp>, + linalg::LinalgTilingPattern<linalg::Conv3DNdhwcDhwcfOp>, linalg::LinalgTilingPattern<linalg::DepthwiseConv2DNhwcOp>, linalg::LinalgTilingPattern<linalg::GenericOp>, linalg::LinalgTilingPattern<linalg::PoolingNhwcMaxOp>, @@ -246,7 +217,7 @@ } //====---------------------------------------------------------------------===// -// Patterns to tile convolution window dimensions +// Convolution filter tiling patterns //====---------------------------------------------------------------------===// static void populateTilingConvFilterPatterns( @@ -274,29 +245,28 @@ context, tilingOptions, marker); } -//====---------------------------------------------------------------------===// -// Patterns to lower linalg ops to loops -//====---------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// +// Main pass +//===----------------------------------------------------------------------===// -template <typename OpTy> -struct LowerToLoops final : public OpRewritePattern<OpTy> { - using OpRewritePattern<OpTy>::OpRewritePattern; +namespace { +/// Function pass that implements tiling and distributing Linalg ops with +/// buffer semantics. +class SPIRVTileAndDistributePass + : public SPIRVTileAndDistributeBase<SPIRVTileAndDistributePass> { + public: + SPIRVTileAndDistributePass() = default; + SPIRVTileAndDistributePass(const SPIRVTileAndDistributePass &pass) = default; - LogicalResult matchAndRewrite(OpTy op, - PatternRewriter &rewriter) const override { - // Only handle the cases where tiling to invocations was done, where tiling - // convolution filters or vectorization is expected. - if (!hasMarker(op, {getConvFilterTileMarker(), getVectorizeMarker()})) - return failure(); - - if (linalg::linalgOpToLoops(rewriter, op)) { - rewriter.eraseOp(op); - return success(); - } - - return failure(); + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert<AffineDialect, gpu::GPUDialect, linalg::LinalgDialect, + memref::MemRefDialect, scf::SCFDialect, + vector::VectorDialect>(); } + + void runOnOperation() override; }; +} // namespace //====---------------------------------------------------------------------===// // Main pass implementation @@ -309,16 +279,40 @@ if (!entryPointOp) return; { - RewritePatternSet thirdLevelTilingPatterns(&getContext()); - populateTilingToInvocationPatterns(context, thirdLevelTilingPatterns); + RewritePatternSet subgroupTilingPatterns(&getContext()); + populateTilingToSubgroupPatterns(context, subgroupTilingPatterns); (void)applyPatternsAndFoldGreedily(funcOp, - std::move(thirdLevelTilingPatterns)); + std::move(subgroupTilingPatterns)); + + RewritePatternSet canonicalizationPatterns = + linalg::getLinalgTilingCanonicalizationPatterns(context); + populateAffineMinCanonicalizationPattern(canonicalizationPatterns); + (void)applyPatternsAndFoldGreedily(funcOp, + std::move(canonicalizationPatterns)); + promoteSingleIterationLoops(funcOp); + + LLVM_DEBUG({ + llvm::dbgs() << "--- After tiling to subgroups ---\n"; + funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); + llvm::dbgs() << "\n\n"; + }); + } + + { + RewritePatternSet invocationTilingPatterns(&getContext()); + populateTilingToInvocationPatterns(context, invocationTilingPatterns); + (void)applyPatternsAndFoldGreedily(funcOp, + std::move(invocationTilingPatterns)); // Remove trip-one loops created during cyclic loop distribution if we can // prove the tiling was perfect. RewritePatternSet canoncalizationPatterns(context); populateAffineMinSCFCanonicalizationPattern(canoncalizationPatterns); - auto workgroupSize = getWorkgroupSize(entryPointOp); + SmallVector<int64_t> workgroupSize = getWorkgroupSize(entryPointOp); + if (workgroupSize.empty()) { + entryPointOp.emitError("expected to have workgroup_size attribute"); + return signalPassFailure(); + } auto getThreadRangeFn = [workgroupSize](Value processorValue, SmallVectorImpl<Value> &dims, SmallVectorImpl<Value> &symbols) { @@ -345,20 +339,21 @@ } { - RewritePatternSet tilingPatterns(&getContext()); + RewritePatternSet convFilterTilingPatterns(&getContext()); auto marker = getLinalgMatchAndReplaceMarker(getConvFilterTileMarker(), getVectorizeMarker(), context); - populateTilingConvFilterPatterns(context, tilingPatterns, marker); - populateFoldGPUProcessorIDUsesPatterns(context, tilingPatterns); - tilingPatterns.insert<linalg::AffineMinSCFCanonicalizationPattern>(context); - (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPatterns)); + populateTilingConvFilterPatterns(context, convFilterTilingPatterns, marker); + populateFoldGPUProcessorIDUsesPatterns(context, convFilterTilingPatterns); + convFilterTilingPatterns + .insert<linalg::AffineMinSCFCanonicalizationPattern>(context); + (void)applyPatternsAndFoldGreedily(funcOp, + std::move(convFilterTilingPatterns)); - RewritePatternSet convTilingCanonicalizationPatterns = + RewritePatternSet canonicalizationPatterns = linalg::getLinalgTilingCanonicalizationPatterns(context); - populateAffineMinCanonicalizationPattern( - convTilingCanonicalizationPatterns); - (void)applyPatternsAndFoldGreedily( - funcOp, std::move(convTilingCanonicalizationPatterns)); + populateAffineMinCanonicalizationPattern(canonicalizationPatterns); + (void)applyPatternsAndFoldGreedily(funcOp, + std::move(canonicalizationPatterns)); LLVM_DEBUG({ llvm::dbgs() << "--- After tiling convolution filter ---\n"; @@ -366,27 +361,6 @@ llvm::dbgs() << "\n\n"; }); } - - // Lower ops that were tiled to invocations but not vectorized to loops. - // TODO(antiagainst): This is here now to simplify the interaction with - // ConvertToGPUPass, where we finally lower away all Linalg ops. Once that - // pass is cleaned up, we can invoke createConvertLinalgToLoopsPass - // directly. - { - RewritePatternSet patterns(context); - patterns.add<LowerToLoops<linalg::BatchMatmulOp>, - LowerToLoops<linalg::Conv1DNwcWcfOp>, - LowerToLoops<linalg::Conv2DNhwcHwcfOp>, - LowerToLoops<linalg::Conv3DNdhwcDhwcfOp>, - LowerToLoops<linalg::DepthwiseConv2DNhwOp>, - LowerToLoops<linalg::DepthwiseConv2DNhwcOp>, - LowerToLoops<linalg::FillOp>, LowerToLoops<linalg::GenericOp>, - LowerToLoops<linalg::MatmulOp>, - LowerToLoops<linalg::PoolingNhwcMaxOp>, - LowerToLoops<linalg::PoolingNhwcMinOp>, - LowerToLoops<linalg::PoolingNhwcSumOp>>(context); - (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); - } } //===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorize.cpp b/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorize.cpp deleted file mode 100644 index 2d62d27..0000000 --- a/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorize.cpp +++ /dev/null
@@ -1,649 +0,0 @@ -// Copyright 2020 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -//===- SPIRVTileAndVectorize.cpp ------------------------------------------===// -// -// This pass tiles and vectorizes Linalg ops on buffers within in a single -// workgroup. -// -//===----------------------------------------------------------------------===// - -#include "iree/compiler/Codegen/PassDetail.h" -#include "iree/compiler/Codegen/Passes.h" -#include "iree/compiler/Codegen/SPIRV/KernelDispatchUtils.h" -#include "iree/compiler/Codegen/SPIRV/MemorySpace.h" -#include "iree/compiler/Codegen/SPIRV/Utils.h" -#include "iree/compiler/Codegen/Transforms/Transforms.h" -#include "iree/compiler/Codegen/Utils/MarkerUtils.h" -#include "iree/compiler/Codegen/Utils/Utils.h" -#include "iree/compiler/Dialect/HAL/IR/HALDialect.h" -#include "iree/compiler/Dialect/HAL/IR/HALOps.h" -#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/Support/Debug.h" -#include "mlir/Dialect/GPU/GPUDialect.h" -#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" -#include "mlir/Dialect/Linalg/IR/LinalgOps.h" -#include "mlir/Dialect/Linalg/Transforms/Hoisting.h" -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/Dialect/Linalg/Utils/Utils.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/Dialect/Vector/VectorTransforms.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Identifier.h" -#include "mlir/IR/Matchers.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/FoldUtils.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "mlir/Transforms/LoopUtils.h" - -#define DEBUG_TYPE "iree-spirv-tile-and-vectorize" - -namespace mlir { -namespace iree_compiler { - -//===----------------------------------------------------------------------===// -// Utility functions -//===----------------------------------------------------------------------===// - -/// Returns a Linalg marker that matches any of the `matchMarkers` and replaces -/// it with `replaceMarker`. -static linalg::LinalgTransformationFilter getLinalgMatchAndReplaceMarker( - ArrayRef<StringRef> matchMarkers, StringRef replaceMarker, - MLIRContext *context) { - SmallVector<Identifier, 2> markers; - markers.reserve(matchMarkers.size()); - for (StringRef marker : matchMarkers) { - markers.emplace_back(Identifier::get(marker, context)); - } - return linalg::LinalgTransformationFilter( - markers, Identifier::get(replaceMarker, context)); -} - -/// Converts a symbolic GPU processor dimension to its numeric one. -static unsigned dimToIndex(StringRef dim) { - return StringSwitch<unsigned>(dim).Case("x", 0).Case("y", 1).Case("z", 2); -} - -//===----------------------------------------------------------------------===// -// Main pass -//===----------------------------------------------------------------------===// - -namespace { -/// Function pass that implements tiling and fusion in Linalg on buffers. -class SPIRVTileAndVectorizePass - : public SPIRVTileAndVectorizeBase<SPIRVTileAndVectorizePass> { - public: - SPIRVTileAndVectorizePass() = default; - SPIRVTileAndVectorizePass(const SPIRVTileAndVectorizePass &pass) = default; - - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert<AffineDialect, IREE::HAL::HALDialect, gpu::GPUDialect, - linalg::LinalgDialect, memref::MemRefDialect, - scf::SCFDialect, ShapeDialect, vector::VectorDialect>(); - } - - void runOnOperation() override; -}; -} // namespace - -//===----------------------------------------------------------------------===// -// Patterns to promote subviews to workgroup memory -//===----------------------------------------------------------------------===// - -namespace { -/// Pattern to promote matmul operands to workgroup memory. -struct PromoteMatmulSubviewsPattern - : public linalg::LinalgPromotionPattern<linalg::MatmulOp> { - PromoteMatmulSubviewsPattern(MLIRContext *context, - linalg::LinalgPromotionOptions options, - linalg::LinalgTransformationFilter marker, - PatternBenefit benefit = 1) - : linalg::LinalgPromotionPattern<linalg::MatmulOp>( - context, - options.setOperandsToPromote({0, 1}).setUseFullTileBuffers( - {false, false}), - marker, benefit) {} -}; - -/// Patterns to promote convolution operands to workgroup memory. -// TODO(ravishankarm): This pattern is only promoting the image subview to -// workgroup memory. In reality we should also be able to promote the filter -// subview to workgroup memory as well. Since none of the loops used to access -// the filter are tiled, this would mean the entire filter is moved to workgroup -// memory. Two reasons this is not done right now: -// 1) Linalg when tiling doesnt create a subview for the filter (since none of -// its dimensions are tiled. This needs to be relaxed (maybe by using an -// option). -// 2) Maybe there are better alternatives for handling filter like using -// different storage classes, since for inference workloads these are model -// constants. This is TBD. -template <typename ConvOpTy> -struct PromoteConvSubviewsPattern - : public linalg::LinalgPromotionPattern<ConvOpTy> { - PromoteConvSubviewsPattern(MLIRContext *context, - linalg::LinalgPromotionOptions options, - linalg::LinalgTransformationFilter marker, - PatternBenefit benefit = 1) - : linalg::LinalgPromotionPattern<ConvOpTy>( - context, - options.setOperandsToPromote({0}).setUseFullTileBuffers( - {false, false}), - marker, benefit) {} -}; -} // namespace - -static void populatePromotionPatterns(MLIRContext *context, - RewritePatternSet &patterns) { - patterns.insert<PromoteMatmulSubviewsPattern, - PromoteConvSubviewsPattern<linalg::Conv2DNhwcHwcfOp>>( - context, - linalg::LinalgPromotionOptions() - .setAllocationDeallocationFns(allocateWorkgroupMemory, - deallocateWorkgroupMemory) - .setCopyInOutFns(copyToWorkgroupMemory, copyToWorkgroupMemory), - getLinalgMatchAndReplaceMarker(getWorkgroupMarker(), - getWorkgroupMemoryMarker(), context)); -} - -//===----------------------------------------------------------------------===// -// Patterns to tile computation to map to subgroups -//===----------------------------------------------------------------------===// - -/// Computes the Value for subgroupID along each dimension given number of -/// subgroups `numSubGroups` along each dimension (x-first, y-second, z-third). -static SmallVector<linalg::ProcInfo, 2> getSubgroupIdsAndCounts( - OpBuilder &builder, Location loc, ArrayRef<int64_t> numSubgroups) { - Type indexType = builder.getIndexType(); - Value subgroupId = builder.create<gpu::SubgroupIdOp>(loc, indexType); - SmallVector<linalg::ProcInfo, 2> procInfo(numSubgroups.size()); - - // subgroupID - // = id.z * nsubgroups.y * nsubgroups.x + id.y * nsubgroups.x + id.x - for (size_t i = 0, e = numSubgroups.size(); i != e; ++i) { - Value nprocs = builder.create<ConstantIndexOp>(loc, numSubgroups[i]); - AffineExpr d0 = getAffineDimExpr(0, builder.getContext()); - AffineExpr s0 = getAffineSymbolExpr(0, builder.getContext()); - Value procId = - makeComposedAffineApply(builder, loc, d0 % s0, {subgroupId, nprocs}); - procInfo[e - i - 1] = linalg::ProcInfo{procId, nprocs}; - subgroupId = builder.create<SignedDivIOp>(loc, subgroupId, nprocs); - } - return procInfo; -} - -namespace { -/// Pattern to tile linalg.matmul for subgroups. -struct TileMatmulSubgroupPattern - : public linalg::LinalgTilingPattern<linalg::MatmulOp> { - using Base = linalg::LinalgTilingPattern<linalg::MatmulOp>; - TileMatmulSubgroupPattern(MLIRContext *context, - linalg::LinalgTilingOptions options, - linalg::LinalgTransformationFilter marker, - PatternBenefit benefit = 1) - : Base(context, options, marker, benefit) {} -}; -} // namespace - -/// Patterns for second level tiling to target subgroups. -static void populateTilingToSubgroupPatterns(MLIRContext *context, - RewritePatternSet &patterns) { - auto getInnerTileSizeFn = [&](OpBuilder &builder, - Operation *operation) -> SmallVector<Value, 4> { - SmallVector<int64_t> tileSizes = getTileSizes(operation, 1); - return llvm::to_vector<4>( - llvm::map_range(tileSizes, [&](int64_t v) -> Value { - return builder.create<ConstantIndexOp>(operation->getLoc(), v); - })); - }; - - auto getSubgroupProcInfoFn = [&](OpBuilder &builder, Location loc, - ArrayRef<Range> parallelLoopRanges) { - // TODO(ravishankarm): For now assume that there is always a single subgroup - std::array<int64_t, 3> numSubgroups = {1, 1, 1}; - return getSubgroupIdsAndCounts(builder, loc, numSubgroups); - }; - - linalg::LinalgLoopDistributionOptions subgroupDistributionOptions; - subgroupDistributionOptions.procInfo = getSubgroupProcInfoFn; - subgroupDistributionOptions.distributionMethod = { - {linalg::DistributionMethod::CyclicNumProcsEqNumIters, - linalg::DistributionMethod::CyclicNumProcsEqNumIters}}; - - patterns.insert<TileMatmulSubgroupPattern>( - context, - linalg::LinalgTilingOptions() - .setLoopType(linalg::LinalgTilingLoopType::ParallelLoops) - .setTileSizeComputationFunction(getInnerTileSizeFn) - .setDistributionOptions(subgroupDistributionOptions), - getLinalgMatchAndReplaceMarker( - {getWorkgroupMemoryMarker(), getWorkgroupMarker()}, - getVectorizeMarker(), context)); -} - -//===----------------------------------------------------------------------===// -// Patterns and methods for thread tiling. -//===----------------------------------------------------------------------===// - -/// Patterns for third level tiling to target invocations. -static void populateTilingToInvocationPatterns(MLIRContext *context, - RewritePatternSet &patterns) { - linalg::TileSizeComputationFunction getInnerTileSizeFn = - [&](OpBuilder &builder, Operation *operation) { - SmallVector<int64_t> tileSizes = getTileSizes(operation, 2); - return llvm::to_vector<4>( - llvm::map_range(tileSizes, [&](int64_t v) -> Value { - return builder.create<ConstantIndexOp>(operation->getLoc(), v); - })); - }; - - auto getThreadProcInfoFn = [](OpBuilder &builder, Location loc, - ArrayRef<Range> parallelLoopRanges) { - return getGPUProcessorIdsAndCounts<gpu::ThreadIdOp, gpu::BlockDimOp>( - builder, loc, parallelLoopRanges.size()); - }; - linalg::LinalgLoopDistributionOptions invocationDistributionOptions; - invocationDistributionOptions.procInfo = getThreadProcInfoFn; - invocationDistributionOptions.distributionMethod = { - {linalg::DistributionMethod::Cyclic, linalg::DistributionMethod::Cyclic, - linalg::DistributionMethod::Cyclic}}; - - auto tilingOptions = - linalg::LinalgTilingOptions() - .setLoopType(linalg::LinalgTilingLoopType::Loops) - .setTileSizeComputationFunction(getInnerTileSizeFn) - .setDistributionOptions(invocationDistributionOptions); - - patterns.insert<linalg::LinalgTilingPattern<linalg::MatmulOp>, - linalg::LinalgTilingPattern<linalg::FillOp>, - linalg::LinalgTilingPattern<linalg::BatchMatmulOp>, - linalg::LinalgTilingPattern<linalg::Conv1DNwcWcfOp>, - linalg::LinalgTilingPattern<linalg::Conv3DNdhwcDhwcfOp>, - linalg::LinalgTilingPattern<linalg::DepthwiseConv2DNhwcOp>, - linalg::LinalgTilingPattern<linalg::GenericOp>, - linalg::LinalgTilingPattern<linalg::PoolingNhwcMaxOp>, - linalg::LinalgTilingPattern<linalg::PoolingNhwcMinOp>, - linalg::LinalgTilingPattern<linalg::PoolingNhwcSumOp>>( - context, tilingOptions, - getLinalgMatchAndReplaceMarker( - {getWorkgroupMemoryMarker(), getWorkgroupMarker()}, - getVectorizeMarker(), context)); - - patterns.insert<linalg::LinalgTilingPattern<linalg::Conv2DNhwcHwcfOp>, - linalg::LinalgTilingPattern<linalg::DepthwiseConv2DNhwOp>>( - context, tilingOptions, - getLinalgMatchAndReplaceMarker( - {getWorkgroupMemoryMarker(), getWorkgroupMarker()}, - getConvFilterTileMarker(), context)); -} - -/// Returns the corresponding range for the given `processorValue` is a GPU -/// thread id or block dim. -static Optional<std::pair<AffineExpr, AffineExpr>> getThreadRange( - Value processorValue, SmallVectorImpl<Value> & /*dims*/, - SmallVectorImpl<Value> & /*symbols*/, ArrayRef<int64_t> workgroupSize) { - if (auto idOp = processorValue.getDefiningOp<gpu::ThreadIdOp>()) { - OpBuilder builder(processorValue.getContext()); - unsigned index = dimToIndex(idOp.dimension()); - AffineExpr zero = builder.getAffineConstantExpr(0); - AffineExpr ubExpr = builder.getAffineConstantExpr(workgroupSize[index]); - return std::make_pair(zero, ubExpr - 1); - } - if (auto dimOp = processorValue.getDefiningOp<gpu::BlockDimOp>()) { - OpBuilder builder(processorValue.getContext()); - unsigned index = dimToIndex(dimOp.dimension()); - AffineExpr bound = builder.getAffineConstantExpr(workgroupSize[index]); - return std::make_pair(bound, bound); - } - return llvm::None; -} - -//====---------------------------------------------------------------------===// -// Patterns for vectorization -//====---------------------------------------------------------------------===// - -static void populateVectorizationPatterns(MLIRContext *context, - RewritePatternSet &patterns) { - linalg::insertVectorizationPatterns<linalg::FillOp, linalg::GenericOp, - linalg::ContractionOpInterface>( - patterns, linalg::LinalgVectorizationOptions(), - linalg::LinalgTransformationFilter( - Identifier::get(getVectorizeMarker(), context))); -} - -//====---------------------------------------------------------------------===// -// Patterns for unrolling vectors -//====---------------------------------------------------------------------===// - -static void populateVectorUnrollPatterns(MLIRContext *context, - RewritePatternSet &patterns) { - vector::populateVectorUnrollPatterns( - patterns, - vector::UnrollVectorOptions().setNativeShapeFn(getSPIRVNativeVectorSize)); -} - -namespace { - -/// Workaround SPIR-V backend limitations. SPIR-V vetorization pass relies on -/// unrolling to reduce instructions to a vector size we can convert to SPIR-V. -/// When vectorization creates transpose those block unrolling and result in -/// large vector we currently cannot lower. For now we always merge the -/// transpose into the contract op so that it can be unrolled. -// TODO(thomasraoux): Make transpose work with the current unrolling mechanism -// or replace unrolling. -class CombineContractTranspose final - : public OpRewritePattern<vector::ContractionOp> { - public: - using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; - - LogicalResult matchAndRewrite(vector::ContractionOp op, - PatternRewriter &rewriter) const override { - // Perform lhs + rhs transpositions to conform to matmul row-major - // semantics. Bail out if the contraction cannot be put in this form. - MLIRContext *ctx = op.getContext(); - Location loc = op.getLoc(); - bool foundTranspose = false; - std::array<Value, 3> sources = {op.lhs(), op.rhs(), op.acc()}; - SmallVector<AffineMap> newMaps; - SmallVector<Value> newSources; - for (auto source : llvm::enumerate(sources)) { - auto map = - op.indexing_maps()[source.index()].cast<AffineMapAttr>().getValue(); - auto tranposeOp = source.value().getDefiningOp<vector::TransposeOp>(); - if (!tranposeOp) { - newSources.push_back(source.value()); - newMaps.push_back(map); - continue; - } - SmallVector<int64_t, 3> perm; - tranposeOp.getTransp(perm); - SmallVector<AffineExpr> exprs(perm.size()); - for (auto remap : llvm::enumerate(perm)) { - exprs[remap.value()] = map.getResult(remap.index()); - } - newMaps.push_back( - AffineMap::get(map.getNumDims(), map.getNumSymbols(), exprs, ctx)); - newSources.push_back(tranposeOp.vector()); - foundTranspose = true; - } - if (!foundTranspose) return failure(); - - Value res = rewriter.create<vector::ContractionOp>( - loc, newSources[0], newSources[1], newSources[2], - rewriter.getAffineMapArrayAttr(newMaps), op.iterator_types()); - rewriter.replaceOp(op, res); - return success(); - } -}; - -} // namespace - -//====---------------------------------------------------------------------===// -// Vector patterns -//====---------------------------------------------------------------------===// - -static void applyVectorTransformation(FuncOp funcOp) { - auto targetEnv = spirv::TargetEnv(spirv::lookupTargetEnv(funcOp)); - bool useCooperativeMatrix = - targetEnv.allows(spirv::Capability::CooperativeMatrixNV) && - targetEnv.allows(spirv::Extension::SPV_NV_cooperative_matrix); - { - { - RewritePatternSet vectorUnrollPatterns(funcOp.getContext()); - populateVectorUnrollPatterns(funcOp.getContext(), vectorUnrollPatterns); - (void)applyPatternsAndFoldGreedily(funcOp, - std::move(vectorUnrollPatterns)); - } - { - linalg::hoistRedundantVectorTransfers(funcOp); - - LLVM_DEBUG({ - llvm::dbgs() << "--- After hoisting vector transfers ---\n"; - funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); - llvm::dbgs() << "\n\n"; - }); - } - { - RewritePatternSet canonicalizationPatterns2(funcOp.getContext()); - vector::populateVectorTransferPermutationMapLoweringPatterns( - canonicalizationPatterns2); - (void)applyPatternsAndFoldGreedily(funcOp, - std::move(canonicalizationPatterns2)); - - if (useCooperativeMatrix) { - // When using cooperative matrix we don't want to lower the contract, - // instead we want to merge contract and transpose so that they can be - // converted to cooperative matrix matmul op. - // TODO(thomasraoux): remove that once we support cooperative matrix - // lowering in MLIR core. - RewritePatternSet combineTransposePatterns(funcOp.getContext()); - combineTransposePatterns.add<CombineContractTranspose>( - funcOp.getContext()); - (void)applyPatternsAndFoldGreedily(funcOp, - std::move(combineTransposePatterns)); - } else { - RewritePatternSet contractLoweringPatterns(funcOp.getContext()); - vector::populateVectorContractLoweringPatterns( - contractLoweringPatterns, - vector::VectorTransformsOptions().setVectorTransformsOptions( - vector::VectorContractLowering::OuterProduct)); - (void)applyPatternsAndFoldGreedily(funcOp, - std::move(contractLoweringPatterns)); - } - } - LLVM_DEBUG({ - llvm::dbgs() << "--- After unrolling vector ---\n"; - funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); - llvm::dbgs() << "\n\n"; - }); - } -} - -//====---------------------------------------------------------------------===// -// Patterns to tile convolution window dimensions -//====---------------------------------------------------------------------===// - -static void populateTilingConvFilterPatterns( - MLIRContext *context, RewritePatternSet &patterns, - linalg::LinalgTransformationFilter marker) { - auto getTileSizeFn = [&](OpBuilder &builder, Operation *op) { - SmallVector<Value, 4> tileSizes; - SmallVector<int64_t, 4> fourthLevel = getTileSizes(op, 3); - tileSizes.reserve(fourthLevel.size()); - - Location loc = op->getLoc(); - for (int64_t size : fourthLevel) { - tileSizes.push_back(builder.create<ConstantIndexOp>(loc, size)); - } - return tileSizes; - }; - - auto tilingOptions = linalg::LinalgTilingOptions() - .setLoopType(linalg::LinalgTilingLoopType::Loops) - .setTileSizeComputationFunction(getTileSizeFn); - - patterns.insert<linalg::LinalgTilingPattern<linalg::Conv2DNhwcHwcfOp>, - linalg::LinalgTilingPattern<linalg::DepthwiseConv2DNhwcOp>, - linalg::LinalgTilingPattern<linalg::DepthwiseConv2DNhwOp>>( - context, tilingOptions, marker); -} - -//====---------------------------------------------------------------------===// -// Patterns to lower linalg ops to loops -//====---------------------------------------------------------------------===// - -template <typename OpTy> -struct LowerToLoops final : public OpRewritePattern<OpTy> { - using OpRewritePattern<OpTy>::OpRewritePattern; - - LogicalResult matchAndRewrite(OpTy op, - PatternRewriter &rewriter) const override { - // Only handle the cases where tiling to invocations was done, where tiling - // convolution filters or vectorization is expected. - if (!hasMarker(op, {getConvFilterTileMarker(), getVectorizeMarker()})) - return failure(); - - if (linalg::linalgOpToLoops(rewriter, op)) { - rewriter.eraseOp(op); - return success(); - } - - return failure(); - } -}; - -//====---------------------------------------------------------------------===// -// Main pass implementation -//====---------------------------------------------------------------------===// - -void SPIRVTileAndVectorizePass::runOnOperation() { - MLIRContext *context = &getContext(); - FuncOp funcOp = getOperation(); - auto entryPointOp = getEntryPoint(funcOp); - if (!entryPointOp) return; - - // TODO(thomasraoux, antiagainst): Tiling to subgroups shouldn't be - // controlled by vectorization. This is needed due to historical reasons. - // Change the second level tiling to cyclic to loops and remove this. - RewritePatternSet secondLevelTilingPatterns(&getContext()); - populateTilingToSubgroupPatterns(context, secondLevelTilingPatterns); - (void)applyPatternsAndFoldGreedily(funcOp, - std::move(secondLevelTilingPatterns)); - - RewritePatternSet secondLevelTilingCanonicalizationPatterns = - linalg::getLinalgTilingCanonicalizationPatterns(context); - populateAffineMinCanonicalizationPattern( - secondLevelTilingCanonicalizationPatterns); - (void)applyPatternsAndFoldGreedily( - funcOp, std::move(secondLevelTilingCanonicalizationPatterns)); - promoteSingleIterationLoops(funcOp); - - LLVM_DEBUG({ - llvm::dbgs() << "--- After tiling to subgroups ---\n"; - funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); - llvm::dbgs() << "\n\n"; - }); - - { - RewritePatternSet thirdLevelTilingPatterns(&getContext()); - populateTilingToInvocationPatterns(context, thirdLevelTilingPatterns); - (void)applyPatternsAndFoldGreedily(funcOp, - std::move(thirdLevelTilingPatterns)); - - // Remove trip-one loops created during cyclic loop distribution if we can - // prove the tiling was perfect. - RewritePatternSet canoncalizationPatterns(context); - populateAffineMinSCFCanonicalizationPattern(canoncalizationPatterns); - SmallVector<int64_t> workgroupSize = getWorkgroupSize(entryPointOp); - if (workgroupSize.empty()) { - entryPointOp.emitError("expected to have workgroup_size attribute"); - return signalPassFailure(); - } - auto getThreadRangeFn = [workgroupSize](Value processorValue, - SmallVectorImpl<Value> &dims, - SmallVectorImpl<Value> &symbols) { - return getThreadRange(processorValue, dims, symbols, workgroupSize); - }; - populateRemoveSingleIterationLoopPattern(canoncalizationPatterns, - getThreadRangeFn); - (void)applyPatternsAndFoldGreedily(funcOp, - std::move(canoncalizationPatterns)); - - // Perform generic canonicalization. - RewritePatternSet threadLevelTilingCanonicalizationPatterns = - linalg::getLinalgTilingCanonicalizationPatterns(context); - populateAffineMinCanonicalizationPattern( - threadLevelTilingCanonicalizationPatterns); - (void)applyPatternsAndFoldGreedily( - funcOp, std::move(threadLevelTilingCanonicalizationPatterns)); - - LLVM_DEBUG({ - llvm::dbgs() << "--- After tiling to invocations ---\n"; - funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); - llvm::dbgs() << "\n\n"; - }); - } - - { - RewritePatternSet tilingPatterns(&getContext()); - auto marker = getLinalgMatchAndReplaceMarker(getConvFilterTileMarker(), - getVectorizeMarker(), context); - populateTilingConvFilterPatterns(context, tilingPatterns, marker); - populateFoldGPUProcessorIDUsesPatterns(context, tilingPatterns); - tilingPatterns.insert<linalg::AffineMinSCFCanonicalizationPattern>(context); - (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPatterns)); - - RewritePatternSet convTilingCanonicalizationPatterns = - linalg::getLinalgTilingCanonicalizationPatterns(context); - populateAffineMinCanonicalizationPattern( - convTilingCanonicalizationPatterns); - (void)applyPatternsAndFoldGreedily( - funcOp, std::move(convTilingCanonicalizationPatterns)); - - LLVM_DEBUG({ - llvm::dbgs() << "--- After tiling convolution filter ---\n"; - funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); - llvm::dbgs() << "\n\n"; - }); - } - - { - RewritePatternSet vectorizationPatterns(&getContext()); - populateVectorizationPatterns(context, vectorizationPatterns); - populateLinalgToVectorVectorizeConvPatterns(context, vectorizationPatterns); - (void)applyPatternsAndFoldGreedily(funcOp, - std::move(vectorizationPatterns)); - LLVM_DEBUG({ - llvm::dbgs() << "--- After vectorization ---\n"; - funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); - llvm::dbgs() << "\n\n"; - }); - } - - // TODO: This should be a folding of Add into Contract in core but while - // they live in different dialects, it is not possible without unnatural - // dependencies. - funcOp.walk([&](Operation *op) { - if (auto contract = canonicalizeContractionAdd(op)) - op->replaceAllUsesWith(contract); - }); - - applyVectorTransformation(funcOp); - - // Lower ops that were tiled to invocations but not vectorized to loops. - // TODO(antiagainst): This is here now to simplify the interaction with - // ConvertToGPUPass, where we finally lower away all Linalg ops. Once that - // pass is cleaned up, we can invoke createConvertLinalgToLoopsPass - // directly. - { - RewritePatternSet patterns(context); - patterns.add<LowerToLoops<linalg::BatchMatmulOp>, - LowerToLoops<linalg::Conv1DNwcWcfOp>, - LowerToLoops<linalg::Conv2DNhwcHwcfOp>, - LowerToLoops<linalg::Conv3DNdhwcDhwcfOp>, - LowerToLoops<linalg::DepthwiseConv2DNhwcOp>, - LowerToLoops<linalg::DepthwiseConv2DNhwOp>, - LowerToLoops<linalg::FillOp>, LowerToLoops<linalg::GenericOp>, - LowerToLoops<linalg::MatmulOp>, - LowerToLoops<linalg::PoolingNhwcMaxOp>, - LowerToLoops<linalg::PoolingNhwcMinOp>, - LowerToLoops<linalg::PoolingNhwcSumOp>>(context); - (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); - } -} - -//===----------------------------------------------------------------------===// -// Pass entry point and registration -//===----------------------------------------------------------------------===// - -std::unique_ptr<OperationPass<FuncOp>> createSPIRVTileAndVectorizePass() { - return std::make_unique<SPIRVTileAndVectorizePass>(); -} - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/Codegen/SPIRV/SPIRVVectorize.cpp b/iree/compiler/Codegen/SPIRV/SPIRVVectorize.cpp index 108bf96..3ea5b03 100644 --- a/iree/compiler/Codegen/SPIRV/SPIRVVectorize.cpp +++ b/iree/compiler/Codegen/SPIRV/SPIRVVectorize.cpp
@@ -6,8 +6,297 @@ //===- SPIRVVectorize.cpp -------------------------------------------------===// // -// This pass vectorizes Linalg ops on buffers within in a single workgroup. +// This pass vectorizes Linalg ops with buffer semantics. // //===----------------------------------------------------------------------===// +#include "iree/compiler/Codegen/PassDetail.h" +#include "iree/compiler/Codegen/Passes.h" +#include "iree/compiler/Codegen/Transforms/Transforms.h" +#include "iree/compiler/Codegen/Utils/MarkerUtils.h" +#include "iree/compiler/Codegen/Utils/Utils.h" +#include "llvm/Support/Debug.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/Transforms/Hoisting.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" +#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/Dialect/Vector/VectorTransforms.h" +#include "mlir/Interfaces/VectorInterfaces.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + #define DEBUG_TYPE "iree-spirv-vectorize" + +namespace mlir { +namespace iree_compiler { +namespace { + +/// Returns the cooperative matrix (M, N, K) sizes that are supported by the +/// target environment and match the given parameters. +static Optional<SmallVector<int64_t, 4>> getCooperativeMatmulSubgroupSize( + spirv::ResourceLimitsAttr resourceLimits, Type lhsType, Type rhsType, + Type initType, Type resultType) { + auto range = resourceLimits.cooperative_matrix_properties_nv() + .getAsRange<spirv::CooperativeMatrixPropertiesNVAttr>(); + for (auto coopMatmulProperties : range) { + if (coopMatmulProperties.a_type().getValue() == lhsType && + coopMatmulProperties.b_type().getValue() == rhsType && + coopMatmulProperties.c_type().getValue() == initType && + coopMatmulProperties.result_type().getValue() == resultType && + coopMatmulProperties.scope().getValue() == spirv::Scope::Subgroup) { + return SmallVector<int64_t, 4>{ + coopMatmulProperties.m_size().getValue().getSExtValue(), + coopMatmulProperties.n_size().getValue().getSExtValue(), + coopMatmulProperties.k_size().getValue().getSExtValue()}; + } + } + return llvm::None; +} + +/// Returns true if the target environment attached to `op`'s ancestor op +/// supports cooperative matrix. +bool useCooperativeMatrix(Operation *op) { + auto targetEnv = spirv::TargetEnv(spirv::lookupTargetEnv(op)); + return targetEnv.allows(spirv::Capability::CooperativeMatrixNV) && + targetEnv.allows(spirv::Extension::SPV_NV_cooperative_matrix); +} + +Optional<SmallVector<int64_t, 4>> getSPIRVNativeVectorSize(Operation *op) { + auto targetEnv = spirv::TargetEnv(spirv::lookupTargetEnv(op)); + bool useCoopMatrix = useCooperativeMatrix(op); + + if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1) { + if (auto vecType = op->getResultTypes()[0].dyn_cast<VectorType>()) { + // Use 4-element vectors for elementwise ops. + SmallVector<int64_t, 4> nativeSize(vecType.getRank(), 1); + nativeSize.back() = 4; + return nativeSize; + } + } else if (auto vtOp = dyn_cast<VectorTransferOpInterface>(op)) { + if (useCoopMatrix) { + if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) { + // Unroll cooperative martrix store based on the size of the contract. + auto insert = + writeOp.vector().getDefiningOp<vector::InsertStridedSliceOp>(); + if (!insert) return llvm::None; + ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape(); + return SmallVector<int64_t, 4>(shape.begin(), shape.end()); + } else if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) { + // Unroll cooperative martrix load based on the size of the contract. + VectorType dstVec; + for (Operation *users : op->getUsers()) { + auto extract = dyn_cast<vector::ExtractStridedSliceOp>(users); + if (!extract) return llvm::None; + auto vecType = extract.getResult().getType().cast<VectorType>(); + if (dstVec && dstVec != vecType) return llvm::None; + dstVec = vecType; + } + return SmallVector<int64_t, 4>(dstVec.getShape().begin(), + dstVec.getShape().end()); + } + } else { + auto rank = vtOp.getVectorType().getRank(); + SmallVector<int64_t, 4> nativeSize(rank, 1); + for (auto dim : llvm::enumerate(vtOp.permutation_map().getResults())) { + if (auto dimExpr = dim.value().dyn_cast<AffineDimExpr>()) { + if (dimExpr.getPosition() == vtOp.permutation_map().getNumDims() - 1) + nativeSize[dim.index()] = 4; + } + } + return nativeSize; + } + } else if (auto contractOp = dyn_cast<vector::ContractionOp>(op)) { + if (useCoopMatrix) { + return getCooperativeMatmulSubgroupSize( + targetEnv.getResourceLimits(), + contractOp.getLhsType().getElementType(), + contractOp.getRhsType().getElementType(), + contractOp.getAccType().cast<VectorType>().getElementType(), + contractOp.getResultType().cast<VectorType>().getElementType()); + } else { + unsigned lastParalleldim = 0; + for (auto it : llvm::enumerate(contractOp.iterator_types())) { + if (isParallelIterator(it.value())) lastParalleldim = it.index(); + } + SmallVector<int64_t, 4> nativeSize(contractOp.iterator_types().size(), 1); + nativeSize[lastParalleldim] = 4; + // Map to vec4 fma operations. + return nativeSize; + } + } + return llvm::None; +} + +/// Add patterns to vectorize Linalg ops with vectorization marker. +void populateVectorizationPatterns(MLIRContext *context, + RewritePatternSet &patterns) { + linalg::insertVectorizationPatterns<linalg::FillOp, linalg::GenericOp, + linalg::ContractionOpInterface>( + patterns, linalg::LinalgVectorizationOptions(), + linalg::LinalgTransformationFilter( + Identifier::get(getVectorizeMarker(), context))); +} + +/// Adds patterns to unroll vector ops to SPIR-V native vector size. +void populateVectorUnrollPatterns(MLIRContext *context, + RewritePatternSet &patterns) { + vector::populateVectorUnrollPatterns( + patterns, + vector::UnrollVectorOptions().setNativeShapeFn(getSPIRVNativeVectorSize)); +} + +/// Workaround SPIR-V backend limitations. SPIR-V vetorization pass relies on +/// unrolling to reduce instructions to a vector size we can convert to SPIR-V. +/// When vectorization creates transpose those block unrolling and result in +/// large vector we currently cannot lower. For now we always merge the +/// transpose into the contract op so that it can be unrolled. +// TODO(thomasraoux): Make transpose work with the current unrolling mechanism +// or replace unrolling. +class CombineContractTranspose final + : public OpRewritePattern<vector::ContractionOp> { + public: + using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ContractionOp op, + PatternRewriter &rewriter) const override { + // Perform lhs + rhs transpositions to conform to matmul row-major + // semantics. Bail out if the contraction cannot be put in this form. + MLIRContext *ctx = op.getContext(); + Location loc = op.getLoc(); + bool foundTranspose = false; + std::array<Value, 3> sources = {op.lhs(), op.rhs(), op.acc()}; + SmallVector<AffineMap> newMaps; + SmallVector<Value> newSources; + for (auto source : llvm::enumerate(sources)) { + auto map = + op.indexing_maps()[source.index()].cast<AffineMapAttr>().getValue(); + auto tranposeOp = source.value().getDefiningOp<vector::TransposeOp>(); + if (!tranposeOp) { + newSources.push_back(source.value()); + newMaps.push_back(map); + continue; + } + SmallVector<int64_t, 3> perm; + tranposeOp.getTransp(perm); + SmallVector<AffineExpr> exprs(perm.size()); + for (auto remap : llvm::enumerate(perm)) { + exprs[remap.value()] = map.getResult(remap.index()); + } + newMaps.push_back( + AffineMap::get(map.getNumDims(), map.getNumSymbols(), exprs, ctx)); + newSources.push_back(tranposeOp.vector()); + foundTranspose = true; + } + if (!foundTranspose) return failure(); + + Value res = rewriter.create<vector::ContractionOp>( + loc, newSources[0], newSources[1], newSources[2], + rewriter.getAffineMapArrayAttr(newMaps), op.iterator_types()); + rewriter.replaceOp(op, res); + return success(); + } +}; + +/// Vectorizes Linalg ops on buffer semantics. +class SPIRVVectorizePass : public SPIRVVectorizeBase<SPIRVVectorizePass> { + public: + SPIRVVectorizePass() = default; + SPIRVVectorizePass(const SPIRVVectorizePass &pass) = default; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert<linalg::LinalgDialect, vector::VectorDialect>(); + } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + FuncOp funcOp = getOperation(); + + auto entryPointOp = getEntryPoint(funcOp); + if (!entryPointOp) return; + + { + RewritePatternSet vectorizationPatterns(&getContext()); + populateVectorizationPatterns(context, vectorizationPatterns); + populateLinalgToVectorVectorizeConvPatterns(context, + vectorizationPatterns); + (void)applyPatternsAndFoldGreedily(funcOp, + std::move(vectorizationPatterns)); + + LLVM_DEBUG({ + llvm::dbgs() << "--- After vectorization ---\n"; + funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); + llvm::dbgs() << "\n\n"; + }); + } + + // TODO: This should be a folding of Add into Contract in core but while + // they live in different dialects, it is not possible without unnatural + // dependencies. + funcOp.walk([&](Operation *op) { + if (auto contract = canonicalizeContractionAdd(op)) + op->replaceAllUsesWith(contract); + }); + + { + RewritePatternSet vectorUnrollPatterns(funcOp.getContext()); + populateVectorUnrollPatterns(funcOp.getContext(), vectorUnrollPatterns); + (void)applyPatternsAndFoldGreedily(funcOp, + std::move(vectorUnrollPatterns)); + } + + { + linalg::hoistRedundantVectorTransfers(funcOp); + + LLVM_DEBUG({ + llvm::dbgs() << "--- After hoisting vector transfers ---\n"; + funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); + llvm::dbgs() << "\n\n"; + }); + } + + { + RewritePatternSet canonicalizationPatterns(funcOp.getContext()); + vector::populateVectorTransferPermutationMapLoweringPatterns( + canonicalizationPatterns); + (void)applyPatternsAndFoldGreedily(funcOp, + std::move(canonicalizationPatterns)); + + if (useCooperativeMatrix(funcOp)) { + // When using cooperative matrix we don't want to lower the contract, + // instead we want to merge contract and transpose so that they can be + // converted to cooperative matrix matmul op. + // TODO(thomasraoux): remove that once we support cooperative matrix + // lowering in MLIR core. + RewritePatternSet combineTransposePatterns(funcOp.getContext()); + combineTransposePatterns.add<CombineContractTranspose>( + funcOp.getContext()); + (void)applyPatternsAndFoldGreedily(funcOp, + std::move(combineTransposePatterns)); + } else { + RewritePatternSet contractLoweringPatterns(funcOp.getContext()); + vector::populateVectorContractLoweringPatterns( + contractLoweringPatterns, + vector::VectorTransformsOptions().setVectorTransformsOptions( + vector::VectorContractLowering::OuterProduct)); + (void)applyPatternsAndFoldGreedily(funcOp, + std::move(contractLoweringPatterns)); + } + } + + LLVM_DEBUG({ + llvm::dbgs() << "--- After unrolling vector ---\n"; + funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); + llvm::dbgs() << "\n\n"; + }); + } +}; + +} // namespace + +std::unique_ptr<OperationPass<FuncOp>> createSPIRVVectorizePass() { + return std::make_unique<SPIRVVectorizePass>(); +} + +} // namespace iree_compiler +} // namespace mlir
diff --git a/iree/compiler/Codegen/SPIRV/Utils.h b/iree/compiler/Codegen/SPIRV/Utils.h index b286642..a7f899b 100644 --- a/iree/compiler/Codegen/SPIRV/Utils.h +++ b/iree/compiler/Codegen/SPIRV/Utils.h
@@ -4,11 +4,12 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -//===- Utils.h - Utility functions used in Linalg to SPIR-V lowering ------===// +//===- Utils.h - Utility functions for lowering Linalg to SPIR-V ----------===// // -// Utility functions used while lowering from Linalg to SPIRV. +// Utility functions used while lowering from Linalg to SPIR-V. // //===----------------------------------------------------------------------===// + #ifndef IREE_COMPILER_CODEGEN_SPIRV_UTILS_H_ #define IREE_COMPILER_CODEGEN_SPIRV_UTILS_H_
diff --git a/iree/compiler/Codegen/SPIRV/test/BUILD b/iree/compiler/Codegen/SPIRV/test/BUILD index 0f3bc85..2624a6a 100644 --- a/iree/compiler/Codegen/SPIRV/test/BUILD +++ b/iree/compiler/Codegen/SPIRV/test/BUILD
@@ -19,8 +19,8 @@ name = "lit", srcs = enforce_glob( [ - "convert_to_gpu.mlir", "convert_to_spirv.mlir", + "distribute_to_global_id.mlir", "fold_gpu_procid_uses.mlir", "pipeline_matmul_cooperative_matrix.mlir", "pipeline_matmul_vectorization.mlir",
diff --git a/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt b/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt index cc4083a..9f99d11 100644 --- a/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt +++ b/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt
@@ -14,8 +14,8 @@ NAME lit SRCS - "convert_to_gpu.mlir" "convert_to_spirv.mlir" + "distribute_to_global_id.mlir" "fold_gpu_procid_uses.mlir" "pipeline_matmul_cooperative_matrix.mlir" "pipeline_matmul_vectorization.mlir"
diff --git a/iree/compiler/Codegen/SPIRV/test/convert_to_gpu.mlir b/iree/compiler/Codegen/SPIRV/test/distribute_to_global_id.mlir similarity index 99% rename from iree/compiler/Codegen/SPIRV/test/convert_to_gpu.mlir rename to iree/compiler/Codegen/SPIRV/test/distribute_to_global_id.mlir index 729557e..8ba2fd0 100644 --- a/iree/compiler/Codegen/SPIRV/test/convert_to_gpu.mlir +++ b/iree/compiler/Codegen/SPIRV/test/distribute_to_global_id.mlir
@@ -1,4 +1,4 @@ -// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(builtin.module(builtin.func(iree-spirv-convert-to-gpu))))' -canonicalize -cse %s | IreeFileCheck %s +// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(builtin.module(builtin.func(iree-spirv-distribute-to-global-id))))' -canonicalize -cse %s | IreeFileCheck %s #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> hal.executable @parallel_4D attributes {sym_visibility = "private"} {
diff --git a/iree/compiler/Codegen/SPIRV/test/promote_workgroup_memory.mlir b/iree/compiler/Codegen/SPIRV/test/promote_workgroup_memory.mlir index b257355..adda0d2 100644 --- a/iree/compiler/Codegen/SPIRV/test/promote_workgroup_memory.mlir +++ b/iree/compiler/Codegen/SPIRV/test/promote_workgroup_memory.mlir
@@ -1,4 +1,4 @@ -// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(builtin.module(builtin.func(iree-spirv-tile-and-vectorize,canonicalize,cse))))' +// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(builtin.module(builtin.func(iree-spirv-tile-and-distribute,iree-spirv-vectorize,canonicalize,cse))))' // TODO(antiagainst): Fix promotion to workgroup and enable the test. // | IreeFileCheck %s
diff --git a/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize.mlir b/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize.mlir index 78d41ab..0ab91cb 100644 --- a/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize.mlir +++ b/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize.mlir
@@ -1,4 +1,4 @@ -// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(builtin.module(builtin.func(iree-spirv-tile-and-vectorize,canonicalize,cse))))' %s | IreeFileCheck %s +// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(builtin.module(builtin.func(iree-spirv-tile-and-distribute,iree-spirv-vectorize,canonicalize,cse))))' %s | IreeFileCheck %s #map0 = affine_map<()[s0] -> (s0 * 8)> #map1 = affine_map<()[s0, s1] -> (8, s1 - s0 * 8)> @@ -56,7 +56,7 @@ %18 = memref.subview %arg2[%3, %10] [%15, %17] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map3> linalg.matmul {__internal_linalg_transform__ = "workgroup", lowering.config = #config} ins(%7, %13 : memref<?x?xf32, #map3>, memref<?x?xf32, #map3>) - outs(%18 : memref<?x?xf32, #map3>) + outs(%18 : memref<?x?xf32, #map3>) } return } @@ -98,7 +98,7 @@ translation.info = {passPipeline = 6 : i32, workloadPerWorkgroup = [32, 4, 1]} } module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative], [SPV_KHR_storage_buffer_storage_class]>, SwiftShader:CPU, {cooperative_matrix_properties_nv = [], max_compute_shared_memory_size = 16384 : i32, max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>, subgroup_size = 4 : i32}>} { - func @conv_1d() attributes {spv.entry_point_abi = {local_size = dense<[32, 4, 1]> : vector<3xi32>}} { + func @conv_1d() { %cst = constant 0.000000e+00 : f32 %c0 = constant 0 : index %0 = hal.interface.binding.subspan @io::@ret0[%c0] : memref<3x6x1xf32> @@ -132,9 +132,7 @@ } // CHECK-LABEL: func @conv_1d -// CHECK-DAG: %[[C0:.+]] = constant 0 : index -// CHECK-DAG: %[[C1:.+]] = constant 1 : index -// CHECK-DAG: %[[C3:.+]] = constant 3 : index +// CHECK: %[[C0:.+]] = constant 0 : index // CHECK: %[[RET:.+]] = hal.interface.binding.subspan @io::@ret0 // CHECK: %[[ARG0:.+]] = hal.interface.binding.subspan @io::@arg0 // CHECK: %[[ARG1:.+]] = hal.interface.binding.subspan @io::@arg1 @@ -151,11 +149,10 @@ // CHECK: %[[ARG0SV2:.+]] = memref.subview %[[ARG0SV1]][%[[TIDZ]], %[[IV0]], 0] [1, %{{.+}}, 1] // CHECK: %[[ARG1SV2:.+]] = memref.subview %[[ARG1SV1]][0, 0, %[[IV1]]] [3, 1, 1] // CHECK: %[[RETSV2:.+]] = memref.subview %[[RETSV1]][%[[TIDZ]], %[[IV0]], %[[IV1]]] [1, 1, 1] -// CHECK: scf.for %[[IV2:.+]] = %[[C0]] to %[[C3]] step %[[C1]] -// CHECK: memref.load %[[ARG0SV2]][%[[C0]], %[[IV2]], %[[C0]]] -// CHECK: memref.load %[[ARG1SV2]][%[[IV2]], %[[C0]], %[[C0]]] -// CHECK: memref.load %[[RETSV2]][%[[C0]], %[[C0]], %[[C0]]] -// CHECK: memref.store %{{.+}}, %[[RETSV2]][%[[C0]], %[[C0]], %[[C0]]] +// CHECK: linalg.conv_1d_nwc_wcf +// CHECK-SAME: __internal_linalg_transform__ = "vectorize" +// CHECK-SAME: ins(%[[ARG0SV2]], %[[ARG1SV2]] +// CHECK-SAME: outs(%[[RETSV2]] // ----- @@ -270,22 +267,24 @@ // CHECK: %[[BSTEPY:.+]] = affine.apply #[[MAP0]]()[%[[NBLOCKSY]]] // CHECK: %[[BOFFSETX:.+]] = affine.apply #[[MAP1]]()[%[[BIDX]]] // CHECK: %[[BSTEPX:.+]] = affine.apply #[[MAP1]]()[%[[NBLOCKSX]]] -// CHECK: scf.for %[[IV3:.+]] = %[[BIDZ]] to %[[N]] step %[[NBLOCKSZ]] -// CHECK: scf.for %[[IV4:.+]] = %[[BOFFSETY]] to %[[P]] step %[[BSTEPY]] -// CHECK: scf.for %[[IV5:.+]] = %[[BOFFSETX]] to %[[Q]] step %[[BSTEPX]] -// CHECK: %[[SV1:.+]] = memref.subview %[[ARG1]][%[[IV3]], %[[IV4]], %[[IV5]], 0] -// CHECK: %[[SV2:.+]] = memref.subview %[[RET0]][%[[IV3]], %[[IV4]], %[[IV5]], 0] +// CHECK: scf.for %[[IV0:.+]] = %[[BIDZ]] to %[[N]] step %[[NBLOCKSZ]] +// CHECK: scf.for %[[IV1:.+]] = %[[BOFFSETY]] to %[[P]] step %[[BSTEPY]] +// CHECK: scf.for %[[IV2:.+]] = %[[BOFFSETX]] to %[[Q]] step %[[BSTEPX]] +// CHECK: %[[SV1:.+]] = memref.subview %[[ARG1]][%[[IV0]], %[[IV1]], %[[IV2]], 0] +// CHECK: %[[SV2:.+]] = memref.subview %[[RET0]][%[[IV0]], %[[IV1]], %[[IV2]], 0] // CHECK-DAG: %[[TIDX:.+]] = "gpu.thread_id"() {dimension = "x"} // CHECK-DAG: %[[TIDY:.+]] = "gpu.thread_id"() {dimension = "y"} // CHECK-DAG: %[[TIDZ:.+]] = "gpu.thread_id"() {dimension = "z"} // CHECK-DAG: %[[BDIMX:.+]] = "gpu.block_dim"() {dimension = "x"} // CHECK-DAG: %[[BDIMY:.+]] = "gpu.block_dim"() {dimension = "y"} // CHECK-DAG: %[[BDIMZ:.+]] = "gpu.block_dim"() {dimension = "z"} -// CHECK: scf.for %{{.+}} = %[[TIDZ]] to %{{.*}} step %[[BDIMZ]] -// CHECK: scf.for %{{.+}} = %[[TIDY]] to %{{.*}} step %[[BDIMY]] -// CHECK: scf.for %{{.+}} = %[[TIDX]] to %{{.*}} step %[[BDIMX]] -// CHECK-COUNT-3: scf.for -// CHECK-NOT: linalg.conv_2d_nhwc_hwcf +// CHECK: scf.for %[[IV3:.+]] = %[[TIDZ]] to %{{.*}} step %[[BDIMZ]] +// CHECK: scf.for %[[IV4:.+]] = %[[TIDY]] to %{{.*}} step %[[BDIMY]] +// CHECK: scf.for %[[IV5:.+]] = %[[TIDX]] to %{{.*}} step %[[BDIMX]] +// CHECK: %[[OUT:.+]] = memref.subview %[[SV2]][0, %[[IV3]], %[[IV4]], %[[IV5]]] +// CHECK: linalg.conv_2d_nhwc_hwcf +// CHECK-SAME: __internal_linalg_transform__ = "tile_conv_filter" +// CHECK-SAME: outs(%[[OUT]] // ----- @@ -304,7 +303,7 @@ translation.info = {passPipeline = 6 : i32, workloadPerWorkgroup = [32, 4, 1]} } module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative], [SPV_KHR_storage_buffer_storage_class]>, SwiftShader:CPU, {cooperative_matrix_properties_nv = [], max_compute_shared_memory_size = 16384 : i32, max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>, subgroup_size = 4 : i32}>} { - func @conv_3d() attributes {spv.entry_point_abi = {local_size = dense<[32, 4, 1]> : vector<3xi32>}} { + func @conv_3d() { %cst = constant 0.000000e+00 : f32 %c0 = constant 0 : index %0 = hal.interface.binding.subspan @io::@ret0[%c0] : memref<2x7x7x7x2xf32> @@ -343,11 +342,13 @@ // CHECK-DAG: %[[BDIMX:.+]] = "gpu.block_dim"() {dimension = "x"} // CHECK-DAG: %[[BDIMY:.+]] = "gpu.block_dim"() {dimension = "y"} // CHECK-DAG: %[[BDIMZ:.+]] = "gpu.block_dim"() {dimension = "z"} -// CHECK: scf.for %{{.+}} = %[[TIDZ]] to %{{.*}} step %[[BDIMZ]] -// CHECK: scf.for %{{.+}} = %[[TIDY]] to %{{.*}} step %[[BDIMY]] -// CHECK: scf.for %{{.+}} = %[[TIDX]] to %{{.*}} step %[[BDIMX]] -// CHECK-COUNT-5: scf.for -// CHECK-NOT: linalg.conv_3d_ndhwc_dhwcf +// CHECK: scf.for %[[IV0:.+]] = %[[TIDZ]] to %{{.*}} step %[[BDIMZ]] +// CHECK: scf.for %[[IV1:.+]] = %[[TIDY]] to %{{.*}} step %[[BDIMY]] +// CHECK: scf.for %[[IV2:.+]] = %[[TIDX]] to %{{.*}} step %[[BDIMX]] +// CHECK: %[[OUT:.+]] = memref.subview %{{.+}}[0, 0, %[[IV0]], %[[IV1]], %[[IV2]]] +// CHECK: linalg.conv_3d_ndhwc_dhwcf +// CHECK-SAME: __internal_linalg_transform__ = "vectorize" +// CHECK-SAME: outs(%[[OUT]] // ----- @@ -376,7 +377,7 @@ translation.info = {passPipeline = 6 : i32, workloadPerWorkgroup = [32, 4, 1]} } module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} { - func @pooling_nhwc_max() attributes {spv.entry_point_abi = {local_size = dense<[32, 4, 1]> : vector<3xi32>}} { + func @pooling_nhwc_max() { %c0 = constant 0 : index %0 = hal.interface.binding.subspan @io::@arg0[%c0] : memref<2x16x16x6xf32> %1 = hal.interface.binding.subspan @io::@arg1[%c0] : memref<3x4xf32> @@ -422,8 +423,12 @@ // CHECK-DAG: %[[BDIMX:.+]] = "gpu.block_dim"() {dimension = "x"} // CHECK-DAG: %[[BDIMY:.+]] = "gpu.block_dim"() {dimension = "y"} // CHECK-DAG: %[[BDIMZ:.+]] = "gpu.block_dim"() {dimension = "z"} -// CHECK: scf.for %{{.+}} = %[[TIDZ]] to %{{.*}} step %[[BDIMZ]] -// CHECK: scf.for %{{.+}} = %[[TIDY]] to %{{.*}} step %[[BDIMY]] -// CHECK: scf.for %{{.+}} = %[[TIDX]] to %{{.*}} step %[[BDIMX]] -// CHECK-COUNT-3: scf.for -// CHECK-NOT: linalg.pooling_nhwc_max +// CHECK: scf.for %[[IV0:.+]] = %[[TIDZ]] to %{{.*}} step %[[BDIMZ]] +// CHECK: scf.for %[[IV1:.+]] = %[[TIDY]] to %{{.*}} step %[[BDIMY]] +// CHECK: scf.for %[[IV2:.+]] = %[[TIDX]] to %{{.*}} step %[[BDIMX]] +// CHECK: %[[IN:.+]] = memref.subview %[[SV1]][%[[IV0]], %[[IV1]], %[[IV2]], 0] [1, %{{.+}}, %{{.+}}, 6] +// CHECK: %[[OUT:.+]] = memref.subview %[[SV2]][%[[IV0]], %[[IV1]], %[[IV2]], 0] [1, 1, 1, 6] +// CHECK: linalg.pooling_nhwc_max +// CHECK-SAME: __internal_linalg_transform__ = "vectorize" +// CHECK-SAME: ins(%[[IN]], %[[ARG1]] +// CHECK-SAME: outs(%[[OUT]]
diff --git a/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_batch_matmul.mlir b/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_batch_matmul.mlir index 1cbb881..5e10ef7 100644 --- a/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_batch_matmul.mlir +++ b/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_batch_matmul.mlir
@@ -1,4 +1,4 @@ -// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(iree-set-num-workgroups,builtin.module(builtin.func(iree-spirv-tile-and-vectorize))))' -canonicalize -cse %s | IreeFileCheck %s +// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(iree-set-num-workgroups,builtin.module(builtin.func(iree-spirv-tile-and-distribute,iree-spirv-vectorize))))' -canonicalize -cse %s | IreeFileCheck %s #config = {tileSizes = [[1, 8, 64, 4], [], [1, 8, 4, 4]]}
diff --git a/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_conv.mlir b/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_conv.mlir index cf389cc..dbb19ad 100644 --- a/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_conv.mlir +++ b/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_conv.mlir
@@ -1,4 +1,4 @@ -// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(iree-set-num-workgroups,builtin.module(builtin.func(canonicalize,iree-spirv-remove-one-trip-tiled-loop,iree-spirv-tile-and-vectorize))))' -canonicalize -cse %s | IreeFileCheck %s +// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(iree-set-num-workgroups,builtin.module(builtin.func(canonicalize,iree-spirv-remove-one-trip-tiled-loop,iree-spirv-tile-and-distribute,iree-spirv-vectorize))))' -canonicalize -cse %s | IreeFileCheck %s #config = {tileSizes = [[0, 4, 4, 16], [], [0, 4, 1, 4], [0, 0, 0, 0, 1, 1, 4]]}
diff --git a/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_matmul.mlir b/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_matmul.mlir index e3c3076..06227f3 100644 --- a/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_matmul.mlir +++ b/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_matmul.mlir
@@ -1,4 +1,4 @@ -// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(iree-set-num-workgroups,builtin.module(builtin.func(iree-spirv-tile-and-vectorize))))' -canonicalize -cse %s | IreeFileCheck %s +// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(iree-set-num-workgroups,builtin.module(builtin.func(iree-spirv-tile-and-distribute,iree-spirv-vectorize))))' -canonicalize -cse %s | IreeFileCheck %s #config = {tileSizes = [[8, 64, 4], [], [8, 4, 4]]}
diff --git a/iree/compiler/Codegen/SPIRV/test/vectorize_elementwise_ops.mlir b/iree/compiler/Codegen/SPIRV/test/vectorize_elementwise_ops.mlir index d84bb33..bb5d1f1 100644 --- a/iree/compiler/Codegen/SPIRV/test/vectorize_elementwise_ops.mlir +++ b/iree/compiler/Codegen/SPIRV/test/vectorize_elementwise_ops.mlir
@@ -1,4 +1,4 @@ -// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(builtin.module(builtin.func(iree-spirv-tile-and-vectorize))))' %s | IreeFileCheck %s +// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(builtin.module(builtin.func(iree-spirv-tile-and-distribute,iree-spirv-vectorize))))' %s | IreeFileCheck %s // CHECK-LABEL: func @elementwise_static_shape // CHECK: vector.transfer_read %{{.+}}[%c0], {{.+}} memref<4xf32, #{{.+}}>, vector<4xf32>
diff --git a/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir b/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir index f343f6b..ffa8a7c 100644 --- a/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir +++ b/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir
@@ -1,4 +1,4 @@ -// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(builtin.module(builtin.func(iree-spirv-tile-and-vectorize,canonicalize,cse))))' %s | IreeFileCheck %s +// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(builtin.module(builtin.func(iree-spirv-tile-and-distribute,iree-spirv-vectorize,canonicalize,cse))))' %s | IreeFileCheck %s // TODO(antiagainst): Fix promotion to workgroup and enable the test. // | IreeFileCheck %s -check-prefix=PROMOTE
diff --git a/iree/compiler/Codegen/Utils/Utils.cpp b/iree/compiler/Codegen/Utils/Utils.cpp index 7fec29c..28a28b2 100644 --- a/iree/compiler/Codegen/Utils/Utils.cpp +++ b/iree/compiler/Codegen/Utils/Utils.cpp
@@ -82,11 +82,11 @@ LogicalResult setOpConfigAndEntryPointFnTranslation( FuncOp entryPointFn, Operation *op, TileSizesListTypeRef tileSizes, - ArrayRef<int64_t> nativeVectorSize, + ArrayRef<int64_t> nativeVectorSizes, IREE::HAL::DispatchLoweringPassPipeline passPipeline, ArrayRef<int64_t> workgroupSize) { IREE::HAL::LoweringConfig config = - buildConfigAttr(tileSizes, nativeVectorSize, op->getContext()); + buildConfigAttr(tileSizes, nativeVectorSizes, op->getContext()); setLoweringConfig(op, config); auto partitionedLoops = getPartitionedLoops(op); SmallVector<int64_t, 3> workloadPerWorkgroup;
diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/constant_ops.mlir b/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/constant_ops.mlir index dbd1893..4cda45e 100644 --- a/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/constant_ops.mlir +++ b/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/constant_ops.mlir
@@ -11,8 +11,13 @@ hal.constant_storage @_storage1 = dense<[6, 7, 8, 0]> : vector<4xi8> } -// CHECK: vm.global.ref private @pool_storage0_buffer initializer(@pool_storage0_buffer_initializer) : !vm.ref<!hal.buffer> +// CHECK: vm.global.ref private @pool_storage0_buffer : !vm.ref<!hal.buffer> util.global private @pool_storage0_buffer initializer(@pool_storage0_buffer_initializer) : !hal.buffer +// CHECK-NEXT: vm.initializer { +// CHECK-NEXT: %[[REF:.+]] = vm.call @pool_storage0_buffer_initializer() : () -> !vm.ref<!hal.buffer> +// CHECK-NEXT: vm.global.store.ref %[[REF]], @pool_storage0_buffer : !vm.ref<!hal.buffer> +// CHECK-NEXT: vm.return +// CHECK-NEXT: } // CHECK: vm.func private @pool_storage0_buffer_initializer() -> !vm.ref<!hal.buffer> func private @pool_storage0_buffer_initializer() -> !hal.buffer { %c0 = constant 0 : index @@ -29,11 +34,11 @@ return %mapped : !hal.buffer } -// CHECK: vm.global.ref private @pool_storage1_buffer initializer(@pool_storage1_buffer_initializer) : !vm.ref<!hal.buffer> +// CHECK: vm.global.ref private @pool_storage1_buffer : !vm.ref<!hal.buffer> util.global private @pool_storage1_buffer initializer(@pool_storage1_buffer_initializer) : !hal.buffer func private @pool_storage1_buffer_initializer() -> !hal.buffer -// CHECK: vm.global.ref private @pool_splats initializer(@pool_splats_initializer) : !vm.ref<!hal.buffer> +// CHECK: vm.global.ref private @pool_splats : !vm.ref<!hal.buffer> util.global private @pool_splats initializer(@pool_splats_initializer) : !hal.buffer // CHECK: vm.func private @pool_splats_initializer() -> !vm.ref<!hal.buffer> func private @pool_splats_initializer() -> !hal.buffer {
diff --git a/iree/compiler/Dialect/Util/IR/UtilOps.cpp b/iree/compiler/Dialect/Util/IR/UtilOps.cpp index 11a6c11..f847694 100644 --- a/iree/compiler/Dialect/Util/IR/UtilOps.cpp +++ b/iree/compiler/Dialect/Util/IR/UtilOps.cpp
@@ -90,12 +90,15 @@ void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type, Attribute attr) { + bool needsSpace = false; if (!attr || attr.getType() != type.getValue()) { - p << " : "; + p << ": "; p.printAttribute(type); + needsSpace = true; // subsequent attr value needs a space separator } if (attr) { - p << " = "; + if (needsSpace) p << ' '; + p << "= "; p.printAttribute(attr); } }
diff --git a/iree/compiler/Dialect/Util/IR/UtilOps.td b/iree/compiler/Dialect/Util/IR/UtilOps.td index 1b8eb82..5a733fe 100644 --- a/iree/compiler/Dialect/Util/IR/UtilOps.td +++ b/iree/compiler/Dialect/Util/IR/UtilOps.td
@@ -178,7 +178,7 @@ (`mutable` $is_mutable^)? $sym_name attr-dict - (`initializer` `(` $initializer^ `)`):(``)? + (`initializer` `(` $initializer^ `)`)? custom<TypeOrAttr>($type, $initial_value) }];
diff --git a/iree/compiler/Dialect/VM/Analysis/ValueLiveness.cpp b/iree/compiler/Dialect/VM/Analysis/ValueLiveness.cpp index c716d57..9ffe386 100644 --- a/iree/compiler/Dialect/VM/Analysis/ValueLiveness.cpp +++ b/iree/compiler/Dialect/VM/Analysis/ValueLiveness.cpp
@@ -47,7 +47,8 @@ // Block names are their order in the function. DenseMap<Block *, int> blockOrdinals; for (auto &block : funcOp.getBlocks()) { - blockOrdinals[&block] = blockOrdinals.size(); + int ordinal = blockOrdinals.size(); + blockOrdinals[std::addressof(block)] = ordinal; } // Keep asm state to make getting the SSA value names fast.
diff --git a/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertGlobalOps.cpp b/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertGlobalOps.cpp index d967703..f1411db 100644 --- a/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertGlobalOps.cpp +++ b/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertGlobalOps.cpp
@@ -22,67 +22,103 @@ LogicalResult matchAndRewrite( IREE::Util::GlobalOp op, llvm::ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const override { + Operation *newOp = nullptr; auto convertedType = typeConverter.convertType(op.type()); if (convertedType.isa<IREE::VM::RefType>() || IREE::VM::RefType::isCompatible(convertedType)) { - auto newOp = rewriter.replaceOpWithNewOp<IREE::VM::GlobalRefOp>( - op, op.sym_name(), op.is_mutable(), convertedType, op.initializer(), - op.initial_value(), llvm::to_vector<4>(op->getDialectAttrs())); - newOp.setVisibility(op.getVisibility()); - return success(); + newOp = rewriter.replaceOpWithNewOp<IREE::VM::GlobalRefOp>( + op, op.sym_name(), op.is_mutable(), convertedType, op.initial_value(), + llvm::to_vector<4>(op->getDialectAttrs())); } else if (convertedType.isInteger(32)) { - auto convertedValue = - op.initial_value().hasValue() - ? rewriter.getI32IntegerAttr(static_cast<int32_t>( - op.initial_value().getValue().cast<IntegerAttr>().getInt())) - : Attribute{}; - auto newOp = rewriter.replaceOpWithNewOp<IREE::VM::GlobalI32Op>( - op, op.sym_name(), op.is_mutable(), convertedType, op.initializer(), - convertedValue, llvm::to_vector<4>(op->getDialectAttrs())); - newOp.setVisibility(op.getVisibility()); - return success(); + llvm::Optional<Attribute> convertedValue = llvm::None; + if (op.initial_value().hasValue()) { + convertedValue = rewriter.getI32IntegerAttr(static_cast<int32_t>( + op.initial_value().getValue().cast<IntegerAttr>().getInt())); + } + newOp = rewriter.replaceOpWithNewOp<IREE::VM::GlobalI32Op>( + op, op.sym_name(), op.is_mutable(), convertedType, convertedValue, + llvm::to_vector<4>(op->getDialectAttrs())); } else if (convertedType.isInteger(64)) { - auto convertedValue = - op.initial_value().hasValue() - ? rewriter.getI64IntegerAttr( - op.initial_value().getValue().cast<IntegerAttr>().getInt()) - : Attribute{}; - auto newOp = rewriter.replaceOpWithNewOp<IREE::VM::GlobalI64Op>( - op, op.sym_name(), op.is_mutable(), convertedType, op.initializer(), - convertedValue, llvm::to_vector<4>(op->getDialectAttrs())); - newOp.setVisibility(op.getVisibility()); - return success(); + llvm::Optional<Attribute> convertedValue = llvm::None; + if (op.initial_value().hasValue()) { + convertedValue = rewriter.getI64IntegerAttr( + op.initial_value().getValue().cast<IntegerAttr>().getInt()); + } + newOp = rewriter.replaceOpWithNewOp<IREE::VM::GlobalI64Op>( + op, op.sym_name(), op.is_mutable(), convertedType, convertedValue, + llvm::to_vector<4>(op->getDialectAttrs())); } else if (convertedType.isF32()) { - auto convertedValue = op.initial_value().hasValue() - ? rewriter.getF32FloatAttr(static_cast<float>( - op.initial_value() - .getValue() - .cast<FloatAttr>() - .getValueAsDouble())) - : Attribute{}; - auto newOp = rewriter.replaceOpWithNewOp<IREE::VM::GlobalF32Op>( - op, op.sym_name(), op.is_mutable(), convertedType, op.initializer(), - convertedValue, llvm::to_vector<4>(op->getDialectAttrs())); - newOp.setVisibility(op.getVisibility()); - return success(); + llvm::Optional<Attribute> convertedValue = llvm::None; + if (op.initial_value().hasValue()) { + convertedValue = rewriter.getF32FloatAttr( + static_cast<float>(op.initial_value() + .getValue() + .cast<FloatAttr>() + .getValueAsDouble())); + } + newOp = rewriter.replaceOpWithNewOp<IREE::VM::GlobalF32Op>( + op, op.sym_name(), op.is_mutable(), convertedType, convertedValue, + llvm::to_vector<4>(op->getDialectAttrs())); } else if (convertedType.isF64()) { - auto convertedValue = - op.initial_value().hasValue() - ? rewriter.getF64FloatAttr(op.initial_value() - .getValue() - .cast<FloatAttr>() - .getValueAsDouble()) - : Attribute{}; - auto newOp = rewriter.replaceOpWithNewOp<IREE::VM::GlobalF64Op>( - op, op.sym_name(), op.is_mutable(), convertedType, op.initializer(), - convertedValue, llvm::to_vector<4>(op->getDialectAttrs())); - newOp.setVisibility(op.getVisibility()); - return success(); + llvm::Optional<Attribute> convertedValue = llvm::None; + if (op.initial_value().hasValue()) { + convertedValue = rewriter.getF64FloatAttr( + op.initial_value().getValue().cast<FloatAttr>().getValueAsDouble()); + } + newOp = rewriter.replaceOpWithNewOp<IREE::VM::GlobalF64Op>( + op, op.sym_name(), op.is_mutable(), convertedType, convertedValue, + llvm::to_vector<4>(op->getDialectAttrs())); + } else { + return op.emitOpError("unsupported global type"); } - return op.emitOpError("unsupported global type"); + + // New global carries the same visibility as the original. + cast<SymbolOpInterface>(newOp).setVisibility(op.getVisibility()); + + // If there was an initializer function specified we turn that into a + // vm.initializer now. + if (op.initializer()) { + auto initializerOp = + rewriter.create<IREE::VM::InitializerOp>(op.getLoc()); + auto ip = rewriter.saveInsertionPoint(); + rewriter.setInsertionPointToStart(initializerOp.addEntryBlock()); + SmallVector<Type> resultTypes; + resultTypes.push_back(convertedType); + auto callOp = rewriter.create<IREE::VM::CallOp>( + op.getLoc(), op.initializer().getValue(), resultTypes, + /*operands=*/ValueRange{}); + storeToGlobal(callOp.getResult(0), newOp, rewriter); + rewriter.create<IREE::VM::ReturnOp>(op.getLoc()); + rewriter.restoreInsertionPoint(ip); + } + + return success(); } private: + void storeToGlobal(Value value, Operation *globalOp, + ConversionPatternRewriter &rewriter) const { + auto globalName = cast<SymbolOpInterface>(globalOp).getName(); + if (value.getType().isa<IREE::VM::RefType>()) { + rewriter.create<IREE::VM::GlobalStoreRefOp>(globalOp->getLoc(), value, + globalName); + } else if (value.getType().isInteger(32)) { + rewriter.create<IREE::VM::GlobalStoreI32Op>(globalOp->getLoc(), value, + globalName); + } else if (value.getType().isInteger(64)) { + rewriter.create<IREE::VM::GlobalStoreI64Op>(globalOp->getLoc(), value, + globalName); + } else if (value.getType().isF32()) { + rewriter.create<IREE::VM::GlobalStoreF32Op>(globalOp->getLoc(), value, + globalName); + } else if (value.getType().isF64()) { + rewriter.create<IREE::VM::GlobalStoreF64Op>(globalOp->getLoc(), value, + globalName); + } else { + llvm_unreachable("unhandled vm type"); + } + } + TypeConverter &typeConverter; };
diff --git a/iree/compiler/Dialect/VM/Conversion/UtilToVM/test/global_ops.mlir b/iree/compiler/Dialect/VM/Conversion/UtilToVM/test/global_ops.mlir index fdc41b6..57fcd33 100644 --- a/iree/compiler/Dialect/VM/Conversion/UtilToVM/test/global_ops.mlir +++ b/iree/compiler/Dialect/VM/Conversion/UtilToVM/test/global_ops.mlir
@@ -8,8 +8,14 @@ // ----- -// CHECK: vm.global.ref public @v_initialized initializer(@initializer) : !vm.ref<!hal.buffer> +// CHECK: vm.global.ref public @v_initialized : !vm.ref<!hal.buffer> util.global public @v_initialized initializer(@initializer) : !hal.buffer +// CHECK-NEXT: vm.initializer { +// CHECK-NEXT: %[[REF:.+]] = vm.call @initializer() : () -> !vm.ref<!hal.buffer> +// CHECK-NEXT: vm.global.store.ref %[[REF]], @v_initialized : !vm.ref<!hal.buffer> +// CHECK-NEXT: vm.return +// CHECK-NEXT: } +// CHECK-NEXT: vm.func private @initializer() -> !vm.ref<!hal.buffer> func private @initializer() -> !hal.buffer // -----
diff --git a/iree/compiler/Dialect/VM/IR/VMBase.td b/iree/compiler/Dialect/VM/IR/VMBase.td index 2aba56d..b58482a 100644 --- a/iree/compiler/Dialect/VM/IR/VMBase.td +++ b/iree/compiler/Dialect/VM/IR/VMBase.td
@@ -218,14 +218,13 @@ }]>, InterfaceMethod<[{}], "StringRef", "getSymbolName", (ins)>, InterfaceMethod<[{}], "bool", "isMutable", (ins)>, - InterfaceMethod<[{}], "Optional<StringRef>", "getInitializerAttr", (ins)>, InterfaceMethod<[{}], "Optional<Attribute>", "getInitialValueAttr", (ins)>, + InterfaceMethod<[{}], "void", "setInitialValue", (ins "Attribute":$value)>, InterfaceMethod<[{}], "Optional<IntegerAttr>", "getOrdinalAttr", (ins)>, InterfaceMethod<[{}], "int", "getOrdinal", (ins), [{ return $_self.getOrdinalAttr().getValue().template cast<IntegerAttr>().getInt(); }]>, InterfaceMethod<[{}], "void", "makeMutable", (ins)>, - InterfaceMethod<[{}], "void", "clearInitializer", (ins)>, InterfaceMethod<[{}], "void", "clearInitialValue", (ins)>, ]; }
diff --git a/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp b/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp index df29a52..7cd8a6f 100644 --- a/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp +++ b/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
@@ -58,39 +58,78 @@ // Structural ops //===----------------------------------------------------------------------===// +namespace { + +// Deletes empty vm.initializer ops. +struct DropEmptyInitializerOp : public OpRewritePattern<InitializerOp> { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(InitializerOp op, + PatternRewriter &rewriter) const override { + if (op.body().getBlocks().size() != 1) return failure(); + auto &block = op.body().front(); + if (block.empty() || isa<ReturnOp>(block.front())) { + rewriter.eraseOp(op); + return success(); + } + return failure(); + } +}; + +// Inlines constant stores from initializers into the global initializer. +// This is not strictly required but can help our initialization code perform +// more efficient initialization of large numbers of primitive values. +struct InlineConstGlobalInitializer : public OpRewritePattern<InitializerOp> { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(InitializerOp op, + PatternRewriter &rewriter) const override { + SmallVector<Operation *> deadOps; + op.walk([&](Operation *op) { + if (!isGlobalStoreOp(op)) return; + auto value = op->getOperand(0); + Attribute valueAttr; + if (!matchPattern(value, m_Constant(&valueAttr))) return; + auto globalRefAttr = op->getAttrOfType<SymbolRefAttr>("global"); + assert(globalRefAttr); + auto globalOp = + SymbolTable::lookupNearestSymbolFrom<IREE::VM::VMGlobalOp>( + op, globalRefAttr); + if (valueAttr && !valueAttr.isa<UnitAttr>()) { + globalOp.setInitialValue(valueAttr); + } else { + globalOp.clearInitialValue(); + } + deadOps.push_back(op); + }); + if (deadOps.empty()) return failure(); + for (auto deadOp : deadOps) rewriter.eraseOp(deadOp); + return success(); + } + + bool isGlobalStoreOp(Operation *op) const { + // TODO(benvanik): trait/interface to make this more generic? + return isa<IREE::VM::GlobalStoreI32Op>(op) || + isa<IREE::VM::GlobalStoreI64Op>(op) || + isa<IREE::VM::GlobalStoreF32Op>(op) || + isa<IREE::VM::GlobalStoreF64Op>(op) || + isa<IREE::VM::GlobalStoreRefOp>(op); + } +}; + +} // namespace + +void InitializerOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert<DropEmptyInitializerOp, InlineConstGlobalInitializer>(context); +} + //===----------------------------------------------------------------------===// // Globals //===----------------------------------------------------------------------===// namespace { -/// Converts global initializer functions that evaluate to a constant to a -/// specified initial value. -template <typename T> -struct InlineConstGlobalOpInitializer : public OpRewritePattern<T> { - using OpRewritePattern<T>::OpRewritePattern; - - LogicalResult matchAndRewrite(T op, - PatternRewriter &rewriter) const override { - if (!op.initializer()) return failure(); - auto initializer = dyn_cast_or_null<FuncOp>( - SymbolTable::lookupNearestSymbolFrom(op, op.initializer().getValue())); - if (!initializer) return failure(); - if (initializer.getBlocks().size() == 1 && - initializer.getBlocks().front().getOperations().size() == 2 && - isa<ReturnOp>(initializer.getBlocks().front().getOperations().back())) { - auto &primaryOp = initializer.getBlocks().front().getOperations().front(); - Attribute constResult; - if (matchPattern(primaryOp.getResult(0), m_Constant(&constResult))) { - rewriter.replaceOpWithNewOp<T>(op, op.sym_name(), op.is_mutable(), - op.type(), constResult); - return success(); - } - } - return failure(); - } -}; - /// Drops initial_values from globals where the value is 0, as by default all /// globals are zero-initialized upon module load. template <typename T> @@ -116,32 +155,26 @@ void GlobalI32Op::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.insert<InlineConstGlobalOpInitializer<GlobalI32Op>, - DropDefaultConstGlobalOpInitializer<GlobalI32Op>>(context); + results.insert<DropDefaultConstGlobalOpInitializer<GlobalI32Op>>(context); } void GlobalI64Op::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.insert<InlineConstGlobalOpInitializer<GlobalI64Op>, - DropDefaultConstGlobalOpInitializer<GlobalI64Op>>(context); + results.insert<DropDefaultConstGlobalOpInitializer<GlobalI64Op>>(context); } void GlobalF32Op::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.insert<InlineConstGlobalOpInitializer<GlobalF32Op>, - DropDefaultConstGlobalOpInitializer<GlobalF32Op>>(context); + results.insert<DropDefaultConstGlobalOpInitializer<GlobalF32Op>>(context); } void GlobalF64Op::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.insert<InlineConstGlobalOpInitializer<GlobalF64Op>, - DropDefaultConstGlobalOpInitializer<GlobalF64Op>>(context); + results.insert<DropDefaultConstGlobalOpInitializer<GlobalF64Op>>(context); } void GlobalRefOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - results.insert<InlineConstGlobalOpInitializer<GlobalRefOp>>(context); -} + MLIRContext *context) {} namespace {
diff --git a/iree/compiler/Dialect/VM/IR/VMOps.cpp b/iree/compiler/Dialect/VM/IR/VMOps.cpp index 92c5ae6..9dd7c86 100644 --- a/iree/compiler/Dialect/VM/IR/VMOps.cpp +++ b/iree/compiler/Dialect/VM/IR/VMOps.cpp
@@ -281,6 +281,47 @@ return success(); } +void InitializerOp::build(OpBuilder &builder, OperationState &result, + ArrayRef<NamedAttribute> attrs) { + result.addAttribute( + "type", TypeAttr::get(FunctionType::get(builder.getContext(), {}, {}))); + result.addRegion(); + result.attributes.append(attrs.begin(), attrs.end()); +} + +static ParseResult parseInitializerOp(OpAsmParser &parser, + OperationState *result) { + result->addAttribute( + "type", TypeAttr::get(FunctionType::get(result->getContext(), {}, {}))); + if (parser.parseOptionalAttrDictWithKeyword(result->attributes)) { + return failure(); + } + auto &body = *result->addRegion(); + if (failed(parser.parseRegion(body))) { + return failure(); + } + return success(); +} + +static void printInitializerOp(OpAsmPrinter &p, InitializerOp &op) { + p << "vm.initializer"; + p.printOptionalAttrDictWithKeyword(op->getAttrs(), /*elidedAttrs=*/{"type"}); + p.printRegion(op.body()); +} + +Block *InitializerOp::addEntryBlock() { + assert(empty() && "function already has an entry block"); + auto *entry = new Block(); + push_back(entry); + return entry; +} + +Block *InitializerOp::addBlock() { + assert(!empty() && "function should at least have an entry block"); + push_back(new Block()); + return &back(); +} + //===----------------------------------------------------------------------===// // Globals //===----------------------------------------------------------------------===// @@ -375,13 +416,13 @@ auto *globalOp = op->getParentOfType<VM::ModuleOp>().lookupSymbol(globalAttr.getValue()); if (!globalOp) { - return op->emitOpError() << "Undefined global: " << globalAttr; + return op->emitOpError() << "undefined global: " << globalAttr; } auto globalType = globalOp->getAttrOfType<TypeAttr>("type"); auto loadType = op->getResult(0).getType(); if (globalType.getValue() != loadType) { return op->emitOpError() - << "Global type mismatch; global " << globalAttr << " is " + << "global type mismatch; global " << globalAttr << " is " << globalType << " but load is " << loadType; } return success(); @@ -392,18 +433,21 @@ auto *globalOp = op->getParentOfType<VM::ModuleOp>().lookupSymbol(globalAttr.getValue()); if (!globalOp) { - return op->emitOpError() << "Undefined global: " << globalAttr; + return op->emitOpError() << "undefined global: " << globalAttr; } auto globalType = globalOp->getAttrOfType<TypeAttr>("type"); auto storeType = op->getOperand(0).getType(); if (globalType.getValue() != storeType) { return op->emitOpError() - << "Global type mismatch; global " << globalAttr << " is " + << "global type mismatch; global " << globalAttr << " is " << globalType << " but store is " << storeType; } if (!globalOp->getAttrOfType<UnitAttr>("is_mutable")) { - return op->emitOpError() << "Global " << globalAttr - << " is not mutable and cannot be stored to"; + // Allow stores to immutable globals in initializers. + if (!op->getParentOfType<IREE::VM::InitializerOp>()) { + return op->emitOpError() << "global " << globalAttr + << " is not mutable and cannot be stored to"; + } } return success(); }
diff --git a/iree/compiler/Dialect/VM/IR/VMOps.td b/iree/compiler/Dialect/VM/IR/VMOps.td index 6abda0c..d445dbb 100644 --- a/iree/compiler/Dialect/VM/IR/VMOps.td +++ b/iree/compiler/Dialect/VM/IR/VMOps.td
@@ -44,6 +44,7 @@ custom<SymbolVisibility>($sym_visibility) $sym_name attr-dict-with-keyword + `` regions }]; @@ -241,6 +242,50 @@ }]; } +def VM_InitializerOp : VM_Op<"initializer", [ + IsolatedFromAbove, + HasParent<"IREE::VM::ModuleOp">, + NativeOpTrait<"FunctionLike">, + CallableOpInterface, + ]> { + let summary = [{global initialization function}]; + let description = [{ + A function that is called in definition order upon module initialization. + Must not load any globals that are defined or initialized after it in the + module. + }]; + + let arguments = (ins + TypeAttr:$type + ); + + let regions = (region AnyRegion:$body); + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder<(ins + CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs + )>, + ]; + + let extraClassDeclaration = [{ + /// Add an entry block to an empty function and set up the block arguments + /// to match the signature of the function. + Block *addEntryBlock(); + Block *addBlock(); + + unsigned getNumFuncArguments() { return 0; } + unsigned getNumFuncResults() { return 0; } + + LogicalResult verifyType() { return success(); } + + Region *getCallableRegion() { return &body(); } + ArrayRef<Type> getCallableResults() { return {}; } + }]; + + let hasCanonicalizer = 1; +} + //===----------------------------------------------------------------------===// // Globals //===----------------------------------------------------------------------===// @@ -259,7 +304,6 @@ SymbolNameAttr:$sym_name, TypeAttr:$type, UnitAttr:$is_mutable, - OptionalAttr<FlatSymbolRefAttr>:$initializer, OptionalAttr<attr_type>:$initial_value, OptionalAttr<VM_Ordinal>:$ordinal ); @@ -269,25 +313,23 @@ (`mutable` $is_mutable^)? $sym_name attr-dict - (`initializer` `(` $initializer^ `)`):(``)? custom<TypeOrAttr>($type, $initial_value) }]; let skipDefaultBuilders = 1; let builders = [ - OpBuilder<(ins "StringRef":$name, "bool":$isMutable, "Type":$type, - "Optional<StringRef>":$initializer, "Optional<Attribute>":$initialValue, - CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs), + OpBuilder<(ins + "StringRef":$name, "bool":$isMutable, "Type":$type, + "Optional<Attribute>":$initialValue, + CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs + ), [{ $_state.addAttribute(SymbolTable::getSymbolAttrName(), $_builder.getStringAttr(name)); if (isMutable) { $_state.addAttribute("is_mutable", $_builder.getUnitAttr()); } - if (initializer.hasValue()) { - $_state.addAttribute("initializer", - $_builder.getSymbolRefAttr(initializer.getValue())); - } else if (initialValue.hasValue() && + if (initialValue.hasValue() && (initialValue.getValue().isa<IntegerAttr>() || initialValue.getValue().isa<FloatAttr>())) { $_state.addAttribute("initial_value", initialValue.getValue()); @@ -295,25 +337,12 @@ $_state.addAttribute("type", TypeAttr::get(type)); $_state.attributes.append(attrs.begin(), attrs.end()); }]>, - OpBuilder<(ins "StringRef":$name, "bool":$isMutable, - "IREE::VM::FuncOp":$initializer, - CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs), + OpBuilder<(ins + "StringRef":$name, "bool":$isMutable, "Type":$type, + CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs + ), [{ - build($_builder, $_state, name, isMutable, - initializer.getType().getResult(0), initializer.getName(), - llvm::None, attrs); - }]>, - OpBuilder<(ins "StringRef":$name, "bool":$isMutable, "Type":$type, - "Attribute":$initialValue, CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs), - [{ - build($_builder, $_state, name, isMutable, type, llvm::None, initialValue, - attrs); - }]>, - OpBuilder<(ins "StringRef":$name, "bool":$isMutable, "Type":$type, - CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs), - [{ - build($_builder, $_state, name, isMutable, type, llvm::None, llvm::None, - attrs); + build($_builder, $_state, name, isMutable, type, llvm::None, attrs); }]>, ]; @@ -322,9 +351,8 @@ Type getStorageType() { return type(); } bool isMutable() { return is_mutable(); } void makeMutable() { (*this)->setAttr("is_mutable", UnitAttr::get(getContext())); } - Optional<StringRef> getInitializerAttr() { return initializer(); } - void clearInitializer() { (*this)->removeAttr("initializer"); } Optional<Attribute> getInitialValueAttr() { return initial_valueAttr(); } + void setInitialValue(Attribute value) { (*this)->setAttr("initial_value", (value)); } void clearInitialValue() { (*this)->removeAttr("initial_value"); } Optional<IntegerAttr> getOrdinalAttr() { return ordinalAttr(); } }]; @@ -3485,7 +3513,6 @@ def VM_ReturnOp : VM_Op<"return", [ DeclareOpInterfaceMethods<VM_SerializableOpInterface>, - HasParent<"IREE::VM::FuncOp">, Terminator, ]> { let summary = "return operation"; @@ -3493,7 +3520,7 @@ Represents a return operation within a function. ``` - vm.func @foo(%0, %1) : (i32, f8) { + vm.func @foo(%0: i32, %1: f8) -> (i32, f8) { vm.return %0, %1 : i32, f8 } ```
diff --git a/iree/compiler/Dialect/VM/IR/test/BUILD b/iree/compiler/Dialect/VM/IR/test/BUILD index 7890721..53303b2 100644 --- a/iree/compiler/Dialect/VM/IR/test/BUILD +++ b/iree/compiler/Dialect/VM/IR/test/BUILD
@@ -36,6 +36,7 @@ "list_op_verification.mlir", "list_ops.mlir", "shift_ops.mlir", + "structural_folding.mlir", "structural_ops.mlir", ], include = ["*.mlir"],
diff --git a/iree/compiler/Dialect/VM/IR/test/CMakeLists.txt b/iree/compiler/Dialect/VM/IR/test/CMakeLists.txt index 2604211..4a277a3 100644 --- a/iree/compiler/Dialect/VM/IR/test/CMakeLists.txt +++ b/iree/compiler/Dialect/VM/IR/test/CMakeLists.txt
@@ -33,6 +33,7 @@ "list_op_verification.mlir" "list_ops.mlir" "shift_ops.mlir" + "structural_folding.mlir" "structural_ops.mlir" DATA iree::tools::IreeFileCheck
diff --git a/iree/compiler/Dialect/VM/IR/test/global_folding.mlir b/iree/compiler/Dialect/VM/IR/test/global_folding.mlir index 3c2dd2e..8ae7a22 100644 --- a/iree/compiler/Dialect/VM/IR/test/global_folding.mlir +++ b/iree/compiler/Dialect/VM/IR/test/global_folding.mlir
@@ -5,10 +5,11 @@ // CHECK-LABEL: @global_i32_folds vm.module @global_i32_folds { // CHECK: vm.global.i32 public mutable @g0 = 123 : i32 - vm.global.i32 mutable @g0 initializer(@g0init) : i32 - vm.func @g0init() -> i32 { + vm.global.i32 mutable @g0 : i32 + vm.initializer { %c123 = vm.const.i32 123 : i32 - vm.return %c123 : i32 + vm.global.store.i32 %c123, @g0 : i32 + vm.return } // CHECK: vm.global.i32 public mutable @g1 : i32 @@ -17,10 +18,11 @@ vm.global.i32 @g2 = 0 : i32 // CHECK: vm.global.i32 public mutable @g3 : i32 - vm.global.i32 mutable @g3 initializer(@g3init) : i32 - vm.func @g3init() -> i32 { + vm.global.i32 mutable @g3 : i32 + vm.initializer { %c0 = vm.const.i32 0 : i32 - vm.return %c0 : i32 + vm.global.store.i32 %c0, @g3 : i32 + vm.return } } @@ -29,10 +31,11 @@ // CHECK-LABEL: @global_ref_folds_null vm.module @global_ref_folds_null { // CHECK: vm.global.ref public mutable @g0 : !vm.ref<?> - vm.global.ref mutable @g0 initializer(@g0init) : !vm.ref<?> - vm.func @g0init() -> !vm.ref<?> { + vm.global.ref mutable @g0 : !vm.ref<?> + vm.initializer { %null = vm.const.ref.zero : !vm.ref<?> - vm.return %null : !vm.ref<?> + vm.global.store.ref %null, @g0 : !vm.ref<?> + vm.return } }
diff --git a/iree/compiler/Dialect/VM/IR/test/structural_folding.mlir b/iree/compiler/Dialect/VM/IR/test/structural_folding.mlir new file mode 100644 index 0000000..1ca2d68 --- /dev/null +++ b/iree/compiler/Dialect/VM/IR/test/structural_folding.mlir
@@ -0,0 +1,9 @@ +// RUN: iree-opt -split-input-file -pass-pipeline='vm.module(canonicalize)' %s | IreeFileCheck %s + +// CHECK-LABEL: @empty_initializer +vm.module @empty_initializer { + // CHECK-NOT: vm.initializer + vm.initializer { + vm.return + } +}
diff --git a/iree/compiler/Dialect/VM/IR/test/structural_ops.mlir b/iree/compiler/Dialect/VM/IR/test/structural_ops.mlir index 72242f5..234fb1e 100644 --- a/iree/compiler/Dialect/VM/IR/test/structural_ops.mlir +++ b/iree/compiler/Dialect/VM/IR/test/structural_ops.mlir
@@ -67,3 +67,34 @@ // CHECK-NEXT: vm.import @my.fn_varargs(%foo : vector<3xi32> ..., %bar : tuple<i32, i32> ...) -> i32 vm.import @my.fn_varargs(%foo : vector<3xi32> ..., %bar : tuple<i32, i32> ...) -> i32 } + +// ----- + +// CHECK-LABEL: @initializers +vm.module @initializers { + // CHECK-NEXT: vm.initializer { + // CHECK-NEXT: vm.return + // CHECK-NEXT: } + vm.initializer { + vm.return + } + + // CHECK-NEXT: vm.initializer attributes {foo} { + // CHECK-NEXT: vm.return + // CHECK-NEXT: } + vm.initializer attributes {foo} { + vm.return + } + + // CHECK-NEXT: vm.initializer { + vm.initializer { + // CHECK-NEXT: %zero = vm.const.i32 0 : i32 + %zero = vm.const.i32 0 : i32 + // CHECK-NEXT: vm.br ^bb1(%zero : i32) + vm.br ^bb1(%zero: i32) + // CHECK-NEXT: ^bb1(%0: i32): + ^bb1(%0: i32): + // CHECK-NEXT: vm.return + vm.return + } +}
diff --git a/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp b/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp index 0757d9b..c513c8a 100644 --- a/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp +++ b/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp
@@ -13,6 +13,7 @@ #include "mlir/Pass/PassRegistry.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/InliningUtils.h" #include "mlir/Transforms/Utils.h" namespace mlir { @@ -31,7 +32,6 @@ // point in the lowering though we cannot know that so we rely on dialects // providing their own initialization functions for those cases. // -// TODO(benvanik): add initializer functions to make dialect init possible. // TODO(benvanik): combine i32 initializers to store more efficiently. class GlobalInitializationPass : public PassWrapper<GlobalInitializationPass, OperationPass<ModuleOp>> { @@ -53,6 +53,7 @@ moduleBuilder.create<FuncOp>(moduleBuilder.getUnknownLoc(), "__init", moduleBuilder.getFunctionType({}, {})); OpBuilder initBuilder = OpBuilder::atBlockEnd(initFuncOp.addEntryBlock()); + auto deinitFuncOp = moduleBuilder.create<FuncOp>(moduleBuilder.getUnknownLoc(), "__deinit", moduleBuilder.getFunctionType({}, {})); @@ -64,6 +65,8 @@ // module op order). If we ever want to make this more deterministic we // could gather the ops, sort them (by some rule), and then build the // initialization function. + InlinerInterface inlinerInterface(&getContext()); + SmallVector<Operation *> deadOps; for (auto &op : getOperation().getBlock().getOperations()) { if (auto globalOp = dyn_cast<GlobalRefOp>(op)) { if (failed(appendRefInitialization(globalOp, initBuilder))) { @@ -75,8 +78,21 @@ globalOp.emitOpError() << "unable to be initialized"; return signalPassFailure(); } + } else if (auto initializerOp = dyn_cast<InitializerOp>(op)) { + if (failed(appendInitializer(initializerOp, inlinerInterface, + initBuilder))) { + initializerOp.emitOpError() << "unable to be initialized"; + return signalPassFailure(); + } + deadOps.push_back(initializerOp); } } + for (auto deadOp : deadOps) { + deadOp->erase(); + } + + // Correct mutability of all globals. + fixupGlobalMutability(getOperation()); initBuilder.create<ReturnOp>(initBuilder.getUnknownLoc()); deinitBuilder.create<ReturnOp>(deinitBuilder.getUnknownLoc()); @@ -112,12 +128,6 @@ << "unable to create initializer constant for global"; } globalOp.clearInitialValue(); - } else if (globalOp.getInitializerAttr().hasValue()) { - auto callOp = builder.create<CallOp>( - globalOp.getLoc(), globalOp.getInitializerAttr().getValue(), - ArrayRef<Type>{globalOp.getStorageType()}, ArrayRef<Value>{}); - value = callOp.getResult(0); - globalOp.clearInitializer(); } if (!value) { // Globals are zero-initialized by default so we can just strip the @@ -195,17 +205,90 @@ LogicalResult appendRefInitialization(GlobalRefOp globalOp, OpBuilder &builder) { - if (globalOp.initializer().hasValue()) { - auto callOp = builder.create<CallOp>( - globalOp.getLoc(), globalOp.initializerAttr(), - ArrayRef<Type>{globalOp.type()}, ArrayRef<Value>{}); - builder.create<GlobalStoreRefOp>(globalOp.getLoc(), callOp.getResult(0), - globalOp.sym_name()); - globalOp.clearInitializer(); - globalOp.makeMutable(); - } + // NOTE: nothing yet, though if we had attribute initialization we'd do it + // here (for example, #vm.magic.initial.ref<foo>). return success(); } + + LogicalResult appendInitializer(InitializerOp initializerOp, + InlinerInterface &inlinerInterface, + OpBuilder &builder) { + // mlir::inlineRegion takes the op to inline _after_, which as we are + // building things doesn't exist yet. To work around this we create a dummy + // op, inline after it, and then delete it. + auto dummyOp = + builder.create<IREE::VM::ConstI32ZeroOp>(builder.getUnknownLoc()); + auto result = mlir::inlineRegion( + inlinerInterface, &initializerOp.body(), dummyOp, + /*inlinedOperands=*/ValueRange{}, + /*resultsToReplace=*/ValueRange{}, /*inlineLoc=*/llvm::None, + /*shouldCloneInlinedRegion=*/false); + builder.setInsertionPointToEnd(dummyOp->getBlock()); + dummyOp.erase(); + return result; + } + + void fixupGlobalMutability(Operation *moduleOp) { + SymbolTable symbolTable(moduleOp); + SmallVector<Operation *> deadOps; + for (auto &op : moduleOp->getRegion(0).front()) { + auto globalOp = dyn_cast<IREE::VM::VMGlobalOp>(op); + if (!globalOp) continue; + if (!cast<SymbolOpInterface>(op).isPrivate()) { + // May be used outside the module; treat as used and mutable. + globalOp.makeMutable(); + continue; + } + auto uses = symbolTable.getSymbolUses(globalOp, moduleOp); + if (!uses.hasValue()) { + // No uses - erase the global entirely. + deadOps.push_back(globalOp); + continue; + } + bool isIndirect = false; + bool isLoaded = false; + bool isStored = false; + for (auto use : uses.getValue()) { + if (isa<IREE::VM::GlobalAddressOp>(use.getUser())) { + // Can't analyze indirect variables; assume mutated. + isLoaded = true; + isStored = true; + isIndirect = true; + break; + } else if (isGlobalLoadOp(use.getUser())) { + isLoaded = true; + } else if (isGlobalStoreOp(use.getUser())) { + isStored = true; + } + } + // NOTE: we could erase globals never loaded if we know that computing + // their value has no side effects. + if (isStored) { + globalOp.makeMutable(); + } + } + for (auto *deadOp : deadOps) { + deadOp->erase(); + } + } + + bool isGlobalLoadOp(Operation *op) const { + // TODO(benvanik): trait/interface to make this more generic? + return isa<IREE::VM::GlobalLoadI32Op>(op) || + isa<IREE::VM::GlobalLoadI64Op>(op) || + isa<IREE::VM::GlobalLoadF32Op>(op) || + isa<IREE::VM::GlobalLoadF64Op>(op) || + isa<IREE::VM::GlobalLoadRefOp>(op); + } + + bool isGlobalStoreOp(Operation *op) const { + // TODO(benvanik): trait/interface to make this more generic? + return isa<IREE::VM::GlobalStoreI32Op>(op) || + isa<IREE::VM::GlobalStoreI64Op>(op) || + isa<IREE::VM::GlobalStoreF32Op>(op) || + isa<IREE::VM::GlobalStoreF64Op>(op) || + isa<IREE::VM::GlobalStoreRefOp>(op); + } }; std::unique_ptr<OperationPass<IREE::VM::ModuleOp>>
diff --git a/iree/compiler/Dialect/VM/Transforms/test/global_initialization.mlir b/iree/compiler/Dialect/VM/Transforms/test/global_initialization.mlir index b652d78..f691847 100644 --- a/iree/compiler/Dialect/VM/Transforms/test/global_initialization.mlir +++ b/iree/compiler/Dialect/VM/Transforms/test/global_initialization.mlir
@@ -9,22 +9,16 @@ // CHECK-LABEL: @initI32 vm.module @initI32 { - // CHECK: vm.global.i32 public mutable @g0 : i32 - vm.global.i32 mutable @g0 initializer(@g0init) : i32 - vm.func @g0init() -> i32 { - %c123 = vm.const.i32 123 : i32 - vm.return %c123 : i32 - } + // CHECK: vm.global.i32 private @g0 + vm.global.i32 private @g0 : i32 = 0 : i32 - // CHECK: vm.global.i32 public mutable @g1 : i32 - vm.global.i32 mutable @g1 = 123 : i32 + // CHECK: vm.global.i32 private mutable @g1 : i32 + vm.global.i32 private mutable @g1 = 123 : i32 - // CHECK: vm.global.i32 public mutable @g2 : i32 - vm.global.i32 @g2 = 123 : i32 + // CHECK: vm.global.i32 private mutable @g2 : i32 + vm.global.i32 private @g2 = 123 : i32 // CHECK: vm.func private @__init() { - // CHECK-NEXT: %0 = vm.call @g0init() - // CHECK-NEXT: vm.global.store.i32 %0, @g0 // CHECK-NEXT: %c123 = vm.const.i32 123 : i32 // CHECK-NEXT: vm.global.store.i32 %c123, @g1 // CHECK-NEXT: %c123_0 = vm.const.i32 123 : i32 @@ -37,22 +31,70 @@ // CHECK-LABEL: @initRef vm.module @initRef { - // CHECK: vm.global.ref public mutable @g0 : !vm.ref<?> - vm.global.ref mutable @g0 initializer(@g0init) : !vm.ref<?> - vm.func @g0init() -> !vm.ref<?> { - %null = vm.const.ref.zero : !vm.ref<?> - vm.return %null : !vm.ref<?> + // CHECK: vm.global.ref private mutable @g0 : !vm.ref<?> + vm.global.ref private mutable @g0 : !vm.ref<?> + + // CHECK: vm.global.ref private mutable @g1 : !vm.ref<?> + vm.global.ref private mutable @g1 : !vm.ref<?> + + // CHECK: vm.global.ref private @g2 : !vm.ref<?> + vm.global.ref private @g2 : !vm.ref<?> + + // CHECK-NOT: vm.func private @__init() +} + +// ----- + +// CHECK-LABEL: @initializers +vm.module @initializers { + // CHECK: vm.global.i32 private mutable @g0 : i32 + vm.global.i32 private @g0 : i32 + // CHECK-NOT: vm.initializer + vm.initializer { + %c123 = vm.const.i32 123 : i32 + vm.global.store.i32 %c123, @g0 : i32 + vm.return } - // CHECK: vm.global.ref public mutable @g1 : !vm.ref<?> - vm.global.ref mutable @g1 : !vm.ref<?> + // CHECK: vm.global.ref private mutable @g1 : !vm.ref<?> + vm.global.ref private mutable @g1 : !vm.ref<?> + // CHECK-NOT: vm.initializer + vm.initializer { + %null = vm.const.ref.zero : !vm.ref<?> + vm.global.store.ref %null, @g1 : !vm.ref<?> + vm.return + } - // CHECK: vm.global.ref public @g2 : !vm.ref<?> - vm.global.ref @g2 : !vm.ref<?> + // CHECK: vm.global.ref private mutable @g2 : !vm.ref<?> + vm.global.ref private mutable @g2 : !vm.ref<?> + // CHECK-NOT: vm.initializer + vm.initializer { + %g1 = vm.global.load.ref @g1 : !vm.ref<?> + vm.global.store.ref %g1, @g2 : !vm.ref<?> + vm.return + } - // CHECK: vm.func private @__init() { - // CHECK-NEXT: %ref = vm.call @g0init() - // CHECK-NEXT: vm.global.store.ref %ref, @g0 + // CHECK: vm.func private @__init() { + // CHECK-NEXT: %c123 = vm.const.i32 123 : i32 + // CHECK-NEXT: vm.global.store.i32 %c123, @g0 : i32 + // CHECK-NEXT: %null = vm.const.ref.zero : !vm.ref<?> + // CHECK-NEXT: vm.global.store.ref %null, @g1 : !vm.ref<?> + // CHECK-NEXT: %g1 = vm.global.load.ref @g1 : !vm.ref<?> + // CHECK-NEXT: vm.global.store.ref %g1, @g2 : !vm.ref<?> // CHECK-NEXT: vm.return // CHECK-NEXT: } } + +// ----- + +// CHECK-LABEL: @unused_globals +vm.module @unused_globals { + // CHECK: vm.global.i32 private mutable @used + vm.global.i32 private @used : i32 = 1 : i32 + // CHECK-NOT: vm.global.i32 private @unused + vm.global.i32 private @unused : i32 = 2 : i32 + vm.func @foo() { + %0 = vm.global.load.i32 @used : i32 + vm.return + } +}
diff --git a/iree/compiler/InputConversion/Common/BUILD b/iree/compiler/InputConversion/Common/BUILD index 7e22c00..efce5e6 100644 --- a/iree/compiler/InputConversion/Common/BUILD +++ b/iree/compiler/InputConversion/Common/BUILD
@@ -44,6 +44,7 @@ cc_library( name = "Common", srcs = [ + "IREEImportPublic.cpp", "Passes.cpp", "TopLevelSCFToCFG.cpp", ], @@ -53,6 +54,11 @@ deps = [ ":PassHeaders", ":PassesIncGen", + "//iree/compiler/Dialect/Flow/IR", + "//iree/compiler/Dialect/HAL/IR", + "//iree/compiler/Dialect/Util/IR", + "//llvm-external-projects/iree-dialects:IREEDialect", + "@llvm-project//mlir:IR", "@llvm-project//mlir:LinalgOps", "@llvm-project//mlir:Pass", "@llvm-project//mlir:SCFDialect",
diff --git a/iree/compiler/InputConversion/Common/CMakeLists.txt b/iree/compiler/InputConversion/Common/CMakeLists.txt index 7c2e3a8..c05b2c4 100644 --- a/iree/compiler/InputConversion/Common/CMakeLists.txt +++ b/iree/compiler/InputConversion/Common/CMakeLists.txt
@@ -39,17 +39,23 @@ HDRS "Passes.h" SRCS + "IREEImportPublic.cpp" "Passes.cpp" "TopLevelSCFToCFG.cpp" DEPS ::PassHeaders ::PassesIncGen + IREEDialectsIREEDialect + MLIRIR MLIRLinalg MLIRPass MLIRSCF MLIRSCFToStandard MLIRStandard MLIRTransforms + iree::compiler::Dialect::Flow::IR + iree::compiler::Dialect::HAL::IR + iree::compiler::Dialect::Util::IR PUBLIC )
diff --git a/iree/compiler/InputConversion/Common/IREEImportPublic.cpp b/iree/compiler/InputConversion/Common/IREEImportPublic.cpp new file mode 100644 index 0000000..7e1df67 --- /dev/null +++ b/iree/compiler/InputConversion/Common/IREEImportPublic.cpp
@@ -0,0 +1,308 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree-dialects/Dialect/IREE/IREEDialect.h" +#include "iree-dialects/Dialect/IREE/IREEOps.h" +#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h" +#include "iree/compiler/Dialect/Flow/IR/FlowOps.h" +#include "iree/compiler/Dialect/Flow/IR/FlowTypes.h" +#include "iree/compiler/Dialect/HAL/IR/HALDialect.h" +#include "iree/compiler/Dialect/HAL/IR/HALOps.h" +#include "iree/compiler/Dialect/HAL/IR/HALTypes.h" +#include "iree/compiler/Dialect/Util/IR/UtilDialect.h" +#include "iree/compiler/Dialect/Util/IR/UtilOps.h" +#include "iree/compiler/Dialect/Util/IR/UtilTypes.h" +#include "iree/compiler/InputConversion/Common/PassDetail.h" +#include "iree/compiler/InputConversion/Common/Passes.h" +#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace IREEPublic = mlir::iree; + +namespace mlir { +namespace iree_compiler { + +namespace { + +// Allowlist of function attributes to retain when importing funcs. +constexpr const char *kRetainedAttributes[] = { + "iree.reflection", + "sym_visibility", + "noinline", +}; + +struct IREEImportPublicPass + : public IREEImportPublicBase<IREEImportPublicPass> { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert<mlir::iree::IREEDialect, IREE::Flow::FlowDialect, + IREE::HAL::HALDialect, IREE::Util::UtilDialect>(); + } + void runOnOperation() override; +}; + +class IREETypeConverter : public TypeConverter { + public: + IREETypeConverter(); +}; + +// Generic 1:1 conversion pattern which effectively just renames an op. +// It does not support regions or ops with successors. +class OneToOneConverionPattern : public ConversionPattern { + public: + OneToOneConverionPattern(TypeConverter &converter, StringRef srcName, + StringRef targetName, MLIRContext *context, + PatternBenefit benefit) + : ConversionPattern(converter, srcName, benefit, context), + targetName(targetName) {} + LogicalResult matchAndRewrite( + Operation *srcOp, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + SmallVector<Type> resultTypes; + if (failed(typeConverter->convertTypes(srcOp->getResultTypes(), + resultTypes))) { + return srcOp->emitError() + << "could not convert result types to IREE internal types"; + } + + OperationState state(srcOp->getLoc(), targetName, operands, resultTypes, + srcOp->getAttrs()); + Operation *targetOp = rewriter.createOperation(state); + rewriter.replaceOp(srcOp, targetOp->getResults()); + return success(); + } + + private: + StringRef targetName; +}; + +class BufferViewToTensorPattern + : public OpConversionPattern<IREEPublic::BufferViewToTensorOp> { + using OpConversionPattern< + IREEPublic::BufferViewToTensorOp>::OpConversionPattern; + LogicalResult matchAndRewrite( + IREEPublic::BufferViewToTensorOp srcOp, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + IREEPublic::BufferViewToTensorOpAdaptor adaptor(operands); + Type resultType = typeConverter->convertType(srcOp.target().getType()); + if (!resultType) return failure(); + rewriter.replaceOpWithNewOp<IREE::HAL::TensorCastOp>( + srcOp, resultType, adaptor.source(), adaptor.target_dims()); + return success(); + } +}; + +class TensorToBufferViewPattern + : public OpConversionPattern<IREEPublic::TensorToBufferViewOp> { + using OpConversionPattern< + IREEPublic::TensorToBufferViewOp>::OpConversionPattern; + LogicalResult matchAndRewrite( + IREEPublic::TensorToBufferViewOp srcOp, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + IREEPublic::TensorToBufferViewOpAdaptor adaptor(operands); + Type resultType = typeConverter->convertType(srcOp.target().getType()); + if (!resultType) return failure(); + rewriter.replaceOpWithNewOp<IREE::HAL::TensorCastOp>( + srcOp, resultType, adaptor.source(), adaptor.source_dims()); + return success(); + } +}; + +class BuiltinFuncOpPattern : public OpConversionPattern<FuncOp> { + using OpConversionPattern<FuncOp>::OpConversionPattern; + LogicalResult matchAndRewrite( + FuncOp srcOp, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + FunctionType srcFuncType = srcOp.getType(); + TypeConverter::SignatureConversion signatureConversion( + srcOp.getNumArguments()); + + // Convert function arguments. + for (unsigned i = 0, e = srcFuncType.getNumInputs(); i < e; ++i) { + if (failed(getTypeConverter()->convertSignatureArg( + i, srcFuncType.getInput(i), signatureConversion))) { + return rewriter.notifyMatchFailure(srcOp, "argument failed to convert"); + } + } + + // Convert function results. + SmallVector<Type, 1> convertedResultTypes; + if (failed(getTypeConverter()->convertTypes(srcFuncType.getResults(), + convertedResultTypes))) { + return rewriter.notifyMatchFailure(srcOp, "results failed to convert"); + } + + // Create new function with converted argument and result types. + // Note that attributes are dropped. Consider preserving some if needed. + auto newFuncType = mlir::FunctionType::get( + srcOp.getContext(), signatureConversion.getConvertedTypes(), + convertedResultTypes); + auto newFuncOp = + rewriter.create<FuncOp>(srcOp.getLoc(), srcOp.getName(), newFuncType); + rewriter.inlineRegionBefore(srcOp.getBody(), newFuncOp.getBody(), + newFuncOp.end()); + + // Retain function attributes in the allowlist. + auto retainedAttributes = ArrayRef<const char *>( + kRetainedAttributes, + sizeof(kRetainedAttributes) / sizeof(kRetainedAttributes[0])); + for (auto retainAttrName : retainedAttributes) { + StringRef attrName(retainAttrName); + Attribute attr = srcOp->getAttr(attrName); + if (attr) { + newFuncOp->setAttr(attrName, attr); + } + } + + // Tell the rewriter to convert the region signature. + TypeConverter &typeConverter = *getTypeConverter(); + if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter, + &signatureConversion))) { + return failure(); + } + + rewriter.replaceOp(srcOp, llvm::None); + return success(); + } +}; + +// Matches any op and generically converts types. Matches with benefit 0. +class GenericTypeConvert : public ConversionPattern { + public: + GenericTypeConvert(TypeConverter &converter, MLIRContext *context, + PatternBenefit benefit) + : ConversionPattern(converter, MatchAnyOpTypeTag(), benefit, context) {} + LogicalResult matchAndRewrite( + Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + llvm::SmallVector<NamedAttribute, 4> newAttr; + llvm::SmallVector<Type, 4> newResults; + (void)getTypeConverter()->convertTypes(op->getResultTypes(), newResults); + OperationState state(op->getLoc(), op->getName().getStringRef(), operands, + newResults, newAttr, op->getSuccessors()); + for (Region &r : op->getRegions()) { + Region *newRegion = state.addRegion(); + rewriter.inlineRegionBefore(r, *newRegion, newRegion->begin()); + TypeConverter::SignatureConversion result(newRegion->getNumArguments()); + (void)getTypeConverter()->convertSignatureArgs( + newRegion->getArgumentTypes(), result); + rewriter.applySignatureConversion(newRegion, result); + } + Operation *newOp = rewriter.createOperation(state); + rewriter.replaceOp(op, newOp->getResults()); + return success(); + } +}; + +} // namespace + +IREETypeConverter::IREETypeConverter() { + addConversion([](Type t) { return t; }); + addConversion([=](IREEPublic::BufferViewType t) { + return IREE::HAL::BufferViewType::get(t.getContext()); + }); + addConversion([=](IREEPublic::ListType t) -> IREE::Util::ListType { + auto subType = convertType(t.getElementType()); + if (!subType) return nullptr; + return IREE::Util::ListType::get(subType); + }); + addConversion([=](IREEPublic::PtrType t) -> IREE::Util::PtrType { + auto subType = convertType(t.getTargetType()); + if (!subType) return nullptr; + return IREE::Util::PtrType::get(subType); + }); + addConversion([](IREEPublic::VariantType t) { + return IREE::Util::VariantType::get(t.getContext()); + }); +} + +void IREEImportPublicPass::runOnOperation() { + auto &context = getContext(); + RewritePatternSet patterns(&getContext()); + ConversionTarget target(getContext()); + target.addLegalDialect<IREE::Flow::FlowDialect>(); + target.addLegalDialect<IREE::HAL::HALDialect>(); + target.addLegalDialect<IREE::Util::UtilDialect>(); + target.addIllegalDialect<IREEPublic::IREEDialect>(); + + auto ireeDialect = context.getOrLoadDialect<IREEPublic::IREEDialect>(); + auto isIllegalType = [&](Type t) { + return t.getDialect().getTypeID() == ireeDialect->getTypeID(); + }; + + target.addDynamicallyLegalOp<FuncOp>([&](FuncOp funcOp) { + for (Type type : funcOp.getType().getInputs()) { + if (isIllegalType(type)) return false; + } + for (Type type : funcOp.getType().getResults()) { + if (isIllegalType(type)) return false; + } + return true; + }); + + target.markUnknownOpDynamicallyLegal([&](Operation *op) { + for (Type type : op->getResultTypes()) { + if (isIllegalType(type)) return false; + } + for (Type type : op->getOperandTypes()) { + if (isIllegalType(type)) return false; + } + return true; + }); + + IREETypeConverter typeConverter; + PatternBenefit specific_benefit = 100; + patterns.insert<GenericTypeConvert>(typeConverter, &getContext(), 0); + patterns.insert<BuiltinFuncOpPattern>(typeConverter, &getContext(), + specific_benefit); + patterns.insert<BufferViewToTensorPattern>(typeConverter, &getContext(), + specific_benefit); + patterns.insert<TensorToBufferViewPattern>(typeConverter, &getContext(), + specific_benefit); + +#define ONETOONE(SrcOpTy, TargetOpTy) \ + patterns.insert<OneToOneConverionPattern>( \ + typeConverter, SrcOpTy::getOperationName(), \ + TargetOpTy::getOperationName(), &getContext(), specific_benefit) + + ONETOONE(IREEPublic::BufferViewRankOp, IREE::HAL::BufferViewRankOp); + ONETOONE(IREEPublic::BufferViewDimOp, IREE::HAL::BufferViewDimOp); + ONETOONE(IREEPublic::ListCreateOp, IREE::Util::ListCreateOp); + ONETOONE(IREEPublic::ListSizeOp, IREE::Util::ListSizeOp); + ONETOONE(IREEPublic::ListResizeOp, IREE::Util::ListResizeOp); + ONETOONE(IREEPublic::ListGetOp, IREE::Util::ListGetOp); + ONETOONE(IREEPublic::ListSetOp, IREE::Util::ListSetOp); + ONETOONE(IREEPublic::NullOp, IREE::Util::NullOp); + ONETOONE(IREEPublic::TensorCloneOp, IREE::Flow::TensorCloneOp); + ONETOONE(IREEPublic::TensorLoadOp, IREE::Flow::TensorLoadOp); + ONETOONE(IREEPublic::TensorReshapeOp, IREE::Flow::TensorReshapeOp); + ONETOONE(IREEPublic::TensorSliceOp, IREE::Flow::TensorSliceOp); + ONETOONE(IREEPublic::TensorSplatOp, IREE::Flow::TensorSplatOp); + ONETOONE(IREEPublic::TensorStoreOp, IREE::Flow::TensorStoreOp); + ONETOONE(IREEPublic::TensorUpdateOp, IREE::Flow::TensorUpdateOp); + ONETOONE(IREEPublic::TensorTraceOp, IREE::Flow::TensorTraceOp); + ONETOONE(IREEPublic::GlobalOp, IREE::Util::GlobalOp); + ONETOONE(IREEPublic::GlobalAddressOp, IREE::Util::GlobalAddressOp); + ONETOONE(IREEPublic::GlobalLoadOp, IREE::Util::GlobalLoadOp); + ONETOONE(IREEPublic::GlobalLoadIndirectOp, IREE::Util::GlobalLoadIndirectOp); + ONETOONE(IREEPublic::GlobalStoreOp, IREE::Util::GlobalStoreOp); + ONETOONE(IREEPublic::GlobalStoreIndirectOp, + IREE::Util::GlobalStoreIndirectOp); + + if (failed(applyFullConversion(getOperation(), target, std::move(patterns)))) + signalPassFailure(); +} + +std::unique_ptr<OperationPass<ModuleOp>> createIREEImportPublicPass() { + return std::make_unique<IREEImportPublicPass>(); +} + +} // namespace iree_compiler +} // namespace mlir
diff --git a/iree/compiler/InputConversion/Common/Passes.cpp b/iree/compiler/InputConversion/Common/Passes.cpp index 84332fd..7c21a55 100644 --- a/iree/compiler/InputConversion/Common/Passes.cpp +++ b/iree/compiler/InputConversion/Common/Passes.cpp
@@ -6,6 +6,11 @@ #include "iree/compiler/InputConversion/Common/Passes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Pass/PassOptions.h" +#include "mlir/Pass/PassRegistry.h" + namespace mlir { namespace iree_compiler { @@ -14,9 +19,20 @@ #include "iree/compiler/InputConversion/Common/Passes.h.inc" // IWYU pragma: export } // namespace +void buildCommonInputConversionPassPipeline(OpPassManager &passManager) { + passManager.addPass(createIREEImportPublicPass()); +} + void registerCommonInputConversionPasses() { // Generated passes. registerPasses(); + + PassPipelineRegistration<> mhlo( + "iree-common-input-transformation-pipeline", + "Runs the common input transformation pipeline", + [](OpPassManager &passManager) { + buildCommonInputConversionPassPipeline(passManager); + }); } } // namespace iree_compiler
diff --git a/iree/compiler/InputConversion/Common/Passes.h b/iree/compiler/InputConversion/Common/Passes.h index f67587a..d490d34 100644 --- a/iree/compiler/InputConversion/Common/Passes.h +++ b/iree/compiler/InputConversion/Common/Passes.h
@@ -14,10 +14,19 @@ namespace iree_compiler { //===----------------------------------------------------------------------===// +// Pipelines +//===----------------------------------------------------------------------===// + +// Performs common input legalization after specific input dialect conversions +// have taken place. +void buildCommonInputConversionPassPipeline(OpPassManager &passManager); + +//===----------------------------------------------------------------------===// // Passes //===----------------------------------------------------------------------===// std::unique_ptr<OperationPass<FuncOp>> createTopLevelSCFToCFGPass(); +std::unique_ptr<OperationPass<ModuleOp>> createIREEImportPublicPass(); //===----------------------------------------------------------------------===// // Register all Passes
diff --git a/iree/compiler/InputConversion/Common/Passes.td b/iree/compiler/InputConversion/Common/Passes.td index 7c225c6..be99a3f 100644 --- a/iree/compiler/InputConversion/Common/Passes.td +++ b/iree/compiler/InputConversion/Common/Passes.td
@@ -15,4 +15,10 @@ let constructor = "mlir::iree_compiler::createTopLevelSCFToCFGPass()"; } +def IREEImportPublic : + Pass<"iree-import-public", "ModuleOp"> { + let summary = "Imports IREE public dialect to internal implementation."; + let constructor = "mlir::iree_compiler::createIREEImportPublicPass()"; +} + #endif // IREE_COMPILER_INPUTCONVERSION_COMMON_PASSES
diff --git a/iree/compiler/InputConversion/Common/test/BUILD b/iree/compiler/InputConversion/Common/test/BUILD index d9e87bd..22eef63 100644 --- a/iree/compiler/InputConversion/Common/test/BUILD +++ b/iree/compiler/InputConversion/Common/test/BUILD
@@ -19,6 +19,7 @@ name = "lit", srcs = enforce_glob( [ + "iree_import_public.mlir", "top_level_scf_to_cfg.mlir", ], include = ["*.mlir"],
diff --git a/iree/compiler/InputConversion/Common/test/CMakeLists.txt b/iree/compiler/InputConversion/Common/test/CMakeLists.txt index ab43294..da81176 100644 --- a/iree/compiler/InputConversion/Common/test/CMakeLists.txt +++ b/iree/compiler/InputConversion/Common/test/CMakeLists.txt
@@ -14,6 +14,7 @@ NAME lit SRCS + "iree_import_public.mlir" "top_level_scf_to_cfg.mlir" DATA iree::tools::IreeFileCheck
diff --git a/iree/compiler/InputConversion/Common/test/iree_import_public.mlir b/iree/compiler/InputConversion/Common/test/iree_import_public.mlir new file mode 100644 index 0000000..a69d497 --- /dev/null +++ b/iree/compiler/InputConversion/Common/test/iree_import_public.mlir
@@ -0,0 +1,233 @@ +// RUN: iree-opt -split-input-file -iree-import-public %s | IreeFileCheck %s + +// CHECK-LABEL: func @bv_func +// CHECK-SAME: (%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> (!hal.buffer_view, !hal.buffer_view) +// CHECK: return %arg0, %arg1 : !hal.buffer_view, !hal.buffer_view +builtin.func @bv_func(%arg0 : !iree.buffer_view, %arg1 : !iree.buffer_view) -> (!iree.buffer_view, !iree.buffer_view) { + return %arg0, %arg1 : !iree.buffer_view, !iree.buffer_view +} + +// ----- +// CHECK-LABEL: func @list_func +// CHECK-SAME: (%arg0: !util.list<?>) -> !util.list<?> +builtin.func @list_func(%arg0 : !iree.list<!iree.variant>) -> !iree.list<!iree.variant> { + return %arg0 : !iree.list<!iree.variant> +} + +// ----- +// CHECK-LABEL: func @ptr_func +// CHECK-SAME: (%arg0: !util.ptr<!hal.buffer_view>) -> !util.ptr<!hal.buffer_view> +builtin.func @ptr_func(%arg0 : !iree.ptr<!iree.buffer_view>) -> !iree.ptr<!iree.buffer_view> { + return %arg0 : !iree.ptr<!iree.buffer_view> +} + +// ----- +// CHECK-LABEL: func @null_op +// CHECK: util.null : !util.variant +builtin.func @null_op() -> !iree.variant { + %0 = iree.null : !iree.variant + return %0 : !iree.variant +} + +// ----- +// CHECK-LABEL: func @tensor_to_buffer_view +// CHECK: hal.tensor.cast %arg0 : tensor<?x?x3xf32>{%arg1, %arg2} -> !hal.buffer_view +builtin.func @tensor_to_buffer_view(%arg0 : tensor<?x?x3xf32>, %arg1 : index, %arg2 : index) -> !iree.buffer_view { + %0 = iree.cast.tensor_to_buffer_view %arg0 : tensor<?x?x3xf32> {%arg1, %arg2} -> !iree.buffer_view + return %0 : !iree.buffer_view +} + +// ----- +// CHECK-LABEL: func @buffer_view_to_tensor +// CHECK: hal.tensor.cast %arg0 : !hal.buffer_view -> tensor<?x?x3xf32>{%arg1, %arg2} +builtin.func @buffer_view_to_tensor(%arg0 : !iree.buffer_view, %arg1 : index, %arg2 : index) -> tensor<?x?x3xf32> { + %0 = iree.cast.buffer_view_to_tensor %arg0 : !iree.buffer_view -> tensor<?x?x3xf32> {%arg1, %arg2} + return %0 : tensor<?x?x3xf32> +} + +// ----- +// CHECK-LABEL: func @buffer_view_rank +// CHECK: hal.buffer_view.rank<%arg0 : !hal.buffer_view> : index +builtin.func @buffer_view_rank(%arg0 : !iree.buffer_view) -> index { + %0 = iree.buffer_view.rank %arg0 : index + return %0 : index +} + +// ----- +// CHECK-LABEL: func @buffer_view_dim +// CHECK: hal.buffer_view.dim<%arg0 : !hal.buffer_view>[0] : index +builtin.func @buffer_view_dim(%arg0 : !iree.buffer_view) -> index { + %0 = iree.buffer_view.dim %arg0, 0 : index + return %0: index +} + +// ----- +// CHECK-LABEL: func @list_create +// CHECK: util.list.create %arg0 : !util.list<?> +builtin.func @list_create(%arg0 : index) -> !iree.list<!iree.variant> { + %0 = iree.list.create %arg0 : !iree.list<!iree.variant> + return %0 : !iree.list<!iree.variant> +} + +// ----- +// CHECK-LABEL: func @list_size +// CHECK: util.list.size %arg0 : !util.list<?> +builtin.func @list_size(%arg0 : !iree.list<!iree.variant>) -> index { + %0 = iree.list.size %arg0 : !iree.list<!iree.variant> + return %0 : index +} + +// ----- +// CHECK-LABEL: func @list_resize +// CHECK: util.list.resize %arg0, %arg1 : !util.list<?> +builtin.func @list_resize(%arg0 : !iree.list<!iree.variant>, %arg1 : index) { + iree.list.resize %arg0, %arg1 : !iree.list<!iree.variant> + return +} + +// ----- +// CHECK-LABEL: func @list_get +// CHECK: util.list.get %arg0[%arg1] : !util.list<?> +builtin.func @list_get(%arg0 : !iree.list<!iree.variant>, %arg1 : index) -> !iree.variant { + %0 = iree.list.get %arg0[%arg1] : !iree.list<!iree.variant> -> !iree.variant + return %0 : !iree.variant +} + +// ----- +// CHECK-LABEL: func @list_set +// CHECK: util.list.set %arg0[%arg1], %arg2 : !util.list<?> +builtin.func @list_set(%arg0 : !iree.list<!iree.variant>, %arg1 : index, %arg2 : !iree.variant) { + iree.list.set %arg0[%arg1], %arg2 : !iree.list<!iree.variant>, !iree.variant + return +} + +// ----- +// CHECK-LABEL: func @tensor_reshape +// CHECK: flow.tensor.reshape %arg0 : tensor<?x?xf32>{%arg1, %arg2} -> tensor<?x?xf32>{%arg2, %arg1} +builtin.func @tensor_reshape(%arg0 : tensor<?x?xf32>, %arg1 : index, %arg2 : index) -> tensor<?x?xf32> { + %0 = iree.tensor.reshape %arg0 : tensor<?x?xf32>{%arg1, %arg2} -> tensor<?x?xf32>{%arg2, %arg1} + return %0 : tensor<?x?xf32> +} + +// ----- +// CHECK-LABEL: func @tensor_load +// CHECK: flow.tensor.load %arg0[%arg2, %arg3] : tensor<?x3xf32>{%arg1} +builtin.func @tensor_load(%arg0 : tensor<?x3xf32>, %arg1 : index, %arg2 : index, %arg3 : index) -> f32 { + %0 = iree.tensor.load %arg0[%arg2, %arg3] : tensor<?x3xf32>{%arg1} + return %0 : f32 +} + +// ----- +// CHECK-LABEL: func @tensor_store +// CHECK: flow.tensor.store %arg4, %arg0[%arg2, %arg3] : tensor<?x3xf32>{%arg1} +builtin.func @tensor_store(%arg0 : tensor<?x3xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : f32) { + iree.tensor.store %arg4, %arg0[%arg2, %arg3] : tensor<?x3xf32>{%arg1} + return +} + +// ----- +// CHECK-LABEL: func @tensor_splat +// CHECK: flow.tensor.splat %arg0 : tensor<?x?xf32>{%arg1, %arg2} +builtin.func @tensor_splat(%arg0 : f32, %arg1 : index, %arg2 : index) -> tensor<?x?xf32> { + %0 = iree.tensor.splat %arg0 : tensor<?x?xf32>{%arg1, %arg2} + return %0 : tensor<?x?xf32> +} + +// ----- +// CHECK-LABEL: func @tensor_clone +// CHECK: flow.tensor.clone %arg0 : tensor<?x?xf32>{%arg1, %arg2} +builtin.func @tensor_clone(%arg0 : tensor<?x?xf32>, %arg1 : index, %arg2 : index) -> tensor<?x?xf32> { + %0 = iree.tensor.clone %arg0 : tensor<?x?xf32>{%arg1, %arg2} + return %0 : tensor<?x?xf32> +} + +// ----- +// CHECK-LABEL: func @tensor_slice +// CHECK: flow.tensor.slice %arg0[%arg1 for %arg2] : tensor<?xf32>{%arg3} -> tensor<?xf32>{%arg4} +builtin.func @tensor_slice(%arg0 : tensor<?xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) -> tensor<?xf32> { + %0 = iree.tensor.slice %arg0[%arg1 for %arg2] : tensor<?xf32>{%arg3} -> tensor<?xf32>{%arg4} + return %0 : tensor<?xf32> +} + +// ----- +// CHECK-LABEL: func @tensor_update +// CHECK: flow.tensor.update %arg3, %arg0[%arg1] : tensor<?xf32>{%arg2} -> %arg0 as tensor<?xf32>{%arg4} +builtin.func @tensor_update(%arg0 : tensor<?xf32>, %arg1 : index, %arg2 : index, %arg3 : tensor<?xf32>, %arg4 : index) -> tensor<?xf32> { + %0 = iree.tensor.update %arg3, %arg0[%arg1] : tensor<?xf32>{%arg2} -> tensor<?xf32>{%arg4} + return %0 : tensor<?xf32> +} + +// ----- +// CHECK-LABEL: func @tensor_trace +// CHECK: flow.tensor.trace {key = "FOOBAR"} %arg0, %arg1 : tensor<5xf32>, tensor<3xf32> +builtin.func @tensor_trace(%arg0 : tensor<5xf32>, %arg1 : tensor<3xf32>) { + iree.tensor.trace "FOOBAR" %arg0, %arg1 : tensor<5xf32>, tensor<3xf32> + return +} + +// ----- +// CHECK-LABEL: module @globals +builtin.module @globals { + // CHECK: util.global public mutable @global1 = 50 : i32 + iree.global mutable @global1 = 50 : i32 + // CHECK: util.global public mutable @global2 = 50 : i32 + iree.global public mutable @global2 = 50 : i32 + // CHECK: util.global private mutable @global3 = 50 : i32 + iree.global private mutable @global3 = 50 : i32 + // CHECK: util.global private @global4 = 50 : i32 + iree.global private @global4 = 50 : i32 + + // CHECK: util.global public @global5 initializer(@initializer) : tensor<4xi32> + iree.global @global5 initializer(@initializer) : tensor<4xi32> + builtin.func private @initializer() -> tensor<4xi32> +} + +// ----- +// CHECK-LABEL: module @global_load +builtin.module @global_load { + iree.global private @v_loaded : tensor<4xi32> + func @loaded() { + // CHECK: util.global.load @v_loaded : tensor<4xi32> + %0 = iree.global.load @v_loaded : tensor<4xi32> + return + } +} + +// ----- +// CHECK-LABEL: module @global_store +builtin.module @global_store { + iree.global private mutable @v_stored : tensor<4xi32> + func @stored() { + // CHECK: %[[CST:.*]] = constant + %cst = constant dense<5> : tensor<4xi32> + // CHECK: util.global.store %[[CST]], @v_stored : tensor<4xi32> + iree.global.store %cst, @v_stored : tensor<4xi32> + return + } +} + +// ----- +// CHECK-LABEL: module @global_load_indirect +builtin.module @global_load_indirect { + iree.global private @v_loaded : tensor<4xf32> + func @loaded_indirect() { + // CHECK: %[[ADDR:.*]] = util.global.address @v_loaded : !util.ptr<tensor<4xf32>> + %0 = iree.global.address @v_loaded : !iree.ptr<tensor<4xf32>> + // CHECK: util.global.load.indirect %[[ADDR]] : !util.ptr<tensor<4xf32>> -> tensor<4xf32> + %1 = iree.global.load.indirect %0 : !iree.ptr<tensor<4xf32>> -> tensor<4xf32> + return + } +} + +// ----- +// CHECK-LABEL: module @global_store_indirect +builtin.module @global_store_indirect { + iree.global private mutable @v_stored : tensor<4xf32> + func @stored_indirect(%arg0: tensor<4xf32>) { + // CHECK: %[[ADDR:.*]] = util.global.address @v_stored : !util.ptr<tensor<4xf32>> + %0 = iree.global.address @v_stored : !iree.ptr<tensor<4xf32>> + // CHECK: util.global.store.indirect %arg0, %ptr_v_stored : tensor<4xf32> -> !util.ptr<tensor<4xf32>> + iree.global.store.indirect %arg0, %0 : tensor<4xf32> -> !iree.ptr<tensor<4xf32>> + return + } +}
diff --git a/iree/compiler/Translation/BUILD b/iree/compiler/Translation/BUILD index 4d4bcd1..886a5a2 100644 --- a/iree/compiler/Translation/BUILD +++ b/iree/compiler/Translation/BUILD
@@ -29,6 +29,7 @@ "//iree/compiler/Dialect/VM/Conversion/StandardToVM", "//iree/compiler/Dialect/VM/Target/Bytecode", "//iree/compiler/Dialect/VM/Transforms", + "//iree/compiler/InputConversion/Common", "//iree/compiler/InputConversion/MHLO", "//iree/compiler/InputConversion/TOSA", "//iree/compiler/Utils",
diff --git a/iree/compiler/Translation/IREEVM.cpp b/iree/compiler/Translation/IREEVM.cpp index 4a6ef7b..20bbff2 100644 --- a/iree/compiler/Translation/IREEVM.cpp +++ b/iree/compiler/Translation/IREEVM.cpp
@@ -13,6 +13,7 @@ #include "iree/compiler/Dialect/Util/Transforms/Passes.h" #include "iree/compiler/Dialect/VM/Target/Bytecode/TranslationFlags.h" #include "iree/compiler/Dialect/VM/Transforms/Passes.h" +#include "iree/compiler/InputConversion/Common/Passes.h" #include "iree/compiler/InputConversion/MHLO/Passes.h" #include "iree/compiler/InputConversion/TOSA/Passes.h" #include "iree/compiler/Utils/TracingUtils.h" @@ -145,6 +146,7 @@ break; } + buildCommonInputConversionPassPipeline(passManager); IREE::Flow::buildFlowTransformPassPipeline(passManager); IREE::HAL::buildHALTransformPassPipeline(passManager, executableOptions); IREE::VM::buildVMTransformPassPipeline(passManager, targetOptions);
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREE/IREEDialect.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREE/IREEDialect.td index d9740ae..691359a 100644 --- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREE/IREEDialect.td +++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREE/IREEDialect.td
@@ -95,7 +95,7 @@ let parameters = (ins IREE_PtrTargetTypeParameter:$targetType); let printer = [{ - $_printer << "list<" << getTargetType() << ">"; + $_printer << "ptr<" << getTargetType() << ">"; }]; let parser = [{
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREE/IREEOps.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREE/IREEOps.td index 41774ff..f937ef7 100644 --- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREE/IREEOps.td +++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREE/IREEOps.td
@@ -30,7 +30,7 @@ // Casts //===----------------------------------------------------------------------===// -def IREE_TensorToBufferView : IREE_PureOp<"cast.tensor_to_buffer_view"> { +def IREE_TensorToBufferViewOp : IREE_PureOp<"cast.tensor_to_buffer_view"> { let summary = "Casts a tensor to a BufferView, capturing dynamic dims"; let arguments = (ins IREE_Tensor:$source, @@ -44,7 +44,7 @@ }]; } -def IREE_BufferViewToTensor : IREE_PureOp<"cast.buffer_view_to_tensor"> { +def IREE_BufferViewToTensorOp : IREE_PureOp<"cast.buffer_view_to_tensor"> { let summary = "Casts a BufferView to a tensor, providing dynamic dims"; let arguments = (ins IREE_BufferViewType:$source, @@ -81,15 +81,14 @@ OptionalAttr<AnyAttr>:$initial_value ); - // TODO(laurenzo): copy SymbolVisibility/TypeOrAttr from UtilOps.cpp. - // let assemblyFormat = [{ - // custom<SymbolVisibility>($sym_visibility) - // (`mutable` $is_mutable^)? - // $sym_name - // attr-dict - // (`initializer` `(` $initializer^ `)`):(``)? - // custom<TypeOrAttr>($type, $initial_value) - // }]; + let assemblyFormat = [{ + custom<SymbolVisibility>($sym_visibility) + (`mutable` $is_mutable^)? + $sym_name + attr-dict + (`initializer` `(` $initializer^ `)`):(``)? + custom<TypeOrAttr>($type, $initial_value) + }]; } def IREE_GlobalAddressOp : IREE_PureOp<"global.address"> { @@ -485,8 +484,7 @@ IREE_ShapeDynamicDims:$target_dims, Variadic<IREE_Dim>:$start_indices, IREE_Tensor:$update, - IREE_ShapeDynamicDims:$update_dims, - OptionalAttr<IREE_TiedOpStorageAttr>:$tied_operands + IREE_ShapeDynamicDims:$update_dims ); let results = (outs IREE_Tensor:$result @@ -495,7 +493,7 @@ let assemblyFormat = [{ $update `,` $target `[` $start_indices `]` `:` type($update) (`{` $update_dims^ `}`)? `->` - `(` type($result) `,` $target_dims `,` $tied_operands `)` + type($result) (`{` $target_dims^ `}`)? attr-dict-with-keyword }]; @@ -521,7 +519,7 @@ Variadic<IREE_Tensor>:$operands ); - let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; + let assemblyFormat = "$key attr-dict ($operands^ `:` type($operands))?"; } #endif // IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREE_IREE_OPS_TD
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREE/IREEOps.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/IREE/IREEOps.cpp index 2a12751..a723b58 100644 --- a/llvm-external-projects/iree-dialects/lib/Dialect/IREE/IREEOps.cpp +++ b/llvm-external-projects/iree-dialects/lib/Dialect/IREE/IREEOps.cpp
@@ -15,5 +15,80 @@ using namespace mlir; using namespace mlir::iree; +//===----------------------------------------------------------------------===// +// custom<SymbolVisibility>($sym_visibility) +//===----------------------------------------------------------------------===// +// some.op custom<SymbolVisibility>($sym_visibility) $sym_name +// -> +// some.op @foo +// some.op private @foo + +static ParseResult parseSymbolVisibility(OpAsmParser &parser, + StringAttr &symVisibilityAttr) { + StringRef symVisibility; + parser.parseOptionalKeyword(&symVisibility, {"public", "private", "nested"}); + if (!symVisibility.empty()) { + symVisibilityAttr = parser.getBuilder().getStringAttr(symVisibility); + } + return success(); +} + +static void printSymbolVisibility(OpAsmPrinter &p, Operation *op, + StringAttr symVisibilityAttr) { + if (!symVisibilityAttr) { + p << "public"; + } else { + p << symVisibilityAttr.getValue(); + } +} + +//===----------------------------------------------------------------------===// +// custom<TypeOrAttr>($type, $attr) +//===----------------------------------------------------------------------===// +// some.op custom<TypeOrAttr>($type, $attr) +// -> +// some.op : i32 +// some.op = 42 : i32 +// some.op : i32 = 42 : index + +static ParseResult parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr, + Attribute &attr) { + if (succeeded(parser.parseOptionalEqual())) { + if (failed(parser.parseAttribute(attr))) { + return parser.emitError(parser.getCurrentLocation()) + << "expected attribute"; + } + typeAttr = TypeAttr::get(attr.getType()); + return success(); + } + + Type type; + if (failed(parser.parseColonType(type))) { + return parser.emitError(parser.getCurrentLocation()) << "expected type"; + } + typeAttr = TypeAttr::get(type); + + if (succeeded(parser.parseOptionalEqual())) { + if (failed(parser.parseAttribute(attr))) { + return parser.emitError(parser.getCurrentLocation()) + << "expected attribute"; + } + } + + return success(); +} + +static void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type, + Attribute attr) { + if (!attr || attr.getType() != type.getValue()) { + p << " : "; + p.printAttribute(type); + } + if (attr) { + p << " = "; + p.printAttribute(attr); + } +} + #define GET_OP_CLASSES #include "iree-dialects/Dialect/IREE/IREEOps.cpp.inc"