Skip to content

Commit 449bb82

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
skip tests with internal/external discrepancy (#2759)
Summary: Pull Request resolved: #2759 # context * in torchrec github (OSS env) a few tests are [failing](https://github.com/pytorch/torchrec/actions/runs/13449271251/job/37580767712) * however, these tests pass internally due to different set up * torch.export uses training ir externally but inference ir internally * dlrm transformer tests use random.seed(0) to generate initial weights and the numeric values might be different internally and externally Reviewed By: dstaay-fb, iamzainhuda Differential Revision: D69996988 fbshipit-source-id: 87ad94e3b06f2c9fe9c0e4cd3236ac57cd339e12
1 parent 856ff3c commit 449bb82

File tree

4 files changed

+26
-1
lines changed

4 files changed

+26
-1
lines changed

torchrec/inference/inference_legacy/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,4 @@
2424
- `examples/dlrm/inference/dlrm_predict.py`: this shows how to use `PredictModule` and `PredictFactory` based on an existing model.
2525
"""
2626

27-
from . import model_packager, modules # noqa # noqa
27+
from . import model_packager # noqa

torchrec/inference/tests/test_inference.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,3 +410,7 @@ def test_fused_params_overwrite(self) -> None:
410410

411411
# Make sure that overwrite of ebc_fused_params is not reflected in ec_fused_params
412412
self.assertEqual(ec_fused_params[FUSED_PARAM_REGISTER_TBE_BOOL], orig_value)
413+
414+
# change it back to the original value because it modifies the global variable
415+
# otherwise it will affect other tests
416+
ebc_fused_params[FUSED_PARAM_REGISTER_TBE_BOOL] = orig_value

torchrec/ir/tests/test_serializer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,13 @@ def test_serialize_deserialize_ebc(self) -> None:
253253
self.assertEqual(deserialized.shape, orginal.shape)
254254
self.assertTrue(torch.allclose(deserialized, orginal))
255255

256+
# pyre-ignore[56]: Pyre was not able to infer the type of argument
257+
@unittest.skipIf(
258+
torch.cuda.device_count() == 0,
259+
"skip this test in OSS (no GPU available) because torch.export uses training ir in OSS",
260+
)
256261
def test_dynamic_shape_ebc(self) -> None:
262+
# TODO: https://fb.workplace.com/groups/1028545332188949/permalink/1138699244506890/
257263
model = self.generate_model()
258264
feature1 = KeyedJaggedTensor.from_offsets_sync(
259265
keys=["f1", "f2", "f3"],

torchrec/models/experimental/test_transformerdlrm.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,11 @@ def test_larger(self) -> None:
6161
concat_dense = inter_arch(dense_features, sparse_features)
6262
self.assertEqual(concat_dense.size(), (B, D * (F + 1)))
6363

64+
# pyre-ignore[56]: Pyre was not able to infer the type of argument
65+
@unittest.skipIf(
66+
torch.cuda.device_count() == 0,
67+
"skip this test in OSS (no GPU available) because seed might be different in OSS",
68+
)
6469
def test_correctness(self) -> None:
6570
D = 4
6671
B = 3
@@ -165,6 +170,11 @@ def test_correctness(self) -> None:
165170
)
166171
)
167172

173+
# pyre-ignore[56]: Pyre was not able to infer the type of argument
174+
@unittest.skipIf(
175+
torch.cuda.device_count() == 0,
176+
"skip this test in OSS (no GPU available) because seed might be different in OSS",
177+
)
168178
def test_numerical_stability(self) -> None:
169179
D = 4
170180
B = 3
@@ -194,6 +204,11 @@ def test_numerical_stability(self) -> None:
194204

195205

196206
class DLRMTransformerTest(unittest.TestCase):
207+
# pyre-ignore[56]: Pyre was not able to infer the type of argument
208+
@unittest.skipIf(
209+
torch.cuda.device_count() == 0,
210+
"skip this test in OSS (no GPU available) because seed might be different in OSS",
211+
)
197212
def test_basic(self) -> None:
198213
torch.manual_seed(0)
199214
B = 2

0 commit comments

Comments
 (0)