@@ -3,23 +3,28 @@ package gost
3
3
import (
4
4
"bufio"
5
5
"bytes"
6
+ "context"
6
7
"crypto/tls"
7
8
"fmt"
8
9
"io"
10
+ "io/ioutil"
9
11
"net"
12
+ "net/http"
13
+ "net/url"
10
14
"strings"
11
15
"sync"
12
16
"time"
13
17
14
18
"github.com/go-log/log"
15
19
"github.com/miekg/dns"
20
+ "golang.org/x/net/http2"
16
21
)
17
22
18
23
var (
19
24
// DefaultResolverTimeout is the default timeout for name resolution.
20
- DefaultResolverTimeout = 30 * time .Second
25
+ DefaultResolverTimeout = 5 * time .Second
21
26
// DefaultResolverTTL is the default cache TTL for name resolution.
22
- DefaultResolverTTL = 60 * time .Second
27
+ DefaultResolverTTL = 1 * time .Hour
23
28
)
24
29
25
30
// Resolver is a name resolver for domain name.
@@ -39,9 +44,73 @@ type ReloadResolver interface {
39
44
// NameServer is a name server.
40
45
// Currently supported protocol: TCP, UDP and TLS.
41
46
type NameServer struct {
42
- Addr string
43
- Protocol string
44
- Hostname string // for TLS handshake verification
47
+ Addr string
48
+ Protocol string
49
+ Hostname string // for TLS handshake verification
50
+ Timeout time.Duration
51
+ exchanger Exchanger
52
+ }
53
+
54
+ // Init initializes the name server.
55
+ func (ns * NameServer ) Init () error {
56
+ switch strings .ToLower (ns .Protocol ) {
57
+ case "tcp" :
58
+ ns .exchanger = & dnsExchanger {
59
+ endpoint : ns .Addr ,
60
+ client : & dns.Client {
61
+ Net : "tcp" ,
62
+ Timeout : ns .Timeout ,
63
+ },
64
+ }
65
+ case "tls" :
66
+ cfg := & tls.Config {
67
+ ServerName : ns .Hostname ,
68
+ }
69
+ if cfg .ServerName == "" {
70
+ cfg .InsecureSkipVerify = true
71
+ }
72
+
73
+ ns .exchanger = & dnsExchanger {
74
+ endpoint : ns .Addr ,
75
+ client : & dns.Client {
76
+ Net : "tcp-tls" ,
77
+ Timeout : ns .Timeout ,
78
+ TLSConfig : cfg ,
79
+ },
80
+ }
81
+ case "https" :
82
+ u , err := url .Parse (ns .Addr )
83
+ if err != nil {
84
+ return err
85
+ }
86
+ cfg := & tls.Config {ServerName : u .Hostname ()}
87
+ transport := & http.Transport {
88
+ TLSClientConfig : cfg ,
89
+ DisableCompression : true ,
90
+ MaxIdleConns : 1 ,
91
+ }
92
+ http2 .ConfigureTransport (transport )
93
+
94
+ ns .exchanger = & dohExchanger {
95
+ endpoint : u ,
96
+ client : & http.Client {
97
+ Transport : transport ,
98
+ Timeout : ns .Timeout ,
99
+ },
100
+ }
101
+ case "udp" :
102
+ fallthrough
103
+ default :
104
+ ns .exchanger = & dnsExchanger {
105
+ endpoint : ns .Addr ,
106
+ client : & dns.Client {
107
+ Net : "udp" ,
108
+ Timeout : ns .Timeout ,
109
+ },
110
+ }
111
+ }
112
+
113
+ return nil
45
114
}
46
115
47
116
func (ns NameServer ) String () string {
@@ -62,26 +131,19 @@ type resolverCacheItem struct {
62
131
}
63
132
64
133
type resolver struct {
65
- Resolver * net.Resolver
66
- Servers []NameServer
67
- mCache * sync.Map
68
- Timeout time.Duration
69
- TTL time.Duration
70
- period time.Duration
71
- domain string
72
- stopped chan struct {}
73
- mux sync.RWMutex
134
+ Servers []NameServer
135
+ mCache * sync.Map
136
+ Timeout time.Duration
137
+ TTL time.Duration
138
+ period time.Duration
139
+ domain string
140
+ stopped chan struct {}
141
+ mux sync.RWMutex
74
142
}
75
143
76
144
// NewResolver create a new Resolver with the given name servers and resolution timeout.
77
145
func NewResolver (timeout , ttl time.Duration , servers ... NameServer ) ReloadResolver {
78
- r := & resolver {
79
- Servers : servers ,
80
- Timeout : timeout ,
81
- TTL : ttl ,
82
- mCache : & sync.Map {},
83
- stopped : make (chan struct {}),
84
- }
146
+ r := newResolver (timeout , ttl , servers ... )
85
147
86
148
if r .Timeout <= 0 {
87
149
r .Timeout = DefaultResolverTimeout
@@ -92,6 +154,16 @@ func NewResolver(timeout, ttl time.Duration, servers ...NameServer) ReloadResolv
92
154
return r
93
155
}
94
156
157
+ func newResolver (timeout , ttl time.Duration , servers ... NameServer ) * resolver {
158
+ return & resolver {
159
+ Servers : servers ,
160
+ Timeout : timeout ,
161
+ TTL : ttl ,
162
+ mCache : & sync.Map {},
163
+ stopped : make (chan struct {}),
164
+ }
165
+ }
166
+
95
167
func (r * resolver ) copyServers () []NameServer {
96
168
var servers []NameServer
97
169
for i := range r .Servers {
@@ -107,12 +179,11 @@ func (r *resolver) Resolve(host string) (ips []net.IP, err error) {
107
179
}
108
180
109
181
var domain string
110
- var timeout , ttl time.Duration
182
+ var ttl time.Duration
111
183
var servers []NameServer
112
184
113
185
r .mux .RLock ()
114
186
domain = r .domain
115
- timeout = r .Timeout
116
187
ttl = r .TTL
117
188
servers = r .copyServers ()
118
189
r .mux .RUnlock ()
@@ -133,7 +204,7 @@ func (r *resolver) Resolve(host string) (ips []net.IP, err error) {
133
204
}
134
205
135
206
for _ , ns := range servers {
136
- ips , err = r .resolve (ns , host , timeout )
207
+ ips , err = r .resolve (ns . exchanger , host )
137
208
if err != nil {
138
209
log .Logf ("[resolver] %s via %s : %s" , host , ns , err )
139
210
continue
@@ -151,36 +222,14 @@ func (r *resolver) Resolve(host string) (ips []net.IP, err error) {
151
222
return
152
223
}
153
224
154
- func (* resolver ) resolve (ns NameServer , host string , timeout time.Duration ) (ips []net.IP , err error ) {
155
- addr := ns .Addr
156
- if _ , port , _ := net .SplitHostPort (addr ); port == "" {
157
- addr = net .JoinHostPort (addr , "53" )
158
- }
159
-
160
- client := dns.Client {
161
- Timeout : timeout ,
162
- }
163
- switch strings .ToLower (ns .Protocol ) {
164
- case "tcp" :
165
- client .Net = "tcp"
166
- case "tls" :
167
- cfg := & tls.Config {
168
- ServerName : ns .Hostname ,
169
- }
170
- if cfg .ServerName == "" {
171
- cfg .InsecureSkipVerify = true
172
- }
173
- client .Net = "tcp-tls"
174
- client .TLSConfig = cfg
175
- case "udp" :
176
- fallthrough
177
- default :
178
- client .Net = "udp"
225
+ func (* resolver ) resolve (ex Exchanger , host string ) (ips []net.IP , err error ) {
226
+ if ex == nil {
227
+ return
179
228
}
180
229
181
- m := dns.Msg {}
182
- m .SetQuestion (dns .Fqdn (host ), dns .TypeA )
183
- mr , _ , err := client .Exchange (& m , addr )
230
+ query := dns.Msg {}
231
+ query .SetQuestion (dns .Fqdn (host ), dns .TypeA )
232
+ mr , err := ex .Exchange (context . Background (), & query )
184
233
if err != nil {
185
234
return
186
235
}
@@ -223,7 +272,7 @@ func (r *resolver) Reload(rd io.Reader) error {
223
272
var domain string
224
273
var nss []NameServer
225
274
226
- if r .Stopped () {
275
+ if rd == nil || r .Stopped () {
227
276
return nil
228
277
}
229
278
@@ -293,7 +342,15 @@ func (r *resolver) Reload(rd io.Reader) error {
293
342
ns .Protocol = ss [1 ]
294
343
ns .Hostname = ss [2 ]
295
344
}
296
- nss = append (nss , ns )
345
+
346
+ ns .Timeout = timeout
347
+ if timeout <= 0 {
348
+ ns .Timeout = DefaultResolverTimeout
349
+ }
350
+
351
+ if err := ns .Init (); err == nil {
352
+ nss = append (nss , ns )
353
+ }
297
354
}
298
355
}
299
356
@@ -359,3 +416,80 @@ func (r *resolver) String() string {
359
416
}
360
417
return b .String ()
361
418
}
419
+
420
+ // Exchanger is an interface for DNS synchronous query.
421
+ type Exchanger interface {
422
+ Exchange (ctx context.Context , query * dns.Msg ) (* dns.Msg , error )
423
+ }
424
+
425
+ type dnsExchanger struct {
426
+ endpoint string
427
+ client * dns.Client
428
+ }
429
+
430
+ func (ex * dnsExchanger ) Exchange (ctx context.Context , query * dns.Msg ) (* dns.Msg , error ) {
431
+ ep := ex .endpoint
432
+ if _ , port , _ := net .SplitHostPort (ep ); port == "" {
433
+ ep = net .JoinHostPort (ep , "53" )
434
+ }
435
+ mr , _ , err := ex .client .Exchange (query , ep )
436
+ return mr , err
437
+ }
438
+
439
+ type dohExchanger struct {
440
+ endpoint * url.URL
441
+ client * http.Client
442
+ }
443
+
444
+ // reference: https://github.com/cloudflare/cloudflared/blob/master/tunneldns/https_upstream.go#L54
445
+ func (ex * dohExchanger ) Exchange (ctx context.Context , query * dns.Msg ) (* dns.Msg , error ) {
446
+ queryBuf , err := query .Pack ()
447
+ if err != nil {
448
+ return nil , fmt .Errorf ("failed to pack DNS query: %s" , err )
449
+ }
450
+
451
+ // No content negotiation for now, use DNS wire format
452
+ buf , backendErr := ex .exchangeWireformat (queryBuf )
453
+ if backendErr == nil {
454
+ response := & dns.Msg {}
455
+ if err := response .Unpack (buf ); err != nil {
456
+ return nil , fmt .Errorf ("failed to unpack DNS response from body: %s" , err )
457
+ }
458
+
459
+ response .Id = query .Id
460
+ return response , nil
461
+ }
462
+
463
+ return nil , backendErr
464
+ }
465
+
466
+ // Perform message exchange with the default UDP wireformat defined in current draft
467
+ // https://datatracker.ietf.org/doc/draft-ietf-doh-dns-over-https
468
+ func (ex * dohExchanger ) exchangeWireformat (msg []byte ) ([]byte , error ) {
469
+ req , err := http .NewRequest ("POST" , ex .endpoint .String (), bytes .NewBuffer (msg ))
470
+ if err != nil {
471
+ return nil , fmt .Errorf ("failed to create an HTTPS request: %s" , err )
472
+ }
473
+
474
+ req .Header .Add ("Content-Type" , "application/dns-udpwireformat" )
475
+ req .Host = ex .endpoint .Hostname ()
476
+
477
+ resp , err := ex .client .Do (req )
478
+ if err != nil {
479
+ return nil , fmt .Errorf ("failed to perform an HTTPS request: %s" , err )
480
+ }
481
+
482
+ // Check response status code
483
+ defer resp .Body .Close ()
484
+ if resp .StatusCode != http .StatusOK {
485
+ return nil , fmt .Errorf ("returned status code %d" , resp .StatusCode )
486
+ }
487
+
488
+ // Read wireformat response from the body
489
+ buf , err := ioutil .ReadAll (resp .Body )
490
+ if err != nil {
491
+ return nil , fmt .Errorf ("failed to read the response body: %s" , err )
492
+ }
493
+
494
+ return buf , nil
495
+ }
0 commit comments