Working around SimplifyGlobalAccesses not supporting SCF. Pass needs to be rewritten to use data flow analysis.
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD b/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD index b443240..78f4be7 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD
@@ -62,6 +62,7 @@ "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:SCFToControlFlow", "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", ],
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt index 3a606d7..a772874 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt
@@ -43,6 +43,7 @@ MLIRIR MLIRPass MLIRSCFDialect + MLIRSCFToControlFlow MLIRSupport MLIRTransforms iree::compiler::Dialect::Flow::IR
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp index 6ec9288..ab1e088 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
@@ -13,6 +13,7 @@ #include "iree/compiler/Dialect/Util/Transforms/Passes.h" #include "iree/compiler/Utils/PassUtils.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" #include "mlir/Pass/PassRegistry.h" #include "mlir/Transforms/Passes.h" @@ -202,7 +203,11 @@ // Fixup workgroup count calculations that may have used the affine dialect. // Kind of random here but can happen if the benchmarking code does things. - passManager.addPass(createLowerAffinePass()); + passManager.addPass(mlir::createLowerAffinePass()); + + // TODO(benvanik): remove the need for this; some cleanup passes such as + // SimplifyGlobalAccesses are currently broken with scf present. + FunctionLikeNest(passManager).addPass(mlir::createConvertSCFToCFPass); // Combine the initializers we emitted during resource cache materialization. passManager.addPass(IREE::Util::createCombineInitializersPass()); @@ -222,7 +227,7 @@ // NOTE: symbol DCE will destroy executable target contents, so only run it // if we serialized things. - passManager.addPass(createSymbolDCEPass()); + passManager.addPass(mlir::createSymbolDCEPass()); } }
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/SimplifyGlobalAccesses.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/SimplifyGlobalAccesses.cpp index b5bfa98..b2f6106 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/SimplifyGlobalAccesses.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/SimplifyGlobalAccesses.cpp
@@ -79,7 +79,8 @@ static bool doesOpBlockMotion(Operation *op) { return isa<mlir::CallOpInterface>(op) || - op->hasTrait<OpTrait::IREE::Util::YieldPoint>(); + op->hasTrait<OpTrait::IREE::Util::YieldPoint>() || + op->hasTrait<OpTrait::IsTerminator>(); } static void moveOpUpInBlock(Block &block, Operation *op) { @@ -90,7 +91,7 @@ } static void moveOpDownInBlock(Block &block, Operation *op) { - while (op->getNextNode() != block.getTerminator()) { + while (op->getNextNode()) { if (doesOpBlockMotion(op->getNextNode())) break; op->moveAfter(op->getNextNode()); }