|
12 | 12 | import abc |
13 | 13 | import logging |
14 | 14 | import time |
15 | | -from typing import Any, Dict, List, Optional, Type, Union |
| 15 | +from collections import defaultdict |
| 16 | +from typing import Any, Dict, List, Optional, Sequence, Type, TypeVar, Union |
16 | 17 |
|
17 | 18 | import torch |
18 | 19 | import torch.distributed as dist |
|
106 | 107 | } |
107 | 108 |
|
108 | 109 |
|
| 110 | +T = TypeVar("T") |
| 111 | + |
109 | 112 | # Label used for emitting model metrics to the coresponding trainer publishers. |
110 | 113 | MODEL_METRIC_LABEL: str = "model" |
111 | 114 |
|
@@ -370,31 +373,29 @@ def _get_metric_states( |
370 | 373 | world_size: int, |
371 | 374 | process_group: Union[dist.ProcessGroup, DeviceMesh], |
372 | 375 | ) -> Dict[str, Dict[str, Union[torch.Tensor, List[torch.Tensor]]]]: |
373 | | - metric_computations = metric._metrics_computations |
374 | | - tasks = metric._tasks |
375 | | - |
376 | | - state_aggregated = {} |
377 | | - for task, metric_computation in zip(tasks, metric_computations): |
378 | | - inputs = [] |
379 | | - state_aggregated[task.name] = {} |
| 376 | + result = defaultdict(dict) |
| 377 | + for task, computation in zip(metric._tasks, metric._metrics_computations): |
380 | 378 | # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute |
381 | 379 | # `items`. |
382 | | - for attr, reduction_fn in metric_computation._reductions.items(): |
383 | | - inputs.append((attr, getattr(metric_computation, attr), reduction_fn)) |
384 | | - |
385 | | - # TODO: do one all gather call per metric, instead of one per state |
386 | | - # this may require more logic as shapes of states are not guranteed to be same |
387 | | - # may need padding |
388 | | - for state, tensor, reduction_fn in inputs: |
389 | | - gather_list = [torch.empty_like(tensor) for _ in range(world_size)] |
390 | | - dist.all_gather(gather_list, tensor, group=process_group) |
391 | | - state_aggregated[task.name][state] = ( |
392 | | - reduction_fn(torch.stack(gather_list)) |
393 | | - if reduction_fn is not None |
394 | | - else gather_list |
| 380 | + for state_name, reduction_fn in computation._reductions.items(): |
| 381 | + tensor_or_list: Union[List[torch.Tensor], torch.Tensor] = getattr( |
| 382 | + computation, state_name |
| 383 | + ) |
| 384 | + |
| 385 | + if isinstance(tensor_or_list, list): |
| 386 | + gathered = _all_gather_tensor_list( |
| 387 | + tensor_or_list, world_size, process_group |
| 388 | + ) |
| 389 | + else: |
| 390 | + gathered = torch.stack( |
| 391 | + _all_gather_tensor(tensor_or_list, world_size, process_group) |
| 392 | + ) |
| 393 | + reduced = ( |
| 394 | + reduction_fn(gathered) if reduction_fn is not None else gathered |
395 | 395 | ) |
| 396 | + result[task.name][state_name] = reduced |
396 | 397 |
|
397 | | - return state_aggregated |
| 398 | + return result |
398 | 399 |
|
399 | 400 | def get_pre_compute_states( |
400 | 401 | self, pg: Optional[Union[dist.ProcessGroup, DeviceMesh]] = None |
@@ -611,3 +612,26 @@ def generate_metric_module( |
611 | 612 | ) |
612 | 613 | metrics.to(device) |
613 | 614 | return metrics |
| 615 | + |
| 616 | + |
| 617 | +def _all_gather_tensor( |
| 618 | + tensor: torch.Tensor, |
| 619 | + world_size: int, |
| 620 | + pg: Union[dist.ProcessGroup, DeviceMesh], |
| 621 | +) -> List[torch.Tensor]: |
| 622 | + """All-gather a single tensor and return the gathered list.""" |
| 623 | + out = [torch.empty_like(tensor) for _ in range(world_size)] # pragma: no cover |
| 624 | + dist.all_gather(out, tensor, group=pg) |
| 625 | + return out |
| 626 | + |
| 627 | + |
| 628 | +def _all_gather_tensor_list( |
| 629 | + tensors: List[torch.Tensor], |
| 630 | + world_size: int, |
| 631 | + pg: Union[dist.ProcessGroup, DeviceMesh], |
| 632 | +) -> List[torch.Tensor]: |
| 633 | + """All-gather every tensor in a list and flatten the result.""" |
| 634 | + gathered: List[torch.Tensor] = [] # pragma: no cover |
| 635 | + for t in tensors: |
| 636 | + gathered.extend(_all_gather_tensor(t, world_size, pg)) |
| 637 | + return gathered |
0 commit comments