main
protsenkovi 8 months ago
parent 64674aab60
commit 1a9a5fab15

@ -2,12 +2,12 @@
Example Example
``` ```
python train.py --model SRNetRot90 python train.py --model SRNetRot90
python validate.py --val_datasets Set5,Set14,B100,Urban100,Manga109 --model_path /wd/lut_reproduce/models/last_trained_net.pth python validate.py --val_datasets Set5,Set14,B100,Urban100,Manga109
python transfer_to_lut.py python transfer_to_lut.py
python train.py --model_path /wd/lut_reproduce/models/last_transfered_lut.pth --total_iter 2000 python train.py --model_path /wd/lut_reproduce/models/last_transfered_lut.pth --total_iter 2000
python validate.py --val_datasets Set5,Set14,B100,Urban100,Manga109 --model_path /wd/lut_reproduce/models/last_trained_lut.pth python validate.py --val_datasets Set5,Set14,B100,Urban100,Manga109
python image_demo.py python image_demo.py
``` ```

@ -55,12 +55,12 @@ class SRTrainDataset(Dataset):
hr_patch = hr_image[ hr_patch = hr_image[
(i*scale):(i*scale + self.sz*scale), (i*scale):(i*scale + self.sz*scale),
(j*scale):(j*scale + self.sz*scale), (j*scale):(j*scale + self.sz*scale),
c c:(c+1)
] ]
lr_patch = lr_image[ lr_patch = lr_image[
i:(i + self.sz), i:(i + self.sz),
j:(j + self.sz), j:(j + self.sz),
c c:(c+1)
] ]
if self.rigid_aug: if self.rigid_aug:
@ -78,8 +78,8 @@ class SRTrainDataset(Dataset):
hr_patch = torch.tensor(hr_patch.copy()).type(torch.float32) hr_patch = torch.tensor(hr_patch.copy()).type(torch.float32)
lr_patch = torch.tensor(lr_patch.copy()).type(torch.float32) lr_patch = torch.tensor(lr_patch.copy()).type(torch.float32)
hr_patch = hr_patch.unsqueeze(0) hr_patch = hr_patch.permute(2,0,1)
lr_patch = lr_patch.unsqueeze(0) lr_patch = lr_patch.permute(2,0,1)
return hr_patch, lr_patch return hr_patch, lr_patch

@ -66,5 +66,37 @@ class DenseConvUpscaleBlock(nn.Module):
x = torch.cat([x, torch.relu(conv(x))], dim=1) x = torch.cat([x, torch.relu(conv(x))], dim=1)
x = self.shuffle(self.project_channels(x)) x = self.shuffle(self.project_channels(x))
x = torch.tanh(x) x = torch.tanh(x)
x = round_func(x*127.5 + 127.5) x = x*127.5 + 127.5
x = round_func(x)
return x
class ConvUpscaleBlock(nn.Module):
def __init__(self, hidden_dim = 32, layers_count=5, upscale_factor=1):
super(ConvUpscaleBlock, self).__init__()
assert layers_count > 0
self.upscale_factor = upscale_factor
self.hidden_dim = hidden_dim
self.embed = nn.Conv2d(1, hidden_dim, kernel_size=(2, 2), padding='valid', stride=1, dilation=1, bias=True)
self.convs = []
for i in range(layers_count):
self.convs.append(nn.Conv2d(in_channels = hidden_dim, out_channels = hidden_dim, kernel_size = 1, stride=1, padding=0, dilation=1, bias=True))
self.convs = nn.ModuleList(self.convs)
for name, p in self.named_parameters():
if "weight" in name: nn.init.kaiming_normal_(p)
if "bias" in name: nn.init.constant_(p, 0)
self.project_channels = nn.Conv2d(in_channels = hidden_dim, out_channels = upscale_factor * upscale_factor, kernel_size = 1, stride=1, padding=0, dilation=1, bias=True)
self.shuffle = nn.PixelShuffle(upscale_factor)
def forward(self, x):
x = (x-127.5)/127.5
x = torch.relu(self.embed(x))
for conv in self.convs:
x = torch.relu(conv(x))
x = self.shuffle(self.project_channels(x))
x = torch.tanh(x)
x = x*127.5 + 127.5
x = round_func(x)
return x return x

@ -1,16 +1,18 @@
import torch import torch
import pandas as pd
import numpy as np import numpy as np
from common.utils import PSNR, cal_ssim, logger_info, _rgb2ycbcr, modcrop from common.utils import PSNR, cal_ssim, logger_info, _rgb2ycbcr, modcrop
from pathlib import Path from pathlib import Path
from PIL import Image from PIL import Image
# import ray import time
# @ray.remote(num_cpus=1, num_gpus=0.3) # @ray.remote(num_cpus=1, num_gpus=0.3)
def val_image_pair(model, hr_image, lr_image, output_image_path=None): def val_image_pair(model, hr_image, lr_image, output_image_path=None, device='cuda'):
with torch.inference_mode(): with torch.inference_mode():
start_time = time.perf_counter_ns()
# prepare lr_image # prepare lr_image
lr_image = torch.tensor(lr_image).type(torch.float32).permute(2,0,1) lr_image = torch.tensor(lr_image).type(torch.float32).permute(2,0,1)
lr_image = lr_image.unsqueeze(0).cuda() lr_image = lr_image.unsqueeze(0).to(torch.device(device))
b, c, h, w = lr_image.shape b, c, h, w = lr_image.shape
lr_image = lr_image.reshape(b*c, 1, h, w) lr_image = lr_image.reshape(b*c, 1, h, w)
# predict # predict
@ -18,6 +20,7 @@ def val_image_pair(model, hr_image, lr_image, output_image_path=None):
# postprocess # postprocess
pred_lr_image = pred_lr_image.reshape(b, c, h*model.scale, w*model.scale).squeeze(0).permute(1,2,0).type(torch.uint8) pred_lr_image = pred_lr_image.reshape(b, c, h*model.scale, w*model.scale).squeeze(0).permute(1,2,0).type(torch.uint8)
pred_lr_image = pred_lr_image.cpu().numpy() pred_lr_image = pred_lr_image.cpu().numpy()
run_time_ns = time.perf_counter_ns() - start_time
torch.cuda.empty_cache() torch.cuda.empty_cache()
if not output_image_path is None: if not output_image_path is None:
@ -26,16 +29,17 @@ def val_image_pair(model, hr_image, lr_image, output_image_path=None):
# metrics # metrics
hr_image = modcrop(hr_image, model.scale) hr_image = modcrop(hr_image, model.scale)
left, right = _rgb2ycbcr(pred_lr_image)[:, :, 0], _rgb2ycbcr(hr_image)[:, :, 0] left, right = _rgb2ycbcr(pred_lr_image)[:, :, 0], _rgb2ycbcr(hr_image)[:, :, 0]
torch.cuda.empty_cache() return PSNR(left, right, model.scale), cal_ssim(left, right), run_time_ns
return PSNR(left, right, model.scale), cal_ssim(left, right)
def valid_steps(model, datasets, config, log_prefix=""): def valid_steps(model, datasets, config, log_prefix=""):
# ray.init(num_cpus=16, num_gpus=1, ignore_reinit_error=True, log_to_driver=False, runtime_env={"working_dir": "../"}) # ray.init(num_cpus=16, num_gpus=1, ignore_reinit_error=True, log_to_driver=False, runtime_env={"working_dir": "../"})
dataset_names = list(datasets.keys()) dataset_names = list(datasets.keys())
results = []
for i in range(len(dataset_names)): for i in range(len(dataset_names)):
dataset_name = dataset_names[i] dataset_name = dataset_names[i]
psnrs, ssims = [], [] psnrs, ssims = [], []
run_times_ns = []
predictions_path = config.valout_dir / dataset_name predictions_path = config.valout_dir / dataset_name
if not predictions_path.exists(): if not predictions_path.exists():
@ -45,9 +49,8 @@ def valid_steps(model, datasets, config, log_prefix=""):
tasks = [] tasks = []
for hr_image, lr_image, hr_image_path, lr_image_path in test_dataset: for hr_image, lr_image, hr_image_path, lr_image_path in test_dataset:
output_image_path = predictions_path / f'{Path(hr_image_path).stem}_rcnet.png' if config.save_predictions else None output_image_path = predictions_path / f'{Path(hr_image_path).stem}_rcnet.png' if config.save_predictions else None
task = val_image_pair(model, hr_image, lr_image, output_image_path) task = val_image_pair(model, hr_image, lr_image, output_image_path, device=config.device)
tasks.append(task) tasks.append(task)
# ready_refs, remaining_refs = ray.wait(tasks, num_returns=1, timeout=None) # ready_refs, remaining_refs = ray.wait(tasks, num_returns=1, timeout=None)
# while len(remaining_refs) > 0: # while len(remaining_refs) > 0:
# print(f"\rReady {len(ready_refs)+1}/{len(test_dataset)}", end=" ") # print(f"\rReady {len(ready_refs)+1}/{len(test_dataset)}", end=" ")
@ -55,11 +58,22 @@ def valid_steps(model, datasets, config, log_prefix=""):
# print("\r", end=" ") # print("\r", end=" ")
# tasks = [ray.get(task) for task in tasks] # tasks = [ray.get(task) for task in tasks]
for psnr, ssim in tasks: for psnr, ssim, run_time_ns in tasks:
psnrs.append(psnr) psnrs.append(psnr)
ssims.append(ssim) ssims.append(ssim)
run_times_ns.append(run_time_ns)
config.logger.info( config.logger.info(
'\r{} | Dataset {} | AVG Val PSNR: {:02f}, AVG: SSIM: {:04f}'.format(log_prefix, dataset_name, np.mean(np.asarray(psnrs)), np.mean(np.asarray(ssims)))) '\r{} | Dataset {} | AVG Val PSNR: {:02f}, AVG: SSIM: {:04f}'.format(log_prefix, dataset_name, np.mean(np.asarray(psnrs)), np.mean(np.asarray(ssims))))
# config.writer.add_scalar('PSNR_valid/{}'.format(dataset_name), np.mean(np.asarray(psnrs)), config.current_iter) config.writer.flush()
config.writer.flush()
results.append([
dataset_name,
np.mean(psnrs),
np.mean(ssims),
np.mean(run_times_ns)*1e-9,
np.percentile(run_times_ns, q=95)*1e-9])
results = pd.DataFrame(results, columns=['Dataset', 'PSNR', 'SSIM', f'AVG {config.device} Time, s', f'P95 {config.device} Time, s']).set_index('Dataset')
return results

@ -9,7 +9,8 @@ from pathlib import Path
AVAILABLE_MODELS = { AVAILABLE_MODELS = {
'SRNet': srnet.SRNet, 'SRLut': srlut.SRLut, 'SRNet': srnet.SRNet, 'SRLut': srlut.SRLut,
'SRNetRot90': srnet.SRNetRot90, 'SRLutRot90': srlut.SRLutRot90, 'SRNetDense': srnet.SRNetDense,
'SRNetDenseRot90': srnet.SRNetDenseRot90, 'SRLutRot90': srlut.SRLutRot90,
'RCNetCentered_3x3': rcnet.RCNetCentered_3x3, 'RCLutCentered_3x3': rclut.RCLutCentered_3x3, 'RCNetCentered_3x3': rcnet.RCNetCentered_3x3, 'RCLutCentered_3x3': rclut.RCLutCentered_3x3,
'RCNetCentered_7x7': rcnet.RCNetCentered_7x7, 'RCLutCentered_7x7': rclut.RCLutCentered_7x7, 'RCNetCentered_7x7': rcnet.RCNetCentered_7x7, 'RCLutCentered_7x7': rclut.RCLutCentered_7x7,
'RCNetRot90_3x3': rcnet.RCNetRot90_3x3, 'RCLutRot90_3x3': rclut.RCLutRot90_3x3, 'RCNetRot90_3x3': rcnet.RCNetRot90_3x3, 'RCLutRot90_3x3': rclut.RCLutRot90_3x3,

@ -90,17 +90,17 @@ class SDYLutx2(nn.Module):
@staticmethod @staticmethod
def init_from_lut( def init_from_lut(
stageS_1, stageD_1, stageY_1, stageS_2, stageD_2, stageY_2 stage1_S, stage1_D, stage1_Y, stage2_S, stage2_D, stage2_Y
): ):
scale = int(stageS_2.shape[-1]) scale = int(stage2_S.shape[-1])
quantization_interval = 256//(stageS_2.shape[0]-1) quantization_interval = 256//(stage2_S.shape[0]-1)
lut_model = SDYLutx2(quantization_interval=quantization_interval, scale=scale) lut_model = SDYLutx2(quantization_interval=quantization_interval, scale=scale)
lut_model.stageS_1 = nn.Parameter(torch.tensor(stageS_1).type(torch.float32)) lut_model.stage1_S = nn.Parameter(torch.tensor(stage1_S).type(torch.float32))
lut_model.stageD_1 = nn.Parameter(torch.tensor(stageD_1).type(torch.float32)) lut_model.stage1_D = nn.Parameter(torch.tensor(stage1_D).type(torch.float32))
lut_model.stageY_1 = nn.Parameter(torch.tensor(stageY_1).type(torch.float32)) lut_model.stage1_Y = nn.Parameter(torch.tensor(stage1_Y).type(torch.float32))
lut_model.stageS_2 = nn.Parameter(torch.tensor(stageS_2).type(torch.float32)) lut_model.stage2_S = nn.Parameter(torch.tensor(stage2_S).type(torch.float32))
lut_model.stageD_2 = nn.Parameter(torch.tensor(stageD_2).type(torch.float32)) lut_model.stage2_D = nn.Parameter(torch.tensor(stage2_D).type(torch.float32))
lut_model.stageY_2 = nn.Parameter(torch.tensor(stageY_2).type(torch.float32)) lut_model.stage2_Y = nn.Parameter(torch.tensor(stage2_Y).type(torch.float32))
return lut_model return lut_model
def forward(self, x): def forward(self, x):
@ -111,17 +111,17 @@ class SDYLutx2(nn.Module):
rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1]) rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1])
rb,rc,rh,rw = rotated.shape rb,rc,rh,rw = rotated.shape
s = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_S(rotated), lut=self.stageS_1) s = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_S(rotated), lut=self.stage1_S)
s = s.view(rb*rc, 1, rh, rw, 1, 1).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh, rw) s = s.view(rb*rc, 1, rh, rw, 1, 1).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh, rw)
s = torch.rot90(s, k=-rotations_count, dims=[-2, -1]) s = torch.rot90(s, k=-rotations_count, dims=[-2, -1])
output += s output += s
d = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_D(rotated), lut=self.stageD_1) d = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_D(rotated), lut=self.stage1_D)
d = d.view(rb*rc, 1, rh, rw, 1, 1).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh, rw) d = d.view(rb*rc, 1, rh, rw, 1, 1).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh, rw)
d = torch.rot90(d, k=-rotations_count, dims=[-2, -1]) d = torch.rot90(d, k=-rotations_count, dims=[-2, -1])
output += d output += d
y = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_Y(rotated), lut=self.stageY_1) y = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_Y(rotated), lut=self.stage1_Y)
y = y.view(rb*rc, 1, rh, rw, 1, 1).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh, rw) y = y.view(rb*rc, 1, rh, rw, 1, 1).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh, rw)
y = torch.rot90(y, k=-rotations_count, dims=[-2, -1]) y = torch.rot90(y, k=-rotations_count, dims=[-2, -1])
output += y output += y
@ -135,17 +135,17 @@ class SDYLutx2(nn.Module):
rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1]) rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1])
rb,rc,rh,rw = rotated.shape rb,rc,rh,rw = rotated.shape
s = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_S(rotated), lut=self.stageS_2) s = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_S(rotated), lut=self.stage2_S)
s = s.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale) s = s.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale)
s = torch.rot90(s, k=-rotations_count, dims=[-2, -1]) s = torch.rot90(s, k=-rotations_count, dims=[-2, -1])
output += s output += s
d = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_D(rotated), lut=self.stageD_2) d = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_D(rotated), lut=self.stage2_D)
d = d.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale) d = d.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale)
d = torch.rot90(d, k=-rotations_count, dims=[-2, -1]) d = torch.rot90(d, k=-rotations_count, dims=[-2, -1])
output += d output += d
y = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_Y(rotated), lut=self.stageY_2) y = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_Y(rotated), lut=self.stage2_Y)
y = y.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale) y = y.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale)
y = torch.rot90(y, k=-rotations_count, dims=[-2, -1]) y = torch.rot90(y, k=-rotations_count, dims=[-2, -1])
output += y output += y

@ -41,7 +41,7 @@ class SDYNetx1(nn.Module):
y = y.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale) y = y.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale)
y = torch.rot90(y, k=-rotations_count, dims=[-2, -1]) y = torch.rot90(y, k=-rotations_count, dims=[-2, -1])
output += y output += y
output /= 4*3 output /= 4*3
output = output.view(b, c, h*self.scale, w*self.scale) output = output.view(b, c, h*self.scale, w*self.scale)
return output return output
@ -61,12 +61,12 @@ class SDYNetx2(nn.Module):
self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=3) self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=3)
self._extract_pattern_D = PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]], center=[0,0], window_size=3) self._extract_pattern_D = PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]], center=[0,0], window_size=3)
self._extract_pattern_Y = PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]], center=[0,0], window_size=3) self._extract_pattern_Y = PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]], center=[0,0], window_size=3)
self.stageS_1 = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1) self.stage1_S = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1)
self.stageD_1 = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1) self.stage1_D = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1)
self.stageY_1 = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1) self.stage1_Y = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1)
self.stageS_2 = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) self.stage2_S = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
self.stageD_2 = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) self.stage2_D = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
self.stageY_2 = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) self.stage2_Y = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
def forward(self, x): def forward(self, x):
b,c,h,w = x.shape b,c,h,w = x.shape
@ -76,17 +76,17 @@ class SDYNetx2(nn.Module):
rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1]) rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1])
rb,rc,rh,rw = rotated.shape rb,rc,rh,rw = rotated.shape
s = self.stageS_1(self._extract_pattern_S(rotated)) s = self.stage1_S(self._extract_pattern_S(rotated))
s = s.view(rb*rc, 1, rh, rw, 1, 1).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh, rw) s = s.view(rb*rc, 1, rh, rw, 1, 1).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh, rw)
s = torch.rot90(s, k=-rotations_count, dims=[-2, -1]) s = torch.rot90(s, k=-rotations_count, dims=[-2, -1])
output_1 += s output_1 += s
d = self.stageD_1(self._extract_pattern_D(rotated)) d = self.stage1_D(self._extract_pattern_D(rotated))
d = d.view(rb*rc, 1, rh, rw, 1, 1).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh, rw) d = d.view(rb*rc, 1, rh, rw, 1, 1).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh, rw)
d = torch.rot90(d, k=-rotations_count, dims=[-2, -1]) d = torch.rot90(d, k=-rotations_count, dims=[-2, -1])
output_1 += d output_1 += d
y = self.stageY_1(self._extract_pattern_Y(rotated)) y = self.stage1_Y(self._extract_pattern_Y(rotated))
y = y.view(rb*rc, 1, rh, rw, 1, 1).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh, rw) y = y.view(rb*rc, 1, rh, rw, 1, 1).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh, rw)
y = torch.rot90(y, k=-rotations_count, dims=[-2, -1]) y = torch.rot90(y, k=-rotations_count, dims=[-2, -1])
output_1 += y output_1 += y
@ -100,17 +100,17 @@ class SDYNetx2(nn.Module):
rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1]) rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1])
rb,rc,rh,rw = rotated.shape rb,rc,rh,rw = rotated.shape
s = self.stageS_2(self._extract_pattern_S(rotated)) s = self.stage2_S(self._extract_pattern_S(rotated))
s = s.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale) s = s.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale)
s = torch.rot90(s, k=-rotations_count, dims=[-2, -1]) s = torch.rot90(s, k=-rotations_count, dims=[-2, -1])
output_2 += s output_2 += s
d = self.stageD_2(self._extract_pattern_D(rotated)) d = self.stage2_D(self._extract_pattern_D(rotated))
d = d.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale) d = d.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale)
d = torch.rot90(d, k=-rotations_count, dims=[-2, -1]) d = torch.rot90(d, k=-rotations_count, dims=[-2, -1])
output_2 += d output_2 += d
y = self.stageY_2(self._extract_pattern_Y(rotated)) y = self.stage2_Y(self._extract_pattern_Y(rotated))
y = y.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale) y = y.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale)
y = torch.rot90(y, k=-rotations_count, dims=[-2, -1]) y = torch.rot90(y, k=-rotations_count, dims=[-2, -1])
output_2 += y output_2 += y
@ -120,13 +120,13 @@ class SDYNetx2(nn.Module):
return output_2 return output_2
def get_lut_model(self, quantization_interval=16, batch_size=2**10): def get_lut_model(self, quantization_interval=16, batch_size=2**10):
stageS_1 = lut.transfer_2x2_input_SxS_output(self.stageS_1, quantization_interval=quantization_interval, batch_size=batch_size) stage1_S = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size)
stageD_1 = lut.transfer_2x2_input_SxS_output(self.stageD_1, quantization_interval=quantization_interval, batch_size=batch_size) stage1_D = lut.transfer_2x2_input_SxS_output(self.stage1_D, quantization_interval=quantization_interval, batch_size=batch_size)
stageY_1 = lut.transfer_2x2_input_SxS_output(self.stageY_1, quantization_interval=quantization_interval, batch_size=batch_size) stage1_Y = lut.transfer_2x2_input_SxS_output(self.stage1_Y, quantization_interval=quantization_interval, batch_size=batch_size)
stageS_2 = lut.transfer_2x2_input_SxS_output(self.stageS_2, quantization_interval=quantization_interval, batch_size=batch_size) stage2_S = lut.transfer_2x2_input_SxS_output(self.stage2_S, quantization_interval=quantization_interval, batch_size=batch_size)
stageD_2 = lut.transfer_2x2_input_SxS_output(self.stageD_2, quantization_interval=quantization_interval, batch_size=batch_size) stage2_D = lut.transfer_2x2_input_SxS_output(self.stage2_D, quantization_interval=quantization_interval, batch_size=batch_size)
stageY_2 = lut.transfer_2x2_input_SxS_output(self.stageY_2, quantization_interval=quantization_interval, batch_size=batch_size) stage2_Y = lut.transfer_2x2_input_SxS_output(self.stage2_Y, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = sdylut.SDYLutx2.init_from_lut(stageS_1, stageD_1, stageY_1, stageS_2, stageD_2, stageY_2) lut_model = sdylut.SDYLutx2.init_from_lut(stage1_S, stage1_D, stage1_Y, stage2_S, stage2_D, stage2_Y)
return lut_model return lut_model

@ -6,20 +6,44 @@ from common.utils import round_func
from common import lut from common import lut
from pathlib import Path from pathlib import Path
from .srlut import SRLut, SRLutRot90 from .srlut import SRLut, SRLutRot90
from common.layers import DenseConvUpscaleBlock from common.layers import PercievePattern, DenseConvUpscaleBlock, ConvUpscaleBlock
class SRNet(nn.Module): class SRNet(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SRNet, self).__init__() super(SRNet, self).__init__()
self.scale = scale self.scale = scale
self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
self.stage = ConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
def forward(self, x):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w)
x = self._extract_pattern_S(x)
x = self.stage(x)
x = x.view(b*c, 1, h, w, self.scale, self.scale).permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*self.scale, w*self.scale)
return x
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
stage_lut = lut.transfer_2x2_input_SxS_output(self.stage, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = SRLut.init_from_lut(stage_lut)
return lut_model
class SRNetDense(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SRNetDense, self).__init__()
self.scale = scale
self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
self.stage = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) self.stage = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
def forward(self, x): def forward(self, x):
b,c,h,w = x.shape b,c,h,w = x.shape
x = x.view(b*c, 1, h, w) x = x.view(b*c, 1, h, w)
x = F.pad(x, pad=[0,1,0,1], mode='replicate') x = self._extract_pattern_S(x)
x = self.stage(x) x = self.stage(x)
x = x.view(b, c, h*self.scale, w*self.scale) x = x.view(b*c, 1, h, w, self.scale, self.scale).permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*self.scale, w*self.scale)
return x return x
def get_lut_model(self, quantization_interval=16, batch_size=2**10): def get_lut_model(self, quantization_interval=16, batch_size=2**10):
@ -27,11 +51,11 @@ class SRNet(nn.Module):
lut_model = SRLut.init_from_lut(stage_lut) lut_model = SRLut.init_from_lut(stage_lut)
return lut_model return lut_model
class SRNetDenseRot90(nn.Module):
class SRNetRot90(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SRNetRot90, self).__init__() super(SRNetDenseRot90, self).__init__()
self.scale = scale self.scale = scale
self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
self.stage = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) self.stage = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
def forward(self, x): def forward(self, x):
@ -39,11 +63,12 @@ class SRNetRot90(nn.Module):
x = x.view(b*c, 1, h, w) x = x.view(b*c, 1, h, w)
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
for rotations_count in range(4): for rotations_count in range(4):
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) rx = torch.rot90(x, k=rotations_count, dims=[2, 3])
rotated_padded = F.pad(rotated, pad=[0,1,0,1], mode='replicate') _,_,rh,rw = rx.shape
rotated_prediction = self.stage(rotated_padded) rx = self._extract_pattern_S(rx)
unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3]) rx = self.stage(rx)
output += unrotated_prediction rx = rx.view(b*c, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(b*c, 1, rh*self.scale, rw*self.scale)
output += torch.rot90(rx, k=-rotations_count, dims=[2, 3])
output /= 4 output /= 4
output = output.view(b, c, h*self.scale, w*self.scale) output = output.view(b, c, h*self.scale, w*self.scale)
return output return output

@ -45,6 +45,7 @@ class TrainOptions:
parser.add_argument('--worker_num', '-n', type=int, default=1, help="Number of dataloader workers") parser.add_argument('--worker_num', '-n', type=int, default=1, help="Number of dataloader workers")
parser.add_argument('--prefetch_factor', '-p', type=int, default=16, help="Prefetch factor of dataloader workers.") parser.add_argument('--prefetch_factor', '-p', type=int, default=16, help="Prefetch factor of dataloader workers.")
parser.add_argument('--save_predictions', action='store_true', default=False, help='Save model predictions to exp_dir/val/dataset_name') parser.add_argument('--save_predictions', action='store_true', default=False, help='Save model predictions to exp_dir/val/dataset_name')
parser.add_argument('--device', default='cuda', help='Device of the model')
self.parser = parser self.parser = parser
def parse_args(self): def parse_args(self):
@ -90,6 +91,20 @@ def prepare_experiment_folder(config):
config.logs_dir.mkdir() config.logs_dir.mkdir()
def dice_loss(inputs, targets, smooth=1):
#comment out if your model contains a sigmoid or equivalent activation layer
inputs = F.sigmoid(inputs)
#flatten label and prediction tensors
inputs = inputs.view(-1)
targets = targets.view(-1)
intersection = (inputs * targets).sum()
dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)
return 1 - dice
if __name__ == "__main__": if __name__ == "__main__":
script_start_time = datetime.now() script_start_time = datetime.now()
@ -101,8 +116,9 @@ if __name__ == "__main__":
config.model = model.__class__.__name__ config.model = model.__class__.__name__
else: else:
model = AVAILABLE_MODELS[config.model]( hidden_dim = config.hidden_dim, scale = config.scale) model = AVAILABLE_MODELS[config.model]( hidden_dim = config.hidden_dim, scale = config.scale)
model = model.cuda() model = model.to(torch.device(config.device))
optimizer = AdamWScheduleFree(model.parameters()) optimizer = AdamWScheduleFree(model.parameters())
print(optimizer)
prepare_experiment_folder(config) prepare_experiment_folder(config)
@ -153,9 +169,8 @@ if __name__ == "__main__":
if not config.model_path is None: if not config.model_path is None:
config.current_iter = i config.current_iter = i
valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Iter {i}") valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Iter {i}")
for i in range(config.start_iter + 1, config.total_iter + 1): for i in range(config.start_iter + 1, config.total_iter + 1):
# prof.step()
torch.cuda.empty_cache() torch.cuda.empty_cache()
start_time = time.time() start_time = time.time()
try: try:
@ -163,12 +178,12 @@ if __name__ == "__main__":
except StopIteration: except StopIteration:
train_iter = iter(train_loader) train_iter = iter(train_loader)
hr_patch, lr_patch = next(train_iter) hr_patch, lr_patch = next(train_iter)
hr_patch = hr_patch.cuda() hr_patch = hr_patch.to(torch.device(config.device))
lr_patch = lr_patch.cuda() lr_patch = lr_patch.to(torch.device(config.device))
prepare_data_time += time.time() - start_time prepare_data_time += time.time() - start_time
start_time = time.time() start_time = time.time()
pred = model(lr_patch) pred = model(lr_patch)
loss = F.mse_loss(pred/255, hr_patch/255) loss = F.mse_loss(pred/255, hr_patch/255)
loss.backward() loss.backward()
@ -216,6 +231,10 @@ if __name__ == "__main__":
if link.exists(): if link.exists():
link.unlink() link.unlink()
link.symlink_to(model_path) link.symlink_to(model_path)
link = Path(config.models_dir / f"last.pth")
if link.exists():
link.unlink()
link.symlink_to(model_path)
total_script_time = datetime.now() - script_start_time total_script_time = datetime.now() - script_start_time
config.logger.info(f"Completed after {total_script_time}") config.logger.info(f"Completed after {total_script_time}")

@ -25,8 +25,8 @@ class TransferToLutOptions():
def parse_args(self): def parse_args(self):
args = self.parser.parse_args() args = self.parser.parse_args()
args.model_path = Path(args.model_path) args.model_path = Path(args.model_path).resolve()
args.models_dir = Path(args.model_path).resolve().parent.parent.parent args.models_dir = args.model_path.parent.parent.parent
args.checkpoint_dir = Path(args.model_path).resolve().parent args.checkpoint_dir = Path(args.model_path).resolve().parent
return args return args
@ -73,13 +73,21 @@ if __name__ == "__main__":
link = Path(config.models_dir / f"last_transfered_net.pth") link = Path(config.models_dir / f"last_transfered_net.pth")
if link.exists(): if link.exists():
link.unlink() link.unlink()
link.symlink_to(config.model_path.resolve()) link.symlink_to(config.model_path)
print("Updated link", link)
link = Path(config.models_dir / f"last_transfered_lut.pth") link = Path(config.models_dir / f"last_transfered_lut.pth")
if link.exists(): if link.exists():
link.unlink() link.unlink()
link.symlink_to(lut_path.resolve()) link.symlink_to(lut_path.resolve())
print("Updated link", config.models_dir / f"last_transfered_net.pth") print("Updated link", link)
print("Updated link", config.models_dir / f"last_transfered_lut.pth")
link = Path(config.models_dir / f"last.pth")
if link.exists():
link.unlink()
link.symlink_to(lut_path.resolve())
print("Updated link", link)
print() print()
print("Completed after", datetime.now()-start_time) print("Completed after", datetime.now()-start_time)

@ -23,22 +23,24 @@ import argparse
class ValOptions(): class ValOptions():
def __init__(self): def __init__(self):
self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
self.parser.add_argument('--model_path', type=str, help="Model path.") self.parser.add_argument('--model_path', type=str, default="../models/last.pth", help="Model path.")
self.parser.add_argument('--datasets_dir', type=str, default="../data/", help="Path to datasets.") self.parser.add_argument('--datasets_dir', type=str, default="../data/", help="Path to datasets.")
self.parser.add_argument('--val_datasets', type=str, default='Set5,Set14', help="Names of validation datasets.") self.parser.add_argument('--val_datasets', type=str, default='Set5,Set14', help="Names of validation datasets.")
self.parser.add_argument('--save_predictions', action='store_true', default=True, help='Save model predictions to exp_dir/val/dataset_name') self.parser.add_argument('--save_predictions', action='store_true', default=True, help='Save model predictions to exp_dir/val/dataset_name')
self.parser.add_argument('--device', type=str, default='cuda', help='Device of the model')
def parse_args(self): def parse_args(self):
args = self.parser.parse_args() args = self.parser.parse_args()
args.datasets_dir = Path(args.datasets_dir).resolve() args.datasets_dir = Path(args.datasets_dir).resolve()
args.val_datasets = args.val_datasets.split(',') args.val_datasets = args.val_datasets.split(',')
args.exp_dir = Path(args.model_path).absolute().parent.parent args.exp_dir = Path(args.model_path).resolve().parent.parent
args.model_path = Path(args.model_path) args.model_path = Path(args.model_path).resolve()
args.model_name = args.model_path.stem args.model_name = args.model_path.stem
args.valout_dir = Path(args.exp_dir)/ 'val' args.valout_dir = Path(args.exp_dir).resolve() / 'val'
if not args.valout_dir.exists(): if not args.valout_dir.exists():
args.valout_dir.mkdir() args.valout_dir.mkdir()
args.current_iter = args.model_name.split('_')[-1] args.current_iter = args.model_name.split('_')[-1]
args.results_path = os.path.join(args.valout_dir, f'results_{args.device}.csv')
# Tensorboard for monitoring # Tensorboard for monitoring
writer = SummaryWriter(log_dir=args.valout_dir) writer = SummaryWriter(log_dir=args.valout_dir)
logger_name = f'val_{args.model_path.stem}' logger_name = f'val_{args.model_path.stem}'
@ -50,7 +52,7 @@ class ValOptions():
return args return args
def __repr__(self): def __repr__(self):
config = self.parser.parse_args() config = self.parse_args()
message = '' message = ''
message += '----------------- Options ---------------\n' message += '----------------- Options ---------------\n'
for k, v in sorted(vars(config).items()): for k, v in sorted(vars(config).items()):
@ -72,7 +74,7 @@ if __name__ == "__main__":
config.logger.info(config_inst) config.logger.info(config_inst)
model = LoadCheckpoint(config.model_path) model = LoadCheckpoint(config.model_path)
model = model.cuda() model = model.to(torch.device(config.device))
print(model) print(model)
test_datasets = {} test_datasets = {}
@ -82,7 +84,13 @@ if __name__ == "__main__":
lr_dir_path = Path(config.datasets_dir) / test_dataset_name / "LR" / f"X{model.scale}", lr_dir_path = Path(config.datasets_dir) / test_dataset_name / "LR" / f"X{model.scale}",
) )
valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Model {model.__class__.__name__}") results = valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Model {model.__class__.__name__}")
results.to_csv(config.results_path)
print()
print(results)
print()
print(f"Results saved to {config.results_path}")
total_script_time = datetime.now() - script_start_time total_script_time = datetime.now() - script_start_time
config.logger.info(f"Completed after {total_script_time}") config.logger.info(f"Completed after {total_script_time}")
Loading…
Cancel
Save