Skip to content

Commit 926565e

Browse files
authored
Merge branch 'main' into feat-merge-inference
2 parents 73ca3a8 + f408e24 commit 926565e

File tree

1 file changed

+67
-0
lines changed

1 file changed

+67
-0
lines changed

src/deep_neurographs/utils/ml_util.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,73 @@
2020
GNN_DEPTH = 2
2121

2222

23+
# --- GPU Scheduler ---
24+
import torch
25+
from multiprocessing import Process, Queue
26+
from queue import Empty
27+
28+
class GPUScheduler:
29+
def __init__(self, model_path, num_gpus):
30+
self.model_path = model_path
31+
self.num_gpus = num_gpus
32+
self.job_queues = []
33+
self.return_queues = []
34+
self.processes = []
35+
self._init_workers()
36+
37+
def _init_workers(self):
38+
for gpu_id in range(self.num_gpus):
39+
job_q = Queue()
40+
ret_q = Queue()
41+
p = Process(
42+
target=self._gpu_worker,
43+
args=(gpu_id, self.model_path, job_q, ret_q),
44+
)
45+
p.start()
46+
self.job_queues.append(job_q)
47+
self.return_queues.append(ret_q)
48+
self.processes.append(p)
49+
50+
def _gpu_worker(self, gpu_id, model_path, job_queue, return_queue):
51+
device = torch.device(f"cuda:{gpu_id}")
52+
model = torch.load(model_path, map_location=device)
53+
model.eval()
54+
while True:
55+
job = job_queue.get()
56+
if job is None:
57+
break # Sentinel to exit
58+
batch, job_id = job
59+
with torch.no_grad():
60+
batch = batch.to(device)
61+
preds = model(batch)
62+
return_queue.put((job_id, preds.cpu()))
63+
64+
def submit(self, batch, job_id):
65+
"""Submit a batch to the next GPU in round-robin fashion."""
66+
gpu_id = job_id % self.num_gpus
67+
self.job_queues[gpu_id].put((batch, job_id))
68+
69+
def get_result(self, job_id, timeout=None):
70+
"""Retrieve results from the return queues."""
71+
for q in self.return_queues:
72+
try:
73+
result_job_id, preds = q.get(timeout=timeout)
74+
if result_job_id == job_id:
75+
return preds
76+
else:
77+
# Re-enqueue if it's not the one we're looking for
78+
q.put((result_job_id, preds))
79+
except Empty:
80+
continue
81+
return None # or raise TimeoutError
82+
83+
def shutdown(self):
84+
"""Stop all GPU worker processes cleanly."""
85+
for q in self.job_queues:
86+
q.put(None)
87+
for p in self.processes:
88+
p.join()
89+
2390
# --- Batch Generation ---
2491
def get_batch(graph, proposals, batch_size, flagged_proposals=set()):
2592
"""

0 commit comments

Comments
 (0)