[Util] Add structural conversion pattern for cf.switch (#14729)
diff --git a/compiler/src/iree/compiler/Dialect/Util/Conversion/ConversionPatterns.cpp b/compiler/src/iree/compiler/Dialect/Util/Conversion/ConversionPatterns.cpp index 13945ea..3437050 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Conversion/ConversionPatterns.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Conversion/ConversionPatterns.cpp
@@ -190,6 +190,19 @@ } }; +struct ConvertSwitchOp : public OpConversionPattern<mlir::cf::SwitchOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(mlir::cf::SwitchOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp<mlir::cf::SwitchOp>( + op, adaptor.getFlag(), op.getDefaultDestination(), + adaptor.getDefaultOperands(), op.getCaseValuesAttr(), + op.getCaseDestinations(), adaptor.getCaseOperands()); + return success(); + } +}; + struct ConvertSelectOp : public OpConversionPattern<mlir::arith::SelectOp> { using OpConversionPattern::OpConversionPattern; LogicalResult @@ -260,13 +273,15 @@ addGenericLegalOp<func::ReturnOp>(conversionTarget, typeConverter); addGenericLegalOp<cf::BranchOp>(conversionTarget, typeConverter); addGenericLegalOp<cf::CondBranchOp>(conversionTarget, typeConverter); + addGenericLegalOp<cf::SwitchOp>(conversionTarget, typeConverter); addGenericLegalOp<arith::SelectOp>(conversionTarget, typeConverter); addGenericLegalOp<scf::IfOp>(conversionTarget, typeConverter); addGenericLegalOp<scf::YieldOp>(conversionTarget, typeConverter); - patterns.insert<ConvertInitializerOp, ConvertFuncOp, ConvertCallOp, - ConvertReturnOp, ConvertBranchOp, ConvertCondBranchOp, - ConvertSelectOp, ConvertIfOp, ConvertYieldOp>(typeConverter, - context); + patterns + .insert<ConvertInitializerOp, ConvertFuncOp, ConvertCallOp, + ConvertReturnOp, ConvertBranchOp, ConvertCondBranchOp, + ConvertSwitchOp, ConvertSelectOp, ConvertIfOp, ConvertYieldOp>( + typeConverter, context); } } // namespace iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/Util/Conversion/test/structural_ops.mlir b/compiler/src/iree/compiler/Dialect/Util/Conversion/test/structural_ops.mlir index 5afcba5..b6d910b 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Conversion/test/structural_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/Util/Conversion/test/structural_ops.mlir
@@ -65,6 +65,25 @@ // ----- +// CHECK-LABEL: @switchOp +// CHECK-SAME: (%[[FLAG:.+]]: i32, %[[ARG0:.+]]: !util.buffer, %[[ARG1:.+]]: !util.buffer) -> !util.buffer +func.func @switchOp(%flag: i32, %arg0: memref<?xi8>, %arg1: memref<?xi8>) -> memref<?xi8> { + // CHECK: cf.switch %[[FLAG]] : i32, [ + // CHECK: default: ^bb1(%[[ARG0]] : !util.buffer), + // CHECK: 0: ^bb1(%[[ARG1]] : !util.buffer) + // CHECK: ] + cf.switch %flag : i32, [ + default: ^bb1(%arg0 : memref<?xi8>), + 0: ^bb1(%arg1 : memref<?xi8>) + ] +// CHECK: ^bb1(%[[BB1_ARG0:.+]]: !util.buffer): +^bb1(%bb1_arg0 : memref<?xi8>): + // CHECK: return %[[BB1_ARG0]] : !util.buffer + return %bb1_arg0 : memref<?xi8> +} + +// ----- + // CHECK-LABEL: @selectOp // CHECK-SAME: (%[[COND:.+]]: i1, %[[ARG0:.+]]: !util.buffer, %[[ARG1:.+]]: !util.buffer) -> !util.buffer func.func @selectOp(%cond: i1, %arg0: memref<?xi8>, %arg1: memref<?xi8>) -> memref<?xi8> {