@@ -139,17 +139,27 @@ func (f fieldset) index(i int) (int, int) {
139
139
// ParseRewriteTemplate constructs a Rewriter for a protobuf type using the
140
140
// given json template to describe the rewrite rules.
141
141
//
142
- // The json template contains a representation of the
143
- func ParseRewriteTemplate (typ Type , jsonTemplate []byte ) (Rewriter , error ) {
142
+ // The json template contains a representation of the message that is used as the
143
+ // source values to overwrite in the protobuf targeted by the resulting rewriter.
144
+ //
145
+ // The rules are an optional set of RewriterRules that can provide alternative
146
+ // Rewriters from the default used for the field type. These rules are given the
147
+ // json.RawMessage bytes from the template, and they are expected to create a
148
+ // Rewriter to be applied against the target protobuf.
149
+ func ParseRewriteTemplate (typ Type , jsonTemplate []byte , rules ... RewriterRules ) (Rewriter , error ) {
144
150
switch typ .Kind () {
145
151
case Struct :
146
- return parseRewriteTemplateStruct (typ , 0 , jsonTemplate )
152
+ return parseRewriteTemplateStruct (typ , 0 , jsonTemplate , rules ... )
147
153
default :
148
154
return nil , fmt .Errorf ("cannot construct a rewrite template from a non-struct type %s" , typ .Name ())
149
155
}
150
156
}
151
157
152
- func parseRewriteTemplate (t Type , f FieldNumber , j json.RawMessage ) (Rewriter , error ) {
158
+ func parseRewriteTemplate (t Type , f FieldNumber , j json.RawMessage , rule any ) (Rewriter , error ) {
159
+ if rwer , ok := rule .(Rewriterer ); ok {
160
+ return rwer .Rewriter (t , f , j )
161
+ }
162
+
153
163
switch t .Kind () {
154
164
case Bool :
155
165
return parseRewriteTemplateBool (t , f , j )
@@ -184,7 +194,11 @@ func parseRewriteTemplate(t Type, f FieldNumber, j json.RawMessage) (Rewriter, e
184
194
case Map :
185
195
return parseRewriteTemplateMap (t , f , j )
186
196
case Struct :
187
- return parseRewriteTemplateStruct (t , f , j )
197
+ sub , n , ok := [1 ]RewriterRules {}, 0 , false
198
+ if sub [0 ], ok = rule .(RewriterRules ); ok {
199
+ n = 1
200
+ }
201
+ return parseRewriteTemplateStruct (t , f , j , sub [:n ]... )
188
202
default :
189
203
return nil , fmt .Errorf ("cannot construct a rewriter from type %s" , t .Name ())
190
204
}
@@ -376,7 +390,7 @@ func parseRewriteTemplateMap(t Type, f FieldNumber, j json.RawMessage) (Rewriter
376
390
return MultiRewriter (rewriters ... ), nil
377
391
}
378
392
379
- func parseRewriteTemplateStruct (t Type , f FieldNumber , j json.RawMessage ) (Rewriter , error ) {
393
+ func parseRewriteTemplateStruct (t Type , f FieldNumber , j json.RawMessage , rules ... RewriterRules ) (Rewriter , error ) {
380
394
template := map [string ]json.RawMessage {}
381
395
382
396
if err := json .Unmarshal (j , & template ); err != nil {
@@ -408,10 +422,18 @@ func parseRewriteTemplateStruct(t Type, f FieldNumber, j json.RawMessage) (Rewri
408
422
fields = []json.RawMessage {v }
409
423
}
410
424
425
+ var rule any
426
+ for i := range rules {
427
+ if r , ok := rules [i ][f.Name ]; ok {
428
+ rule = r
429
+ break
430
+ }
431
+ }
432
+
411
433
rewriters = rewriters [:0 ]
412
434
413
435
for _ , v := range fields {
414
- rw , err := parseRewriteTemplate (f .Type , f .Number , v )
436
+ rw , err := parseRewriteTemplate (f .Type , f .Number , v , rule )
415
437
if err != nil {
416
438
return nil , fmt .Errorf ("%s: %w" , k , err )
417
439
}
@@ -462,3 +484,117 @@ func (f *embddedRewriter) Rewrite(out, in []byte) ([]byte, error) {
462
484
copy (out [prefix :], b [:tagAndLen ])
463
485
return out , nil
464
486
}
487
+
488
+ // RewriterRules defines a set of rules for overriding the Rewriter used for any
489
+ // particular field. These maps may be nested for defining rules for struct members.
490
+ //
491
+ // For example:
492
+ //
493
+ // rules := proto.RewriterRules {
494
+ // "flags": proto.BitOr[uint64]{},
495
+ // "nested": proto.RewriterRules {
496
+ // "name": myCustomRewriter,
497
+ // },
498
+ // }
499
+ type RewriterRules map [string ]any
500
+
501
+ // Rewriterer is the interface for producing a Rewriter for a given Type, FieldNumber
502
+ // and json.RawMessage. The JSON value is the JSON-encoded payload that should be
503
+ // decoded to produce the appropriate Rewriter. Implementations of the Rewriterer
504
+ // interface are added to the RewriterRules to specify the rules for performing
505
+ // custom rewrite logic.
506
+ type Rewriterer interface {
507
+ Rewriter (Type , FieldNumber , json.RawMessage ) (Rewriter , error )
508
+ }
509
+
510
+ // BitOr implments the Rewriterer interface for providing a bitwise-or rewrite
511
+ // logic for integers rather than replacing them. Instances of this type are
512
+ // zero-size, carrying only the generic type for creating the appropriate
513
+ // Rewriter when requested.
514
+ //
515
+ // Adding these to a RewriterRules looks like:
516
+ //
517
+ // rules := proto.RewriterRules {
518
+ // "flags": proto.BitOr[uint64]{},
519
+ // }
520
+ //
521
+ // When used as a rule when rewriting from a template, the BitOr expects a JSON-
522
+ // encoded integer passed into the Rewriter method. This parsed integer is then
523
+ // used to perform a bitwise-or against the protobuf message that is being rewritten.
524
+ //
525
+ // The above example can then be used like:
526
+ //
527
+ // template := []byte(`{"flags": 8}`) // n |= 0b1000
528
+ // rw, err := proto.ParseRewriteTemplate(typ, template, rules)
529
+ type BitOr [T integer ] struct {}
530
+
531
+ // integer is the contraint used by the BitOr Rewriterer and the bitOrRW Rewriter.
532
+ // Because these perform bitwise-or operations, the types must be integer-like.
533
+ type integer interface {
534
+ ~ int | ~ int32 | ~ int64 | ~ uint | ~ uint32 | ~ uint64
535
+ }
536
+
537
+ // Rewriter implements the Rewriterer interface. The JSON value provided to this
538
+ // method comes from the template used for rewriting. The returned Rewriter will use
539
+ // this JSON-encoded integer to perform a bitwise-or against the protobuf message
540
+ // that is being rewritten.
541
+ func (BitOr [T ]) Rewriter (t Type , f FieldNumber , j json.RawMessage ) (Rewriter , error ) {
542
+ var v T
543
+ err := json .Unmarshal (j , & v )
544
+ if err != nil {
545
+ return nil , err
546
+ }
547
+ return BitOrRewriter (t , f , v )
548
+ }
549
+
550
+ // BitOrRewriter creates a bitwise-or Rewriter for a given field type and number.
551
+ // The mask is the value or'ed with values in the target protobuf.
552
+ func BitOrRewriter [T integer ](t Type , f FieldNumber , mask T ) (Rewriter , error ) {
553
+ switch t .Kind () {
554
+ case Int32 , Int64 , Sint32 , Sint64 , Uint32 , Uint64 , Fix32 , Fix64 , Sfix32 , Sfix64 :
555
+ default :
556
+ return nil , fmt .Errorf ("cannot construct a rewriter from type %s" , t .Name ())
557
+ }
558
+ return bitOrRW [T ]{mask : mask , t : t , f : f }, nil
559
+ }
560
+
561
+ // bitOrRW is the Rewriter returned by the BitOr Rewriter method.
562
+ type bitOrRW [T integer ] struct {
563
+ mask T
564
+ t Type
565
+ f FieldNumber
566
+ }
567
+
568
+ // Rewrite implements the Rewriter interface performing a bitwise-or between the
569
+ // template value and the input value.
570
+ func (r bitOrRW [T ]) Rewrite (out , in []byte ) ([]byte , error ) {
571
+ var v T
572
+ if err := Unmarshal (in , & v ); err != nil {
573
+ return nil , err
574
+ }
575
+
576
+ v |= r .mask
577
+
578
+ switch r .t .Kind () {
579
+ case Int32 :
580
+ return r .f .Int32 (int32 (v )).Rewrite (out , in )
581
+ case Int64 :
582
+ return r .f .Int64 (int64 (v )).Rewrite (out , in )
583
+ case Sint32 :
584
+ return r .f .Uint32 (encodeZigZag32 (int32 (v ))).Rewrite (out , in )
585
+ case Sint64 :
586
+ return r .f .Uint64 (encodeZigZag64 (int64 (v ))).Rewrite (out , in )
587
+ case Uint32 , Uint64 :
588
+ return r .f .Uint64 (uint64 (v )).Rewrite (out , in )
589
+ case Fix32 :
590
+ return r .f .Fixed32 (uint32 (v )).Rewrite (out , in )
591
+ case Fix64 :
592
+ return r .f .Fixed64 (uint64 (v )).Rewrite (out , in )
593
+ case Sfix32 :
594
+ return r .f .Fixed32 (encodeZigZag32 (int32 (v ))).Rewrite (out , in )
595
+ case Sfix64 :
596
+ return r .f .Fixed64 (encodeZigZag64 (int64 (v ))).Rewrite (out , in )
597
+ }
598
+
599
+ panic ("unreachable" ) // Kind is validated when creating instances
600
+ }
0 commit comments