Skip to content

Commit da3ea0c

Browse files
committed
writes are buffered, other review changes
- writes issued whilst the CONNECT is ongoing are now buffered rather than triggering a failure - Error is restructured to do away with `Kind` - failure logic is consolidated in `failWithError`
1 parent 313e224 commit da3ea0c

File tree

2 files changed

+300
-79
lines changed

2 files changed

+300
-79
lines changed

Sources/NIOExtras/HTTP1ProxyConnectHandler.swift

Lines changed: 124 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,15 @@ public final class NIOHTTP1ProxyConnectHandler: ChannelDuplexHandler, RemovableC
2020
public typealias OutboundOut = HTTPClientRequestPart
2121
public typealias InboundIn = HTTPClientResponsePart
2222

23+
/// Whether we've already seen the first request.
24+
private var seenFirstRequest = false
25+
private var bufferedWrittenMessages: MarkedCircularBuffer<BufferedWrite>
26+
27+
struct BufferedWrite {
28+
var data: NIOAny
29+
var promise: EventLoopPromise<Void>?
30+
}
31+
2332
private enum State {
2433
// transitions to `.connectSent` or `.failed`
2534
case initialized
@@ -39,7 +48,7 @@ public final class NIOHTTP1ProxyConnectHandler: ChannelDuplexHandler, RemovableC
3948
private let targetPort: Int
4049
private let headers: HTTPHeaders
4150
private let deadline: NIODeadline
42-
private let promise: EventLoopPromise<Void>
51+
private let promise: EventLoopPromise<Void>?
4352

4453
/// Creates a new ``NIOHTTP1ProxyConnectHandler`` that issues a CONNECT request to a proxy server
4554
/// and instructs the server to connect to `targetHost`.
@@ -59,8 +68,52 @@ public final class NIOHTTP1ProxyConnectHandler: ChannelDuplexHandler, RemovableC
5968
self.headers = headers
6069
self.deadline = deadline
6170
self.promise = promise
71+
72+
self.bufferedWrittenMessages = MarkedCircularBuffer(initialCapacity: 16) // matches CircularBuffer default
73+
}
74+
75+
public func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
76+
switch self.state {
77+
case .initialized, .connectSent, .headReceived, .completed:
78+
self.bufferedWrittenMessages.append(BufferedWrite(data: data, promise: promise))
79+
case .failed(let error):
80+
promise?.fail(error)
81+
}
82+
}
83+
84+
public func flush(context: ChannelHandlerContext) {
85+
self.bufferedWrittenMessages.mark()
86+
}
87+
88+
public func removeHandler(context: ChannelHandlerContext, removalToken: ChannelHandlerContext.RemovalToken) {
89+
// We have been formally removed from the pipeline. We should send any buffered data we have.
90+
switch self.state {
91+
case .initialized, .connectSent, .headReceived, .failed:
92+
self.failWithError(.noResult(), context: context)
93+
94+
case .completed:
95+
let hadMark = self.bufferedWrittenMessages.hasMark
96+
while self.bufferedWrittenMessages.hasMark {
97+
// write until mark
98+
let bufferedPart = self.bufferedWrittenMessages.removeFirst()
99+
context.write(bufferedPart.data, promise: bufferedPart.promise)
100+
}
101+
102+
// flush any messages up to the mark
103+
if hadMark {
104+
context.flush()
105+
}
106+
107+
// write remainder
108+
while let bufferedPart = self.bufferedWrittenMessages.popFirst() {
109+
context.write(bufferedPart.data, promise: bufferedPart.promise)
110+
}
111+
}
112+
113+
context.leavePipeline(removalToken: removalToken)
62114
}
63115

116+
64117
public func handlerAdded(context: ChannelHandlerContext) {
65118
if context.channel.isActive {
66119
self.sendConnect(context: context)
@@ -70,10 +123,11 @@ public final class NIOHTTP1ProxyConnectHandler: ChannelDuplexHandler, RemovableC
70123
public func handlerRemoved(context: ChannelHandlerContext) {
71124
switch self.state {
72125
case .failed, .completed:
126+
// we don't expect there to be any buffered messages in these states
127+
assert(self.bufferedWrittenMessages.isEmpty)
73128
break
74129
case .initialized, .connectSent, .headReceived:
75-
self.state = .failed(Error.noResult())
76-
self.promise.fail(Error.noResult())
130+
self.failWithError(Error.noResult(), context: context)
77131
}
78132
}
79133

@@ -96,10 +150,6 @@ public final class NIOHTTP1ProxyConnectHandler: ChannelDuplexHandler, RemovableC
96150
context.fireChannelInactive()
97151
}
98152

99-
public func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
100-
preconditionFailure("We don't support outgoing traffic during HTTP Proxy update.")
101-
}
102-
103153
public func channelRead(context: ChannelHandlerContext, data: NIOAny) {
104154
switch self.unwrapInboundIn(data) {
105155
case .head(let head):
@@ -187,22 +237,33 @@ public final class NIOHTTP1ProxyConnectHandler: ChannelDuplexHandler, RemovableC
187237
case .headReceived(let timeout):
188238
timeout.cancel()
189239
self.state = .completed
190-
self.promise.succeed(())
191-
192240
case .failed:
193241
// ran into an error before... ignore this one
194-
break
242+
return
195243
case .initialized, .connectSent, .completed:
196244
preconditionFailure("Invalid state: \(self.state)")
197245
}
246+
247+
// Ok, we've set up the proxy connection. We can now remove ourselves, which should happen synchronously.
248+
context.pipeline.removeHandler(context: context, promise: nil)
249+
250+
self.promise?.succeed(())
198251
}
199252

200253
private func failWithError(_ error: Error, context: ChannelHandlerContext, closeConnection: Bool = true) {
201-
self.state = .failed(error)
202-
self.promise.fail(error)
203-
context.fireErrorCaught(error)
204-
if closeConnection {
205-
context.close(mode: .all, promise: nil)
254+
switch self.state {
255+
case .failed:
256+
return
257+
case .initialized, .connectSent, .headReceived, .completed:
258+
self.state = .failed(error)
259+
self.promise?.fail(error)
260+
context.fireErrorCaught(error)
261+
if closeConnection {
262+
context.close(mode: .all, promise: nil)
263+
}
264+
for bufferedWrite in self.bufferedWrittenMessages {
265+
bufferedWrite.promise?.fail(error)
266+
}
206267
}
207268
}
208269

@@ -217,15 +278,6 @@ public final class NIOHTTP1ProxyConnectHandler: ChannelDuplexHandler, RemovableC
217278
case noResult
218279
}
219280

220-
fileprivate enum Kind: String, Equatable, Hashable {
221-
case proxyAuthenticationRequired
222-
case invalidProxyResponseHead
223-
case invalidProxyResponse
224-
case remoteConnectionClosed
225-
case httpProxyHandshakeTimeout
226-
case noResult
227-
}
228-
229281
final class Storage: Sendable {
230282
fileprivate let details: Details
231283
public let file: String
@@ -273,54 +325,75 @@ public final class NIOHTTP1ProxyConnectHandler: ChannelDuplexHandler, RemovableC
273325
public static func noResult(file: String = #file, line: UInt = #line) -> Error {
274326
Error(error: .noResult, file: file, line: line)
275327
}
328+
329+
fileprivate var errorCode: Int {
330+
switch self.store.details {
331+
case .proxyAuthenticationRequired:
332+
return 0
333+
case .invalidProxyResponseHead:
334+
return 1
335+
case .invalidProxyResponse:
336+
return 2
337+
case .remoteConnectionClosed:
338+
return 3
339+
case .httpProxyHandshakeTimeout:
340+
return 4
341+
case .noResult:
342+
return 5
343+
}
344+
}
276345
}
277346

278347
}
279348

280349
extension NIOHTTP1ProxyConnectHandler.Error: Hashable {
281-
public static func == (lhs: NIOHTTP1ProxyConnectHandler.Error, rhs: NIOHTTP1ProxyConnectHandler.Error) -> Bool {
282-
// ignore *where* the error was thrown
283-
lhs.store.details == rhs.store.details
350+
// compare only the kind of error, not the associated response head
351+
public static func == (lhs: Self, rhs: Self) -> Bool {
352+
switch (lhs.store.details, rhs.store.details) {
353+
case (.proxyAuthenticationRequired, .proxyAuthenticationRequired):
354+
return true
355+
case (.invalidProxyResponseHead, .invalidProxyResponseHead):
356+
return true
357+
case (.invalidProxyResponse, .invalidProxyResponse):
358+
return true
359+
case (.remoteConnectionClosed, .remoteConnectionClosed):
360+
return true
361+
case (.httpProxyHandshakeTimeout, .httpProxyHandshakeTimeout):
362+
return true
363+
case (.noResult, .noResult):
364+
return true
365+
default:
366+
return false
367+
}
284368
}
285369

286370
public func hash(into hasher: inout Hasher) {
287-
hasher.combine(self.store.details)
371+
hasher.combine(self.errorCode)
288372
}
289373
}
290374

291-
extension NIOHTTP1ProxyConnectHandler.Error.Details: Hashable {
292-
// compare only the kind of error, not the associated response head
293-
@inlinable
294-
static func == (lhs: Self, rhs: Self) -> Bool {
295-
NIOHTTP1ProxyConnectHandler.Error.Kind(lhs) == NIOHTTP1ProxyConnectHandler.Error.Kind(rhs)
296-
}
297375

298-
@inlinable
299-
public func hash(into hasher: inout Hasher) {
300-
hasher.combine(NIOHTTP1ProxyConnectHandler.Error.Kind(self))
376+
extension NIOHTTP1ProxyConnectHandler.Error: CustomStringConvertible {
377+
public var description: String {
378+
self.store.details.description
301379
}
302380
}
303381

304-
extension NIOHTTP1ProxyConnectHandler.Error.Kind {
305-
init(_ details: NIOHTTP1ProxyConnectHandler.Error.Details) {
306-
switch details {
382+
extension NIOHTTP1ProxyConnectHandler.Error.Details: CustomStringConvertible {
383+
public var description: String {
384+
switch self {
307385
case .proxyAuthenticationRequired:
308-
self = .proxyAuthenticationRequired
386+
return "Proxy Authentication Required"
309387
case .invalidProxyResponseHead:
310-
self = .invalidProxyResponseHead
388+
return "Invalid Proxy Response Head"
311389
case .invalidProxyResponse:
312-
self = .invalidProxyResponse
390+
return "Invalid Proxy Response"
313391
case .remoteConnectionClosed:
314-
self = .remoteConnectionClosed
392+
return "Remote Connection Closed"
315393
case .httpProxyHandshakeTimeout:
316-
self = .httpProxyHandshakeTimeout
394+
return "HTTP Proxy Handshake Timeout"
317395
case .noResult:
318-
self = .noResult
396+
return "No Result"
319397
}
320398
}
321399
}
322-
323-
324-
extension NIOHTTP1ProxyConnectHandler.Error: CustomStringConvertible {
325-
public var description: String { return NIOHTTP1ProxyConnectHandler.Error.Kind(store.details).rawValue }
326-
}

0 commit comments

Comments
 (0)