From 1a9a5fab1507aabdeaf5a11b8113606023aa217f Mon Sep 17 00:00:00 2001 From: protsenkovi Date: Mon, 6 May 2024 01:05:41 +0400 Subject: [PATCH] update --- readme.md | 4 ++-- src/common/data.py | 8 +++---- src/common/layers.py | 34 +++++++++++++++++++++++++++- src/common/validation.py | 36 ++++++++++++++++++++--------- src/models/__init__.py | 3 ++- src/models/sdylut.py | 30 ++++++++++++------------ src/models/sdynet.py | 40 ++++++++++++++++---------------- src/models/srnet.py | 49 ++++++++++++++++++++++++++++++---------- src/train.py | 31 ++++++++++++++++++++----- src/transfer_to_lut.py | 18 +++++++++++---- src/validate.py | 22 ++++++++++++------ 11 files changed, 191 insertions(+), 84 deletions(-) diff --git a/readme.md b/readme.md index 03c2a5a..f1bbf71 100644 --- a/readme.md +++ b/readme.md @@ -2,12 +2,12 @@ Example ``` python train.py --model SRNetRot90 -python validate.py --val_datasets Set5,Set14,B100,Urban100,Manga109 --model_path /wd/lut_reproduce/models/last_trained_net.pth +python validate.py --val_datasets Set5,Set14,B100,Urban100,Manga109 python transfer_to_lut.py python train.py --model_path /wd/lut_reproduce/models/last_transfered_lut.pth --total_iter 2000 -python validate.py --val_datasets Set5,Set14,B100,Urban100,Manga109 --model_path /wd/lut_reproduce/models/last_trained_lut.pth +python validate.py --val_datasets Set5,Set14,B100,Urban100,Manga109 python image_demo.py ``` diff --git a/src/common/data.py b/src/common/data.py index a19c0c5..ff3214a 100644 --- a/src/common/data.py +++ b/src/common/data.py @@ -55,12 +55,12 @@ class SRTrainDataset(Dataset): hr_patch = hr_image[ (i*scale):(i*scale + self.sz*scale), (j*scale):(j*scale + self.sz*scale), - c + c:(c+1) ] lr_patch = lr_image[ i:(i + self.sz), j:(j + self.sz), - c + c:(c+1) ] if self.rigid_aug: @@ -78,8 +78,8 @@ class SRTrainDataset(Dataset): hr_patch = torch.tensor(hr_patch.copy()).type(torch.float32) lr_patch = torch.tensor(lr_patch.copy()).type(torch.float32) - hr_patch = hr_patch.unsqueeze(0) - lr_patch = lr_patch.unsqueeze(0) + hr_patch = hr_patch.permute(2,0,1) + lr_patch = lr_patch.permute(2,0,1) return hr_patch, lr_patch diff --git a/src/common/layers.py b/src/common/layers.py index 738837d..54cf3f1 100644 --- a/src/common/layers.py +++ b/src/common/layers.py @@ -66,5 +66,37 @@ class DenseConvUpscaleBlock(nn.Module): x = torch.cat([x, torch.relu(conv(x))], dim=1) x = self.shuffle(self.project_channels(x)) x = torch.tanh(x) - x = round_func(x*127.5 + 127.5) + x = x*127.5 + 127.5 + x = round_func(x) + return x + +class ConvUpscaleBlock(nn.Module): + def __init__(self, hidden_dim = 32, layers_count=5, upscale_factor=1): + super(ConvUpscaleBlock, self).__init__() + assert layers_count > 0 + self.upscale_factor = upscale_factor + self.hidden_dim = hidden_dim + self.embed = nn.Conv2d(1, hidden_dim, kernel_size=(2, 2), padding='valid', stride=1, dilation=1, bias=True) + + self.convs = [] + for i in range(layers_count): + self.convs.append(nn.Conv2d(in_channels = hidden_dim, out_channels = hidden_dim, kernel_size = 1, stride=1, padding=0, dilation=1, bias=True)) + self.convs = nn.ModuleList(self.convs) + + for name, p in self.named_parameters(): + if "weight" in name: nn.init.kaiming_normal_(p) + if "bias" in name: nn.init.constant_(p, 0) + + self.project_channels = nn.Conv2d(in_channels = hidden_dim, out_channels = upscale_factor * upscale_factor, kernel_size = 1, stride=1, padding=0, dilation=1, bias=True) + self.shuffle = nn.PixelShuffle(upscale_factor) + + def forward(self, x): + x = (x-127.5)/127.5 + x = torch.relu(self.embed(x)) + for conv in self.convs: + x = torch.relu(conv(x)) + x = self.shuffle(self.project_channels(x)) + x = torch.tanh(x) + x = x*127.5 + 127.5 + x = round_func(x) return x \ No newline at end of file diff --git a/src/common/validation.py b/src/common/validation.py index 385147b..d585f56 100644 --- a/src/common/validation.py +++ b/src/common/validation.py @@ -1,16 +1,18 @@ import torch +import pandas as pd import numpy as np from common.utils import PSNR, cal_ssim, logger_info, _rgb2ycbcr, modcrop from pathlib import Path from PIL import Image -# import ray +import time # @ray.remote(num_cpus=1, num_gpus=0.3) -def val_image_pair(model, hr_image, lr_image, output_image_path=None): +def val_image_pair(model, hr_image, lr_image, output_image_path=None, device='cuda'): with torch.inference_mode(): + start_time = time.perf_counter_ns() # prepare lr_image lr_image = torch.tensor(lr_image).type(torch.float32).permute(2,0,1) - lr_image = lr_image.unsqueeze(0).cuda() + lr_image = lr_image.unsqueeze(0).to(torch.device(device)) b, c, h, w = lr_image.shape lr_image = lr_image.reshape(b*c, 1, h, w) # predict @@ -18,6 +20,7 @@ def val_image_pair(model, hr_image, lr_image, output_image_path=None): # postprocess pred_lr_image = pred_lr_image.reshape(b, c, h*model.scale, w*model.scale).squeeze(0).permute(1,2,0).type(torch.uint8) pred_lr_image = pred_lr_image.cpu().numpy() + run_time_ns = time.perf_counter_ns() - start_time torch.cuda.empty_cache() if not output_image_path is None: @@ -26,16 +29,17 @@ def val_image_pair(model, hr_image, lr_image, output_image_path=None): # metrics hr_image = modcrop(hr_image, model.scale) left, right = _rgb2ycbcr(pred_lr_image)[:, :, 0], _rgb2ycbcr(hr_image)[:, :, 0] - torch.cuda.empty_cache() - return PSNR(left, right, model.scale), cal_ssim(left, right) + return PSNR(left, right, model.scale), cal_ssim(left, right), run_time_ns def valid_steps(model, datasets, config, log_prefix=""): # ray.init(num_cpus=16, num_gpus=1, ignore_reinit_error=True, log_to_driver=False, runtime_env={"working_dir": "../"}) dataset_names = list(datasets.keys()) + results = [] for i in range(len(dataset_names)): dataset_name = dataset_names[i] - psnrs, ssims = [], [] + psnrs, ssims = [], [] + run_times_ns = [] predictions_path = config.valout_dir / dataset_name if not predictions_path.exists(): @@ -45,9 +49,8 @@ def valid_steps(model, datasets, config, log_prefix=""): tasks = [] for hr_image, lr_image, hr_image_path, lr_image_path in test_dataset: output_image_path = predictions_path / f'{Path(hr_image_path).stem}_rcnet.png' if config.save_predictions else None - task = val_image_pair(model, hr_image, lr_image, output_image_path) + task = val_image_pair(model, hr_image, lr_image, output_image_path, device=config.device) tasks.append(task) - # ready_refs, remaining_refs = ray.wait(tasks, num_returns=1, timeout=None) # while len(remaining_refs) > 0: # print(f"\rReady {len(ready_refs)+1}/{len(test_dataset)}", end=" ") @@ -55,11 +58,22 @@ def valid_steps(model, datasets, config, log_prefix=""): # print("\r", end=" ") # tasks = [ray.get(task) for task in tasks] - for psnr, ssim in tasks: + for psnr, ssim, run_time_ns in tasks: psnrs.append(psnr) ssims.append(ssim) + run_times_ns.append(run_time_ns) config.logger.info( '\r{} | Dataset {} | AVG Val PSNR: {:02f}, AVG: SSIM: {:04f}'.format(log_prefix, dataset_name, np.mean(np.asarray(psnrs)), np.mean(np.asarray(ssims)))) - # config.writer.add_scalar('PSNR_valid/{}'.format(dataset_name), np.mean(np.asarray(psnrs)), config.current_iter) - config.writer.flush() \ No newline at end of file + config.writer.flush() + + results.append([ + dataset_name, + np.mean(psnrs), + np.mean(ssims), + np.mean(run_times_ns)*1e-9, + np.percentile(run_times_ns, q=95)*1e-9]) + + results = pd.DataFrame(results, columns=['Dataset', 'PSNR', 'SSIM', f'AVG {config.device} Time, s', f'P95 {config.device} Time, s']).set_index('Dataset') + + return results \ No newline at end of file diff --git a/src/models/__init__.py b/src/models/__init__.py index 2a8e030..42a48db 100644 --- a/src/models/__init__.py +++ b/src/models/__init__.py @@ -9,7 +9,8 @@ from pathlib import Path AVAILABLE_MODELS = { 'SRNet': srnet.SRNet, 'SRLut': srlut.SRLut, - 'SRNetRot90': srnet.SRNetRot90, 'SRLutRot90': srlut.SRLutRot90, + 'SRNetDense': srnet.SRNetDense, + 'SRNetDenseRot90': srnet.SRNetDenseRot90, 'SRLutRot90': srlut.SRLutRot90, 'RCNetCentered_3x3': rcnet.RCNetCentered_3x3, 'RCLutCentered_3x3': rclut.RCLutCentered_3x3, 'RCNetCentered_7x7': rcnet.RCNetCentered_7x7, 'RCLutCentered_7x7': rclut.RCLutCentered_7x7, 'RCNetRot90_3x3': rcnet.RCNetRot90_3x3, 'RCLutRot90_3x3': rclut.RCLutRot90_3x3, diff --git a/src/models/sdylut.py b/src/models/sdylut.py index e10ca3c..f5dde9f 100644 --- a/src/models/sdylut.py +++ b/src/models/sdylut.py @@ -90,17 +90,17 @@ class SDYLutx2(nn.Module): @staticmethod def init_from_lut( - stageS_1, stageD_1, stageY_1, stageS_2, stageD_2, stageY_2 + stage1_S, stage1_D, stage1_Y, stage2_S, stage2_D, stage2_Y ): - scale = int(stageS_2.shape[-1]) - quantization_interval = 256//(stageS_2.shape[0]-1) + scale = int(stage2_S.shape[-1]) + quantization_interval = 256//(stage2_S.shape[0]-1) lut_model = SDYLutx2(quantization_interval=quantization_interval, scale=scale) - lut_model.stageS_1 = nn.Parameter(torch.tensor(stageS_1).type(torch.float32)) - lut_model.stageD_1 = nn.Parameter(torch.tensor(stageD_1).type(torch.float32)) - lut_model.stageY_1 = nn.Parameter(torch.tensor(stageY_1).type(torch.float32)) - lut_model.stageS_2 = nn.Parameter(torch.tensor(stageS_2).type(torch.float32)) - lut_model.stageD_2 = nn.Parameter(torch.tensor(stageD_2).type(torch.float32)) - lut_model.stageY_2 = nn.Parameter(torch.tensor(stageY_2).type(torch.float32)) + lut_model.stage1_S = nn.Parameter(torch.tensor(stage1_S).type(torch.float32)) + lut_model.stage1_D = nn.Parameter(torch.tensor(stage1_D).type(torch.float32)) + lut_model.stage1_Y = nn.Parameter(torch.tensor(stage1_Y).type(torch.float32)) + lut_model.stage2_S = nn.Parameter(torch.tensor(stage2_S).type(torch.float32)) + lut_model.stage2_D = nn.Parameter(torch.tensor(stage2_D).type(torch.float32)) + lut_model.stage2_Y = nn.Parameter(torch.tensor(stage2_Y).type(torch.float32)) return lut_model def forward(self, x): @@ -111,17 +111,17 @@ class SDYLutx2(nn.Module): rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1]) rb,rc,rh,rw = rotated.shape - s = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_S(rotated), lut=self.stageS_1) + s = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_S(rotated), lut=self.stage1_S) s = s.view(rb*rc, 1, rh, rw, 1, 1).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh, rw) s = torch.rot90(s, k=-rotations_count, dims=[-2, -1]) output += s - d = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_D(rotated), lut=self.stageD_1) + d = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_D(rotated), lut=self.stage1_D) d = d.view(rb*rc, 1, rh, rw, 1, 1).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh, rw) d = torch.rot90(d, k=-rotations_count, dims=[-2, -1]) output += d - y = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_Y(rotated), lut=self.stageY_1) + y = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_Y(rotated), lut=self.stage1_Y) y = y.view(rb*rc, 1, rh, rw, 1, 1).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh, rw) y = torch.rot90(y, k=-rotations_count, dims=[-2, -1]) output += y @@ -135,17 +135,17 @@ class SDYLutx2(nn.Module): rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1]) rb,rc,rh,rw = rotated.shape - s = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_S(rotated), lut=self.stageS_2) + s = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_S(rotated), lut=self.stage2_S) s = s.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale) s = torch.rot90(s, k=-rotations_count, dims=[-2, -1]) output += s - d = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_D(rotated), lut=self.stageD_2) + d = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_D(rotated), lut=self.stage2_D) d = d.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale) d = torch.rot90(d, k=-rotations_count, dims=[-2, -1]) output += d - y = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_Y(rotated), lut=self.stageY_2) + y = forward_unfolded_2x2_input_SxS_output(index=self._extract_pattern_Y(rotated), lut=self.stage2_Y) y = y.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale) y = torch.rot90(y, k=-rotations_count, dims=[-2, -1]) output += y diff --git a/src/models/sdynet.py b/src/models/sdynet.py index 19a7aba..0f87a54 100644 --- a/src/models/sdynet.py +++ b/src/models/sdynet.py @@ -41,7 +41,7 @@ class SDYNetx1(nn.Module): y = y.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale) y = torch.rot90(y, k=-rotations_count, dims=[-2, -1]) output += y - + output /= 4*3 output = output.view(b, c, h*self.scale, w*self.scale) return output @@ -61,12 +61,12 @@ class SDYNetx2(nn.Module): self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=3) self._extract_pattern_D = PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]], center=[0,0], window_size=3) self._extract_pattern_Y = PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]], center=[0,0], window_size=3) - self.stageS_1 = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1) - self.stageD_1 = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1) - self.stageY_1 = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1) - self.stageS_2 = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) - self.stageD_2 = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) - self.stageY_2 = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) + self.stage1_S = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1) + self.stage1_D = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1) + self.stage1_Y = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=1) + self.stage2_S = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) + self.stage2_D = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) + self.stage2_Y = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) def forward(self, x): b,c,h,w = x.shape @@ -76,17 +76,17 @@ class SDYNetx2(nn.Module): rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1]) rb,rc,rh,rw = rotated.shape - s = self.stageS_1(self._extract_pattern_S(rotated)) + s = self.stage1_S(self._extract_pattern_S(rotated)) s = s.view(rb*rc, 1, rh, rw, 1, 1).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh, rw) s = torch.rot90(s, k=-rotations_count, dims=[-2, -1]) output_1 += s - d = self.stageD_1(self._extract_pattern_D(rotated)) + d = self.stage1_D(self._extract_pattern_D(rotated)) d = d.view(rb*rc, 1, rh, rw, 1, 1).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh, rw) d = torch.rot90(d, k=-rotations_count, dims=[-2, -1]) output_1 += d - y = self.stageY_1(self._extract_pattern_Y(rotated)) + y = self.stage1_Y(self._extract_pattern_Y(rotated)) y = y.view(rb*rc, 1, rh, rw, 1, 1).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh, rw) y = torch.rot90(y, k=-rotations_count, dims=[-2, -1]) output_1 += y @@ -100,17 +100,17 @@ class SDYNetx2(nn.Module): rotated = torch.rot90(x, k=rotations_count, dims=[-2, -1]) rb,rc,rh,rw = rotated.shape - s = self.stageS_2(self._extract_pattern_S(rotated)) + s = self.stage2_S(self._extract_pattern_S(rotated)) s = s.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale) s = torch.rot90(s, k=-rotations_count, dims=[-2, -1]) output_2 += s - d = self.stageD_2(self._extract_pattern_D(rotated)) + d = self.stage2_D(self._extract_pattern_D(rotated)) d = d.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale) d = torch.rot90(d, k=-rotations_count, dims=[-2, -1]) output_2 += d - y = self.stageY_2(self._extract_pattern_Y(rotated)) + y = self.stage2_Y(self._extract_pattern_Y(rotated)) y = y.view(rb*rc, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(rb*rc, 1, rh*self.scale, rw*self.scale) y = torch.rot90(y, k=-rotations_count, dims=[-2, -1]) output_2 += y @@ -120,13 +120,13 @@ class SDYNetx2(nn.Module): return output_2 def get_lut_model(self, quantization_interval=16, batch_size=2**10): - stageS_1 = lut.transfer_2x2_input_SxS_output(self.stageS_1, quantization_interval=quantization_interval, batch_size=batch_size) - stageD_1 = lut.transfer_2x2_input_SxS_output(self.stageD_1, quantization_interval=quantization_interval, batch_size=batch_size) - stageY_1 = lut.transfer_2x2_input_SxS_output(self.stageY_1, quantization_interval=quantization_interval, batch_size=batch_size) - stageS_2 = lut.transfer_2x2_input_SxS_output(self.stageS_2, quantization_interval=quantization_interval, batch_size=batch_size) - stageD_2 = lut.transfer_2x2_input_SxS_output(self.stageD_2, quantization_interval=quantization_interval, batch_size=batch_size) - stageY_2 = lut.transfer_2x2_input_SxS_output(self.stageY_2, quantization_interval=quantization_interval, batch_size=batch_size) - lut_model = sdylut.SDYLutx2.init_from_lut(stageS_1, stageD_1, stageY_1, stageS_2, stageD_2, stageY_2) + stage1_S = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size) + stage1_D = lut.transfer_2x2_input_SxS_output(self.stage1_D, quantization_interval=quantization_interval, batch_size=batch_size) + stage1_Y = lut.transfer_2x2_input_SxS_output(self.stage1_Y, quantization_interval=quantization_interval, batch_size=batch_size) + stage2_S = lut.transfer_2x2_input_SxS_output(self.stage2_S, quantization_interval=quantization_interval, batch_size=batch_size) + stage2_D = lut.transfer_2x2_input_SxS_output(self.stage2_D, quantization_interval=quantization_interval, batch_size=batch_size) + stage2_Y = lut.transfer_2x2_input_SxS_output(self.stage2_Y, quantization_interval=quantization_interval, batch_size=batch_size) + lut_model = sdylut.SDYLutx2.init_from_lut(stage1_S, stage1_D, stage1_Y, stage2_S, stage2_D, stage2_Y) return lut_model diff --git a/src/models/srnet.py b/src/models/srnet.py index cd30d8c..e52d53c 100644 --- a/src/models/srnet.py +++ b/src/models/srnet.py @@ -6,20 +6,44 @@ from common.utils import round_func from common import lut from pathlib import Path from .srlut import SRLut, SRLutRot90 -from common.layers import DenseConvUpscaleBlock +from common.layers import PercievePattern, DenseConvUpscaleBlock, ConvUpscaleBlock class SRNet(nn.Module): def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): super(SRNet, self).__init__() - self.scale = scale + self.scale = scale + self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2) + self.stage = ConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) + + def forward(self, x): + b,c,h,w = x.shape + x = x.view(b*c, 1, h, w) + x = self._extract_pattern_S(x) + x = self.stage(x) + x = x.view(b*c, 1, h, w, self.scale, self.scale).permute(0,1,2,4,3,5) + x = x.reshape(b, c, h*self.scale, w*self.scale) + return x + + def get_lut_model(self, quantization_interval=16, batch_size=2**10): + stage_lut = lut.transfer_2x2_input_SxS_output(self.stage, quantization_interval=quantization_interval, batch_size=batch_size) + lut_model = SRLut.init_from_lut(stage_lut) + return lut_model + + +class SRNetDense(nn.Module): + def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): + super(SRNetDense, self).__init__() + self.scale = scale + self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2) self.stage = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) def forward(self, x): b,c,h,w = x.shape x = x.view(b*c, 1, h, w) - x = F.pad(x, pad=[0,1,0,1], mode='replicate') + x = self._extract_pattern_S(x) x = self.stage(x) - x = x.view(b, c, h*self.scale, w*self.scale) + x = x.view(b*c, 1, h, w, self.scale, self.scale).permute(0,1,2,4,3,5) + x = x.reshape(b, c, h*self.scale, w*self.scale) return x def get_lut_model(self, quantization_interval=16, batch_size=2**10): @@ -27,11 +51,11 @@ class SRNet(nn.Module): lut_model = SRLut.init_from_lut(stage_lut) return lut_model - -class SRNetRot90(nn.Module): +class SRNetDenseRot90(nn.Module): def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): - super(SRNetRot90, self).__init__() + super(SRNetDenseRot90, self).__init__() self.scale = scale + self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=2) self.stage = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=layers_count, upscale_factor=scale) def forward(self, x): @@ -39,11 +63,12 @@ class SRNetRot90(nn.Module): x = x.view(b*c, 1, h, w) output = torch.zeros([b*c, 1, h*self.scale, w*self.scale], dtype=x.dtype, device=x.device) for rotations_count in range(4): - rotated = torch.rot90(x, k=rotations_count, dims=[2, 3]) - rotated_padded = F.pad(rotated, pad=[0,1,0,1], mode='replicate') - rotated_prediction = self.stage(rotated_padded) - unrotated_prediction = torch.rot90(rotated_prediction, k=-rotations_count, dims=[2, 3]) - output += unrotated_prediction + rx = torch.rot90(x, k=rotations_count, dims=[2, 3]) + _,_,rh,rw = rx.shape + rx = self._extract_pattern_S(rx) + rx = self.stage(rx) + rx = rx.view(b*c, 1, rh, rw, self.scale, self.scale).permute(0,1,2,4,3,5).reshape(b*c, 1, rh*self.scale, rw*self.scale) + output += torch.rot90(rx, k=-rotations_count, dims=[2, 3]) output /= 4 output = output.view(b, c, h*self.scale, w*self.scale) return output diff --git a/src/train.py b/src/train.py index 9e468ed..0f5f936 100644 --- a/src/train.py +++ b/src/train.py @@ -45,6 +45,7 @@ class TrainOptions: parser.add_argument('--worker_num', '-n', type=int, default=1, help="Number of dataloader workers") parser.add_argument('--prefetch_factor', '-p', type=int, default=16, help="Prefetch factor of dataloader workers.") parser.add_argument('--save_predictions', action='store_true', default=False, help='Save model predictions to exp_dir/val/dataset_name') + parser.add_argument('--device', default='cuda', help='Device of the model') self.parser = parser def parse_args(self): @@ -90,6 +91,20 @@ def prepare_experiment_folder(config): config.logs_dir.mkdir() + +def dice_loss(inputs, targets, smooth=1): + #comment out if your model contains a sigmoid or equivalent activation layer + inputs = F.sigmoid(inputs) + + #flatten label and prediction tensors + inputs = inputs.view(-1) + targets = targets.view(-1) + + intersection = (inputs * targets).sum() + dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth) + return 1 - dice + + if __name__ == "__main__": script_start_time = datetime.now() @@ -101,8 +116,9 @@ if __name__ == "__main__": config.model = model.__class__.__name__ else: model = AVAILABLE_MODELS[config.model]( hidden_dim = config.hidden_dim, scale = config.scale) - model = model.cuda() + model = model.to(torch.device(config.device)) optimizer = AdamWScheduleFree(model.parameters()) + print(optimizer) prepare_experiment_folder(config) @@ -153,9 +169,8 @@ if __name__ == "__main__": if not config.model_path is None: config.current_iter = i valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Iter {i}") - + for i in range(config.start_iter + 1, config.total_iter + 1): - # prof.step() torch.cuda.empty_cache() start_time = time.time() try: @@ -163,12 +178,12 @@ if __name__ == "__main__": except StopIteration: train_iter = iter(train_loader) hr_patch, lr_patch = next(train_iter) - hr_patch = hr_patch.cuda() - lr_patch = lr_patch.cuda() + hr_patch = hr_patch.to(torch.device(config.device)) + lr_patch = lr_patch.to(torch.device(config.device)) prepare_data_time += time.time() - start_time start_time = time.time() - + pred = model(lr_patch) loss = F.mse_loss(pred/255, hr_patch/255) loss.backward() @@ -216,6 +231,10 @@ if __name__ == "__main__": if link.exists(): link.unlink() link.symlink_to(model_path) + link = Path(config.models_dir / f"last.pth") + if link.exists(): + link.unlink() + link.symlink_to(model_path) total_script_time = datetime.now() - script_start_time config.logger.info(f"Completed after {total_script_time}") \ No newline at end of file diff --git a/src/transfer_to_lut.py b/src/transfer_to_lut.py index d098780..fbc2052 100644 --- a/src/transfer_to_lut.py +++ b/src/transfer_to_lut.py @@ -25,8 +25,8 @@ class TransferToLutOptions(): def parse_args(self): args = self.parser.parse_args() - args.model_path = Path(args.model_path) - args.models_dir = Path(args.model_path).resolve().parent.parent.parent + args.model_path = Path(args.model_path).resolve() + args.models_dir = args.model_path.parent.parent.parent args.checkpoint_dir = Path(args.model_path).resolve().parent return args @@ -73,13 +73,21 @@ if __name__ == "__main__": link = Path(config.models_dir / f"last_transfered_net.pth") if link.exists(): link.unlink() - link.symlink_to(config.model_path.resolve()) + link.symlink_to(config.model_path) + print("Updated link", link) + link = Path(config.models_dir / f"last_transfered_lut.pth") if link.exists(): link.unlink() link.symlink_to(lut_path.resolve()) - print("Updated link", config.models_dir / f"last_transfered_net.pth") - print("Updated link", config.models_dir / f"last_transfered_lut.pth") + print("Updated link", link) + + link = Path(config.models_dir / f"last.pth") + if link.exists(): + link.unlink() + link.symlink_to(lut_path.resolve()) + print("Updated link", link) + print() print("Completed after", datetime.now()-start_time) \ No newline at end of file diff --git a/src/validate.py b/src/validate.py index 47587da..76640c0 100644 --- a/src/validate.py +++ b/src/validate.py @@ -23,22 +23,24 @@ import argparse class ValOptions(): def __init__(self): self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) - self.parser.add_argument('--model_path', type=str, help="Model path.") + self.parser.add_argument('--model_path', type=str, default="../models/last.pth", help="Model path.") self.parser.add_argument('--datasets_dir', type=str, default="../data/", help="Path to datasets.") self.parser.add_argument('--val_datasets', type=str, default='Set5,Set14', help="Names of validation datasets.") self.parser.add_argument('--save_predictions', action='store_true', default=True, help='Save model predictions to exp_dir/val/dataset_name') + self.parser.add_argument('--device', type=str, default='cuda', help='Device of the model') def parse_args(self): args = self.parser.parse_args() args.datasets_dir = Path(args.datasets_dir).resolve() args.val_datasets = args.val_datasets.split(',') - args.exp_dir = Path(args.model_path).absolute().parent.parent - args.model_path = Path(args.model_path) + args.exp_dir = Path(args.model_path).resolve().parent.parent + args.model_path = Path(args.model_path).resolve() args.model_name = args.model_path.stem - args.valout_dir = Path(args.exp_dir)/ 'val' + args.valout_dir = Path(args.exp_dir).resolve() / 'val' if not args.valout_dir.exists(): args.valout_dir.mkdir() args.current_iter = args.model_name.split('_')[-1] + args.results_path = os.path.join(args.valout_dir, f'results_{args.device}.csv') # Tensorboard for monitoring writer = SummaryWriter(log_dir=args.valout_dir) logger_name = f'val_{args.model_path.stem}' @@ -50,7 +52,7 @@ class ValOptions(): return args def __repr__(self): - config = self.parser.parse_args() + config = self.parse_args() message = '' message += '----------------- Options ---------------\n' for k, v in sorted(vars(config).items()): @@ -72,7 +74,7 @@ if __name__ == "__main__": config.logger.info(config_inst) model = LoadCheckpoint(config.model_path) - model = model.cuda() + model = model.to(torch.device(config.device)) print(model) test_datasets = {} @@ -82,7 +84,13 @@ if __name__ == "__main__": lr_dir_path = Path(config.datasets_dir) / test_dataset_name / "LR" / f"X{model.scale}", ) - valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Model {model.__class__.__name__}") + results = valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Model {model.__class__.__name__}") + + results.to_csv(config.results_path) + print() + print(results) + print() + print(f"Results saved to {config.results_path}") total_script_time = datetime.now() - script_start_time config.logger.info(f"Completed after {total_script_time}") \ No newline at end of file