Skip to content

Commit 62596fd

Browse files
Merge pull request cumulo-autumn#39 from GradientSurfer/main
[examples] fix multiprocessing on Linux
2 parents 03e2a7f + cce1125 commit 62596fd

File tree

4 files changed

+25
-17
lines changed

4 files changed

+25
-17
lines changed

examples/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ If you want to maximize performance, you need to install with following steps ex
1010

1111
## `screen/`
1212

13-
Take a screen capture and process it. **This script only works on Windows.**
13+
Take a screen capture and process it.
1414

1515
When you run the script, a translucent window appears. Position it at where you want to capture the screen and press the enter key to finalize the capture area.
1616

examples/optimal-performance/multi.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import threading
44
import time
55
import tkinter as tk
6-
from multiprocessing import Process, Queue
6+
from multiprocessing import Process, Queue, get_context
77
from typing import List, Literal
88

99
import fire
@@ -174,17 +174,20 @@ def main(
174174
"""
175175
Main function to start the image generation and viewer processes.
176176
"""
177-
queue = Queue()
178-
fps_queue = Queue()
179-
process1 = Process(
177+
ctx = get_context('spawn')
178+
queue = ctx.Queue()
179+
fps_queue = ctx.Queue()
180+
process1 = ctx.Process(
180181
target=image_generation_process,
181182
args=(queue, fps_queue, prompt, model_id_or_path, batch_size, acceleration),
182183
)
183184
process1.start()
184185

185-
process2 = Process(target=receive_images, args=(queue, fps_queue))
186+
process2 = ctx.Process(target=receive_images, args=(queue, fps_queue))
186187
process2.start()
187188

189+
process1.join()
190+
process2.join()
188191

189192
if __name__ == "__main__":
190193
fire.Fire(main)

examples/optimal-performance/single.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
import sys
33
import time
4-
from multiprocessing import Process, Queue
4+
from multiprocessing import Process, Queue, get_context
55
from typing import Literal
66

77
import fire
@@ -72,17 +72,20 @@ def main(
7272
"""
7373
Main function to start the image generation and viewer processes.
7474
"""
75-
queue = Queue()
76-
fps_queue = Queue()
77-
process1 = Process(
75+
ctx = get_context('spawn')
76+
queue = ctx.Queue()
77+
fps_queue = ctx.Queue()
78+
process1 = ctx.Process(
7879
target=image_generation_process,
7980
args=(queue, fps_queue, prompt, model_id_or_path, acceleration),
8081
)
8182
process1.start()
8283

83-
process2 = Process(target=receive_images, args=(queue, fps_queue))
84+
process2 = ctx.Process(target=receive_images, args=(queue, fps_queue))
8485
process2.start()
8586

87+
process1.join()
88+
process2.join()
8689

8790
if __name__ == "__main__":
8891
fire.Fire(main)

examples/screen/main.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import sys
33
import time
44
import threading
5-
from multiprocessing import Process, Queue
5+
from multiprocessing import Process, Queue, get_context
66
from typing import List, Literal, Dict, Optional
77
import torch
88
import PIL.Image
@@ -216,10 +216,10 @@ def main(
216216
Main function to start the image generation and viewer processes.
217217
"""
218218
monitor = dummy_screen(width, height)
219-
220-
queue = Queue()
221-
fps_queue = Queue()
222-
process1 = Process(
219+
ctx = get_context('spawn')
220+
queue = ctx.Queue()
221+
fps_queue = ctx.Queue()
222+
process1 = ctx.Process(
223223
target=image_generation_process,
224224
args=(
225225
queue,
@@ -246,9 +246,11 @@ def main(
246246
)
247247
process1.start()
248248

249-
process2 = Process(target=receive_images, args=(queue, fps_queue))
249+
process2 = ctx.Process(target=receive_images, args=(queue, fps_queue))
250250
process2.start()
251251

252+
process1.join()
253+
process2.join()
252254

253255
if __name__ == "__main__":
254256
fire.Fire(main)

0 commit comments

Comments
 (0)