Skip to content

Commit 3681ed5

Browse files
authored
Merge pull request lucidrains#206 from LWprogramming/inference_mode
Change torch.no_grad() to torch.inference_mode()
2 parents b157179 + e811e86 commit 3681ed5

File tree

7 files changed

+23
-23
lines changed

7 files changed

+23
-23
lines changed

audiolm_pytorch/audiolm_pytorch.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -531,7 +531,7 @@ def forward(
531531

532532
text_mask = None
533533
if not exists(text_embeds) and exists(text):
534-
with torch.no_grad():
534+
with torch.inference_mode():
535535
text_embeds = self.embed_text(text, output_device = device)
536536
text_mask = torch.any(text_embeds != 0, dim = -1)
537537

@@ -677,7 +677,7 @@ def forward(
677677
assert not (self.has_condition ^ has_text)
678678

679679
if not exists(text_embeds) and exists(text):
680-
with torch.no_grad():
680+
with torch.inference_mode():
681681
text_embeds = self.embed_text(text, output_device = device)
682682

683683
text_mask = None
@@ -907,7 +907,7 @@ def forward(
907907

908908
text_mask = None
909909
if not exists(text_embeds) and exists(text):
910-
with torch.no_grad():
910+
with torch.inference_mode():
911911
text_embeds = self.embed_text(text, output_device = device)
912912
text_mask = torch.any(text_embeds != 0, dim = -1)
913913

@@ -1141,7 +1141,7 @@ def embed_text(self, text):
11411141
return self.transformer.embed_text(text, output_device = self.device)
11421142

11431143
@eval_decorator
1144-
@torch.no_grad()
1144+
@torch.inference_mode()
11451145
@beartype
11461146
def generate(
11471147
self,
@@ -1186,7 +1186,7 @@ def generate(
11861186
assert not (self.transformer.has_condition ^ has_text)
11871187

11881188
if not exists(text_embeds) and exists(text):
1189-
with torch.no_grad():
1189+
with torch.inference_mode():
11901190
text_embeds = self.transformer.embed_text(text, output_device = device)
11911191

11921192
# start length and get running id output
@@ -1323,7 +1323,7 @@ def device(self):
13231323
return next(self.parameters()).device
13241324

13251325
@eval_decorator
1326-
@torch.no_grad()
1326+
@torch.inference_mode()
13271327
@beartype
13281328
def generate(
13291329
self,
@@ -1353,7 +1353,7 @@ def generate(
13531353
coarse_token_ids = prime_coarse_token_ids
13541354
elif exists(prime_wave):
13551355
assert exists(self.codec)
1356-
with torch.no_grad():
1356+
with torch.inference_mode():
13571357
self.codec.eval()
13581358
_, indices, _ = self.codec(prime_wave, return_encoded = True)
13591359
coarse_token_ids = indices[..., :self.num_coarse_quantizers]
@@ -1367,7 +1367,7 @@ def generate(
13671367
assert not (self.transformer.has_condition ^ has_text)
13681368

13691369
if not exists(text_embeds) and exists(text):
1370-
with torch.no_grad():
1370+
with torch.inference_mode():
13711371
text_embeds = self.transformer.embed_text(text, output_device = device)
13721372

13731373
if self.unique_consecutive:
@@ -1444,7 +1444,7 @@ def forward(
14441444
if not exists(coarse_token_ids):
14451445
assert exists(self.codec), 'Codec must be provided if given raw wave for training'
14461446

1447-
with torch.no_grad():
1447+
with torch.inference_mode():
14481448
self.codec.eval()
14491449
_, indices, _ = self.codec(raw_wave_for_codec, return_encoded = True)
14501450

@@ -1568,7 +1568,7 @@ def device(self):
15681568
return next(self.parameters()).device
15691569

15701570
@eval_decorator
1571-
@torch.no_grad()
1571+
@torch.inference_mode()
15721572
@beartype
15731573
def generate(
15741574
self,
@@ -1597,7 +1597,7 @@ def generate(
15971597
assert not (self.transformer.has_condition ^ has_text)
15981598

15991599
if not exists(text_embeds) and exists(text):
1600-
with torch.no_grad():
1600+
with torch.inference_mode():
16011601
text_embeds = self.transformer.embed_text(text, output_device = device)
16021602

16031603
# initialize fine token ids
@@ -1609,7 +1609,7 @@ def generate(
16091609
fine_token_ids = prime_fine_token_ids
16101610
elif exists(prime_wave):
16111611
assert exists(self.codec)
1612-
with torch.no_grad():
1612+
with torch.inference_mode():
16131613
self.codec.eval()
16141614
_, token_ids, _ = self.codec(prime_wave, return_encoded = True)
16151615

@@ -1698,7 +1698,7 @@ def forward(
16981698
if exists(raw_wave):
16991699
assert exists(self.codec), 'Codec must be provided if given raw wave for training'
17001700

1701-
with torch.no_grad():
1701+
with torch.inference_mode():
17021702
self.codec.eval()
17031703
_, token_ids, _ = self.codec(raw_wave, return_encoded = True)
17041704

@@ -1829,7 +1829,7 @@ def device(self):
18291829
return next(self.parameters()).device
18301830

18311831
@eval_decorator
1832-
@torch.no_grad()
1832+
@torch.inference_mode()
18331833
def forward(
18341834
self,
18351835
*,

audiolm_pytorch/encodec.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def forward(
8989
wav = rearrange(x, f'b t -> b {self.model.channels} t')
9090

9191
# Extract discrete codes from EnCodec
92-
with torch.no_grad():
92+
with torch.inference_mode():
9393
encoded_frames = self.model.encode(wav)
9494
# encoded_frames is a list of (frame, scale) tuples. Scale is a scalar but we don't use it. Frame is a tensor
9595
# of shape [batch, num_quantizers, num_samples_per_frame]. We want to concatenate the frames to get all the

audiolm_pytorch/hubert_kmeans.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def downsample_factor(self):
7575
# todo: double check
7676
return 320
7777

78-
@torch.no_grad()
78+
@torch.inference_mode()
7979
def forward(
8080
self,
8181
wav_input,

audiolm_pytorch/t5.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def t5_encode_text(
9393

9494
t5.eval()
9595

96-
with torch.no_grad():
96+
with torch.inference_mode():
9797
output = t5(input_ids = input_ids, attention_mask = attn_mask)
9898
encoded_text = output.last_hidden_state.detach()
9999

audiolm_pytorch/trainer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -507,7 +507,7 @@ def train_step(self):
507507
for model, label in models:
508508
model.eval()
509509

510-
with torch.no_grad():
510+
with torch.inference_mode():
511511
recons = model(wave, return_recons_only = True)
512512

513513
for ind, recon in enumerate(recons.unbind(dim = 0)):
@@ -753,7 +753,7 @@ def train_step(self):
753753
if self.is_main and not (steps % self.save_results_every):
754754
data_kwargs = self.data_tuple_to_kwargs(next(self.valid_dl_iter))
755755

756-
with torch.no_grad():
756+
with torch.inference_mode():
757757
self.train_wrapper.eval()
758758
valid_loss = self.train_wrapper(**data_kwargs, return_loss = True)
759759

@@ -1002,7 +1002,7 @@ def train_step(self):
10021002
if self.is_main and not (steps % self.save_results_every):
10031003
data_kwargs = dict(zip(self.ds_fields, next(self.valid_dl_iter)))
10041004

1005-
with torch.no_grad():
1005+
with torch.inference_mode():
10061006
self.train_wrapper.eval()
10071007

10081008
valid_loss = self.train_wrapper(
@@ -1252,7 +1252,7 @@ def train_step(self):
12521252
if self.is_main and not (steps % self.save_results_every):
12531253
data_kwargs = self.data_tuple_to_kwargs(next(self.valid_dl_iter))
12541254

1255-
with torch.no_grad():
1255+
with torch.inference_mode():
12561256
self.train_wrapper.eval()
12571257
valid_loss = self.train_wrapper(**data_kwargs, return_loss = True)
12581258

audiolm_pytorch/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.2.11'
1+
__version__ = '1.2.12'

audiolm_pytorch/vq_wav2vec.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def downsample_factor(self):
5959
def codebook_size(self):
6060
return self.model.vector_quantizer.embedding.shape[0]
6161

62-
@torch.no_grad()
62+
@torch.inference_mode()
6363
def forward(
6464
self,
6565
wav_input,

0 commit comments

Comments
 (0)