Skip to content

Commit a06fbd1

Browse files
[examples] fix multiprocessing on Linux
- use multiprocessing context to specify the spawn start method, which fixes the "RuntimeError: Cannot re-initialize CUDA in forked subprocess" on Linux (verified with Ubuntu 22.04 and kernel 6.2.0-39) - call `.join()` to wait for processes to complete (avoids exiting program immediately)
1 parent 03e2a7f commit a06fbd1

File tree

3 files changed

+24
-16
lines changed

3 files changed

+24
-16
lines changed

examples/optimal-performance/multi.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import threading
44
import time
55
import tkinter as tk
6-
from multiprocessing import Process, Queue
6+
from multiprocessing import Process, Queue, get_context
77
from typing import List, Literal
88

99
import fire
@@ -174,17 +174,20 @@ def main(
174174
"""
175175
Main function to start the image generation and viewer processes.
176176
"""
177-
queue = Queue()
178-
fps_queue = Queue()
179-
process1 = Process(
177+
ctx = get_context('spawn')
178+
queue = ctx.Queue()
179+
fps_queue = ctx.Queue()
180+
process1 = ctx.Process(
180181
target=image_generation_process,
181182
args=(queue, fps_queue, prompt, model_id_or_path, batch_size, acceleration),
182183
)
183184
process1.start()
184185

185-
process2 = Process(target=receive_images, args=(queue, fps_queue))
186+
process2 = ctx.Process(target=receive_images, args=(queue, fps_queue))
186187
process2.start()
187188

189+
process1.join()
190+
process2.join()
188191

189192
if __name__ == "__main__":
190193
fire.Fire(main)

examples/optimal-performance/single.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
import sys
33
import time
4-
from multiprocessing import Process, Queue
4+
from multiprocessing import Process, Queue, get_context
55
from typing import Literal
66

77
import fire
@@ -72,17 +72,20 @@ def main(
7272
"""
7373
Main function to start the image generation and viewer processes.
7474
"""
75-
queue = Queue()
76-
fps_queue = Queue()
77-
process1 = Process(
75+
ctx = get_context('spawn')
76+
queue = ctx.Queue()
77+
fps_queue = ctx.Queue()
78+
process1 = ctx.Process(
7879
target=image_generation_process,
7980
args=(queue, fps_queue, prompt, model_id_or_path, acceleration),
8081
)
8182
process1.start()
8283

83-
process2 = Process(target=receive_images, args=(queue, fps_queue))
84+
process2 = ctx.Process(target=receive_images, args=(queue, fps_queue))
8485
process2.start()
8586

87+
process1.join()
88+
process2.join()
8689

8790
if __name__ == "__main__":
8891
fire.Fire(main)

examples/screen/main.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import sys
33
import time
44
import threading
5-
from multiprocessing import Process, Queue
5+
from multiprocessing import Process, Queue, get_context
66
from typing import List, Literal, Dict, Optional
77
import torch
88
import PIL.Image
@@ -216,10 +216,10 @@ def main(
216216
Main function to start the image generation and viewer processes.
217217
"""
218218
monitor = dummy_screen(width, height)
219-
220-
queue = Queue()
221-
fps_queue = Queue()
222-
process1 = Process(
219+
ctx = get_context('spawn')
220+
queue = ctx.Queue()
221+
fps_queue = ctx.Queue()
222+
process1 = ctx.Process(
223223
target=image_generation_process,
224224
args=(
225225
queue,
@@ -246,9 +246,11 @@ def main(
246246
)
247247
process1.start()
248248

249-
process2 = Process(target=receive_images, args=(queue, fps_queue))
249+
process2 = ctx.Process(target=receive_images, args=(queue, fps_queue))
250250
process2.start()
251251

252+
process1.join()
253+
process2.join()
252254

253255
if __name__ == "__main__":
254256
fire.Fire(main)

0 commit comments

Comments
 (0)