Skip to content

Commit 0188361

Browse files
double-vinzhuww
andauthored
add seed_torch function (#147)
Co-authored-by: zhuww <[email protected]>
1 parent 71735d5 commit 0188361

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

inference.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,17 @@
4747
if int(torch.__version__.split(".")[0]) >= 1 and int(torch.__version__.split(".")[1]) > 11:
4848
torch.backends.cuda.matmul.allow_tf32 = True
4949

50+
def seed_torch(seed=1029):
51+
random.seed(seed)
52+
os.environ['PYTHONHASHSEED'] = str(seed)
53+
np.random.seed(seed)
54+
torch.manual_seed(seed)
55+
torch.cuda.manual_seed(seed)
56+
torch.cuda.manual_seed_all(seed)
57+
torch.backends.cudnn.benchmark = False
58+
torch.backends.cudnn.deterministic = True
59+
torch.use_deterministic_algorithms(True)
60+
5061
@contextlib.contextmanager
5162
def temp_fasta_file(fasta_str: str):
5263
with tempfile.NamedTemporaryFile('w', suffix='.fasta') as fasta_file:
@@ -215,9 +226,11 @@ def inference_multimer_model(args):
215226
)
216227

217228
output_dir_base = args.output_dir
229+
218230
random_seed = args.data_random_seed
219231
if random_seed is None:
220232
random_seed = random.randrange(sys.maxsize)
233+
# seed_torch(seed=1029)
221234

222235
feature_processor = feature_pipeline.FeaturePipeline(
223236
config.data
@@ -347,9 +360,12 @@ def inference_monomer_model(args):
347360
data_processor = data_pipeline.DataPipeline(template_featurizer=template_featurizer,)
348361

349362
output_dir_base = args.output_dir
363+
350364
random_seed = args.data_random_seed
351365
if random_seed is None:
352366
random_seed = random.randrange(sys.maxsize)
367+
# seed_torch(seed=1029)
368+
353369
feature_processor = feature_pipeline.FeaturePipeline(config.data)
354370
if not os.path.exists(output_dir_base):
355371
os.makedirs(output_dir_base)

0 commit comments

Comments
 (0)