Skip to content

Commit 1fd9b75

Browse files
Merge pull request cumulo-autumn#51 from Maoku/patch-screen-closing
added termination processing to the examples/screen demo
2 parents 59de06f + a6f3388 commit 1fd9b75

File tree

1 file changed

+28
-12
lines changed

1 file changed

+28
-12
lines changed

examples/screen/main.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,26 +17,26 @@
1717
from utils.wrapper import StreamDiffusionWrapper
1818

1919
inputs = []
20-
stop_capture = False
2120
top = 0
2221
left = 0
2322

2423
def screen(
24+
event: threading.Event,
2525
height: int = 512,
2626
width: int = 512,
2727
monitor: Dict[str, int] = {"top": 300, "left": 200, "width": 512, "height": 512},
2828
):
2929
global inputs
30-
global stop_capture
3130
with mss.mss() as sct:
3231
while True:
32+
if event.is_set():
33+
print("terminate read thread")
34+
break
3335
img = sct.grab(monitor)
3436
img = PIL.Image.frombytes("RGB", img.size, img.bgra, "raw", "BGRX")
3537
img.resize((height, width))
3638
inputs.append(pil2tensor(img))
37-
if stop_capture:
38-
return
39-
39+
print('exit : screen')
4040
def dummy_screen(
4141
width: int,
4242
height: int,
@@ -61,6 +61,7 @@ def update_geometry(event):
6161
def image_generation_process(
6262
queue: Queue,
6363
fps_queue: Queue,
64+
close_queue: Queue,
6465
model_id_or_path: str,
6566
lora_dict: Optional[Dict[str, float]],
6667
prompt: str,
@@ -133,7 +134,6 @@ def image_generation_process(
133134
"""
134135

135136
global inputs
136-
global stop_capture
137137
stream = StreamDiffusionWrapper(
138138
model_id_or_path=model_id_or_path,
139139
lora_dict=lora_dict,
@@ -160,13 +160,15 @@ def image_generation_process(
160160
guidance_scale=guidance_scale,
161161
delta=delta,
162162
)
163-
164-
input_screen = threading.Thread(target=screen, args=(height, width, monitor))
163+
event = threading.Event()
164+
input_screen = threading.Thread(target=screen, args=(event, height, width, monitor))
165165
input_screen.start()
166166
time.sleep(5)
167167

168168
while True:
169169
try:
170+
if not close_queue.empty(): # closing check
171+
break
170172
if len(inputs) < frame_buffer_size:
171173
time.sleep(0.005)
172174
continue
@@ -188,10 +190,12 @@ def image_generation_process(
188190
fps = 1 / (time.time() - start_time)
189191
fps_queue.put(fps)
190192
except KeyboardInterrupt:
191-
stop_capture = True
192-
print(f"fps: {fps}")
193-
return
193+
break
194194

195+
print("closing image_generation_process...")
196+
event.set() # stop capture thread
197+
input_screen.join()
198+
print(f"fps: {fps}")
195199

196200
def main(
197201
model_id_or_path: str = "KBlueLeaf/kohaku-v2.1",
@@ -219,11 +223,13 @@ def main(
219223
ctx = get_context('spawn')
220224
queue = ctx.Queue()
221225
fps_queue = ctx.Queue()
226+
close_queue = Queue()
222227
process1 = ctx.Process(
223228
target=image_generation_process,
224229
args=(
225230
queue,
226231
fps_queue,
232+
close_queue,
227233
model_id_or_path,
228234
lora_dict,
229235
prompt,
@@ -249,8 +255,18 @@ def main(
249255
process2 = ctx.Process(target=receive_images, args=(queue, fps_queue))
250256
process2.start()
251257

252-
process1.join()
258+
# terminate
253259
process2.join()
260+
print("process2 terminated.")
261+
close_queue.put(True)
262+
print("process1 terminating...")
263+
process1.join(5) # with timeout
264+
if process1.is_alive():
265+
print("process1 still alive. force killing...")
266+
process1.terminate() # force kill...
267+
process1.join()
268+
print("process1 terminated.")
269+
254270

255271
if __name__ == "__main__":
256272
fire.Fire(main)

0 commit comments

Comments
 (0)