Skip to content

Commit ef1f96f

Browse files
committed
Created main.py to use lightning cli
Signed-off-by: George Araujo <[email protected]>
1 parent 1498553 commit ef1f96f

File tree

3 files changed

+364
-0
lines changed

3 files changed

+364
-0
lines changed

configs/all.yml

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# this file should only be used to see possible configuration params and how to set them
2+
# it should NOT be called directly by the training script
3+
file_log_level: info
4+
log_level: warning
5+
seed_everything: true
6+
seed: 42
7+
8+
data:
9+
augment: true
10+
batch_size: 16
11+
datasets_dir: /datasets
12+
eval_datasets:
13+
- B100
14+
- DIV2K
15+
- Set14
16+
- Set5
17+
- Urban100
18+
patch_size: 128
19+
predict_datasets: []
20+
scale_factor: 4
21+
train_datasets:
22+
- DIV2K
23+
24+
model:
25+
class_path: SRCNN
26+
init_args:
27+
# batch_size: 16 # linked to data.batch_size
28+
channels: 3
29+
default_root_dir: .
30+
# devices: null
31+
# eval_datasets: # linked to data.eval_datasets
32+
# - B100
33+
# - DIV2K
34+
# - Set14
35+
# - Set5
36+
# - Urban100
37+
log_loss_every_n_epochs: 50
38+
log_weights_every_n_epochs: ${trainer.check_val_every_n_epoch}
39+
losses: l1
40+
# max_epochs: 20 # linked to trainer.max_epochs
41+
metrics:
42+
- BRISQUE
43+
- FLIP
44+
- LPIPS
45+
- MS-SSIM
46+
- PSNR
47+
- SSIM
48+
metrics_for_pbar: # can be only metric name (PSNR) or dataset/metric name (DIV2K/PSNR)
49+
- DIV2K/PSNR
50+
- DIV2K/SSIM
51+
model_gpus: []
52+
model_parallel: false
53+
optimizer: ADAM
54+
optimizer_params: []
55+
# patch_size: 128 # linked to data.patch_size
56+
precision: 32
57+
predict_datasets: []
58+
save_results: -1
59+
save_results_from_epoch: last
60+
# scale_factor: 4 # linked to data.scale_factor
61+
62+
trainer:
63+
# https://lightning.ai/docs/pytorch/stable/common/trainer.html
64+
accelerator: auto
65+
accumulate_grad_batches: 1
66+
barebones: false
67+
benchmark: null
68+
callbacks:
69+
- class_path: pytorch_lightning.callbacks.ModelCheckpoint
70+
init_args:
71+
dirpath: ${trainer.default_root_dir}/checkpoints
72+
every_n_epochs: ${trainer.check_val_every_n_epoch}
73+
filename: ${model.class_path}_X${data.scale_factor}_e_${trainer.max_epochs}_p_${data.patch_size}
74+
mode: max # could be different for different monitored metrics
75+
monitor: DIV2K/PSNR
76+
save_last: true
77+
save_top_k: 3
78+
verbose: false
79+
check_val_every_n_epoch: 200
80+
default_root_dir: experiments/${model.class_path}_X${data.scale_factor}_e_${trainer.max_epochs}_p_${data.patch_size}
81+
detect_anomaly: false
82+
deterministic: null
83+
devices: [0]
84+
enable_checkpointing: null
85+
enable_model_summary: null
86+
enable_progress_bar: null
87+
fast_dev_run: false
88+
gradient_clip_algorithm: null
89+
gradient_clip_val: null
90+
inference_mode: true
91+
logger:
92+
- class_path: pytorch_lightning.loggers.CometLogger
93+
# for this to work, create the file ~/.comet.config with
94+
# [comet]
95+
# api_key = YOUR API KEY
96+
# for more info, see https://www.comet.com/docs/v2/api-and-sdk/python-sdk/advanced/configuration/#configuration-parameters
97+
init_args:
98+
experiment_name: ${model.class_path}_X${data.scale_factor}_e_${trainer.max_epochs}_p_${data.patch_size}
99+
offline: false
100+
project_name: sr-pytorch-lightning
101+
save_dir: ${trainer.default_root_dir}
102+
- class_path: pytorch_lightning.loggers.TensorBoardLogger
103+
init_args:
104+
default_hp_metric: false
105+
log_graph: true
106+
name: tensorboard_logs
107+
save_dir: ${trainer.default_root_dir}
108+
limit_predict_batches: null
109+
limit_test_batches: null
110+
limit_train_batches: null
111+
limit_val_batches: null
112+
log_every_n_steps: null
113+
max_epochs: 2000
114+
max_steps: -1
115+
max_time: null
116+
min_epochs: null
117+
min_steps: null
118+
num_nodes: 1
119+
num_sanity_val_steps: null
120+
overfit_batches: 0.0
121+
plugins: null
122+
precision: 32-true
123+
profiler: null
124+
reload_dataloaders_every_n_epochs: 0
125+
strategy: auto
126+
sync_batchnorm: false
127+
use_distributed_sampler: true
128+
val_check_interval: null

configs/train_default_sr.yml

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
data:
2+
augment: true
3+
batch_size: 16
4+
datasets_dir: /datasets
5+
eval_datasets:
6+
- B100
7+
- DIV2K
8+
- Set14
9+
- Set5
10+
- Urban100
11+
patch_size: 128
12+
scale_factor: 4
13+
train_datasets:
14+
- DIV2K
15+
16+
model:
17+
init_args:
18+
channels: 3
19+
log_loss_every_n_epochs: 50
20+
losses: l1
21+
metrics:
22+
- BRISQUE
23+
- FLIP
24+
- LPIPS
25+
- MS-SSIM
26+
- PSNR
27+
- SSIM
28+
metrics_for_pbar: # can be only metric name (PSNR) or dataset/metric name (DIV2K/PSNR)
29+
- DIV2K/PSNR
30+
- DIV2K/SSIM
31+
optimizer: ADAM
32+
save_results: -1
33+
save_results_from_epoch: last
34+
35+
trainer:
36+
# https://lightning.ai/docs/pytorch/stable/common/trainer.html
37+
callbacks:
38+
- class_path: pytorch_lightning.callbacks.ModelCheckpoint
39+
init_args:
40+
every_n_epochs: ${trainer.check_val_every_n_epoch}
41+
filename: model
42+
mode: max # could be different for different monitored metrics
43+
monitor: DIV2K/PSNR
44+
save_last: true
45+
save_top_k: 3
46+
verbose: false
47+
# - class_path: pytorch_lightning.callbacks.RichModelSummary
48+
# init_args:
49+
# max_depth: -1
50+
# - class_path: pytorch_lightning.callbacks.RichProgressBar
51+
check_val_every_n_epoch: 200
52+
default_root_dir: experiments/test
53+
logger:
54+
- class_path: pytorch_lightning.loggers.CometLogger
55+
init_args:
56+
experiment_name: test
57+
offline: false
58+
project_name: sr-pytorch-lightning
59+
save_dir: ${trainer.default_root_dir} # without save_dir defined here, Trainer throws an assertion error
60+
# - class_path: pytorch_lightning.loggers.TensorBoardLogger
61+
# init_args:
62+
# default_hp_metric: false
63+
# log_graph: true
64+
# name: tensorboard_logs
65+
# save_dir: ${trainer.default_root_dir}
66+
max_epochs: 2000

main.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
import logging
2+
from logging.handlers import RotatingFileHandler
3+
from pathlib import Path
4+
import numpy as np
5+
from pytorch_lightning.cli import LightningCLI
6+
from pytorch_lightning.loggers import CometLogger
7+
8+
import models
9+
from srdata import SRData
10+
11+
12+
class CustomLightningCLI(LightningCLI):
13+
def add_arguments_to_parser(self, parser):
14+
parser.add_argument('--log_level', type=str, default='warning',
15+
choices=('debug', 'info', 'warning', 'error', 'critical'))
16+
parser.add_argument('--file_log_level', type=str, default='info',
17+
choices=('debug', 'info', 'warning', 'error', 'critical'))
18+
19+
# https://lightning.ai/docs/pytorch/LTS/cli/lightning_cli_expert.html#argument-linking
20+
parser.link_arguments('data.batch_size', 'model.init_args.batch_size')
21+
parser.link_arguments('data.eval_datasets', 'model.init_args.eval_datasets')
22+
parser.link_arguments('data.patch_size', 'model.init_args.patch_size')
23+
parser.link_arguments('data.scale_factor', 'model.init_args.scale_factor')
24+
25+
parser.link_arguments('trainer.check_val_every_n_epoch', 'model.init_args.log_weights_every_n_epochs')
26+
parser.link_arguments('trainer.check_val_every_n_epoch', 'trainer.callbacks.init_args.every_n_epochs')
27+
parser.link_arguments('trainer.default_root_dir', 'model.init_args.default_root_dir')
28+
parser.link_arguments('trainer.default_root_dir', 'trainer.logger.init_args.save_dir') # not working for comet logger
29+
parser.link_arguments('trainer.default_root_dir', 'trainer.callbacks.init_args.dirpath',
30+
compute_fn=lambda x: f'{x}/checkpoints')
31+
parser.link_arguments('trainer.max_epochs', 'model.init_args.max_epochs')
32+
33+
def before_fit(self):
34+
# setup logging
35+
default_root_dir = Path(self.config['fit']['trainer']['default_root_dir'])
36+
default_root_dir.mkdir(parents=True, exist_ok=True)
37+
38+
setup_log(
39+
level=self.config['fit']['log_level'],
40+
log_file=default_root_dir / 'run.log',
41+
file_level=self.config['fit']['file_log_level'],
42+
logs_to_silence=['PIL'],
43+
)
44+
45+
for logger in self.trainer.loggers:
46+
if isinstance(logger, CometLogger):
47+
# all code will be under /work when running on docker
48+
logger.experiment.log_code(folder='/work')
49+
logger.experiment.log_parameters(self.config.as_dict())
50+
logger.experiment.set_model_graph(str(self.model))
51+
logger.experiment.log_other(
52+
'trainable params', sum(p.numel() for p in self.model.parameters() if p.requires_grad))
53+
54+
total_params = sum(p.numel() for p in self.model.parameters())
55+
logger.experiment.log_other('total params', total_params)
56+
57+
total_loss_params = 0
58+
total_loss_trainable_params = 0
59+
for loss in self.model._losses:
60+
if loss.name.find('adaptive') >= 0:
61+
total_loss_params += sum(p.numel() for p in loss.loss.parameters())
62+
total_loss_trainable_params += sum(p.numel()for p in loss.loss.parameters() if p.requires_grad)
63+
64+
if total_loss_params > 0:
65+
logger.experiment.log_other('loss total params', total_loss_params)
66+
logger.experiment.log_other('loss trainable params', total_loss_trainable_params)
67+
68+
# assume 4 bytes/number (float on cuda)
69+
denom = 1024 ** 2.
70+
input_size = abs(np.prod(self.model.example_input_array.size()) * 4. / denom)
71+
params_size = abs(total_params * 4. / denom)
72+
logger.experiment.log_other('input size (MB)', input_size)
73+
logger.experiment.log_other('params size (MB)', params_size)
74+
break
75+
76+
def after_fit(self):
77+
for logger in self.trainer.loggers:
78+
if isinstance(logger, CometLogger):
79+
default_root_dir = Path(self.config['fit']['trainer']['default_root_dir'])
80+
last_checkpoint = default_root_dir / 'checkpoints' / 'last.ckpt'
81+
model_name = self.config['fit']['model']['class_path'].split('.')[-1]
82+
logger.experiment.log_model(f'{model_name}', f'{last_checkpoint}', overwrite=True)
83+
logger.experiment.log_asset(f'{default_root_dir / "run.log"}')
84+
break
85+
86+
87+
def cli_main() -> None:
88+
_ = CustomLightningCLI(
89+
model_class=models.SRModel,
90+
subclass_mode_model=True,
91+
datamodule_class=SRData,
92+
parser_kwargs={"parser_mode": "omegaconf"},
93+
)
94+
95+
96+
def setup_log(
97+
level: str = 'warning',
98+
log_file: str | Path = Path('run.log'),
99+
file_level: str = 'info',
100+
logs_to_silence: list[str] = [],
101+
) -> None:
102+
"""
103+
Setup the logging.
104+
105+
Args:
106+
log_level (str): stdout log level. Defaults to 'warning'.
107+
log_file (str | Path): file where the log output should be stored. Defaults to 'run.log'.
108+
file_log_level (str): file log level. Defaults to 'info'.
109+
logs_to_silence (list[str]): list of loggers to be silenced. Useful when using log level < 'warning'. Defaults to [].
110+
"""
111+
# TODO: fix this according to this
112+
# https://stackoverflow.com/questions/384076/how-can-i-color-python-logging-output
113+
# https://www.electricmonk.nl/log/2017/08/06/understanding-pythons-logging-module/
114+
115+
# convert log levels to int
116+
int_log_level = {
117+
'debug': logging.DEBUG, # 10
118+
'info': logging.INFO, # 20
119+
'warning': logging.WARNING, # 30
120+
'error': logging.ERROR, # 40
121+
'critical': logging.CRITICAL, # 50
122+
}
123+
124+
stdout_log_level = int_log_level[level]
125+
file_log_level = int_log_level[file_level]
126+
127+
# create a handler to log to stderr
128+
stderr_handler = logging.StreamHandler()
129+
stderr_handler.setLevel(stdout_log_level)
130+
131+
# create a logging format
132+
if stdout_log_level >= logging.WARNING:
133+
stderr_formatter = logging.Formatter('{message}', style='{')
134+
else:
135+
stderr_formatter = logging.Formatter(
136+
# format:
137+
# <10 = pad with spaces if needed until it reaches 10 chars length
138+
# .10 = limit the length to 10 chars
139+
'{name:<10.10} [{levelname:.1}] {message}', style='{')
140+
stderr_handler.setFormatter(stderr_formatter)
141+
142+
# create a file handler that have size limit
143+
if isinstance(log_file, str):
144+
log_file = Path(log_file).expanduser()
145+
146+
file_handler = RotatingFileHandler(log_file, maxBytes=5_000_000, backupCount=5) # ~ 5 MB
147+
file_handler.setLevel(file_log_level)
148+
149+
# https://docs.python.org/3/library/logging.html#logrecord-attributes
150+
file_formatter = logging.Formatter(
151+
'{asctime} - {name:<20.20} {levelname:<8} {message}', datefmt='%Y-%m-%d %H:%M:%S', style='{')
152+
file_handler.setFormatter(file_formatter)
153+
154+
# add the handlers to the root logger
155+
logging.basicConfig(handlers=[file_handler, stderr_handler], level=logging.DEBUG)
156+
157+
# change logger level of logs_to_silence to warning
158+
for other_logger in logs_to_silence:
159+
logging.getLogger(other_logger).setLevel(logging.WARNING)
160+
161+
# create logger
162+
logger = logging.getLogger(__name__)
163+
164+
logger.info(f'Saving logs to {log_file.absolute()}')
165+
logger.info(f'Log level: {logging.getLevelName(stdout_log_level)}')
166+
167+
168+
if __name__ == "__main__":
169+
cli_main()
170+
# note: it is good practice to implement the CLI in a function and call it in the main if block

0 commit comments

Comments
 (0)