Skip to content

Commit 1498553

Browse files
committed
Updated to latest pytorch version
Signed-off-by: George Araujo <[email protected]>
1 parent c98e745 commit 1498553

File tree

15 files changed

+151
-286
lines changed

15 files changed

+151
-286
lines changed

Dockerfile

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Created by: George Corrêa de Araújo ([email protected])
22

3-
FROM pytorch/pytorch:1.11.0-cuda11.3-cudnn8-devel
3+
# FROM pytorch/pytorch:1.11.0-cuda11.3-cudnn8-devel
4+
FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-devel
45

56
# used to make generated files belong to actual user
67
ARG GROUPID=901
@@ -59,11 +60,13 @@ RUN APT_INSTALL="apt-get install -y --no-install-recommends" && \
5960
kornia \
6061
matplotlib \
6162
numpy \
63+
omegaconf \
6264
pillow \
6365
piq \
6466
prettytable \
6567
pytorch-lightning \
6668
"pytorch-lightning[extra]" \
69+
rich \
6770
tensorboard \
6871
torch_optimizer \
6972
tqdm && \

models/common.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from math import log2
2-
from typing import Tuple, Union
32

43
import torch
54
from torch import nn
@@ -12,8 +11,8 @@ class DefaultConv2d(nn.Conv2d):
1211

1312
def __init__(
1413
self,
15-
kernel_size: Union[int, Tuple[int, int]],
16-
padding: Union[str, int, Tuple[int, int]] = 'same',
14+
kernel_size: int | tuple[int, int],
15+
padding: str | int | tuple[int, int] = 'same',
1716
**kwargs
1817
):
1918
if isinstance(padding, str):
@@ -22,7 +21,10 @@ def __init__(
2221
if lower_padding == 'valid':
2322
padding = 0
2423
else: # if lower_padding == 'same':
25-
padding = kernel_size//2
24+
if isinstance(kernel_size, int):
25+
padding = kernel_size // 2
26+
else:
27+
padding = tuple(k // 2 for k in kernel_size)
2628

2729
super(DefaultConv2d, self).__init__(
2830
kernel_size=kernel_size, padding=padding, **kwargs)
@@ -57,8 +59,8 @@ class MeanShift(nn.Conv2d):
5759
def __init__(
5860
self,
5961
rgb_range: int = 1,
60-
rgb_mean: Tuple[float, float, float] = (0.4488, 0.4371, 0.4040),
61-
rgb_std: Tuple[float, float, float] = (1.0, 1.0, 1.0),
62+
rgb_mean: tuple[float, float, float] = (0.4488, 0.4371, 0.4040),
63+
rgb_std: tuple[float, float, float] = (1.0, 1.0, 1.0),
6264
sign: int = -1
6365
):
6466
super(MeanShift, self).__init__(3, 3, kernel_size=1)

models/ddbpn.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
from argparse import ArgumentParser
2-
from typing import Any, Dict
1+
from typing import Any
32

43
import torch
54
import torch.nn as nn
@@ -69,12 +68,7 @@ class DDBPN(SRModel):
6968
"""
7069
LightningModule for DDBPN, https://openaccess.thecvf.com/content_cvpr_2018/papers/Haris_Deep_Back-Projection_Networks_CVPR_2018_paper.pdf.
7170
"""
72-
@staticmethod
73-
def add_model_specific_args(parent: ArgumentParser) -> ArgumentParser:
74-
parent = SRModel.add_model_specific_args(parent)
75-
return parent
76-
77-
def __init__(self, **kwargs: Dict[str, Any]):
71+
def __init__(self, **kwargs: dict[str, Any]):
7872
super(DDBPN, self).__init__(**kwargs)
7973

8074
n0 = 128

models/edsr.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
from argparse import ArgumentParser
2-
from typing import Any, Dict
1+
from typing import Any
32

43
import torch.nn as nn
54

@@ -11,19 +10,7 @@ class EDSR(SRModel):
1110
"""
1211
LightningModule for EDSR, https://openaccess.thecvf.com/content_cvpr_2017_workshops/w12/papers/Lim_Enhanced_Deep_Residual_CVPR_2017_paper.pdf.
1312
"""
14-
@staticmethod
15-
def add_model_specific_args(parent: ArgumentParser) -> ArgumentParser:
16-
parent = SRModel.add_model_specific_args(parent)
17-
parser = ArgumentParser(parents=[parent], add_help=False)
18-
parser.add_argument('--n_feats', type=int, default=64,
19-
help='number of feature maps')
20-
parser.add_argument('--n_resblocks', type=int, default=16,
21-
help='number of residual blocks')
22-
parser.add_argument('--res_scale', type=float, default=1,
23-
help='residual scaling')
24-
return parser
25-
26-
def __init__(self, n_feats: int=64, n_resblocks: int=16, res_scale: int=1, **kwargs: Dict[str, Any]):
13+
def __init__(self, n_feats: int=64, n_resblocks: int=16, res_scale: int=1, **kwargs: dict[str, Any]):
2714
super(EDSR, self).__init__(**kwargs)
2815
kernel_size = 3
2916

models/rcan.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
from argparse import ArgumentParser
2-
from typing import Any, Dict
1+
from typing import Any
32

43
import torch.nn as nn
54

@@ -80,23 +79,7 @@ class RCAN(SRModel):
8079
"""
8180
LightningModule for RCAN, https://openaccess.thecvf.com/content_ECCV_2018/papers/Yulun_Zhang_Image_Super-Resolution_Using_ECCV_2018_paper.pdf.
8281
"""
83-
@staticmethod
84-
def add_model_specific_args(parent: ArgumentParser) -> ArgumentParser:
85-
parent = SRModel.add_model_specific_args(parent)
86-
parser = ArgumentParser(parents=[parent], add_help=False)
87-
parser.add_argument('--n_feats', type=int, default=64,
88-
help='number of feature maps')
89-
parser.add_argument('--n_resblocks', type=int, default=16,
90-
help='number of residual blocks')
91-
parser.add_argument('--n_resgroups', type=int, default=10,
92-
help='number of residual groups')
93-
parser.add_argument('--reduction', type=int, default=16,
94-
help='number of feature maps reduction')
95-
parser.add_argument('--res_scale', type=float, default=1,
96-
help='residual scaling')
97-
return parser
98-
99-
def __init__(self, n_feats: int=64, n_resblocks: int=16, n_resgroups: int=10, reduction: int=16, res_scale: int=1, **kwargs: Dict[str, Any]):
82+
def __init__(self, n_feats: int=64, n_resblocks: int=16, n_resgroups: int=10, reduction: int=16, res_scale: int=1, **kwargs: dict[str, Any]):
10083
super(RCAN, self).__init__(**kwargs)
10184
kernel_size = 3
10285

models/rdn.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
from argparse import ArgumentParser
2-
from typing import Any, Dict
1+
from typing import Any
32

43
import torch
54
import torch.nn as nn
@@ -45,17 +44,7 @@ class RDN(SRModel):
4544
"""
4645
LightningModule for RDN, https://openaccess.thecvf.com/content_cvpr_2018/papers/Zhang_Residual_Dense_Network_CVPR_2018_paper.pdf.
4746
"""
48-
@staticmethod
49-
def add_model_specific_args(parent: ArgumentParser) -> ArgumentParser:
50-
parent = SRModel.add_model_specific_args(parent)
51-
parser = ArgumentParser(parents=[parent], add_help=False)
52-
parser.add_argument('--G0', type=int, default=64)
53-
parser.add_argument('--kernel_size', type=int, default=3)
54-
parser.add_argument('--rdn_config', type=str, default='B',
55-
choices=['A', 'B'])
56-
return parser
57-
58-
def __init__(self, rdn_config: str='B', G0: int=64, kernel_size: int=3, **kwargs: Dict[str, Any]):
47+
def __init__(self, rdn_config: str='B', G0: int=64, kernel_size: int=3, **kwargs: dict[str, Any]):
5948
super(RDN, self).__init__(**kwargs)
6049

6150
# number of RDB blocks, conv layers, out channels

models/srcnn.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
from argparse import ArgumentParser
2-
from typing import Any, Dict
1+
from typing import Any
32

43
import torch.nn as nn
54
import torch.nn.functional as F
@@ -12,12 +11,7 @@ class SRCNN(SRModel):
1211
LightningModule for SRCNN, https://ieeexplore.ieee.org/document/7115171?arnumber=7115171
1312
https://arxiv.org/pdf/1501.00092.pdf.
1413
"""
15-
@staticmethod
16-
def add_model_specific_args(parent: ArgumentParser) -> ArgumentParser:
17-
parent = SRModel.add_model_specific_args(parent)
18-
return parent
19-
20-
def __init__(self, **kwargs: Dict[str, Any]):
14+
def __init__(self, **kwargs: dict[str, Any]):
2115
super(SRCNN, self).__init__(**kwargs)
2216
self._net = nn.Sequential(
2317
nn.Conv2d(self._channels, 64, 9, padding=4),

models/srgan.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
from argparse import ArgumentParser
21
from math import ceil, sqrt
3-
from typing import Any, Dict
2+
from typing import Any
43

54
import piq
65
import torch
@@ -113,16 +112,7 @@ class SRGAN(SRModel):
113112
"""
114113
LightningModule for SRGAN, https://arxiv.org/pdf/1609.04802.
115114
"""
116-
@staticmethod
117-
def add_model_specific_args(parent: ArgumentParser) -> ArgumentParser:
118-
parent = SRModel.add_model_specific_args(parent)
119-
parser = ArgumentParser(parents=[parent], add_help=False)
120-
parser.add_argument('--ngf', type=int, default=64)
121-
parser.add_argument('--n_blocks', type=int, default=16)
122-
parser.add_argument('--ndf', type=int, default=64)
123-
return parser
124-
125-
def __init__(self, ngf: int=64, ndf: int=64, n_blocks: int=16, **kwargs: Dict[str, Any]):
115+
def __init__(self, ngf: int=64, ndf: int=64, n_blocks: int=16, **kwargs: dict[str, Any]):
126116

127117
super(SRGAN, self).__init__(**kwargs)
128118

0 commit comments

Comments
 (0)