@@ -20,6 +20,15 @@ public final class NIOHTTP1ProxyConnectHandler: ChannelDuplexHandler, RemovableC
20
20
public typealias OutboundOut = HTTPClientRequestPart
21
21
public typealias InboundIn = HTTPClientResponsePart
22
22
23
+ /// Whether we've already seen the first request.
24
+ private var seenFirstRequest = false
25
+ private var bufferedWrittenMessages : MarkedCircularBuffer < BufferedWrite >
26
+
27
+ struct BufferedWrite {
28
+ var data : NIOAny
29
+ var promise : EventLoopPromise < Void > ?
30
+ }
31
+
23
32
private enum State {
24
33
// transitions to `.connectSent` or `.failed`
25
34
case initialized
@@ -39,7 +48,7 @@ public final class NIOHTTP1ProxyConnectHandler: ChannelDuplexHandler, RemovableC
39
48
private let targetPort : Int
40
49
private let headers : HTTPHeaders
41
50
private let deadline : NIODeadline
42
- private let promise : EventLoopPromise < Void >
51
+ private let promise : EventLoopPromise < Void > ?
43
52
44
53
/// Creates a new ``NIOHTTP1ProxyConnectHandler`` that issues a CONNECT request to a proxy server
45
54
/// and instructs the server to connect to `targetHost`.
@@ -59,8 +68,52 @@ public final class NIOHTTP1ProxyConnectHandler: ChannelDuplexHandler, RemovableC
59
68
self . headers = headers
60
69
self . deadline = deadline
61
70
self . promise = promise
71
+
72
+ self . bufferedWrittenMessages = MarkedCircularBuffer ( initialCapacity: 16 ) // matches CircularBuffer default
73
+ }
74
+
75
+ public func write( context: ChannelHandlerContext , data: NIOAny , promise: EventLoopPromise < Void > ? ) {
76
+ switch self . state {
77
+ case . initialized, . connectSent, . headReceived, . completed:
78
+ self . bufferedWrittenMessages. append ( BufferedWrite ( data: data, promise: promise) )
79
+ case . failed( let error) :
80
+ promise? . fail ( error)
81
+ }
82
+ }
83
+
84
+ public func flush( context: ChannelHandlerContext ) {
85
+ self . bufferedWrittenMessages. mark ( )
86
+ }
87
+
88
+ public func removeHandler( context: ChannelHandlerContext , removalToken: ChannelHandlerContext . RemovalToken ) {
89
+ // We have been formally removed from the pipeline. We should send any buffered data we have.
90
+ switch self . state {
91
+ case . initialized, . connectSent, . headReceived, . failed:
92
+ self . failWithError ( . noResult( ) , context: context)
93
+
94
+ case . completed:
95
+ let hadMark = self . bufferedWrittenMessages. hasMark
96
+ while self . bufferedWrittenMessages. hasMark {
97
+ // write until mark
98
+ let bufferedPart = self . bufferedWrittenMessages. removeFirst ( )
99
+ context. write ( bufferedPart. data, promise: bufferedPart. promise)
100
+ }
101
+
102
+ // flush any messages up to the mark
103
+ if hadMark {
104
+ context. flush ( )
105
+ }
106
+
107
+ // write remainder
108
+ while let bufferedPart = self . bufferedWrittenMessages. popFirst ( ) {
109
+ context. write ( bufferedPart. data, promise: bufferedPart. promise)
110
+ }
111
+ }
112
+
113
+ context. leavePipeline ( removalToken: removalToken)
62
114
}
63
115
116
+
64
117
public func handlerAdded( context: ChannelHandlerContext ) {
65
118
if context. channel. isActive {
66
119
self . sendConnect ( context: context)
@@ -70,10 +123,11 @@ public final class NIOHTTP1ProxyConnectHandler: ChannelDuplexHandler, RemovableC
70
123
public func handlerRemoved( context: ChannelHandlerContext ) {
71
124
switch self . state {
72
125
case . failed, . completed:
126
+ // we don't expect there to be any buffered messages in these states
127
+ assert ( self . bufferedWrittenMessages. isEmpty)
73
128
break
74
129
case . initialized, . connectSent, . headReceived:
75
- self . state = . failed( Error . noResult ( ) )
76
- self . promise. fail ( Error . noResult ( ) )
130
+ self . failWithError ( Error . noResult ( ) , context: context)
77
131
}
78
132
}
79
133
@@ -96,10 +150,6 @@ public final class NIOHTTP1ProxyConnectHandler: ChannelDuplexHandler, RemovableC
96
150
context. fireChannelInactive ( )
97
151
}
98
152
99
- public func write( context: ChannelHandlerContext , data: NIOAny , promise: EventLoopPromise < Void > ? ) {
100
- preconditionFailure ( " We don't support outgoing traffic during HTTP Proxy update. " )
101
- }
102
-
103
153
public func channelRead( context: ChannelHandlerContext , data: NIOAny ) {
104
154
switch self . unwrapInboundIn ( data) {
105
155
case . head( let head) :
@@ -187,22 +237,33 @@ public final class NIOHTTP1ProxyConnectHandler: ChannelDuplexHandler, RemovableC
187
237
case . headReceived( let timeout) :
188
238
timeout. cancel ( )
189
239
self . state = . completed
190
- self . promise. succeed ( ( ) )
191
-
192
240
case . failed:
193
241
// ran into an error before... ignore this one
194
- break
242
+ return
195
243
case . initialized, . connectSent, . completed:
196
244
preconditionFailure ( " Invalid state: \( self . state) " )
197
245
}
246
+
247
+ // Ok, we've set up the proxy connection. We can now remove ourselves, which should happen synchronously.
248
+ context. pipeline. removeHandler ( context: context, promise: nil )
249
+
250
+ self . promise? . succeed ( ( ) )
198
251
}
199
252
200
253
private func failWithError( _ error: Error , context: ChannelHandlerContext , closeConnection: Bool = true ) {
201
- self . state = . failed( error)
202
- self . promise. fail ( error)
203
- context. fireErrorCaught ( error)
204
- if closeConnection {
205
- context. close ( mode: . all, promise: nil )
254
+ switch self . state {
255
+ case . failed:
256
+ return
257
+ case . initialized, . connectSent, . headReceived, . completed:
258
+ self . state = . failed( error)
259
+ self . promise? . fail ( error)
260
+ context. fireErrorCaught ( error)
261
+ if closeConnection {
262
+ context. close ( mode: . all, promise: nil )
263
+ }
264
+ for bufferedWrite in self . bufferedWrittenMessages {
265
+ bufferedWrite. promise? . fail ( error)
266
+ }
206
267
}
207
268
}
208
269
@@ -217,15 +278,6 @@ public final class NIOHTTP1ProxyConnectHandler: ChannelDuplexHandler, RemovableC
217
278
case noResult
218
279
}
219
280
220
- fileprivate enum Kind : String , Equatable , Hashable {
221
- case proxyAuthenticationRequired
222
- case invalidProxyResponseHead
223
- case invalidProxyResponse
224
- case remoteConnectionClosed
225
- case httpProxyHandshakeTimeout
226
- case noResult
227
- }
228
-
229
281
final class Storage : Sendable {
230
282
fileprivate let details : Details
231
283
public let file : String
@@ -273,54 +325,75 @@ public final class NIOHTTP1ProxyConnectHandler: ChannelDuplexHandler, RemovableC
273
325
public static func noResult( file: String = #file, line: UInt = #line) -> Error {
274
326
Error ( error: . noResult, file: file, line: line)
275
327
}
328
+
329
+ fileprivate var errorCode : Int {
330
+ switch self . store. details {
331
+ case . proxyAuthenticationRequired:
332
+ return 0
333
+ case . invalidProxyResponseHead:
334
+ return 1
335
+ case . invalidProxyResponse:
336
+ return 2
337
+ case . remoteConnectionClosed:
338
+ return 3
339
+ case . httpProxyHandshakeTimeout:
340
+ return 4
341
+ case . noResult:
342
+ return 5
343
+ }
344
+ }
276
345
}
277
346
278
347
}
279
348
280
349
extension NIOHTTP1ProxyConnectHandler . Error : Hashable {
281
- public static func == ( lhs: NIOHTTP1ProxyConnectHandler . Error , rhs: NIOHTTP1ProxyConnectHandler . Error ) -> Bool {
282
- // ignore *where* the error was thrown
283
- lhs. store. details == rhs. store. details
350
+ // compare only the kind of error, not the associated response head
351
+ public static func == ( lhs: Self , rhs: Self ) -> Bool {
352
+ switch ( lhs. store. details, rhs. store. details) {
353
+ case ( . proxyAuthenticationRequired, . proxyAuthenticationRequired) :
354
+ return true
355
+ case ( . invalidProxyResponseHead, . invalidProxyResponseHead) :
356
+ return true
357
+ case ( . invalidProxyResponse, . invalidProxyResponse) :
358
+ return true
359
+ case ( . remoteConnectionClosed, . remoteConnectionClosed) :
360
+ return true
361
+ case ( . httpProxyHandshakeTimeout, . httpProxyHandshakeTimeout) :
362
+ return true
363
+ case ( . noResult, . noResult) :
364
+ return true
365
+ default :
366
+ return false
367
+ }
284
368
}
285
369
286
370
public func hash( into hasher: inout Hasher ) {
287
- hasher. combine ( self . store . details )
371
+ hasher. combine ( self . errorCode )
288
372
}
289
373
}
290
374
291
- extension NIOHTTP1ProxyConnectHandler . Error . Details : Hashable {
292
- // compare only the kind of error, not the associated response head
293
- @inlinable
294
- static func == ( lhs: Self , rhs: Self ) -> Bool {
295
- NIOHTTP1ProxyConnectHandler . Error. Kind ( lhs) == NIOHTTP1ProxyConnectHandler . Error. Kind ( rhs)
296
- }
297
375
298
- @ inlinable
299
- public func hash ( into hasher : inout Hasher ) {
300
- hasher . combine ( NIOHTTP1ProxyConnectHandler . Error . Kind ( self ) )
376
+ extension NIOHTTP1ProxyConnectHandler . Error : CustomStringConvertible {
377
+ public var description : String {
378
+ self . store . details . description
301
379
}
302
380
}
303
381
304
- extension NIOHTTP1ProxyConnectHandler . Error . Kind {
305
- init ( _ details : NIOHTTP1ProxyConnectHandler . Error . Details ) {
306
- switch details {
382
+ extension NIOHTTP1ProxyConnectHandler . Error . Details : CustomStringConvertible {
383
+ public var description : String {
384
+ switch self {
307
385
case . proxyAuthenticationRequired:
308
- self = . proxyAuthenticationRequired
386
+ return " Proxy Authentication Required "
309
387
case . invalidProxyResponseHead:
310
- self = . invalidProxyResponseHead
388
+ return " Invalid Proxy Response Head "
311
389
case . invalidProxyResponse:
312
- self = . invalidProxyResponse
390
+ return " Invalid Proxy Response "
313
391
case . remoteConnectionClosed:
314
- self = . remoteConnectionClosed
392
+ return " Remote Connection Closed "
315
393
case . httpProxyHandshakeTimeout:
316
- self = . httpProxyHandshakeTimeout
394
+ return " HTTP Proxy Handshake Timeout "
317
395
case . noResult:
318
- self = . noResult
396
+ return " No Result "
319
397
}
320
398
}
321
399
}
322
-
323
-
324
- extension NIOHTTP1ProxyConnectHandler . Error : CustomStringConvertible {
325
- public var description : String { return NIOHTTP1ProxyConnectHandler . Error. Kind ( store. details) . rawValue }
326
- }
0 commit comments