Skip to content

Commit 1b4dbf2

Browse files
authored
Add files via upload
1 parent 35aa8a6 commit 1b4dbf2

File tree

2 files changed

+231
-0
lines changed

2 files changed

+231
-0
lines changed

generate_patches.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
from glob import glob
2+
from tqdm import tqdm
3+
import numpy as np
4+
import os
5+
from natsort import natsorted
6+
import cv2
7+
from joblib import Parallel, delayed
8+
import argparse
9+
10+
parser = argparse.ArgumentParser(description='Generate patches from Full Resolution images')
11+
parser.add_argument('--src_dir', default='C:/Users/Lab722 BX/Desktop/IO-Haze/train', type=str, help='Directory for full resolution images')
12+
parser.add_argument('--tar_dir', default='G:/IO_Haze/train',type=str, help='Directory for image patches')
13+
parser.add_argument('--ps', default=256, type=int, help='Image Patch Size')
14+
parser.add_argument('--num_patches', default=200, type=int, help='Number of patches per image')
15+
parser.add_argument('--num_cores', default=6, type=int, help='Number of CPU Cores')
16+
17+
args = parser.parse_args()
18+
19+
src = args.src_dir
20+
tar = args.tar_dir
21+
PS = args.ps
22+
NUM_PATCHES = args.num_patches
23+
NUM_CORES = args.num_cores
24+
25+
noisy_patchDir = os.path.join(tar, 'input')
26+
clean_patchDir = os.path.join(tar, 'target')
27+
28+
if os.path.exists(tar):
29+
os.system("rm -r {}".format(tar))
30+
31+
os.makedirs(noisy_patchDir)
32+
os.makedirs(clean_patchDir)
33+
34+
#get sorted folders
35+
files = natsorted(glob(os.path.join(src, '*', '*.JPG')))
36+
37+
noisy_files, clean_files = [], []
38+
for file_ in files:
39+
filename = os.path.split(file_)[-1]
40+
if 'GT' in filename:
41+
clean_files.append(file_)
42+
if 'hazy' in filename:
43+
noisy_files.append(file_)
44+
#if 'gt' in file_:
45+
# clean_files.append(file_)
46+
#if 'data' in file_:
47+
# noisy_files.append(file_)
48+
def save_files(i):
49+
noisy_file, clean_file = noisy_files[i], clean_files[i]
50+
noisy_img = cv2.imread(noisy_file)
51+
clean_img = cv2.imread(clean_file)
52+
53+
H = noisy_img.shape[0]
54+
W = noisy_img.shape[1]
55+
for j in range(NUM_PATCHES):
56+
rr = np.random.randint(0, H - PS)
57+
cc = np.random.randint(0, W - PS)
58+
noisy_patch = noisy_img[rr:rr + PS, cc:cc + PS, :]
59+
clean_patch = clean_img[rr:rr + PS, cc:cc + PS, :]
60+
61+
cv2.imwrite(os.path.join(noisy_patchDir, '{}_{}.png'.format(i+1, j+1)), noisy_patch)
62+
cv2.imwrite(os.path.join(clean_patchDir, '{}_{}.png'.format(i+1, j+1)), clean_patch)
63+
64+
Parallel(n_jobs=NUM_CORES)(delayed(save_files)(i) for i in tqdm(range(len(noisy_files))))

losses.py

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
import torch
2+
from torch import nn as nn
3+
from torch.nn import functional as F
4+
import numpy as np
5+
from typing import Tuple
6+
"""
7+
Reference from: https://github.com/swz30/MPRNet/blob/main/Denoising/generate_patches_SIDD.py
8+
"""
9+
def gaussian(window_size, sigma):
10+
def gauss_fcn(x):
11+
return -(x - window_size // 2)**2 / float(2 * sigma**2)
12+
gauss = torch.stack(
13+
[torch.exp(torch.tensor(gauss_fcn(x))) for x in range(window_size)])
14+
return gauss / gauss.sum()
15+
16+
def get_gaussian_kernel(ksize: int, sigma: float) -> torch.Tensor:
17+
if not isinstance(ksize, int) or ksize % 2 == 0 or ksize <= 0:
18+
raise TypeError("ksize must be an odd positive integer. Got {}"
19+
.format(ksize))
20+
window_1d: torch.Tensor = gaussian(ksize, sigma)
21+
return window_1d
22+
23+
def get_gaussian_kernel2d(ksize: Tuple[int, int],
24+
sigma: Tuple[float, float]) -> torch.Tensor:
25+
if not isinstance(ksize, tuple) or len(ksize) != 2:
26+
raise TypeError("ksize must be a tuple of length two. Got {}"
27+
.format(ksize))
28+
if not isinstance(sigma, tuple) or len(sigma) != 2:
29+
raise TypeError("sigma must be a tuple of length two. Got {}"
30+
.format(sigma))
31+
ksize_x, ksize_y = ksize
32+
sigma_x, sigma_y = sigma
33+
kernel_x: torch.Tensor = get_gaussian_kernel(ksize_x, sigma_x)
34+
kernel_y: torch.Tensor = get_gaussian_kernel(ksize_y, sigma_y)
35+
kernel_2d: torch.Tensor = torch.matmul(
36+
kernel_x.unsqueeze(-1), kernel_y.unsqueeze(-1).t())
37+
return kernel_2d
38+
39+
40+
class PSNRLoss(nn.Module):
41+
"""
42+
reference from: https://github.com/megvii-model/HINet/blob/main/basicsr/models/losses/losses.py
43+
"""
44+
def __init__(self, loss_weight=1.0, reduction='mean', toY=False):
45+
super(PSNRLoss, self).__init__()
46+
assert reduction == 'mean'
47+
self.loss_weight = loss_weight
48+
self.scale = 10 / np.log(10)
49+
self.toY = toY
50+
self.coef = torch.tensor([65.481, 128.553, 24.966]).reshape(1, 3, 1, 1)
51+
self.first = True
52+
53+
def forward(self, pred, target):
54+
assert len(pred.size()) == 4
55+
if self.toY:
56+
if self.first:
57+
self.coef = self.coef.to(pred.device)
58+
self.first = False
59+
60+
pred = (pred * self.coef).sum(dim=1).unsqueeze(dim=1) + 16.
61+
target = (target * self.coef).sum(dim=1).unsqueeze(dim=1) + 16.
62+
63+
pred, target = pred / 255., target / 255.
64+
pass
65+
assert len(pred.size()) == 4
66+
loss = -(self.loss_weight * self.scale * torch.log(((pred - target) ** 2).mean(dim=(1, 2, 3)) + 1e-8).mean())
67+
return loss
68+
69+
70+
class SSIMLoss(nn.Module):
71+
def __init__(self, window_size: int = 11, reduction: str = 'mean', max_val: float = 1.0) -> None:
72+
super(SSIMLoss, self).__init__()
73+
self.window_size: int = window_size
74+
self.max_val: float = max_val
75+
self.reduction: str = reduction
76+
77+
self.window: torch.Tensor = get_gaussian_kernel2d(
78+
(window_size, window_size), (1.5, 1.5))
79+
self.padding: int = self.compute_zero_padding(window_size)
80+
81+
self.C1: float = (0.01 * self.max_val) ** 2
82+
self.C2: float = (0.03 * self.max_val) ** 2
83+
84+
@staticmethod
85+
def compute_zero_padding(kernel_size: int) -> int:
86+
"""Computes zero padding."""
87+
return (kernel_size - 1) // 2
88+
89+
def filter2D(
90+
self,
91+
input: torch.Tensor,
92+
kernel: torch.Tensor,
93+
channel: int) -> torch.Tensor:
94+
return F.conv2d(input, kernel, padding=self.padding, groups=channel)
95+
96+
def forward(self, img1: torch.Tensor, img2: torch.Tensor) -> torch.Tensor:
97+
# prepare kernel
98+
b, c, h, w = img1.shape
99+
tmp_kernel: torch.Tensor = self.window.to(img1.device).to(img1.dtype)
100+
kernel: torch.Tensor = tmp_kernel.repeat(c, 1, 1, 1)
101+
102+
# compute local mean per channel
103+
mu1: torch.Tensor = self.filter2D(img1, kernel, c)
104+
mu2: torch.Tensor = self.filter2D(img2, kernel, c)
105+
106+
mu1_sq = mu1.pow(2)
107+
mu2_sq = mu2.pow(2)
108+
mu1_mu2 = mu1 * mu2
109+
110+
# compute local sigma per channel
111+
sigma1_sq = self.filter2D(img1 * img1, kernel, c) - mu1_sq
112+
sigma2_sq = self.filter2D(img2 * img2, kernel, c) - mu2_sq
113+
sigma12 = self.filter2D(img1 * img2, kernel, c) - mu1_mu2
114+
115+
ssim_map = ((2 * mu1_mu2 + self.C1) * (2 * sigma12 + self.C2)) / \
116+
((mu1_sq + mu2_sq + self.C1) * (sigma1_sq + sigma2_sq + self.C2))
117+
118+
loss = torch.clamp(1. - ssim_map, min=0, max=1) / 2.
119+
120+
if self.reduction == 'mean':
121+
loss = torch.mean(loss)
122+
elif self.reduction == 'sum':
123+
loss = torch.sum(loss)
124+
elif self.reduction == 'none':
125+
pass
126+
return loss
127+
# ------------------------------------------------------------------------------
128+
129+
class CharbonnierLoss(nn.Module):
130+
"""Charbonnier Loss (L1)"""
131+
132+
def __init__(self, eps=1e-3):
133+
super(CharbonnierLoss, self).__init__()
134+
self.eps = eps
135+
136+
def forward(self, x, y):
137+
diff = x - y
138+
# loss = torch.sum(torch.sqrt(diff * diff + self.eps))
139+
loss = torch.mean(torch.sqrt((diff * diff) + (self.eps*self.eps)))
140+
return loss
141+
142+
class EdgeLoss(nn.Module):
143+
def __init__(self):
144+
super(EdgeLoss, self).__init__()
145+
k = torch.Tensor([[.05, .25, .4, .25, .05]])
146+
self.kernel = torch.matmul(k.t(),k).unsqueeze(0).repeat(3,1,1,1)
147+
if torch.cuda.is_available():
148+
self.kernel = self.kernel.cuda()
149+
self.loss = CharbonnierLoss()
150+
151+
def conv_gauss(self, img):
152+
n_channels, _, kw, kh = self.kernel.shape
153+
img = F.pad(img, (kw//2, kh//2, kw//2, kh//2), mode='replicate')
154+
return F.conv2d(img, self.kernel, groups=n_channels)
155+
156+
def laplacian_kernel(self, current):
157+
filtered = self.conv_gauss(current) # filter
158+
down = filtered[:,:,::2,::2] # downsample
159+
new_filter = torch.zeros_like(filtered)
160+
new_filter[:, :, ::2, ::2] = down*4 # upsample
161+
filtered = self.conv_gauss(new_filter) # filter
162+
diff = current - filtered
163+
return diff
164+
165+
def forward(self, x, y):
166+
loss = self.loss(self.laplacian_kernel(x), self.laplacian_kernel(y))
167+
return loss

0 commit comments

Comments
 (0)