Skip to content

Commit df95fe3

Browse files
authored
Fix RequestBody.withInboundCloseHandler for HTTP2 streams (#699)
* RequestBody.iterate(iterator:source:) shouldnt throw errors * Pass on iterator.next() errors * Add new RequestBodyTests * Fix 5.10
1 parent 5629032 commit df95fe3

File tree

4 files changed

+193
-19
lines changed

4 files changed

+193
-19
lines changed

Sources/HummingbirdCore/Request/RequestBody+inboundClose.swift

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -119,11 +119,9 @@ extension RequestBody {
119119
let unsafeIterator = UnsafeTransfer(iterator)
120120
let value = try await withThrowingTaskGroup(of: Void.self) { group in
121121
group.addTask {
122-
do {
123-
if try await self.iterate(iterator: unsafeIterator.wrappedValue, source: source) == .inboundClosed {
124-
onInboundClosed()
125-
}
126-
} catch is CancellationError {}
122+
if await self.iterate(iterator: unsafeIterator.wrappedValue, source: source) == .inboundClosed {
123+
onInboundClosed()
124+
}
127125
}
128126
let value = try await operation()
129127
group.cancelAll()
@@ -140,16 +138,30 @@ extension RequestBody {
140138
fileprivate func iterate<AsyncIterator: AsyncIteratorProtocol>(
141139
iterator: AsyncIterator,
142140
source: RequestBody.Source
143-
) async throws -> IterateResult where AsyncIterator.Element == HTTPRequestPart {
141+
) async -> IterateResult where AsyncIterator.Element == HTTPRequestPart {
144142
var iterator = iterator
145-
while let part = try await iterator.next() {
146-
switch part {
147-
case .head:
148-
return .nextRequestReady
149-
case .body(let buffer):
150-
try await source.yield(buffer)
151-
case .end:
152-
source.finish()
143+
var finished = false
144+
while true {
145+
do {
146+
guard let part = try await iterator.next() else { break }
147+
switch part {
148+
case .head:
149+
return .nextRequestReady
150+
case .body(let buffer):
151+
await source.yield(buffer)
152+
case .end:
153+
finished = true
154+
source.finish()
155+
}
156+
} catch {
157+
// if we are not finished receiving the request body pass error onto source
158+
if !finished {
159+
source.finish(error)
160+
}
161+
// we received an error on the inbound stream it is in effect closed. This
162+
// is of particular importance for HTTP2 streams where stream closure invokes
163+
// an error on the inbound stream of HTTP parts instead of just finishing it.
164+
return .inboundClosed
153165
}
154166
}
155167
return .inboundClosed

Sources/HummingbirdCore/Request/RequestBody.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ extension RequestBody {
267267
///
268268
/// - Parameter element: The element to yield to the inbound stream.
269269
@inlinable
270-
public func yield(_ element: ByteBuffer) async throws {
270+
public func yield(_ element: ByteBuffer) async {
271271
// if previous call indicated we should stop producing wait until the delegate
272272
// says we can start producing again
273273
await self.delegate.waitForProduceMore()
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// This source file is part of the Hummingbird server framework project
4+
//
5+
// Copyright (c) 2023 the Hummingbird authors
6+
// Licensed under Apache License v2.0
7+
//
8+
// See LICENSE.txt for license information
9+
// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors
10+
//
11+
// SPDX-License-Identifier: Apache-2.0
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
import HummingbirdCore
16+
import NIOCore
17+
import NIOHTTPTypes
18+
import XCTest
19+
20+
final class RequestBodyTests: XCTestCase {
21+
func testSingleRequestBody() async throws {
22+
try await withThrowingTaskGroup(of: Void.self) { group in
23+
let (httpSource, httpStream) = NIOAsyncChannelInboundStream<HTTPRequestPart>.makeTestingStream()
24+
let httpSourceIterator = httpSource.makeAsyncIterator()
25+
let requestBody = RequestBody(nioAsyncChannelInbound: .init(iterator: httpSourceIterator))
26+
group.addTask {
27+
httpStream.yield(.body(ByteBuffer(string: "hello ")))
28+
httpStream.yield(.body(ByteBuffer(string: "world")))
29+
httpStream.yield(.end(nil))
30+
httpStream.finish()
31+
}
32+
group.addTask {
33+
let buffer = try await requestBody.collect(upTo: .max)
34+
XCTAssertEqual(String(buffer: buffer), "hello world")
35+
}
36+
try await group.waitForAll()
37+
}
38+
}
39+
40+
func testMultipleRequestBodies() async throws {
41+
try await withThrowingTaskGroup(of: Void.self) { group in
42+
let (httpSource, httpStream) = NIOAsyncChannelInboundStream<HTTPRequestPart>.makeTestingStream()
43+
let httpSourceIterator = httpSource.makeAsyncIterator()
44+
let requestBody = RequestBody(nioAsyncChannelInbound: .init(iterator: httpSourceIterator))
45+
group.addTask {
46+
httpStream.yield(.body(ByteBuffer(string: "hello ")))
47+
httpStream.yield(.body(ByteBuffer(string: "world")))
48+
httpStream.yield(.end(nil))
49+
httpStream.yield(.head(.init(method: .get, scheme: nil, authority: nil, path: "/test")))
50+
httpStream.yield(.end(nil))
51+
httpStream.finish()
52+
}
53+
group.addTask {
54+
let buffer = try await requestBody.collect(upTo: .max)
55+
XCTAssertEqual(String(buffer: buffer), "hello world")
56+
}
57+
try await group.waitForAll()
58+
}
59+
}
60+
61+
#if compiler(>=6.0)
62+
func testInboundClosureParsingStream() async throws {
63+
try await withThrowingTaskGroup(of: Void.self) { group in
64+
let (httpSource, httpStream) = NIOAsyncChannelInboundStream<HTTPRequestPart>.makeTestingStream()
65+
let httpSourceIterator = httpSource.makeAsyncIterator()
66+
let requestBody = RequestBody(nioAsyncChannelInbound: .init(iterator: httpSourceIterator))
67+
let (stream, cont) = AsyncStream.makeStream(of: Void.self)
68+
group.addTask {
69+
httpStream.yield(.body(ByteBuffer(string: "hello ")))
70+
httpStream.yield(.body(ByteBuffer(string: "world")))
71+
httpStream.yield(.end(nil))
72+
httpStream.finish()
73+
}
74+
group.addTask {
75+
try await requestBody.consumeWithInboundCloseHandler { requestBody in
76+
let buffer = try await requestBody.collect(upTo: .max)
77+
XCTAssertEqual(String(buffer: buffer), "hello world")
78+
await stream.first { _ in true }
79+
} onInboundClosed: {
80+
cont.yield()
81+
}
82+
}
83+
try await group.waitForAll()
84+
}
85+
}
86+
87+
func testInboundClosureWithoutParsingStream() async throws {
88+
try await withThrowingTaskGroup(of: Void.self) { group in
89+
let (httpSource, httpStream) = NIOAsyncChannelInboundStream<HTTPRequestPart>.makeTestingStream()
90+
let httpSourceIterator = httpSource.makeAsyncIterator()
91+
let requestBody = RequestBody(nioAsyncChannelInbound: .init(iterator: httpSourceIterator))
92+
let (stream, cont) = AsyncStream.makeStream(of: Void.self)
93+
group.addTask {
94+
httpStream.yield(.body(ByteBuffer(string: "hello ")))
95+
httpStream.yield(.body(ByteBuffer(string: "world")))
96+
httpStream.yield(.end(nil))
97+
httpStream.finish()
98+
}
99+
group.addTask {
100+
try await requestBody.consumeWithInboundCloseHandler { requestBody in
101+
await stream.first { _ in true }
102+
} onInboundClosed: {
103+
cont.yield()
104+
}
105+
}
106+
try await group.waitForAll()
107+
}
108+
}
109+
110+
func testInboundClosureWithStreamError() async throws {
111+
struct TestError: Error {}
112+
try await withThrowingTaskGroup(of: Void.self) { group in
113+
let (httpSource, httpStream) = NIOAsyncChannelInboundStream<HTTPRequestPart>.makeTestingStream()
114+
let httpSourceIterator = httpSource.makeAsyncIterator()
115+
let requestBody = RequestBody(nioAsyncChannelInbound: .init(iterator: httpSourceIterator))
116+
let (stream, cont) = AsyncStream.makeStream(of: Void.self)
117+
group.addTask {
118+
httpStream.yield(.body(ByteBuffer(string: "hello ")))
119+
httpStream.yield(.end(nil))
120+
httpStream.finish(throwing: TestError())
121+
}
122+
group.addTask {
123+
try await requestBody.consumeWithInboundCloseHandler { requestBody in
124+
await stream.first { _ in true }
125+
} onInboundClosed: {
126+
cont.yield()
127+
}
128+
}
129+
try await group.waitForAll()
130+
}
131+
}
132+
133+
func testInboundClosureWithStreamErrorIsPassedOn() async throws {
134+
struct TestError: Error {}
135+
try await withThrowingTaskGroup(of: Void.self) { group in
136+
let (httpSource, httpStream) = NIOAsyncChannelInboundStream<HTTPRequestPart>.makeTestingStream()
137+
let httpSourceIterator = httpSource.makeAsyncIterator()
138+
let requestBody = RequestBody(nioAsyncChannelInbound: .init(iterator: httpSourceIterator))
139+
let (stream, cont) = AsyncStream.makeStream(of: Void.self)
140+
group.addTask {
141+
httpStream.yield(.body(ByteBuffer(string: "hello ")))
142+
httpStream.yield(.body(ByteBuffer(string: "world")))
143+
httpStream.finish(throwing: TestError())
144+
}
145+
group.addTask {
146+
try await requestBody.consumeWithInboundCloseHandler { requestBody in
147+
do {
148+
_ = try await requestBody.collect(upTo: .max)
149+
XCTFail("Should not get here")
150+
} catch is TestError {
151+
//
152+
}
153+
await stream.first { _ in true }
154+
} onInboundClosed: {
155+
cont.yield()
156+
}
157+
}
158+
try await group.waitForAll()
159+
}
160+
}
161+
#endif // compiler(>=6.0)
162+
}

Tests/HummingbirdTests/ApplicationTests.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -844,7 +844,7 @@ final class ApplicationTests: XCTestCase {
844844
let (requestBody, source) = RequestBody.makeStream()
845845
group.addTask {
846846
for try await buffer in request.body {
847-
try await source.yield(buffer)
847+
await source.yield(buffer)
848848
}
849849
source.finish()
850850
}
@@ -880,17 +880,17 @@ final class ApplicationTests: XCTestCase {
880880
await withThrowingTaskGroup(of: Void.self) { group in
881881
group.addTask {
882882
for value in 0..<100 {
883-
try await source.yield(ByteBuffer(string: String(describing: value)))
883+
await source.yield(ByteBuffer(string: String(describing: value)))
884884
}
885885
}
886886
group.addTask {
887887
for value in 0..<100 {
888-
try await source.yield(ByteBuffer(string: String(describing: value)))
888+
await source.yield(ByteBuffer(string: String(describing: value)))
889889
}
890890
}
891891
group.addTask {
892892
for value in 0..<100 {
893-
try await source.yield(ByteBuffer(string: String(describing: value)))
893+
await source.yield(ByteBuffer(string: String(describing: value)))
894894
}
895895
}
896896
}

0 commit comments

Comments
 (0)