Skip to content

Commit 4e77570

Browse files
committed
fix connection reading in UDP
1 parent b52725c commit 4e77570

File tree

5 files changed

+97
-3
lines changed

5 files changed

+97
-3
lines changed

common/buf/multi_buffer.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,17 @@ func SplitBytes(mb MultiBuffer, b []byte) (MultiBuffer, int) {
122122
return mb, totalBytes
123123
}
124124

125+
// SplitFirstBytes splits the first buffer from MultiBuffer, and then copy its content into the given slice.
126+
func SplitFirstBytes(mb MultiBuffer, p []byte) (MultiBuffer, int) {
127+
mb, b := SplitFirst(mb)
128+
if b == nil {
129+
return mb, 0
130+
}
131+
n := copy(p, b.Bytes())
132+
b.Release()
133+
return mb, n
134+
}
135+
125136
// Compact returns another MultiBuffer by merging all content of the given one together.
126137
func Compact(mb MultiBuffer) MultiBuffer {
127138
if len(mb) == 0 {

common/buf/reader.go

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ type BufferedReader struct {
5858
Reader Reader
5959
// Buffer is the internal buffer to be read from first
6060
Buffer MultiBuffer
61+
// Spliter is a function to read bytes from MultiBuffer
62+
Spliter func(MultiBuffer, []byte) (MultiBuffer, int)
6163
}
6264

6365
// BufferedBytes returns the number of bytes that is cached in this reader.
@@ -74,8 +76,13 @@ func (r *BufferedReader) ReadByte() (byte, error) {
7476

7577
// Read implements io.Reader. It reads from internal buffer first (if available) and then reads from the underlying reader.
7678
func (r *BufferedReader) Read(b []byte) (int, error) {
79+
spliter := r.Spliter
80+
if spliter == nil {
81+
spliter = SplitBytes
82+
}
83+
7784
if !r.Buffer.IsEmpty() {
78-
buffer, nBytes := SplitBytes(r.Buffer, b)
85+
buffer, nBytes := spliter(r.Buffer, b)
7986
r.Buffer = buffer
8087
if r.Buffer.IsEmpty() {
8188
r.Buffer = nil
@@ -88,7 +95,7 @@ func (r *BufferedReader) Read(b []byte) (int, error) {
8895
return 0, err
8996
}
9097

91-
mb, nBytes := SplitBytes(mb, b)
98+
mb, nBytes := spliter(mb, b)
9299
if !mb.IsEmpty() {
93100
r.Buffer = mb
94101
}

common/net/connection.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,15 @@ func ConnectionOutputMulti(reader buf.Reader) ConnectionOption {
4848
}
4949
}
5050

51+
func ConnectionOutputMultiUDP(reader buf.Reader) ConnectionOption {
52+
return func(c *connection) {
53+
c.reader = &buf.BufferedReader{
54+
Reader: reader,
55+
Spliter: buf.SplitFirstBytes,
56+
}
57+
}
58+
}
59+
5160
func ConnectionOnClose(n io.Closer) ConnectionOption {
5261
return func(c *connection) {
5362
c.onClose = n

functions.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,13 @@ func Dial(ctx context.Context, v *Instance, dest net.Destination) (net.Conn, err
5353
if err != nil {
5454
return nil, err
5555
}
56-
return net.NewConnection(net.ConnectionInputMulti(r.Writer), net.ConnectionOutputMulti(r.Reader)), nil
56+
var readerOpt net.ConnectionOption
57+
if dest.Network == net.Network_TCP {
58+
readerOpt = net.ConnectionOutputMulti(r.Reader)
59+
} else {
60+
readerOpt = net.ConnectionOutputMultiUDP(r.Reader)
61+
}
62+
return net.NewConnection(net.ConnectionInputMulti(r.Writer), readerOpt), nil
5763
}
5864

5965
// DialUDP provides a way to exchange UDP packets through V2Ray instance to remote servers.

functions_test.go

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"crypto/rand"
66
"io"
77
"testing"
8+
"time"
89

910
"github.com/golang/protobuf/proto"
1011
"github.com/google/go-cmp/cmp"
@@ -86,6 +87,66 @@ func TestV2RayDial(t *testing.T) {
8687
}
8788
}
8889

90+
func TestV2RayDialUDPConn(t *testing.T) {
91+
udpServer := udp.Server{
92+
MsgProcessor: xor,
93+
}
94+
dest, err := udpServer.Start()
95+
common.Must(err)
96+
defer udpServer.Close()
97+
98+
config := &core.Config{
99+
App: []*serial.TypedMessage{
100+
serial.ToTypedMessage(&dispatcher.Config{}),
101+
serial.ToTypedMessage(&proxyman.InboundConfig{}),
102+
serial.ToTypedMessage(&proxyman.OutboundConfig{}),
103+
},
104+
Outbound: []*core.OutboundHandlerConfig{
105+
{
106+
ProxySettings: serial.ToTypedMessage(&freedom.Config{}),
107+
},
108+
},
109+
}
110+
111+
cfgBytes, err := proto.Marshal(config)
112+
common.Must(err)
113+
114+
server, err := core.StartInstance("protobuf", cfgBytes)
115+
common.Must(err)
116+
defer server.Close()
117+
118+
conn, err := core.Dial(context.Background(), server, dest)
119+
common.Must(err)
120+
defer conn.Close()
121+
122+
const size = 1024
123+
payload := make([]byte, size)
124+
common.Must2(rand.Read(payload))
125+
126+
for i := 0; i < 2; i++ {
127+
if _, err := conn.Write(payload); err != nil {
128+
t.Fatal(err)
129+
}
130+
}
131+
132+
time.Sleep(time.Millisecond * 500)
133+
134+
receive := make([]byte, size*2)
135+
for i := 0; i < 2; i++ {
136+
n, err := conn.Read(receive)
137+
if err != nil {
138+
t.Fatal("expect no error, but got ", err)
139+
}
140+
if n != size {
141+
t.Fatal("expect read size ", size, " but got ", n)
142+
}
143+
144+
if r := cmp.Diff(xor(receive[:n]), payload); r != "" {
145+
t.Fatal(r)
146+
}
147+
}
148+
}
149+
89150
func TestV2RayDialUDP(t *testing.T) {
90151
udpServer1 := udp.Server{
91152
MsgProcessor: xor,

0 commit comments

Comments
 (0)