17
17
from utils .wrapper import StreamDiffusionWrapper
18
18
19
19
inputs = []
20
- stop_capture = False
21
20
top = 0
22
21
left = 0
23
22
24
23
def screen (
24
+ event : threading .Event ,
25
25
height : int = 512 ,
26
26
width : int = 512 ,
27
27
monitor : Dict [str , int ] = {"top" : 300 , "left" : 200 , "width" : 512 , "height" : 512 },
28
28
):
29
29
global inputs
30
- global stop_capture
31
30
with mss .mss () as sct :
32
31
while True :
32
+ if event .is_set ():
33
+ print ("terminate read thread" )
34
+ break
33
35
img = sct .grab (monitor )
34
36
img = PIL .Image .frombytes ("RGB" , img .size , img .bgra , "raw" , "BGRX" )
35
37
img .resize ((height , width ))
36
38
inputs .append (pil2tensor (img ))
37
- if stop_capture :
38
- return
39
-
39
+ print ('exit : screen' )
40
40
def dummy_screen (
41
41
width : int ,
42
42
height : int ,
@@ -61,6 +61,7 @@ def update_geometry(event):
61
61
def image_generation_process (
62
62
queue : Queue ,
63
63
fps_queue : Queue ,
64
+ close_queue : Queue ,
64
65
model_id_or_path : str ,
65
66
lora_dict : Optional [Dict [str , float ]],
66
67
prompt : str ,
@@ -133,7 +134,6 @@ def image_generation_process(
133
134
"""
134
135
135
136
global inputs
136
- global stop_capture
137
137
stream = StreamDiffusionWrapper (
138
138
model_id_or_path = model_id_or_path ,
139
139
lora_dict = lora_dict ,
@@ -160,13 +160,15 @@ def image_generation_process(
160
160
guidance_scale = guidance_scale ,
161
161
delta = delta ,
162
162
)
163
-
164
- input_screen = threading .Thread (target = screen , args = (height , width , monitor ))
163
+ event = threading . Event ()
164
+ input_screen = threading .Thread (target = screen , args = (event , height , width , monitor ))
165
165
input_screen .start ()
166
166
time .sleep (5 )
167
167
168
168
while True :
169
169
try :
170
+ if not close_queue .empty (): # closing check
171
+ break
170
172
if len (inputs ) < frame_buffer_size :
171
173
time .sleep (0.005 )
172
174
continue
@@ -188,10 +190,12 @@ def image_generation_process(
188
190
fps = 1 / (time .time () - start_time )
189
191
fps_queue .put (fps )
190
192
except KeyboardInterrupt :
191
- stop_capture = True
192
- print (f"fps: { fps } " )
193
- return
193
+ break
194
194
195
+ print ("closing image_generation_process..." )
196
+ event .set () # stop capture thread
197
+ input_screen .join ()
198
+ print (f"fps: { fps } " )
195
199
196
200
def main (
197
201
model_id_or_path : str = "KBlueLeaf/kohaku-v2.1" ,
@@ -219,11 +223,13 @@ def main(
219
223
ctx = get_context ('spawn' )
220
224
queue = ctx .Queue ()
221
225
fps_queue = ctx .Queue ()
226
+ close_queue = Queue ()
222
227
process1 = ctx .Process (
223
228
target = image_generation_process ,
224
229
args = (
225
230
queue ,
226
231
fps_queue ,
232
+ close_queue ,
227
233
model_id_or_path ,
228
234
lora_dict ,
229
235
prompt ,
@@ -249,8 +255,18 @@ def main(
249
255
process2 = ctx .Process (target = receive_images , args = (queue , fps_queue ))
250
256
process2 .start ()
251
257
252
- process1 . join ()
258
+ # terminate
253
259
process2 .join ()
260
+ print ("process2 terminated." )
261
+ close_queue .put (True )
262
+ print ("process1 terminating..." )
263
+ process1 .join (5 ) # with timeout
264
+ if process1 .is_alive ():
265
+ print ("process1 still alive. force killing..." )
266
+ process1 .terminate () # force kill...
267
+ process1 .join ()
268
+ print ("process1 terminated." )
269
+
254
270
255
271
if __name__ == "__main__" :
256
272
fire .Fire (main )
0 commit comments