diff --git a/Sources/SwiftFormatPrettyPrint/TokenStreamCreator.swift b/Sources/SwiftFormatPrettyPrint/TokenStreamCreator.swift index f97497262..1818146a3 100644 --- a/Sources/SwiftFormatPrettyPrint/TokenStreamCreator.swift +++ b/Sources/SwiftFormatPrettyPrint/TokenStreamCreator.swift @@ -412,6 +412,21 @@ fileprivate final class TokenStreamCreator: SyntaxVisitor { return .visitChildren } + override func visit(_ node: AccessorEffectSpecifiersSyntax) -> SyntaxVisitorContinueKind { + arrangeEffectSpecifiers(node) + return .visitChildren + } + + override func visit(_ node: FunctionEffectSpecifiersSyntax) -> SyntaxVisitorContinueKind { + arrangeEffectSpecifiers(node) + return .visitChildren + } + + override func visit(_ node: TypeEffectSpecifiersSyntax) -> SyntaxVisitorContinueKind { + arrangeEffectSpecifiers(node) + return .visitChildren + } + /// Applies formatting tokens to the tokens in the given function or function-like declaration /// node (e.g., initializers, deinitiailizers, and subscripts). private func arrangeFunctionLikeDecl( @@ -434,6 +449,17 @@ fileprivate final class TokenStreamCreator: SyntaxVisitor { after(node.lastToken(viewMode: .sourceAccurate), tokens: .close) } + /// Arranges the `async` and `throws` effect specifiers of a function or accessor declaration. + private func arrangeEffectSpecifiers(_ node: Node) { + before(node.asyncSpecifier, tokens: .break) + before(node.throwsSpecifier, tokens: .break) + // Keep them together if both `async` and `throws` are present. + if let asyncSpecifier = node.asyncSpecifier, let throwsSpecifier = node.throwsSpecifier { + before(asyncSpecifier, tokens: .open) + after(throwsSpecifier, tokens: .close) + } + } + // MARK: - Property and subscript accessor block nodes override func visit(_ node: AccessorListSyntax) -> SyntaxVisitorContinueKind { @@ -449,22 +475,6 @@ fileprivate final class TokenStreamCreator: SyntaxVisitor { override func visit(_ node: AccessorDeclSyntax) -> SyntaxVisitorContinueKind { arrangeAttributeList(node.attributes) - - if let asyncKeyword = node.effectSpecifiers?.asyncSpecifier { - if node.effectSpecifiers?.throwsSpecifier != nil { - before(asyncKeyword, tokens: .break, .open) - } else { - before(asyncKeyword, tokens: .break) - } - } - - if let throwsKeyword = node.effectSpecifiers?.throwsSpecifier { - before(node.effectSpecifiers?.throwsSpecifier, tokens: .break) - if node.effectSpecifiers?.asyncSpecifier != nil { - after(throwsKeyword, tokens: .close) - } - } - arrangeBracesAndContents(of: node.body, contentsKeyPath: \.statements) return .visitChildren } @@ -1160,13 +1170,6 @@ fileprivate final class TokenStreamCreator: SyntaxVisitor { } } - before(node.effectSpecifiers?.asyncSpecifier, tokens: .break) - before(node.effectSpecifiers?.throwsSpecifier, tokens: .break) - if let asyncKeyword = node.effectSpecifiers?.asyncSpecifier, let throwsTok = node.effectSpecifiers?.throwsSpecifier { - before(asyncKeyword, tokens: .open) - after(throwsTok, tokens: .close) - } - before(node.output?.arrow, tokens: .break) after(node.lastToken(viewMode: .sourceAccurate), tokens: .close) before(node.inTok, tokens: .break(.same)) @@ -1607,8 +1610,6 @@ fileprivate final class TokenStreamCreator: SyntaxVisitor { override func visit(_ node: FunctionTypeSyntax) -> SyntaxVisitorContinueKind { after(node.leftParen, tokens: .break(.open, size: 0), .open) before(node.rightParen, tokens: .break(.close, size: 0), .close) - before(node.effectSpecifiers?.asyncSpecifier, tokens: .break) - before(node.effectSpecifiers?.throwsSpecifier, tokens: .break) return .visitChildren } @@ -1833,14 +1834,6 @@ fileprivate final class TokenStreamCreator: SyntaxVisitor { } override func visit(_ node: FunctionSignatureSyntax) -> SyntaxVisitorContinueKind { - before(node.effectSpecifiers?.asyncSpecifier, tokens: .break) - before(node.effectSpecifiers?.throwsSpecifier, tokens: .break) - if let asyncOrReasyncKeyword = node.effectSpecifiers?.asyncSpecifier, - let throwsOrRethrowsKeyword = node.effectSpecifiers?.throwsSpecifier - { - before(asyncOrReasyncKeyword, tokens: .open) - after(throwsOrRethrowsKeyword, tokens: .close) - } before(node.output?.firstToken(viewMode: .sourceAccurate), tokens: .break) return .visitChildren } @@ -1873,6 +1866,14 @@ fileprivate final class TokenStreamCreator: SyntaxVisitor { } let binOp = node.operatorOperand + if binOp.is(ArrowExprSyntax.self) { + // `ArrowExprSyntax` nodes occur when a function type is written in an expression context; + // for example, `let x = [(Int) throws -> Void]()`. We want to treat those consistently like + // we do other function return clauses and not treat them as regular binary operators, so + // handle that behavior there instead. + return .visitChildren + } + let rhs = node.rightOperand maybeGroupAroundSubexpression(rhs, combiningOperator: binOp) @@ -1986,9 +1987,8 @@ fileprivate final class TokenStreamCreator: SyntaxVisitor { } override func visit(_ node: ArrowExprSyntax) -> SyntaxVisitorContinueKind { - // The break before the `throws` keyword is inserted at the `InfixOperatorExpr` level so that it - // is placed in the correct relative position to the group surrounding the "operator". - after(node.effectSpecifiers?.throwsSpecifier, tokens: .break) + before(node.arrowToken, tokens: .break) + after(node.arrowToken, tokens: .space) return .visitChildren } diff --git a/Tests/SwiftFormatPrettyPrintTests/ArrayDeclTests.swift b/Tests/SwiftFormatPrettyPrintTests/ArrayDeclTests.swift index c0e8d6576..03b092915 100644 --- a/Tests/SwiftFormatPrettyPrintTests/ArrayDeclTests.swift +++ b/Tests/SwiftFormatPrettyPrintTests/ArrayDeclTests.swift @@ -70,17 +70,70 @@ final class ArrayDeclTests: PrettyPrintTestCase { let input = """ let A = [(Int, Double) -> Bool]() + let A = [(Int, Double) async -> Bool]() let A = [(Int, Double) throws -> Bool]() + let A = [(Int, Double) async throws -> Bool]() """ - let expected = + let expected46 = """ let A = [(Int, Double) -> Bool]() + let A = [(Int, Double) async -> Bool]() let A = [(Int, Double) throws -> Bool]() + let A = [(Int, Double) async throws -> Bool]() """ + assertPrettyPrintEqual(input: input, expected: expected46, linelength: 46) - assertPrettyPrintEqual(input: input, expected: expected, linelength: 45) + let expected43 = + """ + let A = [(Int, Double) -> Bool]() + let A = [(Int, Double) async -> Bool]() + let A = [(Int, Double) throws -> Bool]() + let A = [ + (Int, Double) async throws -> Bool + ]() + + """ + assertPrettyPrintEqual(input: input, expected: expected43, linelength: 43) + + let expected35 = + """ + let A = [(Int, Double) -> Bool]() + let A = [ + (Int, Double) async -> Bool + ]() + let A = [ + (Int, Double) throws -> Bool + ]() + let A = [ + (Int, Double) async throws + -> Bool + ]() + + """ + assertPrettyPrintEqual(input: input, expected: expected35, linelength: 35) + + let expected27 = + """ + let A = [ + (Int, Double) -> Bool + ]() + let A = [ + (Int, Double) async + -> Bool + ]() + let A = [ + (Int, Double) throws + -> Bool + ]() + let A = [ + (Int, Double) + async throws -> Bool + ]() + + """ + assertPrettyPrintEqual(input: input, expected: expected27, linelength: 27) } func testNoTrailingCommasInTypes() { diff --git a/Tests/SwiftFormatPrettyPrintTests/FunctionTypeTests.swift b/Tests/SwiftFormatPrettyPrintTests/FunctionTypeTests.swift index d3e78ee4c..8ee6370ee 100644 --- a/Tests/SwiftFormatPrettyPrintTests/FunctionTypeTests.swift +++ b/Tests/SwiftFormatPrettyPrintTests/FunctionTypeTests.swift @@ -60,6 +60,127 @@ final class FunctionTypeTests: PrettyPrintTestCase { assertPrettyPrintEqual(input: input, expected: expected, linelength: 60) } + func testFunctionTypeAsync() { + let input = + """ + func f(g: (_ somevalue: Int) async -> String?) { + let a = 123 + let b = "abc" + } + func f(g: (currentLevel: Int) async -> String?) { + let a = 123 + let b = "abc" + } + func f(g: (currentLevel: inout Int) async -> String?) { + let a = 123 + let b = "abc" + } + func f(g: (variable1: Int, variable2: Double, variable3: Bool) async -> Double) { + let a = 123 + let b = "abc" + } + func f(g: (variable1: Int, variable2: Double, variable3: Bool, variable4: String) async -> Double) { + let a = 123 + let b = "abc" + } + """ + + let expected = + """ + func f(g: (_ somevalue: Int) async -> String?) { + let a = 123 + let b = "abc" + } + func f(g: (currentLevel: Int) async -> String?) { + let a = 123 + let b = "abc" + } + func f(g: (currentLevel: inout Int) async -> String?) { + let a = 123 + let b = "abc" + } + func f( + g: (variable1: Int, variable2: Double, variable3: Bool) async -> + Double + ) { + let a = 123 + let b = "abc" + } + func f( + g: ( + variable1: Int, variable2: Double, variable3: Bool, + variable4: String + ) async -> Double + ) { + let a = 123 + let b = "abc" + } + + """ + + assertPrettyPrintEqual(input: input, expected: expected, linelength: 66) + } + + func testFunctionTypeAsyncThrows() { + let input = + """ + func f(g: (_ somevalue: Int) async throws -> String?) { + let a = 123 + let b = "abc" + } + func f(g: (currentLevel: Int) async throws -> String?) { + let a = 123 + let b = "abc" + } + func f(g: (currentLevel: inout Int) async throws -> String?) { + let a = 123 + let b = "abc" + } + func f(g: (variable1: Int, variable2: Double, variable3: Bool) async throws -> Double) { + let a = 123 + let b = "abc" + } + func f(g: (variable1: Int, variable2: Double, variable3: Bool, variable4: String) async throws -> Double) { + let a = 123 + let b = "abc" + } + """ + + let expected = + """ + func f(g: (_ somevalue: Int) async throws -> String?) { + let a = 123 + let b = "abc" + } + func f(g: (currentLevel: Int) async throws -> String?) { + let a = 123 + let b = "abc" + } + func f(g: (currentLevel: inout Int) async throws -> String?) { + let a = 123 + let b = "abc" + } + func f( + g: (variable1: Int, variable2: Double, variable3: Bool) async throws -> + Double + ) { + let a = 123 + let b = "abc" + } + func f( + g: ( + variable1: Int, variable2: Double, variable3: Bool, variable4: String + ) async throws -> Double + ) { + let a = 123 + let b = "abc" + } + + """ + + assertPrettyPrintEqual(input: input, expected: expected, linelength: 73) + } + func testFunctionTypeThrows() { let input = """ @@ -84,7 +205,7 @@ final class FunctionTypeTests: PrettyPrintTestCase { let b = "abc" } """ - + let expected = """ func f(g: (_ somevalue: Int) throws -> String?) { @@ -117,7 +238,7 @@ final class FunctionTypeTests: PrettyPrintTestCase { } """ - + assertPrettyPrintEqual(input: input, expected: expected, linelength: 67) }