Skip to content

Commit a7d49f0

Browse files
committed
DNS resolver support DoH #335
1 parent 02f1d09 commit a7d49f0

File tree

4 files changed

+431
-58
lines changed

4 files changed

+431
-58
lines changed

cmd/gost/.config/dns.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ reload 10s
1010
# ip[:port] [protocol] [hostname]
1111

1212
1.1.1.1:853 tls cloudflare-dns.com
13+
https://1.0.0.1/dns-query https
1314
8.8.8.8
1415
8.8.8.8 tcp
1516
1.1.1.1 udp

cmd/gost/cfg.go

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -207,17 +207,34 @@ func parseResolver(cfg string) gost.Resolver {
207207
if s == "" {
208208
continue
209209
}
210+
if strings.HasPrefix(s, "https") {
211+
ns := gost.NameServer{
212+
Addr: s,
213+
Protocol: "https",
214+
}
215+
if err := ns.Init(); err == nil {
216+
nss = append(nss, ns)
217+
}
218+
continue
219+
}
220+
210221
ss := strings.Split(s, "/")
211222
if len(ss) == 1 {
212-
nss = append(nss, gost.NameServer{
223+
ns := gost.NameServer{
213224
Addr: ss[0],
214-
})
225+
}
226+
if err := ns.Init(); err == nil {
227+
nss = append(nss, ns)
228+
}
215229
}
216230
if len(ss) == 2 {
217-
nss = append(nss, gost.NameServer{
231+
ns := gost.NameServer{
218232
Addr: ss[0],
219233
Protocol: ss[1],
220-
})
234+
}
235+
if err := ns.Init(); err == nil {
236+
nss = append(nss, ns)
237+
}
221238
}
222239
}
223240
return gost.NewResolver(timeout, ttl, nss...)

resolver.go

Lines changed: 188 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,28 @@ package gost
33
import (
44
"bufio"
55
"bytes"
6+
"context"
67
"crypto/tls"
78
"fmt"
89
"io"
10+
"io/ioutil"
911
"net"
12+
"net/http"
13+
"net/url"
1014
"strings"
1115
"sync"
1216
"time"
1317

1418
"github.com/go-log/log"
1519
"github.com/miekg/dns"
20+
"golang.org/x/net/http2"
1621
)
1722

1823
var (
1924
// DefaultResolverTimeout is the default timeout for name resolution.
20-
DefaultResolverTimeout = 30 * time.Second
25+
DefaultResolverTimeout = 5 * time.Second
2126
// DefaultResolverTTL is the default cache TTL for name resolution.
22-
DefaultResolverTTL = 60 * time.Second
27+
DefaultResolverTTL = 1 * time.Hour
2328
)
2429

2530
// Resolver is a name resolver for domain name.
@@ -39,9 +44,73 @@ type ReloadResolver interface {
3944
// NameServer is a name server.
4045
// Currently supported protocol: TCP, UDP and TLS.
4146
type NameServer struct {
42-
Addr string
43-
Protocol string
44-
Hostname string // for TLS handshake verification
47+
Addr string
48+
Protocol string
49+
Hostname string // for TLS handshake verification
50+
Timeout time.Duration
51+
exchanger Exchanger
52+
}
53+
54+
// Init initializes the name server.
55+
func (ns *NameServer) Init() error {
56+
switch strings.ToLower(ns.Protocol) {
57+
case "tcp":
58+
ns.exchanger = &dnsExchanger{
59+
endpoint: ns.Addr,
60+
client: &dns.Client{
61+
Net: "tcp",
62+
Timeout: ns.Timeout,
63+
},
64+
}
65+
case "tls":
66+
cfg := &tls.Config{
67+
ServerName: ns.Hostname,
68+
}
69+
if cfg.ServerName == "" {
70+
cfg.InsecureSkipVerify = true
71+
}
72+
73+
ns.exchanger = &dnsExchanger{
74+
endpoint: ns.Addr,
75+
client: &dns.Client{
76+
Net: "tcp-tls",
77+
Timeout: ns.Timeout,
78+
TLSConfig: cfg,
79+
},
80+
}
81+
case "https":
82+
u, err := url.Parse(ns.Addr)
83+
if err != nil {
84+
return err
85+
}
86+
cfg := &tls.Config{ServerName: u.Hostname()}
87+
transport := &http.Transport{
88+
TLSClientConfig: cfg,
89+
DisableCompression: true,
90+
MaxIdleConns: 1,
91+
}
92+
http2.ConfigureTransport(transport)
93+
94+
ns.exchanger = &dohExchanger{
95+
endpoint: u,
96+
client: &http.Client{
97+
Transport: transport,
98+
Timeout: ns.Timeout,
99+
},
100+
}
101+
case "udp":
102+
fallthrough
103+
default:
104+
ns.exchanger = &dnsExchanger{
105+
endpoint: ns.Addr,
106+
client: &dns.Client{
107+
Net: "udp",
108+
Timeout: ns.Timeout,
109+
},
110+
}
111+
}
112+
113+
return nil
45114
}
46115

47116
func (ns NameServer) String() string {
@@ -62,26 +131,19 @@ type resolverCacheItem struct {
62131
}
63132

64133
type resolver struct {
65-
Resolver *net.Resolver
66-
Servers []NameServer
67-
mCache *sync.Map
68-
Timeout time.Duration
69-
TTL time.Duration
70-
period time.Duration
71-
domain string
72-
stopped chan struct{}
73-
mux sync.RWMutex
134+
Servers []NameServer
135+
mCache *sync.Map
136+
Timeout time.Duration
137+
TTL time.Duration
138+
period time.Duration
139+
domain string
140+
stopped chan struct{}
141+
mux sync.RWMutex
74142
}
75143

76144
// NewResolver create a new Resolver with the given name servers and resolution timeout.
77145
func NewResolver(timeout, ttl time.Duration, servers ...NameServer) ReloadResolver {
78-
r := &resolver{
79-
Servers: servers,
80-
Timeout: timeout,
81-
TTL: ttl,
82-
mCache: &sync.Map{},
83-
stopped: make(chan struct{}),
84-
}
146+
r := newResolver(timeout, ttl, servers...)
85147

86148
if r.Timeout <= 0 {
87149
r.Timeout = DefaultResolverTimeout
@@ -92,6 +154,16 @@ func NewResolver(timeout, ttl time.Duration, servers ...NameServer) ReloadResolv
92154
return r
93155
}
94156

157+
func newResolver(timeout, ttl time.Duration, servers ...NameServer) *resolver {
158+
return &resolver{
159+
Servers: servers,
160+
Timeout: timeout,
161+
TTL: ttl,
162+
mCache: &sync.Map{},
163+
stopped: make(chan struct{}),
164+
}
165+
}
166+
95167
func (r *resolver) copyServers() []NameServer {
96168
var servers []NameServer
97169
for i := range r.Servers {
@@ -107,12 +179,11 @@ func (r *resolver) Resolve(host string) (ips []net.IP, err error) {
107179
}
108180

109181
var domain string
110-
var timeout, ttl time.Duration
182+
var ttl time.Duration
111183
var servers []NameServer
112184

113185
r.mux.RLock()
114186
domain = r.domain
115-
timeout = r.Timeout
116187
ttl = r.TTL
117188
servers = r.copyServers()
118189
r.mux.RUnlock()
@@ -133,7 +204,7 @@ func (r *resolver) Resolve(host string) (ips []net.IP, err error) {
133204
}
134205

135206
for _, ns := range servers {
136-
ips, err = r.resolve(ns, host, timeout)
207+
ips, err = r.resolve(ns.exchanger, host)
137208
if err != nil {
138209
log.Logf("[resolver] %s via %s : %s", host, ns, err)
139210
continue
@@ -151,36 +222,14 @@ func (r *resolver) Resolve(host string) (ips []net.IP, err error) {
151222
return
152223
}
153224

154-
func (*resolver) resolve(ns NameServer, host string, timeout time.Duration) (ips []net.IP, err error) {
155-
addr := ns.Addr
156-
if _, port, _ := net.SplitHostPort(addr); port == "" {
157-
addr = net.JoinHostPort(addr, "53")
158-
}
159-
160-
client := dns.Client{
161-
Timeout: timeout,
162-
}
163-
switch strings.ToLower(ns.Protocol) {
164-
case "tcp":
165-
client.Net = "tcp"
166-
case "tls":
167-
cfg := &tls.Config{
168-
ServerName: ns.Hostname,
169-
}
170-
if cfg.ServerName == "" {
171-
cfg.InsecureSkipVerify = true
172-
}
173-
client.Net = "tcp-tls"
174-
client.TLSConfig = cfg
175-
case "udp":
176-
fallthrough
177-
default:
178-
client.Net = "udp"
225+
func (*resolver) resolve(ex Exchanger, host string) (ips []net.IP, err error) {
226+
if ex == nil {
227+
return
179228
}
180229

181-
m := dns.Msg{}
182-
m.SetQuestion(dns.Fqdn(host), dns.TypeA)
183-
mr, _, err := client.Exchange(&m, addr)
230+
query := dns.Msg{}
231+
query.SetQuestion(dns.Fqdn(host), dns.TypeA)
232+
mr, err := ex.Exchange(context.Background(), &query)
184233
if err != nil {
185234
return
186235
}
@@ -223,7 +272,7 @@ func (r *resolver) Reload(rd io.Reader) error {
223272
var domain string
224273
var nss []NameServer
225274

226-
if r.Stopped() {
275+
if rd == nil || r.Stopped() {
227276
return nil
228277
}
229278

@@ -293,7 +342,15 @@ func (r *resolver) Reload(rd io.Reader) error {
293342
ns.Protocol = ss[1]
294343
ns.Hostname = ss[2]
295344
}
296-
nss = append(nss, ns)
345+
346+
ns.Timeout = timeout
347+
if timeout <= 0 {
348+
ns.Timeout = DefaultResolverTimeout
349+
}
350+
351+
if err := ns.Init(); err == nil {
352+
nss = append(nss, ns)
353+
}
297354
}
298355
}
299356

@@ -359,3 +416,80 @@ func (r *resolver) String() string {
359416
}
360417
return b.String()
361418
}
419+
420+
// Exchanger is an interface for DNS synchronous query.
421+
type Exchanger interface {
422+
Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error)
423+
}
424+
425+
type dnsExchanger struct {
426+
endpoint string
427+
client *dns.Client
428+
}
429+
430+
func (ex *dnsExchanger) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) {
431+
ep := ex.endpoint
432+
if _, port, _ := net.SplitHostPort(ep); port == "" {
433+
ep = net.JoinHostPort(ep, "53")
434+
}
435+
mr, _, err := ex.client.Exchange(query, ep)
436+
return mr, err
437+
}
438+
439+
type dohExchanger struct {
440+
endpoint *url.URL
441+
client *http.Client
442+
}
443+
444+
// reference: https://github.com/cloudflare/cloudflared/blob/master/tunneldns/https_upstream.go#L54
445+
func (ex *dohExchanger) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) {
446+
queryBuf, err := query.Pack()
447+
if err != nil {
448+
return nil, fmt.Errorf("failed to pack DNS query: %s", err)
449+
}
450+
451+
// No content negotiation for now, use DNS wire format
452+
buf, backendErr := ex.exchangeWireformat(queryBuf)
453+
if backendErr == nil {
454+
response := &dns.Msg{}
455+
if err := response.Unpack(buf); err != nil {
456+
return nil, fmt.Errorf("failed to unpack DNS response from body: %s", err)
457+
}
458+
459+
response.Id = query.Id
460+
return response, nil
461+
}
462+
463+
return nil, backendErr
464+
}
465+
466+
// Perform message exchange with the default UDP wireformat defined in current draft
467+
// https://datatracker.ietf.org/doc/draft-ietf-doh-dns-over-https
468+
func (ex *dohExchanger) exchangeWireformat(msg []byte) ([]byte, error) {
469+
req, err := http.NewRequest("POST", ex.endpoint.String(), bytes.NewBuffer(msg))
470+
if err != nil {
471+
return nil, fmt.Errorf("failed to create an HTTPS request: %s", err)
472+
}
473+
474+
req.Header.Add("Content-Type", "application/dns-udpwireformat")
475+
req.Host = ex.endpoint.Hostname()
476+
477+
resp, err := ex.client.Do(req)
478+
if err != nil {
479+
return nil, fmt.Errorf("failed to perform an HTTPS request: %s", err)
480+
}
481+
482+
// Check response status code
483+
defer resp.Body.Close()
484+
if resp.StatusCode != http.StatusOK {
485+
return nil, fmt.Errorf("returned status code %d", resp.StatusCode)
486+
}
487+
488+
// Read wireformat response from the body
489+
buf, err := ioutil.ReadAll(resp.Body)
490+
if err != nil {
491+
return nil, fmt.Errorf("failed to read the response body: %s", err)
492+
}
493+
494+
return buf, nil
495+
}

0 commit comments

Comments
 (0)