Skip to content

Commit 2d04c4b

Browse files
author
Achille
authored
add encoding.TextMarshaler and encoding.TextUnmarshaler implementations (#754)
* add encoding.TextMarshaler and encoding.TextUnmarshaler implementations * support numeric codes as well * PR feedback
1 parent c03923d commit 2d04c4b

File tree

4 files changed

+138
-0
lines changed

4 files changed

+138
-0
lines changed

compress/compress.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
package compress
22

33
import (
4+
"encoding"
5+
"fmt"
46
"io"
7+
"strconv"
8+
"strings"
59

610
"github.com/segmentio/kafka-go/compress/gzip"
711
"github.com/segmentio/kafka-go/compress/lz4"
@@ -13,6 +17,7 @@ import (
1317
type Compression int8
1418

1519
const (
20+
None Compression = 0
1621
Gzip Compression = 1
1722
Snappy Compression = 2
1823
Lz4 Compression = 3
@@ -33,6 +38,50 @@ func (c Compression) String() string {
3338
return "uncompressed"
3439
}
3540

41+
func (c Compression) MarshalText() ([]byte, error) {
42+
return []byte(c.String()), nil
43+
}
44+
45+
func (c *Compression) UnmarshalText(b []byte) error {
46+
switch string(b) {
47+
case "none", "uncompressed":
48+
*c = None
49+
return nil
50+
}
51+
52+
for _, codec := range Codecs[None+1:] {
53+
if codec.Name() == string(b) {
54+
*c = Compression(codec.Code())
55+
return nil
56+
}
57+
}
58+
59+
i, err := strconv.ParseInt(string(b), 10, 64)
60+
if err == nil && i >= 0 && i < int64(len(Codecs)) {
61+
*c = Compression(i)
62+
return nil
63+
}
64+
65+
s := &strings.Builder{}
66+
s.WriteString("none, uncompressed")
67+
68+
for i, codec := range Codecs[None+1:] {
69+
if i < (len(Codecs) - 1) {
70+
s.WriteString(", ")
71+
} else {
72+
s.WriteString(", or ")
73+
}
74+
s.WriteString(codec.Name())
75+
}
76+
77+
return fmt.Errorf("compression format must be one of %s, not %q", s, b)
78+
}
79+
80+
var (
81+
_ encoding.TextMarshaler = Compression(0)
82+
_ encoding.TextUnmarshaler = (*Compression)(nil)
83+
)
84+
3685
// Codec represents a compression codec to encode and decode the messages.
3786
// See : https://cwiki.apache.org/confluence/display/KAFKA/Compression
3887
//
@@ -66,6 +115,7 @@ var (
66115

67116
// The global table of compression codecs supported by the kafka protocol.
68117
Codecs = [...]Codec{
118+
None: nil,
69119
Gzip: &GzipCodec,
70120
Snappy: &SnappyCodec,
71121
Lz4: &Lz4Codec,

compress/compress_test.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,31 @@ func testEncodeDecode(t *testing.T, m kafka.Message, codec pkg.Codec) {
8585
var r1, r2 []byte
8686
var err error
8787

88+
t.Run("text format of "+codec.Name(), func(t *testing.T) {
89+
c := pkg.Compression(codec.Code())
90+
a := strconv.Itoa(int(c))
91+
x := pkg.Compression(-1)
92+
y := pkg.Compression(-1)
93+
b, err := c.MarshalText()
94+
if err != nil {
95+
t.Fatal(err)
96+
}
97+
98+
if err := x.UnmarshalText([]byte(a)); err != nil {
99+
t.Fatal(err)
100+
}
101+
if err := y.UnmarshalText(b); err != nil {
102+
t.Fatal(err)
103+
}
104+
105+
if x != c {
106+
t.Errorf("compression mismatch after marshal/unmarshal: want=%s got=%s", c, x)
107+
}
108+
if y != c {
109+
t.Errorf("compression mismatch after marshal/unmarshal: want=%s got=%s", c, y)
110+
}
111+
})
112+
88113
t.Run("encode with "+codec.Name(), func(t *testing.T) {
89114
r1, err = compress(codec, m.Value)
90115
if err != nil {

produce.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@ package kafka
33
import (
44
"bufio"
55
"context"
6+
"encoding"
67
"errors"
78
"fmt"
89
"net"
10+
"strconv"
911
"time"
1012

1113
"github.com/segmentio/kafka-go/protocol"
@@ -33,6 +35,34 @@ func (acks RequiredAcks) String() string {
3335
}
3436
}
3537

38+
func (acks RequiredAcks) MarshalText() ([]byte, error) {
39+
return []byte(acks.String()), nil
40+
}
41+
42+
func (acks *RequiredAcks) UnmarshalText(b []byte) error {
43+
switch string(b) {
44+
case "none":
45+
*acks = RequireNone
46+
case "one":
47+
*acks = RequireOne
48+
case "all":
49+
*acks = RequireAll
50+
default:
51+
x, err := strconv.ParseInt(string(b), 10, 64)
52+
parsed := RequiredAcks(x)
53+
if err != nil || (parsed != RequireNone && parsed != RequireOne && parsed != RequireAll) {
54+
return fmt.Errorf("required acks must be one of none, one, or all, not %q", b)
55+
}
56+
*acks = parsed
57+
}
58+
return nil
59+
}
60+
61+
var (
62+
_ encoding.TextMarshaler = RequiredAcks(0)
63+
_ encoding.TextUnmarshaler = (*RequiredAcks)(nil)
64+
)
65+
3666
// ProduceRequest represents a request sent to a kafka broker to produce records
3767
// to a topic partition.
3868
type ProduceRequest struct {

produce_test.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,45 @@ package kafka
22

33
import (
44
"context"
5+
"strconv"
56
"testing"
67
"time"
78

89
"github.com/segmentio/kafka-go/compress"
910
)
1011

12+
func TestRequiredAcks(t *testing.T) {
13+
for _, acks := range []RequiredAcks{
14+
RequireNone,
15+
RequireOne,
16+
RequireAll,
17+
} {
18+
t.Run(acks.String(), func(t *testing.T) {
19+
a := strconv.Itoa(int(acks))
20+
x := RequiredAcks(-2)
21+
y := RequiredAcks(-2)
22+
b, err := acks.MarshalText()
23+
if err != nil {
24+
t.Fatal(err)
25+
}
26+
27+
if err := x.UnmarshalText([]byte(a)); err != nil {
28+
t.Fatal(err)
29+
}
30+
if err := y.UnmarshalText(b); err != nil {
31+
t.Fatal(err)
32+
}
33+
34+
if x != acks {
35+
t.Errorf("required acks mismatch after marshal/unmarshal text: want=%s got=%s", acks, x)
36+
}
37+
if y != acks {
38+
t.Errorf("required acks mismatch after marshal/unmarshal value: want=%s got=%s", acks, y)
39+
}
40+
})
41+
}
42+
}
43+
1144
func TestClientProduce(t *testing.T) {
1245
client, topic, shutdown := newLocalClientAndTopic()
1346
defer shutdown()

0 commit comments

Comments
 (0)