Skip to content

Commit 5d5caca

Browse files
authored
Make our compile tests actually work (#1522)
1 parent c169bcd commit 5d5caca

File tree

3 files changed

+14
-11
lines changed

3 files changed

+14
-11
lines changed

tests/recipes/test_full_finetune_single_device.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,6 @@ def test_loss(self, compile, config, model_type, ckpt_type, tmpdir, monkeypatch)
7474
ckpt_dir = ckpt_path.parent
7575
log_file = gen_log_file_name(tmpdir)
7676

77-
# To workaround https://github.com/pytorch/torchtune/issues/676
78-
if compile:
79-
os.environ["TORCH_COMPILE_BACKEND"] = "aot_eager"
8077
cmd = f"""
8178
tune run full_finetune_single_device \
8279
--config {config} \
@@ -99,8 +96,13 @@ def test_loss(self, compile, config, model_type, ckpt_type, tmpdir, monkeypatch)
9996
with pytest.raises(SystemExit, match=""):
10097
runpy.run_path(TUNE_PATH, run_name="__main__")
10198

99+
# Make sure to clear compile state in between tests
100+
if compile:
101+
torch._dynamo.reset()
102+
102103
loss_values = get_loss_values_from_metric_logger(log_file)
103104
expected_loss_values = self._fetch_expected_loss_values(model_type)
105+
104106
torch.testing.assert_close(
105107
loss_values, expected_loss_values, rtol=1e-4, atol=1e-4
106108
)

tests/recipes/test_lora_finetune_single_device.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,6 @@ def test_loss(self, compile, config, model_type, ckpt_type, tmpdir, monkeypatch)
7575
ckpt_dir = ckpt_path.parent
7676
log_file = gen_log_file_name(tmpdir)
7777

78-
# To workaround https://github.com/pytorch/torchtune/issues/676
79-
if compile:
80-
os.environ["TORCH_COMPILE_BACKEND"] = "aot_eager"
8178
cmd = f"""
8279
tune run lora_finetune_single_device \
8380
--config {config} \
@@ -100,6 +97,10 @@ def test_loss(self, compile, config, model_type, ckpt_type, tmpdir, monkeypatch)
10097
with pytest.raises(SystemExit, match=""):
10198
runpy.run_path(TUNE_PATH, run_name="__main__")
10299

100+
# Make sure to clear compile state in between tests
101+
if compile:
102+
torch._dynamo.reset()
103+
103104
loss_values = get_loss_values_from_metric_logger(log_file)
104105
expected_loss_values = self._fetch_expected_loss_values(model_type)
105106
torch.testing.assert_close(
@@ -119,10 +120,6 @@ def test_loss_qlora(self, compile, dtype, tmpdir, monkeypatch):
119120
ckpt_dir = ckpt_path.parent
120121
log_file = gen_log_file_name(tmpdir)
121122

122-
# To workaround https://github.com/pytorch/torchtune/issues/676
123-
if compile:
124-
os.environ["TORCH_COMPILE_BACKEND"] = "aot_eager"
125-
126123
cmd = f"""
127124
tune run lora_finetune_single_device
128125
--config llama2/7B_qlora_single_device \
@@ -145,6 +142,10 @@ def test_loss_qlora(self, compile, dtype, tmpdir, monkeypatch):
145142
with pytest.raises(SystemExit):
146143
runpy.run_path(TUNE_PATH, run_name="__main__")
147144

145+
# Make sure to clear compile state in between tests
146+
if compile:
147+
torch._dynamo.reset()
148+
148149
loss_values = get_loss_values_from_metric_logger(log_file)
149150
expected_loss_values = self._fetch_qlora_expected_loss_values(dtype=dtype)
150151
torch.testing.assert_close(

torchtune/training/_compile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def compile_model(
3434
verbose (bool): Whether to log compile info. Default: True
3535
Returns:
3636
None
37+
3738
"""
3839
backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor")
3940
if torch_version_ge("2.5.0"):
@@ -65,7 +66,6 @@ def compile_loss(loss: nn.Module, verbose: bool = True) -> None:
6566
Returns:
6667
loss (nn.Module): loss with either entire module compiled or (in the case of
6768
CEWithChunkedOutputLoss) only the upcast and cross-entropy calculation compiled.
68-
6969
"""
7070
backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor")
7171
if verbose:

0 commit comments

Comments
 (0)