Skip to content

Commit 9606b38

Browse files
committed
add response filter
1 parent 8acaf21 commit 9606b38

File tree

6 files changed

+173
-51
lines changed

6 files changed

+173
-51
lines changed

direct.go

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,21 @@
11
package main
22

33
import (
4+
"errors"
45
"fmt"
56
"io"
67
"net/http"
78
)
89

9-
type DirectPlugin struct {
10-
Plugin
10+
type DirectRequestPlugin struct {
11+
RequestPlugin
1112
}
1213

13-
func (p DirectPlugin) Handle(c *PluginContext, rw http.ResponseWriter, req *http.Request) {
14+
type DirectResponsePlugin struct {
15+
ResponsePlugin
16+
}
17+
18+
func (p DirectRequestPlugin) HandleRequest(c *PluginContext, rw http.ResponseWriter, req *http.Request) (*http.Response, error) {
1419
if req.Method != "CONNECT" {
1520
if !req.URL.IsAbs() {
1621
if req.TLS != nil {
@@ -29,14 +34,30 @@ func (p DirectPlugin) Handle(c *PluginContext, rw http.ResponseWriter, req *http
2934
if err != nil {
3035
rw.WriteHeader(502)
3136
fmt.Fprintf(rw, "Error: %s\n", err)
32-
return
37+
return nil, err
3338
}
3439
newReq.Header = req.Header
3540
res, err := c.H.Net.HttpClientDo(newReq)
36-
if err != nil {
41+
return res, err
42+
} else {
43+
c.H.Log.Printf("%s \"DIRECT %s %s %s\" - -", req.RemoteAddr, req.Method, req.Host, req.Proto)
44+
response := &http.Response{
45+
StatusCode: 200,
46+
ProtoMajor: 1,
47+
ProtoMinor: 1,
48+
Header: http.Header{},
49+
ContentLength: -1,
50+
}
51+
return response, nil
52+
}
53+
}
54+
55+
func (p DirectResponsePlugin) HandleResponse(c *PluginContext, rw http.ResponseWriter, req *http.Request, res *http.Response, resError error) error {
56+
if req.Method != "CONNECT" {
57+
if resError != nil {
3758
rw.WriteHeader(502)
38-
fmt.Fprintf(rw, "Error: %s\n", err)
39-
return
59+
fmt.Fprintf(rw, "Error: %s\n", resError)
60+
return resError
4061
}
4162
c.H.Log.Printf("%s \"DIRECT %s %s %s\" %d %s", req.RemoteAddr, req.Method, req.URL.String(), req.Proto, res.StatusCode, res.Header.Get("Content-Length"))
4263
rw.WriteHeader(res.StatusCode)
@@ -47,20 +68,27 @@ func (p DirectPlugin) Handle(c *PluginContext, rw http.ResponseWriter, req *http
4768
}
4869
io.Copy(rw, res.Body)
4970
} else {
50-
c.H.Log.Printf("%s \"DIRECT %s %s %s\" - -", req.RemoteAddr, req.Method, req.Host, req.Proto)
71+
if resError != nil {
72+
rw.WriteHeader(502)
73+
fmt.Fprintf(rw, "Error: %s\n", resError)
74+
c.H.Log.Printf("NetDialTimeout %s failed %s", req.Host, resError)
75+
return resError
76+
}
5177
remoteConn, err := c.H.Net.NetDialTimeout("tcp", req.Host, c.H.Net.GetTimeout())
5278
if err != nil {
53-
c.H.Log.Printf("NetDialTimeout %s failed %s", req.Host, err)
54-
return
79+
return err
5580
}
5681
hijacker, ok := rw.(http.Hijacker)
5782
if !ok {
58-
c.H.Log.Printf("http.ResponseWriter does not implments Hijacker")
59-
return
83+
resError = errors.New("http.ResponseWriter does not implments Hijacker")
84+
rw.WriteHeader(502)
85+
fmt.Fprintf(rw, "Error: %s\n", resError)
86+
return resError
6087
}
6188
localConn, _, err := hijacker.Hijack()
6289
localConn.Write([]byte("HTTP/1.1 200 OK\r\n\r\n"))
6390
go io.Copy(remoteConn, localConn)
6491
io.Copy(localConn, remoteConn)
6592
}
93+
return nil
6694
}

filter.go

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,27 @@ import (
44
"net/http"
55
)
66

7-
type DirectFilter struct {
7+
type DirectRequestFilter struct {
88
RequestFilter
99
}
1010

11-
func (d *DirectFilter) Filter(req *http.Request) (pluginName string, pluginArgs *http.Header, err error) {
11+
func (d *DirectRequestFilter) Filter(req *http.Request) (pluginName string, pluginArgs *http.Header, err error) {
1212
return "direct", nil, nil
1313
}
1414

15-
type StripFilter struct {
15+
type DirectResponseFilter struct {
1616
RequestFilter
1717
}
1818

19-
func (d *StripFilter) Filter(req *http.Request) (pluginName string, pluginArgs *http.Header, err error) {
19+
func (d *DirectResponseFilter) Filter(req *http.Request, res *http.Response) (pluginName string, pluginArgs *http.Header, err error) {
20+
return "direct", nil, nil
21+
}
22+
23+
type StripRequestFilter struct {
24+
RequestFilter
25+
}
26+
27+
func (d *StripRequestFilter) Filter(req *http.Request) (pluginName string, pluginArgs *http.Header, err error) {
2028
if req.Method == "CONNECT" {
2129
args := http.Header{
2230
"Foo": []string{"bar"},

goagent.go

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,19 @@ func main() {
2020
Listener: ln,
2121
Log: log.New(os.Stderr, "INFO - ", 3),
2222
Net: &SimpleNetwork{},
23-
Plugins: map[string]Plugin{
24-
"direct": &DirectPlugin{},
25-
"strip": &StripPlugin{},
23+
RequestPlugins: map[string]RequestPlugin{
24+
"direct": &DirectRequestPlugin{},
25+
"strip": &StripRequestPlugin{},
26+
},
27+
ResponsePlugins: map[string]ResponsePlugin{
28+
"direct": &DirectResponsePlugin{},
2629
},
2730
RequestFilters: []RequestFilter{
28-
&StripFilter{},
29-
&DirectFilter{},
31+
&StripRequestFilter{},
32+
&DirectRequestFilter{},
33+
},
34+
ResponseFilters: []ResponseFilter{
35+
&DirectResponseFilter{},
3036
},
3137
}
3238
s := &http.Server{

handler.go

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package main
22

33
import (
44
"crypto/tls"
5+
"io"
56
"log"
67
"net"
78
"net/http"
@@ -13,6 +14,7 @@ type Net2 interface {
1314
NetDialTimeout(network, address string, timeout time.Duration) (net.Conn, error)
1415
TlsDialTimeout(network, address string, config *tls.Config, timeout time.Duration) (*tls.Conn, error)
1516
HttpClientDo(req *http.Request) (*http.Response, error)
17+
CopyResponseBody(w io.Writer, res *http.Response) (int64, error)
1618
GetTimeout() time.Duration
1719
SetTimeout()
1820
GetAddressAlias(addr string) (alias string)
@@ -23,7 +25,7 @@ type RequestFilter interface {
2325
}
2426

2527
type ResponseFilter interface {
26-
Filter(req *http.Response) (newReq *http.Response, err error)
28+
Filter(req *http.Request, res *http.Response) (pluginName string, pluginArgs *http.Header, err error)
2729
}
2830

2931
type PushListener interface {
@@ -100,7 +102,8 @@ type Handler struct {
100102
Listener net.Listener
101103
Log *log.Logger
102104
Net Net2
103-
Plugins map[string]Plugin
105+
RequestPlugins map[string]RequestPlugin
106+
ResponsePlugins map[string]ResponsePlugin
104107
RequestFilters []RequestFilter
105108
ResponseFilters []ResponseFilter
106109
}
@@ -110,22 +113,53 @@ type PluginContext struct {
110113
Args *http.Header
111114
}
112115

116+
type RequestPlugin interface {
117+
HandleRequest(*PluginContext, http.ResponseWriter, *http.Request) (*http.Response, error)
118+
}
119+
120+
type ResponsePlugin interface {
121+
HandleResponse(*PluginContext, http.ResponseWriter, *http.Request, *http.Response, error) error
122+
}
123+
113124
type Plugin interface {
114-
Handle(*PluginContext, http.ResponseWriter, *http.Request)
125+
RequestPlugin
126+
ResponsePlugin
115127
}
116128

117129
func (h Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
118-
for _, f := range h.RequestFilters {
119-
name, args, err := f.Filter(req)
130+
for _, reqfilter := range h.RequestFilters {
131+
name, args, err := reqfilter.Filter(req)
120132
if err != nil {
121-
h.Log.Fatalf("ServeHTTP error: %v", err)
133+
h.Log.Printf("RequestFilter error: %v", err)
122134
}
123135
if name == "" {
124136
continue
125137
}
126-
if plugin, ok := h.Plugins[name]; ok {
127-
context := &PluginContext{&h, args}
128-
plugin.Handle(context, rw, req)
138+
if reqplugin, ok := h.RequestPlugins[name]; ok {
139+
reqctx := &PluginContext{&h, args}
140+
res, err := reqplugin.HandleRequest(reqctx, rw, req)
141+
if err != nil {
142+
h.Log.Printf("Plugin %s HandleResponse error: %v", name, err)
143+
}
144+
if res != nil {
145+
for _, resfilter := range h.ResponseFilters {
146+
name, args, err := resfilter.Filter(req, res)
147+
if err != nil {
148+
h.Log.Printf("ServeHTTP RequestFilter error: %v", err)
149+
}
150+
if name == "" {
151+
continue
152+
}
153+
if resplugin, ok := h.ResponsePlugins[name]; ok {
154+
resctx := &PluginContext{&h, args}
155+
err := resplugin.HandleResponse(resctx, rw, req, res, err)
156+
if err != nil {
157+
h.Log.Printf("Plugin %s HandleResponse error: %v", name, err)
158+
}
159+
}
160+
break
161+
}
162+
}
129163
break
130164
} else {
131165
h.Log.Fatalf("plugin \"%s\" not registered", name)

network.go

Lines changed: 54 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,39 +2,86 @@ package main
22

33
import (
44
"crypto/tls"
5+
// "fmt"
6+
"io"
57
"net"
68
"net/http"
9+
"sync"
710
"time"
811
)
912

1013
type SimpleNetwork struct {
1114
Net2
1215
}
1316

14-
func (s *SimpleNetwork) NetResolveIPAddr(network, addr string) (*net.IPAddr, error) {
17+
func (sn *SimpleNetwork) NetResolveIPAddr(network, addr string) (*net.IPAddr, error) {
1518
return net.ResolveIPAddr(network, addr)
1619
}
1720

18-
func (s *SimpleNetwork) NetDialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
21+
func (sn *SimpleNetwork) NetDialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
1922
return net.DialTimeout(network, address, timeout)
2023
}
2124

22-
func (s *SimpleNetwork) TlsDialTimeout(network string, addr string, config *tls.Config, timeout time.Duration) (*tls.Conn, error) {
25+
func (sn *SimpleNetwork) TlsDialTimeout(network string, addr string, config *tls.Config, timeout time.Duration) (*tls.Conn, error) {
2326
return tls.Dial(network, addr, config)
2427
}
2528

26-
func (s *SimpleNetwork) HttpClientDo(req *http.Request) (*http.Response, error) {
29+
func (sn *SimpleNetwork) HttpClientDo(req *http.Request) (*http.Response, error) {
2730
client := &http.Client{}
2831
return client.Do(req)
2932
}
3033

31-
func (s *SimpleNetwork) GetTimeout() time.Duration {
34+
func (sn *SimpleNetwork) CopyResponseBody(w io.Writer, res *http.Response) (int64, error) {
35+
return io.Copy(w, res.Body)
36+
}
37+
38+
func (sn *SimpleNetwork) GetTimeout() time.Duration {
39+
return 8 * time.Second
40+
}
41+
42+
func (sn *SimpleNetwork) SetTimeout() {
43+
}
44+
45+
func (sn *SimpleNetwork) GetAddressAlias(addr string) (alias string) {
46+
return ""
47+
}
48+
49+
type AdvancedNetwork struct {
50+
Net2
51+
dnsCache map[string]*net.IPAddr
52+
dnsCacheMu sync.Mutex
53+
}
54+
55+
func NewAdvancedNetwork() *AdvancedNetwork {
56+
return &AdvancedNetwork{
57+
dnsCache: map[string]*net.IPAddr{},
58+
}
59+
}
60+
61+
func (an *AdvancedNetwork) NetResolveIPAddr(network, addr string) (*net.IPAddr, error) {
62+
return net.ResolveIPAddr(network, addr)
63+
}
64+
65+
func (an *AdvancedNetwork) NetDialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
66+
return net.DialTimeout(network, address, timeout)
67+
}
68+
69+
func (an *AdvancedNetwork) TlsDialTimeout(network string, addr string, config *tls.Config, timeout time.Duration) (*tls.Conn, error) {
70+
return tls.Dial(network, addr, config)
71+
}
72+
73+
func (an *AdvancedNetwork) HttpClientDo(req *http.Request) (*http.Response, error) {
74+
client := &http.Client{}
75+
return client.Do(req)
76+
}
77+
78+
func (an *AdvancedNetwork) GetTimeout() time.Duration {
3279
return 8 * time.Second
3380
}
3481

35-
func (s *SimpleNetwork) SetTimeout() {
82+
func (s *AdvancedNetwork) SetTimeout() {
3683
}
3784

38-
func (s *SimpleNetwork) GetAddressAlias(addr string) (alias string) {
85+
func (s *AdvancedNetwork) GetAddressAlias(addr string) (alias string) {
3986
return ""
4087
}

0 commit comments

Comments
 (0)