Skip to content

Commit a8de6b8

Browse files
authored
Add protobuf rewrite rule overrides (#144)
* allow rewrite rule overrides and add bitwise-or rewriter * finish docs * remove repeated, these apply a little odd in the rewriter
1 parent 93030d3 commit a8de6b8

File tree

2 files changed

+207
-7
lines changed

2 files changed

+207
-7
lines changed

proto/rewrite.go

Lines changed: 143 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -139,17 +139,27 @@ func (f fieldset) index(i int) (int, int) {
139139
// ParseRewriteTemplate constructs a Rewriter for a protobuf type using the
140140
// given json template to describe the rewrite rules.
141141
//
142-
// The json template contains a representation of the
143-
func ParseRewriteTemplate(typ Type, jsonTemplate []byte) (Rewriter, error) {
142+
// The json template contains a representation of the message that is used as the
143+
// source values to overwrite in the protobuf targeted by the resulting rewriter.
144+
//
145+
// The rules are an optional set of RewriterRules that can provide alternative
146+
// Rewriters from the default used for the field type. These rules are given the
147+
// json.RawMessage bytes from the template, and they are expected to create a
148+
// Rewriter to be applied against the target protobuf.
149+
func ParseRewriteTemplate(typ Type, jsonTemplate []byte, rules ...RewriterRules) (Rewriter, error) {
144150
switch typ.Kind() {
145151
case Struct:
146-
return parseRewriteTemplateStruct(typ, 0, jsonTemplate)
152+
return parseRewriteTemplateStruct(typ, 0, jsonTemplate, rules...)
147153
default:
148154
return nil, fmt.Errorf("cannot construct a rewrite template from a non-struct type %s", typ.Name())
149155
}
150156
}
151157

152-
func parseRewriteTemplate(t Type, f FieldNumber, j json.RawMessage) (Rewriter, error) {
158+
func parseRewriteTemplate(t Type, f FieldNumber, j json.RawMessage, rule any) (Rewriter, error) {
159+
if rwer, ok := rule.(Rewriterer); ok {
160+
return rwer.Rewriter(t, f, j)
161+
}
162+
153163
switch t.Kind() {
154164
case Bool:
155165
return parseRewriteTemplateBool(t, f, j)
@@ -184,7 +194,11 @@ func parseRewriteTemplate(t Type, f FieldNumber, j json.RawMessage) (Rewriter, e
184194
case Map:
185195
return parseRewriteTemplateMap(t, f, j)
186196
case Struct:
187-
return parseRewriteTemplateStruct(t, f, j)
197+
sub, n, ok := [1]RewriterRules{}, 0, false
198+
if sub[0], ok = rule.(RewriterRules); ok {
199+
n = 1
200+
}
201+
return parseRewriteTemplateStruct(t, f, j, sub[:n]...)
188202
default:
189203
return nil, fmt.Errorf("cannot construct a rewriter from type %s", t.Name())
190204
}
@@ -376,7 +390,7 @@ func parseRewriteTemplateMap(t Type, f FieldNumber, j json.RawMessage) (Rewriter
376390
return MultiRewriter(rewriters...), nil
377391
}
378392

379-
func parseRewriteTemplateStruct(t Type, f FieldNumber, j json.RawMessage) (Rewriter, error) {
393+
func parseRewriteTemplateStruct(t Type, f FieldNumber, j json.RawMessage, rules ...RewriterRules) (Rewriter, error) {
380394
template := map[string]json.RawMessage{}
381395

382396
if err := json.Unmarshal(j, &template); err != nil {
@@ -408,10 +422,18 @@ func parseRewriteTemplateStruct(t Type, f FieldNumber, j json.RawMessage) (Rewri
408422
fields = []json.RawMessage{v}
409423
}
410424

425+
var rule any
426+
for i := range rules {
427+
if r, ok := rules[i][f.Name]; ok {
428+
rule = r
429+
break
430+
}
431+
}
432+
411433
rewriters = rewriters[:0]
412434

413435
for _, v := range fields {
414-
rw, err := parseRewriteTemplate(f.Type, f.Number, v)
436+
rw, err := parseRewriteTemplate(f.Type, f.Number, v, rule)
415437
if err != nil {
416438
return nil, fmt.Errorf("%s: %w", k, err)
417439
}
@@ -462,3 +484,117 @@ func (f *embddedRewriter) Rewrite(out, in []byte) ([]byte, error) {
462484
copy(out[prefix:], b[:tagAndLen])
463485
return out, nil
464486
}
487+
488+
// RewriterRules defines a set of rules for overriding the Rewriter used for any
489+
// particular field. These maps may be nested for defining rules for struct members.
490+
//
491+
// For example:
492+
//
493+
// rules := proto.RewriterRules {
494+
// "flags": proto.BitOr[uint64]{},
495+
// "nested": proto.RewriterRules {
496+
// "name": myCustomRewriter,
497+
// },
498+
// }
499+
type RewriterRules map[string]any
500+
501+
// Rewriterer is the interface for producing a Rewriter for a given Type, FieldNumber
502+
// and json.RawMessage. The JSON value is the JSON-encoded payload that should be
503+
// decoded to produce the appropriate Rewriter. Implementations of the Rewriterer
504+
// interface are added to the RewriterRules to specify the rules for performing
505+
// custom rewrite logic.
506+
type Rewriterer interface {
507+
Rewriter(Type, FieldNumber, json.RawMessage) (Rewriter, error)
508+
}
509+
510+
// BitOr implments the Rewriterer interface for providing a bitwise-or rewrite
511+
// logic for integers rather than replacing them. Instances of this type are
512+
// zero-size, carrying only the generic type for creating the appropriate
513+
// Rewriter when requested.
514+
//
515+
// Adding these to a RewriterRules looks like:
516+
//
517+
// rules := proto.RewriterRules {
518+
// "flags": proto.BitOr[uint64]{},
519+
// }
520+
//
521+
// When used as a rule when rewriting from a template, the BitOr expects a JSON-
522+
// encoded integer passed into the Rewriter method. This parsed integer is then
523+
// used to perform a bitwise-or against the protobuf message that is being rewritten.
524+
//
525+
// The above example can then be used like:
526+
//
527+
// template := []byte(`{"flags": 8}`) // n |= 0b1000
528+
// rw, err := proto.ParseRewriteTemplate(typ, template, rules)
529+
type BitOr[T integer] struct{}
530+
531+
// integer is the contraint used by the BitOr Rewriterer and the bitOrRW Rewriter.
532+
// Because these perform bitwise-or operations, the types must be integer-like.
533+
type integer interface {
534+
~int | ~int32 | ~int64 | ~uint | ~uint32 | ~uint64
535+
}
536+
537+
// Rewriter implements the Rewriterer interface. The JSON value provided to this
538+
// method comes from the template used for rewriting. The returned Rewriter will use
539+
// this JSON-encoded integer to perform a bitwise-or against the protobuf message
540+
// that is being rewritten.
541+
func (BitOr[T]) Rewriter(t Type, f FieldNumber, j json.RawMessage) (Rewriter, error) {
542+
var v T
543+
err := json.Unmarshal(j, &v)
544+
if err != nil {
545+
return nil, err
546+
}
547+
return BitOrRewriter(t, f, v)
548+
}
549+
550+
// BitOrRewriter creates a bitwise-or Rewriter for a given field type and number.
551+
// The mask is the value or'ed with values in the target protobuf.
552+
func BitOrRewriter[T integer](t Type, f FieldNumber, mask T) (Rewriter, error) {
553+
switch t.Kind() {
554+
case Int32, Int64, Sint32, Sint64, Uint32, Uint64, Fix32, Fix64, Sfix32, Sfix64:
555+
default:
556+
return nil, fmt.Errorf("cannot construct a rewriter from type %s", t.Name())
557+
}
558+
return bitOrRW[T]{mask: mask, t: t, f: f}, nil
559+
}
560+
561+
// bitOrRW is the Rewriter returned by the BitOr Rewriter method.
562+
type bitOrRW[T integer] struct {
563+
mask T
564+
t Type
565+
f FieldNumber
566+
}
567+
568+
// Rewrite implements the Rewriter interface performing a bitwise-or between the
569+
// template value and the input value.
570+
func (r bitOrRW[T]) Rewrite(out, in []byte) ([]byte, error) {
571+
var v T
572+
if err := Unmarshal(in, &v); err != nil {
573+
return nil, err
574+
}
575+
576+
v |= r.mask
577+
578+
switch r.t.Kind() {
579+
case Int32:
580+
return r.f.Int32(int32(v)).Rewrite(out, in)
581+
case Int64:
582+
return r.f.Int64(int64(v)).Rewrite(out, in)
583+
case Sint32:
584+
return r.f.Uint32(encodeZigZag32(int32(v))).Rewrite(out, in)
585+
case Sint64:
586+
return r.f.Uint64(encodeZigZag64(int64(v))).Rewrite(out, in)
587+
case Uint32, Uint64:
588+
return r.f.Uint64(uint64(v)).Rewrite(out, in)
589+
case Fix32:
590+
return r.f.Fixed32(uint32(v)).Rewrite(out, in)
591+
case Fix64:
592+
return r.f.Fixed64(uint64(v)).Rewrite(out, in)
593+
case Sfix32:
594+
return r.f.Fixed32(encodeZigZag32(int32(v))).Rewrite(out, in)
595+
case Sfix64:
596+
return r.f.Fixed64(encodeZigZag64(int64(v))).Rewrite(out, in)
597+
}
598+
599+
panic("unreachable") // Kind is validated when creating instances
600+
}

proto/rewrite_test.go

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,70 @@ func TestParseRewriteTemplate(t *testing.T) {
291291
}
292292
}
293293

294+
func TestParseRewriteRules(t *testing.T) {
295+
type submessage struct {
296+
Flags uint64 `protobuf:"varint,1,opt,name=flags,proto3"`
297+
}
298+
299+
type message struct {
300+
Flags uint64 `protobuf:"varint,2,opt,name=flags,proto3"`
301+
Subfield *submessage `protobuf:"bytes,99,opt,name=subfield,proto3"`
302+
}
303+
304+
original := &message{
305+
Flags: 0b00000001,
306+
Subfield: &submessage{
307+
Flags: 0b00000010,
308+
},
309+
}
310+
311+
expected := &message{
312+
Flags: 0b00000001 | 16,
313+
Subfield: &submessage{
314+
Flags: 0b00000010 | 32,
315+
},
316+
}
317+
318+
rules := RewriterRules{
319+
"flags": BitOr[uint64]{},
320+
"subfield": RewriterRules{
321+
"flags": BitOr[uint64]{},
322+
},
323+
}
324+
325+
rw, err := ParseRewriteTemplate(TypeOf(reflect.TypeOf(original)), []byte(`{
326+
"flags": 16,
327+
"subfield": {
328+
"flags": 32
329+
}
330+
}`), rules)
331+
332+
if err != nil {
333+
t.Fatal(err)
334+
}
335+
336+
b1, err := Marshal(original)
337+
if err != nil {
338+
t.Fatal(err)
339+
}
340+
341+
b2, err := rw.Rewrite(nil, b1)
342+
if err != nil {
343+
t.Fatal(err)
344+
}
345+
346+
found := &message{}
347+
if err := Unmarshal(b2, &found); err != nil {
348+
t.Fatal(err)
349+
}
350+
351+
if !reflect.DeepEqual(expected, found) {
352+
t.Error("messages mismatch after rewrite")
353+
t.Logf("want:\n%+v", expected)
354+
t.Logf("got:\n%+v", found)
355+
}
356+
}
357+
294358
func BenchmarkRewrite(b *testing.B) {
295359
type message struct {
296360
A int

0 commit comments

Comments
 (0)