Pull in passes to convert from XLA HLO ops to SPIR-V. Pulls in passes to convert xla-hlo to linalg.generic on tensors, followed by conversion from linalg.generic on tensors to linalg.generic on buffers. The the linalg to SPIR-V pass pipeline can be used to generate SPIR-V dialect. Also adds some patterns to IREELinalgTensorToBuffer to eliminate IREE::*Ops PiperOrigin-RevId: 293252317
diff --git a/build_tools/third_party/tensorflow/tensorflow/compiler/mlir/xla/CMakeLists.txt b/build_tools/third_party/tensorflow/tensorflow/compiler/mlir/xla/CMakeLists.txt index 1bb1bee..4543971 100644 --- a/build_tools/third_party/tensorflow/tensorflow/compiler/mlir/xla/CMakeLists.txt +++ b/build_tools/third_party/tensorflow/tensorflow/compiler/mlir/xla/CMakeLists.txt
@@ -43,11 +43,13 @@ "transforms/lower_general_dot.cc" "transforms/materialize_broadcasts.cc" "transforms/unfuse_batch_norm.cc" + "transforms/xla_legalize_to_linalg.cc" HDRS "convert_op_folder.h" "ir/hlo_ops.h" "ir/hlo_utils.h" "ir/lhlo_ops.h" + "transforms/map_xla_to_scalar_op.h" "transforms/passes.h" "transforms/rewriters.h" COPTS
diff --git a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/BUILD b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/BUILD index eb51823..1c5ac34 100644 --- a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/BUILD +++ b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/BUILD
@@ -30,6 +30,7 @@ "//iree/compiler/Dialect/HAL/Target:ExecutableTarget", "//iree/compiler/Dialect/IREE/IR", "//iree/compiler/Translation/SPIRV/EmbeddedKernels", + "//iree/compiler/Translation/SPIRV/LinalgToSPIRV", "//iree/compiler/Translation/SPIRV/XLAToSPIRV", "//iree/schemas:spirv_executable_def_cc_fbs", "@com_github_google_flatbuffers//:flatbuffers",
diff --git a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/CMakeLists.txt index 6c962e5..917dfa4 100644 --- a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/CMakeLists.txt +++ b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/CMakeLists.txt
@@ -24,6 +24,7 @@ iree::compiler::Dialect::HAL::Target::ExecutableTarget iree::compiler::Dialect::IREE::IR iree::compiler::Translation::SPIRV::EmbeddedKernels + iree::compiler::Translation::SPIRV::LinalgToSPIRV iree::compiler::Translation::SPIRV::XLAToSPIRV iree::schemas::spirv_executable_def_cc_fbs flatbuffers
diff --git a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp index 75a4a83..ac1d0d8 100644 --- a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp +++ b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
@@ -20,6 +20,7 @@ #include "iree/compiler/Dialect/Flow/IR/FlowOps.h" #include "iree/compiler/Dialect/HAL/Target/LegacyUtil.h" #include "iree/compiler/Translation/SPIRV/EmbeddedKernels/EmbeddedKernels.h" +#include "iree/compiler/Translation/SPIRV/LinalgToSPIRV/LowerToSPIRV.h" #include "iree/compiler/Translation/SPIRV/XLAToSPIRV/IREEToSPIRVPass.h" #include "iree/schemas/spirv_executable_def_generated.h" #include "llvm/ADT/STLExtras.h" @@ -43,6 +44,12 @@ // static llvm::cl::OptionCategory halVulkanSPIRVOptionsCategory( // "IREE Vulkan/SPIR-V backend options"); +static llvm::cl::opt<bool> useLinalgPathForCodegen( + "iree-use-linalg-to-spirv-path", + llvm::cl::desc( + "Flag to use the XLA-HLO to Linalg To SPIR-V pass pipeline."), + llvm::cl::init(false)); + VulkanSPIRVTargetOptions getVulkanSPIRVTargetOptionsFromFlags() { VulkanSPIRVTargetOptions targetOptions; // TODO(benvanik): flags. @@ -140,7 +147,11 @@ // Lower module to spirv::ModuleOp. PassManager conversionPassManager(moduleOp.getContext()); - addIREEToSPIRVPasses(conversionPassManager); + if (useLinalgPathForCodegen) { + addLowerToSPIRVPasses(conversionPassManager); + } else { + addIREEToSPIRVPasses(conversionPassManager); + } if (failed(conversionPassManager.run(moduleOp))) { return moduleOp.emitError() << "failed to run conversion passes"; }
diff --git a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/BUILD b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/BUILD index d45ecaf..d787c2e 100644 --- a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/BUILD +++ b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/BUILD
@@ -20,18 +20,26 @@ cc_library( name = "LinalgToSPIRV", srcs = [ - "LinalgToSPIRV.cpp", + "LowerToSPIRV.cpp", + ], + hdrs = [ + "LowerToSPIRV.h", ], deps = [ + "//iree/compiler/Translation/XLAToLinalg:IREELinalgTensorToBuffer", + "//iree/compiler/Utils", "@llvm-project//llvm:support", "@llvm-project//mlir:AffineOps", "@llvm-project//mlir:AffineToStandardTransforms", + "@llvm-project//mlir:EDSC", + "@llvm-project//mlir:GPUDialect", "@llvm-project//mlir:GPUToSPIRVTransforms", "@llvm-project//mlir:GPUTransforms", "@llvm-project//mlir:IR", "@llvm-project//mlir:LinalgOps", "@llvm-project//mlir:LinalgTransforms", - "@llvm-project//mlir:LoopsToGPUPass", + "@llvm-project//mlir:LoopOps", + "@llvm-project//mlir:LoopsToGPU", "@llvm-project//mlir:Pass", "@llvm-project//mlir:SPIRVDialect", "@llvm-project//mlir:SPIRVLowering", @@ -39,6 +47,8 @@ "@llvm-project//mlir:StandardToSPIRVConversions", "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", + "@org_tensorflow//tensorflow/compiler/mlir/xla:hlo", + "@org_tensorflow//tensorflow/compiler/mlir/xla:xla_legalize_to_linalg", ], alwayslink = 1, )
diff --git a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/CMakeLists.txt b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/CMakeLists.txt index 20f8392..6a13ae2 100644 --- a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/CMakeLists.txt +++ b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/CMakeLists.txt
@@ -17,17 +17,23 @@ iree_cc_library( NAME LinalgToSPIRV + HDRS + "LowerToSPIRV.h" SRCS - "LinalgToSPIRV.cpp" + "LowerToSPIRV.cpp" DEPS + iree::compiler::Translation::XLAToLinalg::IREELinalgTensorToBuffer + iree::compiler::Utils LLVMSupport MLIRAffineOps MLIRAffineToStandard + MLIREDSC MLIRGPUtoSPIRVTransforms MLIRGPU MLIRIR MLIRLinalgOps MLIRLinalgTransforms + MLIRLoopOps MLIRLoopsToGPU MLIRPass MLIRSPIRV @@ -35,6 +41,7 @@ MLIRStandardToSPIRVTransforms MLIRSupport MLIRTransforms + tensorflow::mlir_xla ALWAYSLINK PUBLIC )
diff --git a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/LinalgToSPIRV.cpp b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/LinalgToSPIRV.cpp deleted file mode 100644 index 62e1986..0000000 --- a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/LinalgToSPIRV.cpp +++ /dev/null
@@ -1,104 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//===- LinalgToSPIRV.cpp - Linalg dialect to SPIR-V dialect----------------===// -// -// Implementation of conversion from Linalg To SPIRV -// -//===----------------------------------------------------------------------===// -#include "llvm/Support/CommandLine.h" -#include "llvm/Support/MemoryBuffer.h" -#include "mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.h" -#include "mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h" -#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h" -#include "mlir/Dialect/GPU/Passes.h" -#include "mlir/Dialect/Linalg/Passes.h" -#include "mlir/Dialect/SPIRV/Passes.h" -#include "mlir/Dialect/SPIRV/SPIRVOps.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/Module.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassManager.h" -#include "mlir/Pass/PassOptions.h" -#include "mlir/Pass/PassRegistry.h" -#include "mlir/Transforms/Passes.h" - -namespace mlir { -namespace iree_compiler { - -namespace { -struct LinalgToSPIRVPassOptions - : public PassPipelineOptions<LinalgToSPIRVPassOptions> { - ListOption<int64_t> numWorkGroups{ - *this, "num-workgroups", - llvm::cl::desc( - "Number of workgroups in the SPIR-V module for x, followed by y, " - "followed by z dimension of the dispatch (others will be ignored)"), - llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; - ListOption<int64_t> workGroupSize{ - *this, "workgroup-size", - llvm::cl::desc( - "Workgroup Sizes in the SPIR-V module for x, followed by y, followed " - "by z dimension of the dispatch (others will be ignored)"), - llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; -}; -} // namespace - -static void addLinalgToSPIRVPasses(OpPassManager &pm, - const LinalgToSPIRVPassOptions &options) { - // TODO(ravishankarm): For now only evaluated with 2D tiling. So set the - // workgroup size and numworkgroups to size 2 - SmallVector<int64_t, 2> numWorkGroups, workGroupSize; - numWorkGroups.assign(options.numWorkGroups.begin(), - options.numWorkGroups.end()); - numWorkGroups.resize(2, 1); - workGroupSize.assign(options.workGroupSize.begin(), - options.workGroupSize.end()); - workGroupSize.resize(2, 1); - - // Linalg to loops. - pm.addPass(createLinalgTilingPass(workGroupSize)); - pm.addPass(createConvertLinalgToLoopsPass()); - pm.addPass(createLowerAffinePass()); - pm.addPass(createCanonicalizerPass()); - pm.addPass(createCSEPass()); - - // Loops to GPU. - pm.addPass(createLoopToGPUPass(numWorkGroups, workGroupSize)); - pm.addPass(createGpuKernelOutliningPass()); - pm.addPass(createCanonicalizerPass()); - pm.addPass(createCSEPass()); - pm.addPass(createLowerAffinePass()); - - // GPU to SPIR-V. - pm.addPass(createLegalizeStdOpsForSPIRVLoweringPass()); - pm.addPass(createCanonicalizerPass()); - pm.addPass(createCSEPass()); - pm.addPass(createConvertGPUToSPIRVPass(workGroupSize)); - - // SPIR-V passes for lowering attributes. - OpPassManager &spirvModulePM = pm.nest<spirv::ModuleOp>(); - spirvModulePM.addPass(spirv::createLowerABIAttributesPass()); - spirvModulePM.addPass(createCanonicalizerPass()); - spirvModulePM.addPass(createCSEPass()); -} - -static PassPipelineRegistration<LinalgToSPIRVPassOptions> linalgToSPIRVPipeline( - "iree-linalg-to-spirv", - "Runs the progressive lowering pipeline from Linalg to SPIR-V", - [](OpPassManager &passManager, const LinalgToSPIRVPassOptions &options) { - addLinalgToSPIRVPasses(passManager, options); - }); -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/LowerToSPIRV.cpp b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/LowerToSPIRV.cpp new file mode 100644 index 0000000..e84fd48 --- /dev/null +++ b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/LowerToSPIRV.cpp
@@ -0,0 +1,289 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//===- LowerToSPIRV.cpp - Lower from XLA to Linalg to SPIR-V dialect-------===// +// +// Implementation of conversion from XLA-HLO to Linalg to SPIR-V dialect. +// +//===----------------------------------------------------------------------===// + +#include "iree/compiler/Translation/XLAToLinalg/LinalgTensorToBuffer.h" +#include "iree/compiler/Utils/IREECodegenUtils.h" +#include "mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.h" +#include "mlir/Conversion/LoopsToGPU/LoopsToGPU.h" +#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h" +#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h" +#include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/GPU/Passes.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/LoopOps/LoopOps.h" +#include "mlir/Dialect/SPIRV/Passes.h" +#include "mlir/Dialect/SPIRV/SPIRVLowering.h" +#include "mlir/Dialect/SPIRV/SPIRVOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/EDSC/Intrinsics.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Pass/PassOptions.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Transforms/FoldUtils.h" +#include "mlir/Transforms/Passes.h" +#include "tensorflow/compiler/mlir/xla/transforms/passes.h" + +namespace mlir { +namespace iree_compiler { + +namespace { +/// These options are only for testing purposes. For actual execution with IREE, +/// these are computed by IREE/Backends automatically. +struct WorkGroupOptions : public PassPipelineOptions<WorkGroupOptions> { + ListOption<int64_t> workGroupSize{ + *this, "workgroup-size", + llvm::cl::desc( + "Number of workgroups to dispatch for the SPIR-V module; at most " + "three integers standarding for the x, y, and z dimension; " + "additional arguments will be ignored (used only for testing)"), + llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; +}; +} // namespace + +static DenseIntElementsAttr getDenseIntElementsAttrVal( + Builder *builder, ArrayRef<int64_t> value) { + SmallVector<int32_t, 3> vector; + vector.reserve(3); + for (auto val : value) { + vector.emplace_back(val); + } + vector.resize(3, 1); + return builder->getI32VectorAttr(vector); +} + +/// Helper function to create a std.constant of index type to initialize the +/// workgroup size as a SSA value. +static void createConstantsInFunc(FuncOp funcOp, ArrayRef<int64_t> intVal, + SmallVectorImpl<Value> &constVal) { + OpBuilder builder(funcOp.getBody()); + MLIRContext *context = funcOp.getContext(); + for (auto val : intVal) { + constVal.push_back(builder.create<ConstantOp>( + funcOp.getLoc(), IntegerAttr::get(IndexType::get(context), val))); + } +} + +namespace { + +/// To be able to use the workgroup size from the dispatch function attribute +/// within the linalg tiling pass, need to actually implement a pass to retrieve +/// the attribute value from the function and pass it along. +// TODO(ravishankarm): Move this into Linalg dialect. +struct IREETileLinalgPass : public FunctionPass<IREETileLinalgPass> { + void runOnFunction() override { + FuncOp funcOp = getFunction(); + SmallVector<int64_t, 3> workGroupSize; + workGroupSize.reserve(3); + if (failed(getLegacyWorkGroupSize(funcOp, workGroupSize))) { + return; + } + OpBuilder builder(funcOp); + OperationFolder folder(funcOp.getContext()); + funcOp.walk([&workGroupSize, &builder, &folder](linalg::LinalgOp op) { + if (!op.hasBufferSemantics()) { + return; + } + SmallVector<int64_t, 3> tileSizes; + auto nLoops = op.getNumLoops(); + tileSizes.assign(workGroupSize.begin(), workGroupSize.end()); + // Linalg convention is to use 0 for no tiling. If the workgroup size is + // 1, then dont tile along that dimension. So overriding 1 to 0. + for (auto &tileSize : tileSizes) { + if (tileSize == 1) tileSize = 0; + } + tileSizes.resize(nLoops, 0); + auto tiledOp = linalg::tileLinalgOp(builder, op, tileSizes, {}, &folder); + if (tiledOp) { + op.erase(); + } + }); + } +}; + +/// To be able to use the workgroup size from the dispatch function attribute to +/// convert loops to GPU kernel, need to actually implement a pass to retrieve +/// the attribute value from the function and pass it along. +// TODO(ravishankarm): Structure the Loops to GPU pass in MLIR so that we dont +// have to do this. Maybe make it an OpPassBase<loop::ForOp> ? +struct LoopsToGPUPass : public FunctionPass<LoopsToGPUPass> { + void runOnFunction() override { + // Get the workgroup size from the attributes. + FuncOp funcOp = getFunction(); + SmallVector<int64_t, 3> workGroupSize; + workGroupSize.reserve(3); + if (failed(getLegacyWorkGroupSize(funcOp, workGroupSize))) { + return; + } + // TODO(ravishankarm): Currently evaluating only 2D tiling. Generalize this. + workGroupSize.resize(2); + // The Loop To GPU pass expects the numWorkGroups only to create the + // host-side launch operation. We don't care about that, so just pass {1, 1, + // 1} for that. + SmallVector<int64_t, 3> numWorkGroups(workGroupSize.size(), 1); + SmallVector<Value, 3> numWorkGroupsVal, workGroupSizeVal; + numWorkGroupsVal.reserve(3); + workGroupSizeVal.reserve(3); + createConstantsInFunc(funcOp, numWorkGroups, numWorkGroupsVal); + createConstantsInFunc(funcOp, workGroupSize, workGroupSizeVal); + for (Block &block : getFunction()) { + for (Operation &op : llvm::make_early_inc_range(block)) { + if (auto forOp = dyn_cast<loop::ForOp>(&op)) { + if (failed(convertLoopToGPULaunch(forOp, numWorkGroupsVal, + workGroupSizeVal))) { + return signalPassFailure(); + } + } + } + } + } +}; + +/// To be able to use the workgroup size from the dispatch function attribute to +/// convert GPU kernel into SPIR-V kernel, need to actually implement a pass to +/// retrieve the attribute value from the function and pass it along. +// TODO(ravishankarm): Move this into MLIR core. +struct IREEGPUToSPIRVPass : public ModulePass<IREEGPUToSPIRVPass> { + void runOnModule() { + MLIRContext *context = &getContext(); + ModuleOp moduleOp = getModule(); + FuncOp funcOp = nullptr; + auto walkResult = moduleOp.walk([&funcOp](FuncOp fOp) -> WalkResult { + if (fOp.getAttr("iree.executable.export")) { + if (funcOp) { + return WalkResult::interrupt(); + } + funcOp = fOp; + } + return WalkResult::advance(); + }); + if (!funcOp || walkResult.wasInterrupted()) { + moduleOp.emitError("expected a single dispatch function within module"); + return signalPassFailure(); + } + SmallVector<Operation *, 1> kernelModules; + OpBuilder builder(context); + builder.setInsertionPoint(funcOp.getOperation()); + + // Clone the GPU module into the funcop to convert into a SPIR-V module. + funcOp.walk( + [&builder, &moduleOp, &kernelModules](gpu::LaunchFuncOp gpuLaunchOp) { + auto kernelModuleName = gpuLaunchOp.getKernelModuleName(); + auto gpuModuleOp = + moduleOp.lookupSymbol<gpu::GPUModuleOp>(kernelModuleName); + kernelModules.push_back(builder.clone(*gpuModuleOp.getOperation())); + }); + SPIRVTypeConverter typeConverter; + OwningRewritePatternList patterns; + SmallVector<int64_t, 3> workGroupSize; + if (failed(getLegacyWorkGroupSize(funcOp, workGroupSize))) { + return; + } + populateGPUToSPIRVPatterns(context, typeConverter, patterns, workGroupSize); + populateStandardToSPIRVPatterns(context, typeConverter, patterns); + + std::unique_ptr<ConversionTarget> target = + spirv::SPIRVConversionTarget::get( + spirv::lookupTargetEnvOrDefault(funcOp), context); + target->addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { + return typeConverter.isSignatureLegal(op.getType()); + }); + + if (failed(applyFullConversion(kernelModules, *target, patterns, + &typeConverter))) { + return signalPassFailure(); + } + } +}; + +/// Pass to override the workgroup_size attribute value of a dispatch function. +// TODO(ravishankarm): Use a more cohorent strategy than just setting it to {2, +// 2}. +struct UpdateWorkGroupSizePass : FunctionPass<UpdateWorkGroupSizePass> { + UpdateWorkGroupSizePass(ArrayRef<int64_t> workGroupSize) + : workGroupSize(workGroupSize.begin(), workGroupSize.end()) {} + void runOnFunction() { + FuncOp funcOp = getFunction(); + if (!funcOp.getAttr("iree.executable.export")) { + return; + } + if (workGroupSize.empty()) { + workGroupSize = {2, 2}; + } + workGroupSize.resize(3, 1); + OpBuilder builder(&getContext()); + funcOp.setAttr("iree.executable.workgroup_size", + getDenseIntElementsAttrVal(&builder, workGroupSize)); + } + + private: + SmallVector<int64_t, 3> workGroupSize; +}; +} // namespace + +static void addLinalgToSPIRVPasses(OpPassManager &pm) { + // Linalg to loops. + pm.addPass(std::make_unique<IREETileLinalgPass>()); + pm.addPass(createConvertLinalgToLoopsPass()); + pm.addPass(createLowerAffinePass()); + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); + + pm.addPass(std::make_unique<LoopsToGPUPass>()); + pm.addPass(createGpuKernelOutliningPass()); + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); + pm.addPass(createLowerAffinePass()); + + // GPU to SPIR-V. + pm.addPass(createLegalizeStdOpsForSPIRVLoweringPass()); + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); + pm.addPass(std::make_unique<IREEGPUToSPIRVPass>()); + + // SPIR-V passes for lowering attributes. + OpPassManager &spirvModulePM = pm.nest<spirv::ModuleOp>(); + spirvModulePM.addPass(spirv::createLowerABIAttributesPass()); + spirvModulePM.addPass(createCanonicalizerPass()); + spirvModulePM.addPass(createCSEPass()); +} + +void addLowerToSPIRVPasses(OpPassManager &pm, ArrayRef<int64_t> workGroupSize) { + pm.addPass(xla_hlo::createLegalizeHloToLinalgPass()); + pm.addPass(createLinalgTensorToBufferConversionPass()); + pm.addPass(std::make_unique<UpdateWorkGroupSizePass>(workGroupSize)); + addLinalgToSPIRVPasses(pm); +} + +static PassPipelineRegistration<WorkGroupOptions> xlaToLinalgSPIRVPipeline( + "iree-xla-to-linalg-to-spirv", + "Runs the progressive lowering pipeline from XLA HLO to Linalg to SPIR-V", + [](OpPassManager &passManager, const WorkGroupOptions &options) { + SmallVector<int64_t, 2> workGroupSize; + workGroupSize.assign(options.workGroupSize.begin(), + options.workGroupSize.end()); + addLowerToSPIRVPasses(passManager, workGroupSize); + }); +} // namespace iree_compiler +} // namespace mlir
diff --git a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/LowerToSPIRV.h b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/LowerToSPIRV.h new file mode 100644 index 0000000..8dc6c25 --- /dev/null +++ b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/LowerToSPIRV.h
@@ -0,0 +1,30 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_COMPILER_TRANSLATION_SPIRV_LINALGTOSPIRV_LOWERTOSPIRV_H +#define IREE_COMPILER_TRANSLATION_SPIRV_LINALGTOSPIRV_LOWERTOSPIRV_H + +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace iree_compiler { + +/// Populates passes needed to lower a XLA HLO op to SPIR-V dialect. +void addLowerToSPIRVPasses(OpPassManager &pm, + ArrayRef<int64_t> workGroupSize = {}); + +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_TRANSLATION_SPIRV_XLATOLINALGSPIRV_LOWERTOSPIRV_H
diff --git a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/BUILD b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/BUILD index 1d8d07c..f63a472 100644 --- a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/BUILD +++ b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/BUILD
@@ -27,5 +27,6 @@ data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", + "//iree/tools:iree-run-mlir", ], )
diff --git a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/CMakeLists.txt b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/CMakeLists.txt index 2fafb16..e10a865 100644 --- a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/CMakeLists.txt +++ b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/CMakeLists.txt
@@ -16,8 +16,10 @@ NAME lit SRCS - "single_pw_op.mlir" + "pw_add.mlir" + "pw_add_e2e.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt + iree::tools::iree-run-mlir )
diff --git a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/pw_add.mlir b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/pw_add.mlir new file mode 100644 index 0000000..95812a6 --- /dev/null +++ b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/pw_add.mlir
@@ -0,0 +1,13 @@ +// RUN: iree-opt -pass-pipeline='iree-xla-to-linalg-to-spirv' %s | IreeFileCheck %s + +module { + func @simple_load_store(%arg0: memref<4x8xi32>, %arg1: memref<4x8xi32>, %arg2 : memref<4x8xi32>) + attributes {iree.executable.export, iree.executable.workload = dense<[8, 4, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[2, 2, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} { + // CHECK: spv.module + %0 = iree.load_input(%arg0 : memref<4x8xi32>) : tensor<4x8xi32> + %1 = iree.load_input(%arg1 : memref<4x8xi32>) : tensor<4x8xi32> + %2 = "xla_hlo.add"(%0, %1) : (tensor<4x8xi32>, tensor<4x8xi32>) -> tensor<4x8xi32> + iree.store_output(%2 : tensor<4x8xi32>, %arg2 : memref<4x8xi32>) + iree.return + } +}
diff --git a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/single_pw_op.mlir b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/single_pw_op.mlir deleted file mode 100644 index d875f14..0000000 --- a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/single_pw_op.mlir +++ /dev/null
@@ -1,14 +0,0 @@ -// RUN: iree-opt -pass-pipeline='iree-linalg-to-spirv{workgroup-size=2,2 num-workgroups=2,2}' %s - -#map0 = affine_map<(d0, d1) -> (d0, d1)> - -module { - func @fmul(%arg0: memref<12x4xf32>, %arg1: memref<12x4xf32>, %arg2: memref<12x4xf32>) { - linalg.generic {args_in = 2 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} %arg0, %arg1, %arg2 { - ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors - %0 = mulf %arg3, %arg4 : f32 - linalg.yield %0 : f32 - }: memref<12x4xf32>, memref<12x4xf32>, memref<12x4xf32> - return - } -}
diff --git a/iree/compiler/Translation/XLAToLinalg/LinalgTensorToBuffer.cpp b/iree/compiler/Translation/XLAToLinalg/LinalgTensorToBuffer.cpp index 9ba3583..1d0d3f9 100644 --- a/iree/compiler/Translation/XLAToLinalg/LinalgTensorToBuffer.cpp +++ b/iree/compiler/Translation/XLAToLinalg/LinalgTensorToBuffer.cpp
@@ -22,6 +22,17 @@ namespace iree_compiler { namespace { +/// Remove IREE::LoadInputOp operations +struct RemoveLoadInputOpPattern : OpConversionPattern<IREE::LoadInputOp> { + using OpConversionPattern<IREE::LoadInputOp>::OpConversionPattern; + PatternMatchResult matchAndRewrite( + IREE::LoadInputOp op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOp(op, op.getOperand()); + return matchSuccess(); + } +}; + /// Convert from a linalg.generic on tensors to linalg.generic on buffers. In /// IREE it is expected that each dispatch region will become a single /// linalg.generic op on tensors (after XLA-HLO -> Linalg conversion and @@ -41,11 +52,11 @@ }; /// Remove IREE::StoreOutputOp operations. -struct RemoveDeadStorePattern : OpConversionPattern<IREE::StoreOutputOp> { +struct RemoveStoreOutputOpPattern : OpConversionPattern<IREE::StoreOutputOp> { using OpConversionPattern<IREE::StoreOutputOp>::OpConversionPattern; PatternMatchResult matchAndRewrite( IREE::StoreOutputOp op, ArrayRef<Value> operands, - ConversionPatternRewriter &rewriter) const { + ConversionPatternRewriter &rewriter) const override { rewriter.eraseOp(op); return matchSuccess(); } @@ -61,7 +72,6 @@ return matchSuccess(); } }; - } // namespace PatternMatchResult LinalgTensorToBufferConverter::matchAndRewrite( @@ -69,17 +79,7 @@ ConversionPatternRewriter &rewriter) const { // TODO(ravishankarm): Find a way to write this using Matchers, but need to // figure out how to match operations with variadic operands. - SmallVector<Value, 2> memrefArgs; - for (auto arg : op.getOperands()) { - if (!arg.getType().isa<RankedTensorType>()) { - return matchFailure(); - } - auto definingOp = dyn_cast_or_null<IREE::LoadInputOp>(arg.getDefiningOp()); - if (!definingOp) { - return matchFailure(); - } - memrefArgs.push_back(definingOp.getOperand()); - } + SmallVector<Value, 2> memrefArgs(operands.begin(), operands.end()); // For result, check that there is a single use in an iree::store_output op. for (auto result : op.getResults()) { if (!result.hasOneUse()) { @@ -121,8 +121,9 @@ void populateLinalgTensorToBufferConversionPattern( MLIRContext *context, OwningRewritePatternList &patterns) { - patterns.insert<LinalgTensorToBufferConverter, RemoveDeadStorePattern, - IREEReturnOpLowering>(context); + patterns.insert<IREEReturnOpLowering, LinalgTensorToBufferConverter, + RemoveLoadInputOpPattern, RemoveStoreOutputOpPattern>( + context); } struct LinalgTensorToBufferConversionPass @@ -132,7 +133,6 @@ MLIRContext *context = &getContext(); ConversionTarget target(*context); target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect>(); - target.addLegalOp<IREE::LoadInputOp>(); target.addLegalOp<FuncOp>(); target.addDynamicallyLegalOp<linalg::GenericOp>([&](linalg::GenericOp op) { return llvm::all_of(op.getOperands(),
diff --git a/iree/compiler/Translation/XLAToLinalg/LinalgTensorToBuffer.h b/iree/compiler/Translation/XLAToLinalg/LinalgTensorToBuffer.h index 60e5246..c1be661 100644 --- a/iree/compiler/Translation/XLAToLinalg/LinalgTensorToBuffer.h +++ b/iree/compiler/Translation/XLAToLinalg/LinalgTensorToBuffer.h
@@ -1,4 +1,4 @@ -// Copyright 2019 Google LLC +// Copyright 2020 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License.
diff --git a/iree/compiler/Utils/BUILD b/iree/compiler/Utils/BUILD index c254864..a76ef85 100644 --- a/iree/compiler/Utils/BUILD +++ b/iree/compiler/Utils/BUILD
@@ -35,6 +35,7 @@ "//iree/compiler/Dialect/IREE/IR", "@llvm-project//llvm:support", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils",
diff --git a/iree/compiler/Utils/CMakeLists.txt b/iree/compiler/Utils/CMakeLists.txt index d8db64d..c361f20 100644 --- a/iree/compiler/Utils/CMakeLists.txt +++ b/iree/compiler/Utils/CMakeLists.txt
@@ -27,6 +27,7 @@ iree::compiler::Dialect::IREE::IR LLVMSupport MLIRIR + MLIRPass MLIRStandardOps MLIRSupport MLIRTransformUtils
diff --git a/iree/compiler/Utils/IREECodegenUtils.cpp b/iree/compiler/Utils/IREECodegenUtils.cpp index 98bd878..b0b1f3b 100644 --- a/iree/compiler/Utils/IREECodegenUtils.cpp +++ b/iree/compiler/Utils/IREECodegenUtils.cpp
@@ -48,8 +48,9 @@ } /// Gets the workgroup size. +template <typename intType> LogicalResult getLegacyWorkGroupSize(FuncOp funcOp, - SmallVectorImpl<int32_t> &workGroupSize) { + SmallVectorImpl<intType> &workGroupSize) { if (!funcOp.getAttr("iree.executable.export")) { return funcOp.emitError( "expected operation to be in dispatch function to get launch size"); @@ -68,5 +69,10 @@ return success(); } +template LogicalResult getLegacyWorkGroupSize<int32_t>( + FuncOp funcOp, SmallVectorImpl<int32_t> &workGroupSize); +template LogicalResult getLegacyWorkGroupSize<int64_t>( + FuncOp funcOp, SmallVectorImpl<int64_t> &workGroupSize); + } // namespace iree_compiler } // namespace mlir
diff --git a/iree/compiler/Utils/IREECodegenUtils.h b/iree/compiler/Utils/IREECodegenUtils.h index 3983156..908de04 100644 --- a/iree/compiler/Utils/IREECodegenUtils.h +++ b/iree/compiler/Utils/IREECodegenUtils.h
@@ -29,10 +29,10 @@ LogicalResult getLegacyLaunchSize(FuncOp funcOp, SmallVectorImpl<int64_t> &launchSize); -// TODO(ravishankarm): remove this; it is not safe for variable sizes. -/// Gets the workgroup size. +/// Gets the workgroup size. Has to be a static constant. +template <typename intType> LogicalResult getLegacyWorkGroupSize(FuncOp funcOp, - SmallVectorImpl<int32_t> &workGroupSize); + SmallVectorImpl<intType> &workGroupSize); } // namespace iree_compiler } // namespace mlir
diff --git a/iree/tools/CMakeLists.txt b/iree/tools/CMakeLists.txt index f6ddae0..67b08c8 100644 --- a/iree/tools/CMakeLists.txt +++ b/iree/tools/CMakeLists.txt
@@ -117,8 +117,8 @@ iree::compiler::Dialect::VM::Transforms iree::compiler::Translation::Interpreter::Transforms iree::compiler::Translation::IREEVM - iree::compiler::Translation::SPIRV::XLAToSPIRV iree::compiler::Translation::SPIRV::LinalgToSPIRV + iree::compiler::Translation::SPIRV::XLAToSPIRV iree::compiler::Translation::XLAToLinalg iree::compiler::Translation::XLAToLinalg::IREELinalgTensorToBuffer LLVMSupport