@@ -113,30 +113,49 @@ def get_node(n):
113113 return nodes , edges
114114
115115
116- def draw_graph (graph : Graph , ax = None ):
116+ def draw_graph (graph : Graph , ax = None , * , adjust_axes = None ):
117117 if ax is None :
118118 fig , ax = plt .subplots ()
119+ if adjust_axes is None :
120+ adjust_axes = True
121+
122+ inverted = adjust_axes or ax .yaxis .get_inverted ()
119123
120124 origin_y = 0
125+ xmax = 0
121126
122127 for sg in graph ._subgraphs :
123128 nodes , edges = _position_subgraph (sg )
129+ annotations = {}
124130 # Draw nodes
125131 for node in nodes :
126- ax .annotate (
127- node .format (), (node .x , node .y + origin_y ), bbox = {"boxstyle" : "round" }
132+ annotations [node .format ()] = ax .annotate (
133+ node .format (),
134+ (node .x , node .y + origin_y ),
135+ ha = "center" ,
136+ va = "center" ,
137+ bbox = {"boxstyle" : "round" , "facecolor" : "none" },
128138 )
129139
130140 # Draw edges
131141 for edge in edges :
132- ax .annotate (
142+ arr = ax .annotate (
133143 "" ,
134- (edge .child .x , edge .child .y + origin_y ),
135- (edge .parent .x , edge .parent .y + origin_y ),
144+ (0.5 , 1.05 if inverted else - 0.05 ),
145+ (0.5 , - 0.05 if inverted else 1.05 ),
146+ xycoords = annotations [edge .child .format ()],
147+ textcoords = annotations [edge .parent .format ()],
136148 arrowprops = {"arrowstyle" : "->" },
149+ annotation_clip = True ,
137150 )
138- mid_x = (edge .child .x + edge .parent .x ) / 2
139- mid_y = (edge .child .y + edge .parent .y ) / 2
140- ax .text (mid_x , mid_y + origin_y , edge .name )
151+ ax .annotate (edge .name , (0.5 , 0.5 ), xytext = (0.5 , 0.5 ), textcoords = arr )
141152
142153 origin_y += max (node .y for node in nodes ) + 1
154+ xmax = max (xmax , max (node .x for node in nodes ))
155+
156+ if adjust_axes :
157+ ax .set_ylim (origin_y , - 1 )
158+ ax .set_xlim (- 1 , xmax + 1 )
159+ ax .spines [:].set_visible (False )
160+ ax .set_xticks ([])
161+ ax .set_yticks ([])
0 commit comments