main
protsenkovi 6 months ago
parent 64674aab60
commit 1a9a5fab15

@ -2,12 +2,12 @@
Example
```
python train.py --model SRNetRot90
python validate.py --val_datasets Set5,Set14,B100,Urban100,Manga109 --model_path /wd/lut_reproduce/models/last_trained_net.pth
python validate.py --val_datasets Set5,Set14,B100,Urban100,Manga109
python transfer_to_lut.py
python train.py --model_path /wd/lut_reproduce/models/last_transfered_lut.pth --total_iter 2000
python validate.py --val_datasets Set5,Set14,B100,Urban100,Manga109 --model_path /wd/lut_reproduce/models/last_trained_lut.pth
python validate.py --val_datasets Set5,Set14,B100,Urban100,Manga109
python image_demo.py
```

@ -55,12 +55,12 @@ class SRTrainDataset(Dataset):
hr_patch = hr_image[
(i*scale):(i*scale + self.sz*scale),
(j*scale):(j*scale + self.sz*scale),
c
c:(c+1)
]
lr_patch = lr_image[
i:(i + self.sz),
j:(j + self.sz),
c
c:(c+1)
]
if self.rigid_aug:
@ -78,8 +78,8 @@ class SRTrainDataset(Dataset):
hr_patch = torch.tensor(hr_patch.copy()).type(torch.float32)
lr_patch = torch.tensor(lr_patch.copy()).type(torch.float32)
hr_patch = hr_patch.unsqueeze(0)
lr_patch = lr_patch.unsqueeze(0)
hr_patch = hr_patch.permute(2,0,1)
lr_patch = lr_patch.permute(2,0,1)
return hr_patch, lr_patch

@ -66,5 +66,37 @@ class DenseConvUpscaleBlock(nn.Module):
x = torch.cat([x, torch.relu(conv(x))], dim=1)
x = self.shuffle(self.project_channels(x))
x = torch.tanh(x)
x = round_func(x*127.5 + 127.5)
x = x*127.5 + 127.5
x = round_func(x)
return x
class ConvUpscaleBlock(nn.Module):
def __init__(self, hidden_dim = 32, layers_count=5, upscale_factor=1):
super(ConvUpscaleBlock, self).__init__()
assert layers_count > 0
self.upscale_factor = upscale_factor
self.hidden_dim = hidden_dim
self.embed = nn.Conv2d(1, hidden_dim, kernel_size=(2, 2), padding='valid', stride=1, dilation=1, bias=True)
self.convs = []
for i in range(layers_count):
self.convs.append(nn.Conv2d(in_channels = hidden_dim, out_channels = hidden_dim, kernel_size = 1, stride=1, padding=0, dilation=1, bias=True))
self.convs = nn.ModuleList(self.convs)
for name, p in self.named_parameters():
if "weight" in name: nn.init.kaiming_normal_(p)
if "bias" in name: nn.init.constant_(p, 0)
self.project_channels = nn.Conv2d(in_channels = hidden_dim, out_channels = upscale_factor * upscale_factor, kernel_size = 1, stride=1, padding=0, dilation=1, bias=True)
self.shuffle = nn.PixelShuffle(upscale_factor)
def forward(self, x):
x = (x-127.5)/127.5
x = torch.relu(self.embed(x))
for conv in self.convs:
x = torch.relu(conv(x))
x = self.shuffle(self.project_channels(x))
x = torch.tanh(x)
x = x*127.5 + 127.5
x = round_func(x)
return x

@ -1,16 +1,18 @@
import torch
import pandas as pd
import numpy as np
from common.utils import PSNR, cal_ssim, logger_info, _rgb2ycbcr, modcrop
from pathlib import Path
from PIL import Image
# import ray
import time
# @ray.remote(num_cpus=1, num_gpus=0.3)
def val_image_pair(model, hr_image, lr_image, output_image_path=None):
def val_image_pair(model, hr_image, lr_image, output_image_path=None, device='cuda'):
with torch.inference_mode():
start_time = time.perf_counter_ns()
# prepare lr_image
lr_image = torch.tensor(lr_image).type(torch.float32).permute(2,0,1)
lr_image = lr_image.unsqueeze(0).cuda()
lr_image = lr_image.unsqueeze(0).to(torch.device(device))
b, c, h, w = lr_image.shape
lr_image = lr_image.reshape(b*c, 1, h, w)
# predict
@ -18,6 +20,7 @@ def val_image_pair(model, hr_image, lr_image, output_image_path=None):
# postprocess
pred_lr_image = pred_lr_image.reshape(b, c, h*model.scale, w*model.scale).squeeze(0).permute(1,2,0).type(torch.uint8)
pred_lr_image = pred_lr_image.cpu().numpy()
run_time_ns = time.perf_counter_ns() - start_time
torch.cuda.empty_cache()
if not output_image_path is None:
@ -26,16 +29,17 @@ def val_image_pair(model, hr_image, lr_image, output_image_path=None):
# metrics
hr_image = modcrop(hr_image, model.scale)
left, right = _rgb2ycbcr(pred_lr_image)[:, :, 0], _rgb2ycbcr(hr_image)[:, :, 0]
torch.cuda.empty_cache()
return PSNR(left, right, model.scale), cal_ssim(left, right)
return PSNR(left, right, model.scale), cal_ssim(left, right), run_time_ns
def valid_steps(model, datasets, config, log_prefix=""):
# ray.init(num_cpus=16, num_gpus=1, ignore_reinit_error=True, log_to_driver=False, runtime_env={"working_dir": "../"})
dataset_names = list(datasets.keys())
results = []
for i in range(len(dataset_names)):
dataset_name = dataset_names[i]
psnrs, ssims = [], []
psnrs, ssims = [], []
run_times_ns = []
predictions_path = config.valout_dir / dataset_name
if not predictions_path.exists():
@ -45,9 +49,8 @@ def valid_steps(model, datasets, config, log_prefix=""):
tasks = []
for hr_image, lr_image, hr_image_path, lr_image_path in test_dataset:
output_image_path = predictions_path / f'{Path(hr_image_path).stem}_rcnet.png' if config.save_predictions else None
task = val_image_pair(model, hr_image, lr_image, output_image_path)
task = val_image_pair(model, hr_image, lr_image, output_image_path, device=config.device)
tasks.append(task)
# ready_refs, remaining_refs = ray.wait(tasks, num_returns=1, timeout=None)
# while len(remaining_refs) > 0:
# print(f"\rReady {len(ready_refs)+1}/{len(test_dataset)}", end=" ")
@ -55,11 +58,22 @@ def valid_steps(model, datasets, config, log_prefix=""):
# print("\r", end=" ")
# tasks = [ray.get(task) for task in tasks]
for psnr, ssim in tasks:
for psnr, ssim, run_time_ns in tasks:
psnrs.append(psnr)
ssims.append(ssim)
run_times_ns.append(run_time_ns)
config.logger.info(
'\r{} | Dataset {} | AVG Val PSNR: {:02f}, AVG: SSIM: {:04f}'.format(log_prefix, dataset_name, np.mean(np.asarray(psnrs)), np.mean(np.asarray(ssims))))
# config.writer.add_scalar('PSNR_valid/{}'.format(dataset_name), np.mean(np.asarray(psnrs)), config.current_iter)
config.writer.flush()
config.writer.flush()
results.append([
dataset_name,
np.mean(psnrs),
np.mean(ssims),
np.mean(run_times_ns)*1e-9,
np.percentile(run_times_ns, q=95)*1e-9])
results = pd.DataFrame(results, columns=['Dataset', 'PSNR', 'SSIM', f'AVG {config.device} Time, s', f'P95 {config.device} Time, s']).set_index('Dataset')
return results

@ -9,7 +9,8 @@ from pathlib import Path
AVAILABLE_MODELS = {
'SRNet': srnet.SRNet, 'SRLut': srlut.SRLut,
'SRNetRot90': srnet.SRNetRot90, 'SRLutRot90': srlut.SRLutRot90,
'SRNetDense': srnet.SRNetDense,
'SRNetDenseRot90': srnet.SRNetDenseRot90, 'SRLutRot90': srlut.SRLutRot90,
'RCNetCentered_3x3': rcnet.RCNetCentered_3x3, 'RCLutCentered_3x3': rclut.RCLutCentered_3x3,
'RCNetCentered_7x7': rcnet.RCNetCentered_7x7, 'RCLutCentered_7x7': rclut.RCLutCentered_7x7,
'RCNetRot90_3x3': rcnet.RCNetRot90_3x3, 'RCLutRot90_3x3': rclut.RCLutRot90_3x3,

@ -90,17 +90,17 @@ class SDYLutx2(nn.Module):
@staticmethod
def init_from_lut(
stageS_1, stageD_1, stageY_1, stageS_2, stageD_2, stageY_2
stage1_S, stage1_D, stage1_Y, stage2_S, stage2_D, stage2_Y
):
scale = int(stageS_2.shape[-1])
quantization_interval = 256//(stageS_2.shape[0]-1)
scale = int(stage2_S.shape[-1])
quantization_interval = 256//(stage2_S.shape[0]-1)
lut_model = SDYLutx2(quantization_interval=quantization_interval, scale=scale)
lut_model.stageS_1 = nn.Parameter(torch.tensor(stageS_1).type(torch.float32))
lut_model.stageD_1 = nn.Parameter(torch.tensor(stageD_1).type(torch.float32))
lut_model.stageY_1 = nn.Parameter(torch.tensor(stageY_1).type(torch.float32))
lut_model.stageS_2 = nn.Parameter(torch.tensor(stageS_2).type(torch.float32))
lut_model.stageD_2 = nn.Parameter(torch.tensor(stageD_2).type(torch.float32))
lut_model.stageY_2 = nn.Parameter(torch.tensor(stageY_2).type(torch.float32))
lut_model.stage1_S = nn.Parameter(torch.tensor(stage1_S).type(torch.float32))
lut_model.stage1_D = nn.Parameter(torch.tensor(stage1_D).type(torch.float32))
lut_model.stage1_Y = nn.Parameter(torch.tensor(stage1_Y).type(torch.float32))
lut_model.stage2_S = nn.Parameter(torch.tensor(stage2_S).type(torch.float32))
lut_model.stage2_D = nn.Parameter(torch.tensor(stage2_D).type(torch.float32))
lut_model.stage2_Y = nn.Parameter(torch.tensor(stage2_Y).type(torch.float32))
return lut_model
def forward(self, x):
@ -111,17 +111,17 @@ class SDYLutx2(nn.Module):
rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1])
rb,rc,rh,rw = rotated.shape
s = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_S(rotated), lut=self.stageS_1)
s = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_S(rotated), lut=self.stage1_S)
s = s.view(rb*rc, 1, rh, rw, 1, 1).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh, rw)
s = torch.rot90(s, k=-rotations_count, dims=[-2, -1])
output += s
d = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_D(rotated), lut=self.stageD_1)
d = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_D(rotated), lut=self.stage1_D)
d = d.view(rb*rc, 1, rh, rw, 1, 1).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh, rw)
d = torch.rot90(d, k=-rotations_count, dims=[-2, -1])
output += d
y = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_Y(rotated), lut=self.stageY_1)
y = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_Y(rotated), lut=self.stage1_Y)
y = y.view(rb*rc, 1, rh, rw, 1, 1).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh, rw)
y = torch.rot90(y, k=-rotations_count, dims=[-2, -1])
output += y
@ -135,17 +135,17 @@ class SDYLutx2(nn.Module):
rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1])
rb,rc,rh,rw = rotated.shape
s = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_S(rotated), lut=self.stageS_2)
s = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_S(rotated), lut=self.stage2_S)
s = s.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale)
s = torch.rot90(s, k=-rotations_count, dims=[-2, -1])
output += s
d = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_D(rotated), lut=self.stageD_2)
d = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_D(rotated), lut=self.stage2_D)
d = d.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale)
d = torch.rot90(d, k=-rotations_count, dims=[-2, -1])
output += d
y = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_Y(rotated), lut=self.stageY_2)
y = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_Y(rotated), lut=self.stage2_Y)
y = y.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale)
y = torch.rot90(y, k=-rotations_count, dims=[-2, -1])
output += y

@ -41,7 +41,7 @@ class SDYNetx1(nn.Module):
y = y.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale)
y = torch.rot90(y, k=-rotations_count, dims=[-2, -1])
output += y
output /= 4*3
output = output.view(b, c, h*self.scale, w*self.scale)
return output
@ -61,12 +61,12 @@ class SDYNetx2(nn.Module):
self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=3)
self._extract_pattern_D = PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]], center=[0,0], window_size=3)
self._extract_pattern_Y = PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]], center=[0,0], window_size=3)
self.stageS_1 = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1)
self.stageD_1 = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1)
self.stageY_1 = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1)
self.stageS_2 = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
self.stageD_2 = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
self.stageY_2 = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
self.stage1_S = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1)
self.stage1_D = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1)
self.stage1_Y = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1)
self.stage2_S = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
self.stage2_D = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
self.stage2_Y = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
def forward(self, x):
b,c,h,w = x.shape
@ -76,17 +76,17 @@ class SDYNetx2(nn.Module):
rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1])
rb,rc,rh,rw = rotated.shape
s = self.stageS_1(self._extract_pattern_S(rotated))
s = self.stage1_S(self._extract_pattern_S(rotated))
s = s.view(rb*rc, 1, rh, rw, 1, 1).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh, rw)
s = torch.rot90(s, k=-rotations_count, dims=[-2, -1])
output_1 += s
d = self.stageD_1(self._extract_pattern_D(rotated))
d = self.stage1_D(self._extract_pattern_D(rotated))
d = d.view(rb*rc, 1, rh, rw, 1, 1).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh, rw)
d = torch.rot90(d, k=-rotations_count, dims=[-2, -1])
output_1 += d
y = self.stageY_1(self._extract_pattern_Y(rotated))
y = self.stage1_Y(self._extract_pattern_Y(rotated))
y = y.view(rb*rc, 1, rh, rw, 1, 1).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh, rw)
y = torch.rot90(y, k=-rotations_count, dims=[-2, -1])
output_1 += y
@ -100,17 +100,17 @@ class SDYNetx2(nn.Module):
rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1])
rb,rc,rh,rw = rotated.shape
s = self.stageS_2(self._extract_pattern_S(rotated))
s = self.stage2_S(self._extract_pattern_S(rotated))
s = s.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale)
s = torch.rot90(s, k=-rotations_count, dims=[-2, -1])
output_2 += s
d = self.stageD_2(self._extract_pattern_D(rotated))
d = self.stage2_D(self._extract_pattern_D(rotated))
d = d.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale)
d = torch.rot90(d, k=-rotations_count, dims=[-2, -1])
output_2 += d
y = self.stageY_2(self._extract_pattern_Y(rotated))
y = self.stage2_Y(self._extract_pattern_Y(rotated))
y = y.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale)
y = torch.rot90(y, k=-rotations_count, dims=[-2, -1])
output_2 += y
@ -120,13 +120,13 @@ class SDYNetx2(nn.Module):
return output_2
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
stageS_1 = lut.transfer_2x2_input_SxS_output(self.stageS_1, quantization_interval=quantization_interval, batch_size=batch_size)
stageD_1 = lut.transfer_2x2_input_SxS_output(self.stageD_1, quantization_interval=quantization_interval, batch_size=batch_size)
stageY_1 = lut.transfer_2x2_input_SxS_output(self.stageY_1, quantization_interval=quantization_interval, batch_size=batch_size)
stageS_2 = lut.transfer_2x2_input_SxS_output(self.stageS_2, quantization_interval=quantization_interval, batch_size=batch_size)
stageD_2 = lut.transfer_2x2_input_SxS_output(self.stageD_2, quantization_interval=quantization_interval, batch_size=batch_size)
stageY_2 = lut.transfer_2x2_input_SxS_output(self.stageY_2, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = sdylut.SDYLutx2.init_from_lut(stageS_1, stageD_1, stageY_1, stageS_2, stageD_2, stageY_2)
stage1_S = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size)
stage1_D = lut.transfer_2x2_input_SxS_output(self.stage1_D, quantization_interval=quantization_interval, batch_size=batch_size)
stage1_Y = lut.transfer_2x2_input_SxS_output(self.stage1_Y, quantization_interval=quantization_interval, batch_size=batch_size)
stage2_S = lut.transfer_2x2_input_SxS_output(self.stage2_S, quantization_interval=quantization_interval, batch_size=batch_size)
stage2_D = lut.transfer_2x2_input_SxS_output(self.stage2_D, quantization_interval=quantization_interval, batch_size=batch_size)
stage2_Y = lut.transfer_2x2_input_SxS_output(self.stage2_Y, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = sdylut.SDYLutx2.init_from_lut(stage1_S, stage1_D, stage1_Y, stage2_S, stage2_D, stage2_Y)
return lut_model

@ -6,20 +6,44 @@ from common.utils import round_func
from common import lut
from pathlib import Path
from .srlut import SRLut, SRLutRot90
from common.layers import DenseConvUpscaleBlock
from common.layers import PercievePattern, DenseConvUpscaleBlock, ConvUpscaleBlock
class SRNet(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SRNet, self).__init__()
self.scale = scale
self.scale = scale
self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
self.stage = ConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
def forward(self, x):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w)
x = self._extract_pattern_S(x)
x = self.stage(x)
x = x.view(b*c, 1, h, w, self.scale, self.scale).permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*self.scale, w*self.scale)
return x
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
stage_lut = lut.transfer_2x2_input_SxS_output(self.stage, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = SRLut.init_from_lut(stage_lut)
return lut_model
class SRNetDense(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SRNetDense, self).__init__()
self.scale = scale
self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
self.stage = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
def forward(self, x):
b,c,h,w = x.shape
x = x.view(b*c, 1, h, w)
x = F.pad(x, pad=[0,1,0,1], mode='replicate')
x = self._extract_pattern_S(x)
x = self.stage(x)
x = x.view(b, c, h*self.scale, w*self.scale)
x = x.view(b*c, 1, h, w, self.scale, self.scale).permute(0,1,2,4,3,5)
x = x.reshape(b, c, h*self.scale, w*self.scale)
return x
def get_lut_model(self, quantization_interval=16, batch_size=2**10):
@ -27,11 +51,11 @@ class SRNet(nn.Module):
lut_model = SRLut.init_from_lut(stage_lut)
return lut_model
class SRNetRot90(nn.Module):
class SRNetDenseRot90(nn.Module):
def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4):
super(SRNetRot90, self).__init__()
super(SRNetDenseRot90, self).__init__()
self.scale = scale
self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2)
self.stage = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale)
def forward(self, x):
@ -39,11 +63,12 @@ class SRNetRot90(nn.Module):
x = x.view(b*c, 1, h, w)
output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device)
for rotations_count in range(4):
rotated = torch.rot90(x, k=rotations_count, dims=[2, 3])
rotated_padded = F.pad(rotated, pad=[0,1,0,1], mode='replicate')
rotated_prediction = self.stage(rotated_padded)
unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3])
output += unrotated_prediction
rx = torch.rot90(x, k=rotations_count, dims=[2, 3])
_,_,rh,rw = rx.shape
rx = self._extract_pattern_S(rx)
rx = self.stage(rx)
rx = rx.view(b*c, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(b*c, 1, rh*self.scale, rw*self.scale)
output += torch.rot90(rx, k=-rotations_count, dims=[2, 3])
output /= 4
output = output.view(b, c, h*self.scale, w*self.scale)
return output

@ -45,6 +45,7 @@ class TrainOptions:
parser.add_argument('--worker_num', '-n', type=int, default=1, help="Number of dataloader workers")
parser.add_argument('--prefetch_factor', '-p', type=int, default=16, help="Prefetch factor of dataloader workers.")
parser.add_argument('--save_predictions', action='store_true', default=False, help='Save model predictions to exp_dir/val/dataset_name')
parser.add_argument('--device', default='cuda', help='Device of the model')
self.parser = parser
def parse_args(self):
@ -90,6 +91,20 @@ def prepare_experiment_folder(config):
config.logs_dir.mkdir()
def dice_loss(inputs, targets, smooth=1):
#comment out if your model contains a sigmoid or equivalent activation layer
inputs = F.sigmoid(inputs)
#flatten label and prediction tensors
inputs = inputs.view(-1)
targets = targets.view(-1)
intersection = (inputs * targets).sum()
dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)
return 1 - dice
if __name__ == "__main__":
script_start_time = datetime.now()
@ -101,8 +116,9 @@ if __name__ == "__main__":
config.model = model.__class__.__name__
else:
model = AVAILABLE_MODELS[config.model]( hidden_dim = config.hidden_dim, scale = config.scale)
model = model.cuda()
model = model.to(torch.device(config.device))
optimizer = AdamWScheduleFree(model.parameters())
print(optimizer)
prepare_experiment_folder(config)
@ -153,9 +169,8 @@ if __name__ == "__main__":
if not config.model_path is None:
config.current_iter = i
valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Iter {i}")
for i in range(config.start_iter + 1, config.total_iter + 1):
# prof.step()
torch.cuda.empty_cache()
start_time = time.time()
try:
@ -163,12 +178,12 @@ if __name__ == "__main__":
except StopIteration:
train_iter = iter(train_loader)
hr_patch, lr_patch = next(train_iter)
hr_patch = hr_patch.cuda()
lr_patch = lr_patch.cuda()
hr_patch = hr_patch.to(torch.device(config.device))
lr_patch = lr_patch.to(torch.device(config.device))
prepare_data_time += time.time() - start_time
start_time = time.time()
pred = model(lr_patch)
loss = F.mse_loss(pred/255, hr_patch/255)
loss.backward()
@ -216,6 +231,10 @@ if __name__ == "__main__":
if link.exists():
link.unlink()
link.symlink_to(model_path)
link = Path(config.models_dir / f"last.pth")
if link.exists():
link.unlink()
link.symlink_to(model_path)
total_script_time = datetime.now() - script_start_time
config.logger.info(f"Completed after {total_script_time}")

@ -25,8 +25,8 @@ class TransferToLutOptions():
def parse_args(self):
args = self.parser.parse_args()
args.model_path = Path(args.model_path)
args.models_dir = Path(args.model_path).resolve().parent.parent.parent
args.model_path = Path(args.model_path).resolve()
args.models_dir = args.model_path.parent.parent.parent
args.checkpoint_dir = Path(args.model_path).resolve().parent
return args
@ -73,13 +73,21 @@ if __name__ == "__main__":
link = Path(config.models_dir / f"last_transfered_net.pth")
if link.exists():
link.unlink()
link.symlink_to(config.model_path.resolve())
link.symlink_to(config.model_path)
print("Updated link", link)
link = Path(config.models_dir / f"last_transfered_lut.pth")
if link.exists():
link.unlink()
link.symlink_to(lut_path.resolve())
print("Updated link", config.models_dir / f"last_transfered_net.pth")
print("Updated link", config.models_dir / f"last_transfered_lut.pth")
print("Updated link", link)
link = Path(config.models_dir / f"last.pth")
if link.exists():
link.unlink()
link.symlink_to(lut_path.resolve())
print("Updated link", link)
print()
print("Completed after", datetime.now()-start_time)

@ -23,22 +23,24 @@ import argparse
class ValOptions():
def __init__(self):
self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
self.parser.add_argument('--model_path', type=str, help="Model path.")
self.parser.add_argument('--model_path', type=str, default="../models/last.pth", help="Model path.")
self.parser.add_argument('--datasets_dir', type=str, default="../data/", help="Path to datasets.")
self.parser.add_argument('--val_datasets', type=str, default='Set5,Set14', help="Names of validation datasets.")
self.parser.add_argument('--save_predictions', action='store_true', default=True, help='Save model predictions to exp_dir/val/dataset_name')
self.parser.add_argument('--device', type=str, default='cuda', help='Device of the model')
def parse_args(self):
args = self.parser.parse_args()
args.datasets_dir = Path(args.datasets_dir).resolve()
args.val_datasets = args.val_datasets.split(',')
args.exp_dir = Path(args.model_path).absolute().parent.parent
args.model_path = Path(args.model_path)
args.exp_dir = Path(args.model_path).resolve().parent.parent
args.model_path = Path(args.model_path).resolve()
args.model_name = args.model_path.stem
args.valout_dir = Path(args.exp_dir)/ 'val'
args.valout_dir = Path(args.exp_dir).resolve() / 'val'
if not args.valout_dir.exists():
args.valout_dir.mkdir()
args.current_iter = args.model_name.split('_')[-1]
args.results_path = os.path.join(args.valout_dir, f'results_{args.device}.csv')
# Tensorboard for monitoring
writer = SummaryWriter(log_dir=args.valout_dir)
logger_name = f'val_{args.model_path.stem}'
@ -50,7 +52,7 @@ class ValOptions():
return args
def __repr__(self):
config = self.parser.parse_args()
config = self.parse_args()
message = ''
message += '----------------- Options ---------------\n'
for k, v in sorted(vars(config).items()):
@ -72,7 +74,7 @@ if __name__ == "__main__":
config.logger.info(config_inst)
model = LoadCheckpoint(config.model_path)
model = model.cuda()
model = model.to(torch.device(config.device))
print(model)
test_datasets = {}
@ -82,7 +84,13 @@ if __name__ == "__main__":
lr_dir_path = Path(config.datasets_dir) / test_dataset_name / "LR" / f"X{model.scale}",
)
valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Model {model.__class__.__name__}")
results = valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Model {model.__class__.__name__}")
results.to_csv(config.results_path)
print()
print(results)
print()
print(f"Results saved to {config.results_path}")
total_script_time = datetime.now() - script_start_time
config.logger.info(f"Completed after {total_script_time}")
Loading…
Cancel
Save