|
20 | 20 | GNN_DEPTH = 2 |
21 | 21 |
|
22 | 22 |
|
| 23 | +# --- GPU Scheduler --- |
| 24 | +import torch |
| 25 | +from multiprocessing import Process, Queue |
| 26 | +from queue import Empty |
| 27 | + |
| 28 | +class GPUScheduler: |
| 29 | + def __init__(self, model_path, num_gpus): |
| 30 | + self.model_path = model_path |
| 31 | + self.num_gpus = num_gpus |
| 32 | + self.job_queues = [] |
| 33 | + self.return_queues = [] |
| 34 | + self.processes = [] |
| 35 | + self._init_workers() |
| 36 | + |
| 37 | + def _init_workers(self): |
| 38 | + for gpu_id in range(self.num_gpus): |
| 39 | + job_q = Queue() |
| 40 | + ret_q = Queue() |
| 41 | + p = Process( |
| 42 | + target=self._gpu_worker, |
| 43 | + args=(gpu_id, self.model_path, job_q, ret_q), |
| 44 | + ) |
| 45 | + p.start() |
| 46 | + self.job_queues.append(job_q) |
| 47 | + self.return_queues.append(ret_q) |
| 48 | + self.processes.append(p) |
| 49 | + |
| 50 | + def _gpu_worker(self, gpu_id, model_path, job_queue, return_queue): |
| 51 | + device = torch.device(f"cuda:{gpu_id}") |
| 52 | + model = torch.load(model_path, map_location=device) |
| 53 | + model.eval() |
| 54 | + while True: |
| 55 | + job = job_queue.get() |
| 56 | + if job is None: |
| 57 | + break # Sentinel to exit |
| 58 | + batch, job_id = job |
| 59 | + with torch.no_grad(): |
| 60 | + batch = batch.to(device) |
| 61 | + preds = model(batch) |
| 62 | + return_queue.put((job_id, preds.cpu())) |
| 63 | + |
| 64 | + def submit(self, batch, job_id): |
| 65 | + """Submit a batch to the next GPU in round-robin fashion.""" |
| 66 | + gpu_id = job_id % self.num_gpus |
| 67 | + self.job_queues[gpu_id].put((batch, job_id)) |
| 68 | + |
| 69 | + def get_result(self, job_id, timeout=None): |
| 70 | + """Retrieve results from the return queues.""" |
| 71 | + for q in self.return_queues: |
| 72 | + try: |
| 73 | + result_job_id, preds = q.get(timeout=timeout) |
| 74 | + if result_job_id == job_id: |
| 75 | + return preds |
| 76 | + else: |
| 77 | + # Re-enqueue if it's not the one we're looking for |
| 78 | + q.put((result_job_id, preds)) |
| 79 | + except Empty: |
| 80 | + continue |
| 81 | + return None # or raise TimeoutError |
| 82 | + |
| 83 | + def shutdown(self): |
| 84 | + """Stop all GPU worker processes cleanly.""" |
| 85 | + for q in self.job_queues: |
| 86 | + q.put(None) |
| 87 | + for p in self.processes: |
| 88 | + p.join() |
| 89 | + |
23 | 90 | # --- Batch Generation --- |
24 | 91 | def get_batch(graph, proposals, batch_size, flagged_proposals=set()): |
25 | 92 | """ |
|
0 commit comments