Skip to content

Commit f09b8db

Browse files
Implement torch.Tensor APIs for TorchRec wrappers (pytorch#3096)
Summary: Pull Request resolved: pytorch#3096 ### Diff Context Sometime trainer `state_dict` input to checkpointing can contain `LocalShardsWrapper` from TorchRec, which is a `torch.Tensor`. However, it doesn't implement some `torch.Tensor` operations like `copy_`, `zeros_like`, `empty_like`. This diff aims to implement those. Reviewed By: iamzainhuda Differential Revision: D75553113
1 parent be4e6d7 commit f09b8db

File tree

1 file changed

+33
-0
lines changed

1 file changed

+33
-0
lines changed

torchrec/distributed/shards_wrapper.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,9 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
116116
aten.detach.default: cls.handle_detach,
117117
aten.clone.default: cls.handle_clone,
118118
aten.new_empty.default: cls.handle_new_empty,
119+
aten.copy_.default: cls.handle_copy_,
120+
aten.zeros_like.default: cls.handle_zeros_like,
121+
aten.empty_like.default: cls.handle_empty_like,
119122
}
120123

121124
if func in dispatcher:
@@ -125,6 +128,36 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
125128
f"{func} is not supported for LocalShardsWrapper!"
126129
)
127130

131+
@staticmethod
132+
# pyre-fixme[3]: Return type must be annotated.
133+
# pyre-fixme[2]: Parameter must be annotated.
134+
def handle_zeros_like(args, kwargs):
135+
return LocalShardsWrapper(
136+
[torch.zeros_like(shard) for shard in args[0].local_shards()],
137+
args[0].local_offsets(),
138+
)
139+
140+
@staticmethod
141+
# pyre-fixme[3]: Return type must be annotated.
142+
# pyre-fixme[2]: Parameter must be annotated.
143+
def handle_empty_like(args, kwargs):
144+
return LocalShardsWrapper(
145+
[torch.empty_like(shard) for shard in args[0].local_shards()],
146+
args[0].local_offsets(),
147+
)
148+
149+
@staticmethod
150+
# pyre-fixme[3]: Return type must be annotated.
151+
# pyre-fixme[2]: Parameter must be annotated.
152+
def handle_copy_(args, kwargs):
153+
src = args[1]
154+
dst = args[0]
155+
156+
for i, shard in enumerate(src.local_shards()):
157+
dst.local_shards()[i].copy_(shard, **kwargs)
158+
159+
return args[0]
160+
128161
@staticmethod
129162
# pyre-fixme[3]: Return type must be annotated.
130163
# pyre-fixme[2]: Parameter must be annotated.

0 commit comments

Comments
 (0)