Skip to content

Commit eecd233

Browse files
committed
feature(prost-build): Generate less boxed if nested type is boxed manually
1 parent 75692e6 commit eecd233

File tree

7 files changed

+310
-14
lines changed

7 files changed

+310
-14
lines changed

prost-build/src/code_generator.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1075,7 +1075,7 @@ impl<'a> CodeGenerator<'a> {
10751075
&& (fd_type == Type::Message || fd_type == Type::Group)
10761076
&& self
10771077
.message_graph
1078-
.is_nested(field.type_name(), fq_message_name)
1078+
.is_directly_nested(field.type_name(), fq_message_name)
10791079
{
10801080
return true;
10811081
}

prost-build/src/message_graph.rs

Lines changed: 129 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
use std::collections::HashMap;
1+
use std::collections::{HashMap, HashSet};
22

3-
use petgraph::algo::has_path_connecting;
43
use petgraph::graph::NodeIndex;
5-
use petgraph::Graph;
4+
use petgraph::visit::{EdgeRef, VisitMap};
5+
use petgraph::{Direction, Graph};
66

77
use prost_types::{
88
field_descriptor_proto::{Label, Type},
@@ -15,9 +15,13 @@ use crate::path::PathMap;
1515
/// The goal is to recognize when message types are recursively nested, so
1616
/// that fields can be boxed when necessary.
1717
pub struct MessageGraph {
18+
/// Map<fq type name, graph node index>
1819
index: HashMap<String, NodeIndex>,
19-
graph: Graph<String, ()>,
20+
/// Graph with fq type name as node, field name as edge
21+
graph: Graph<String, String>,
22+
/// Map<fq type name, DescriptorProto>
2023
messages: HashMap<String, DescriptorProto>,
24+
/// Manually boxed fields
2125
boxed: PathMap<()>,
2226
}
2327

@@ -71,7 +75,8 @@ impl MessageGraph {
7175
for field in &msg.field {
7276
if field.r#type() == Type::Message && field.label() != Label::Repeated {
7377
let field_index = self.get_or_insert_index(field.type_name.clone().unwrap());
74-
self.graph.add_edge(msg_index, field_index, ());
78+
self.graph
79+
.add_edge(msg_index, field_index, field.name.clone().unwrap());
7580
}
7681
}
7782
self.messages.insert(msg_name.clone(), msg.clone());
@@ -86,8 +91,9 @@ impl MessageGraph {
8691
self.messages.get(message)
8792
}
8893

89-
/// Returns true if message type `inner` is nested in message type `outer`.
90-
pub fn is_nested(&self, outer: &str, inner: &str) -> bool {
94+
/// Returns true if message type `inner` is nested in message type `outer`,
95+
/// and no field edge in the chain of dependencies is manually boxed.
96+
pub fn is_directly_nested(&self, outer: &str, inner: &str) -> bool {
9197
let outer = match self.index.get(outer) {
9298
Some(outer) => *outer,
9399
None => return false,
@@ -97,7 +103,12 @@ impl MessageGraph {
97103
None => return false,
98104
};
99105

100-
has_path_connecting(&self.graph, outer, inner, None)
106+
// Check if `inner` is nested in `outer` and ensure that all edge fields are not boxed manually.
107+
is_connected_with_edge_filter(&self.graph, outer, inner, |node, field_name| {
108+
self.boxed
109+
.get_first_field(&self.graph[node], field_name)
110+
.is_none()
111+
})
101112
}
102113

103114
/// Returns `true` if this message can automatically derive Copy trait.
@@ -123,11 +134,11 @@ impl MessageGraph {
123134
false
124135
} else if field.r#type() == Type::Message {
125136
// nested and boxed messages cannot derive Copy
126-
if self.is_nested(field.type_name(), fq_message_name)
127-
|| self
128-
.boxed
129-
.get_first_field(fq_message_name, field.name())
130-
.is_some()
137+
if self
138+
.boxed
139+
.get_first_field(fq_message_name, field.name())
140+
.is_some()
141+
|| self.is_directly_nested(field.type_name(), fq_message_name)
131142
{
132143
false
133144
} else {
@@ -154,3 +165,108 @@ impl MessageGraph {
154165
}
155166
}
156167
}
168+
169+
/// Check two nodes is connected with edge filter
170+
fn is_connected_with_edge_filter<F, N, E>(
171+
graph: &Graph<N, E>,
172+
start: NodeIndex,
173+
end: NodeIndex,
174+
mut is_good_edge: F,
175+
) -> bool
176+
where
177+
F: FnMut(NodeIndex, &E) -> bool,
178+
{
179+
fn visitor<F, N, E>(
180+
graph: &Graph<N, E>,
181+
start: NodeIndex,
182+
end: NodeIndex,
183+
is_good_edge: &mut F,
184+
visited: &mut HashSet<NodeIndex>,
185+
) -> bool
186+
where
187+
F: FnMut(NodeIndex, &E) -> bool,
188+
{
189+
if start == end {
190+
return true;
191+
}
192+
visited.visit(start);
193+
for edge in graph.edges_directed(start, Direction::Outgoing) {
194+
// if the edge doesn't pass the filter, skip it
195+
if !is_good_edge(start, edge.weight()) {
196+
continue;
197+
}
198+
let target = edge.target();
199+
if visited.is_visited(&target) {
200+
continue;
201+
}
202+
if visitor(graph, target, end, is_good_edge, visited) {
203+
return true;
204+
}
205+
}
206+
false
207+
}
208+
let mut visited = HashSet::new();
209+
visitor(graph, start, end, &mut is_good_edge, &mut visited)
210+
}
211+
212+
#[cfg(test)]
213+
mod tests {
214+
use super::*;
215+
216+
#[test]
217+
fn test_connected() {
218+
let mut graph = Graph::new();
219+
let n1 = graph.add_node(1);
220+
let n2 = graph.add_node(2);
221+
let n3 = graph.add_node(3);
222+
let n4 = graph.add_node(4);
223+
let n5 = graph.add_node(5);
224+
let n6 = graph.add_node(6);
225+
let n7 = graph.add_node(7);
226+
let n8 = graph.add_node(8);
227+
graph.add_edge(n1, n2, 1.);
228+
graph.add_edge(n2, n3, 2.);
229+
graph.add_edge(n3, n4, 3.);
230+
graph.add_edge(n4, n5, 4.);
231+
graph.add_edge(n5, n6, 5.);
232+
graph.add_edge(n6, n7, 6.);
233+
graph.add_edge(n7, n8, 7.);
234+
graph.add_edge(n8, n1, 8.);
235+
assert!(is_connected_with_edge_filter(&graph, n2, n1, |_, edge| {
236+
dbg!(edge);
237+
true
238+
}),);
239+
assert!(is_connected_with_edge_filter(&graph, n2, n1, |_, edge| {
240+
dbg!(edge);
241+
edge < &8.5
242+
}),);
243+
assert!(!is_connected_with_edge_filter(&graph, n2, n1, |_, edge| {
244+
dbg!(edge);
245+
edge < &7.5
246+
}),);
247+
}
248+
249+
#[test]
250+
fn test_connected_multi_circle() {
251+
let mut graph = Graph::new();
252+
let n0 = graph.add_node(0);
253+
let n1 = graph.add_node(1);
254+
let n2 = graph.add_node(2);
255+
let n3 = graph.add_node(3);
256+
let n4 = graph.add_node(4);
257+
graph.add_edge(n0, n1, 0.);
258+
graph.add_edge(n1, n2, 1.);
259+
graph.add_edge(n2, n3, 2.);
260+
graph.add_edge(n3, n0, 3.);
261+
graph.add_edge(n1, n4, 1.5);
262+
graph.add_edge(n4, n0, 2.5);
263+
assert!(is_connected_with_edge_filter(&graph, n1, n0, |_, edge| {
264+
dbg!(edge);
265+
edge < &2.8
266+
}),);
267+
assert!(!is_connected_with_edge_filter(&graph, n1, n0, |_, edge| {
268+
dbg!(edge);
269+
edge < &2.1
270+
}),);
271+
}
272+
}

tests/src/build.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,19 @@ fn main() {
151151

152152
std::fs::create_dir_all(&out_path).unwrap();
153153

154+
prost_build::Config::new()
155+
.out_dir(src.join("nesting_complex/boxed"))
156+
.boxed("Foo.bar")
157+
.boxed("BazB.baz_c")
158+
.boxed("BakC.bak_d")
159+
.compile_protos(&[src.join("nesting_complex.proto")], includes)
160+
.unwrap();
161+
162+
prost_build::Config::new()
163+
.out_dir(src.join("nesting_complex/"))
164+
.compile_protos(&[src.join("nesting_complex.proto")], includes)
165+
.unwrap();
166+
154167
prost_build::Config::new()
155168
.bytes(["."])
156169
.out_dir(out_path)

tests/src/lib.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,14 @@ pub mod proto3 {
134134
}
135135
}
136136

137+
pub mod nesting_complex_boxed {
138+
include!("nesting_complex/boxed/nesting_complex.rs");
139+
}
140+
141+
pub mod nesting_complex {
142+
include!("nesting_complex/nesting_complex.rs");
143+
}
144+
137145
pub mod invalid {
138146
pub mod doctest {
139147
include!(concat!(env!("OUT_DIR"), "/invalid.doctest.rs"));

tests/src/nesting_complex.proto

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
syntax = "proto2";
2+
3+
package nesting_complex;
4+
5+
// ----- Directly nested
6+
message Foo {
7+
optional Bar bar = 1;
8+
}
9+
10+
message Bar {
11+
optional Foo foo = 1;
12+
}
13+
14+
// ----- Transitively nested
15+
message BazA {
16+
optional BazB baz_b = 1;
17+
}
18+
19+
message BazB {
20+
optional BazC baz_c = 1;
21+
}
22+
23+
message BazC {
24+
optional BazA baz_a = 1;
25+
}
26+
27+
// ----- Transitively nested in two chain
28+
message BakA {
29+
optional BakB bak_b = 1;
30+
}
31+
32+
message BakB {
33+
optional BakC bak_c = 1;
34+
optional BakE bak_e = 2;
35+
}
36+
37+
message BakC {
38+
optional BakD bak_d = 1;
39+
}
40+
41+
message BakD {
42+
optional BakA bak_a = 1;
43+
}
44+
45+
message BakE {
46+
optional BakA bak_a = 1;
47+
}
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
// This file is @generated by prost-build.
2+
/// ----- Directly nested
3+
#[derive(Clone, PartialEq, ::prost::Message)]
4+
pub struct Foo {
5+
#[prost(message, optional, boxed, tag = "1")]
6+
pub bar: ::core::option::Option<::prost::alloc::boxed::Box<Bar>>,
7+
}
8+
#[derive(Clone, PartialEq, ::prost::Message)]
9+
pub struct Bar {
10+
#[prost(message, optional, tag = "1")]
11+
pub foo: ::core::option::Option<Foo>,
12+
}
13+
/// ----- Transitively nested
14+
#[derive(Clone, PartialEq, ::prost::Message)]
15+
pub struct BazA {
16+
#[prost(message, optional, tag = "1")]
17+
pub baz_b: ::core::option::Option<BazB>,
18+
}
19+
#[derive(Clone, PartialEq, ::prost::Message)]
20+
pub struct BazB {
21+
#[prost(message, optional, boxed, tag = "1")]
22+
pub baz_c: ::core::option::Option<::prost::alloc::boxed::Box<BazC>>,
23+
}
24+
#[derive(Clone, PartialEq, ::prost::Message)]
25+
pub struct BazC {
26+
#[prost(message, optional, tag = "1")]
27+
pub baz_a: ::core::option::Option<BazA>,
28+
}
29+
/// ----- Transitively nested in two chain
30+
#[derive(Clone, PartialEq, ::prost::Message)]
31+
pub struct BakA {
32+
#[prost(message, optional, boxed, tag = "1")]
33+
pub bak_b: ::core::option::Option<::prost::alloc::boxed::Box<BakB>>,
34+
}
35+
#[derive(Clone, PartialEq, ::prost::Message)]
36+
pub struct BakB {
37+
#[prost(message, optional, tag = "1")]
38+
pub bak_c: ::core::option::Option<BakC>,
39+
#[prost(message, optional, boxed, tag = "2")]
40+
pub bak_e: ::core::option::Option<::prost::alloc::boxed::Box<BakE>>,
41+
}
42+
#[derive(Clone, PartialEq, ::prost::Message)]
43+
pub struct BakC {
44+
#[prost(message, optional, boxed, tag = "1")]
45+
pub bak_d: ::core::option::Option<::prost::alloc::boxed::Box<BakD>>,
46+
}
47+
#[derive(Clone, PartialEq, ::prost::Message)]
48+
pub struct BakD {
49+
#[prost(message, optional, tag = "1")]
50+
pub bak_a: ::core::option::Option<BakA>,
51+
}
52+
#[derive(Clone, PartialEq, ::prost::Message)]
53+
pub struct BakE {
54+
#[prost(message, optional, boxed, tag = "1")]
55+
pub bak_a: ::core::option::Option<::prost::alloc::boxed::Box<BakA>>,
56+
}
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
// This file is @generated by prost-build.
2+
/// ----- Directly nested
3+
#[derive(Clone, PartialEq, ::prost::Message)]
4+
pub struct Foo {
5+
#[prost(message, optional, boxed, tag = "1")]
6+
pub bar: ::core::option::Option<::prost::alloc::boxed::Box<Bar>>,
7+
}
8+
#[derive(Clone, PartialEq, ::prost::Message)]
9+
pub struct Bar {
10+
#[prost(message, optional, boxed, tag = "1")]
11+
pub foo: ::core::option::Option<::prost::alloc::boxed::Box<Foo>>,
12+
}
13+
/// ----- Transitively nested
14+
#[derive(Clone, PartialEq, ::prost::Message)]
15+
pub struct BazA {
16+
#[prost(message, optional, boxed, tag = "1")]
17+
pub baz_b: ::core::option::Option<::prost::alloc::boxed::Box<BazB>>,
18+
}
19+
#[derive(Clone, PartialEq, ::prost::Message)]
20+
pub struct BazB {
21+
#[prost(message, optional, boxed, tag = "1")]
22+
pub baz_c: ::core::option::Option<::prost::alloc::boxed::Box<BazC>>,
23+
}
24+
#[derive(Clone, PartialEq, ::prost::Message)]
25+
pub struct BazC {
26+
#[prost(message, optional, boxed, tag = "1")]
27+
pub baz_a: ::core::option::Option<::prost::alloc::boxed::Box<BazA>>,
28+
}
29+
/// ----- Transitively nested in two chain
30+
#[derive(Clone, PartialEq, ::prost::Message)]
31+
pub struct BakA {
32+
#[prost(message, optional, boxed, tag = "1")]
33+
pub bak_b: ::core::option::Option<::prost::alloc::boxed::Box<BakB>>,
34+
}
35+
#[derive(Clone, PartialEq, ::prost::Message)]
36+
pub struct BakB {
37+
#[prost(message, optional, boxed, tag = "1")]
38+
pub bak_c: ::core::option::Option<::prost::alloc::boxed::Box<BakC>>,
39+
#[prost(message, optional, boxed, tag = "2")]
40+
pub bak_e: ::core::option::Option<::prost::alloc::boxed::Box<BakE>>,
41+
}
42+
#[derive(Clone, PartialEq, ::prost::Message)]
43+
pub struct BakC {
44+
#[prost(message, optional, boxed, tag = "1")]
45+
pub bak_d: ::core::option::Option<::prost::alloc::boxed::Box<BakD>>,
46+
}
47+
#[derive(Clone, PartialEq, ::prost::Message)]
48+
pub struct BakD {
49+
#[prost(message, optional, boxed, tag = "1")]
50+
pub bak_a: ::core::option::Option<::prost::alloc::boxed::Box<BakA>>,
51+
}
52+
#[derive(Clone, PartialEq, ::prost::Message)]
53+
pub struct BakE {
54+
#[prost(message, optional, boxed, tag = "1")]
55+
pub bak_a: ::core::option::Option<::prost::alloc::boxed::Box<BakA>>,
56+
}

0 commit comments

Comments
 (0)