From 64674aab60c9803dd793329d0856b4ce453d2808 Mon Sep 17 00:00:00 2001 From: protsenkovi Date: Fri, 3 May 2024 01:37:25 +0400 Subject: [PATCH] upd --- readme.md | 5 ++++- src/common/layers.py | 14 ++++++++++---- src/{scripts => }/image_demo.py | 11 +++++------ src/models/sdylut.py | 12 ++++++------ src/{scripts => }/train.py | 6 +++--- src/{scripts => }/transfer_to_lut.py | 4 ++-- src/{scripts => }/validate.py | 4 ++-- 7 files changed, 32 insertions(+), 24 deletions(-) rename src/{scripts => }/image_demo.py (87%) rename src/{scripts => }/train.py (98%) rename src/{scripts => }/transfer_to_lut.py (94%) rename src/{scripts => }/validate.py (95%) diff --git a/readme.md b/readme.md index b526e2f..03c2a5a 100644 --- a/readme.md +++ b/readme.md @@ -21,4 +21,7 @@ python image_demo.py --help ``` Requirements: -- [shedulefree](https://github.com/facebookresearch/schedule_free) \ No newline at end of file +- [shedulefree](https://github.com/facebookresearch/schedule_free) +- tensorboard +- opencv-python-headless +- scipy \ No newline at end of file diff --git a/src/common/layers.py b/src/common/layers.py index 62a0bd3..738837d 100644 --- a/src/common/layers.py +++ b/src/common/layers.py @@ -19,8 +19,12 @@ class PercievePattern(): def __call__(self, x): b,c,h,w = x.shape - x = F.pad(x, pad=[self.center[0], self.window_size-self.center[0]-1, - self.center[1], self.window_size-self.center[1]-1], mode='replicate') + x = F.pad( + x, + pad=[self.center[0], self.window_size-self.center[0]-1, + self.center[1], self.window_size-self.center[1]-1], + mode='replicate' + ) x = F.unfold(input=x, kernel_size=self.window_size) x = torch.stack([ x[:,self.receptive_field_idxes[0],:], @@ -34,13 +38,15 @@ class PercievePattern(): # Huang G. et al. Densely connected convolutional networks //Proceedings of the IEEE conference on computer vision and pattern recognition. – 2017. – С. 4700-4708. # https://ar5iv.labs.arxiv.org/html/1608.06993 # https://github.com/andreasveit/densenet-pytorch/blob/63152f4a40644b62717749536ed2e011c6e4d9ab/densenet.py#L40 +# refactoring to linear give slight speed up, but require total rewrite to be consistent class DenseConvUpscaleBlock(nn.Module): def __init__(self, hidden_dim = 32, layers_count=5, upscale_factor=1): super(DenseConvUpscaleBlock, self).__init__() assert layers_count > 0 self.upscale_factor = upscale_factor self.hidden_dim = hidden_dim - self.percieve = nn.Conv2d(1, hidden_dim, kernel_size=(2, 2), padding='valid', stride=1, dilation=1, bias=True) + self.embed = nn.Conv2d(1, hidden_dim, kernel_size=(2, 2), padding='valid', stride=1, dilation=1, bias=True) + self.convs = [] for i in range(layers_count): self.convs.append(nn.Conv2d(in_channels = (i+1)*hidden_dim, out_channels = hidden_dim, kernel_size = 1, stride=1, padding=0, dilation=1, bias=True)) @@ -55,7 +61,7 @@ class DenseConvUpscaleBlock(nn.Module): def forward(self, x): x = (x-127.5)/127.5 - x = torch.relu(self.percieve(x)) + x = torch.relu(self.embed(x)) for conv in self.convs: x = torch.cat([x, torch.relu(conv(x))], dim=1) x = self.shuffle(self.project_channels(x)) diff --git a/src/scripts/image_demo.py b/src/image_demo.py similarity index 87% rename from src/scripts/image_demo.py rename to src/image_demo.py index 48a7c1c..609981d 100644 --- a/src/scripts/image_demo.py +++ b/src/image_demo.py @@ -1,6 +1,5 @@ from pathlib import Path import sys -sys.path.insert(0, str(Path("../").resolve()) + "/") from models import LoadCheckpoint import torch @@ -12,11 +11,11 @@ import argparse class ImageDemoOptions(): def __init__(self): self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) - self.parser.add_argument('--net_model_path', '-n', type=str, default="../../models/last_transfered_net.pth", help="Net model path folder") - self.parser.add_argument('--lut_model_path', '-l', type=str, default="../../models/last_transfered_lut.pth", help="Lut model path folder") - self.parser.add_argument('--hr_image_path', '-a', type=str, default="../../data/Set14/HR/monarch.png", help="HR image path") - self.parser.add_argument('--lr_image_path', '-b', type=str, default="../../data/Set14/LR/X4/monarch.png", help="LR image path") - self.parser.add_argument('--project_path', type=str, default="../../", help="Project path.") + self.parser.add_argument('--net_model_path', '-n', type=str, default="../models/last_transfered_net.pth", help="Net model path folder") + self.parser.add_argument('--lut_model_path', '-l', type=str, default="../models/last_transfered_lut.pth", help="Lut model path folder") + self.parser.add_argument('--hr_image_path', '-a', type=str, default="../data/Set14/HR/monarch.png", help="HR image path") + self.parser.add_argument('--lr_image_path', '-b', type=str, default="../data/Set14/LR/X4/monarch.png", help="LR image path") + self.parser.add_argument('--project_path', type=str, default="../", help="Project path.") self.parser.add_argument('--batch_size', type=int, default=2**10, help="Size of the batch for the input domain values.") def parse_args(self): diff --git a/src/models/sdylut.py b/src/models/sdylut.py index 813cdbd..e10ca3c 100644 --- a/src/models/sdylut.py +++ b/src/models/sdylut.py @@ -15,9 +15,9 @@ class SDYLutx1(nn.Module): super(SDYLutx1, self).__init__() self.scale = scale self.quantization_interval = quantization_interval - self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]]) - self._extract_pattern_D = PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]]) - self._extract_pattern_Y = PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]]) + self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=3) + self._extract_pattern_D = PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]], center=[0,0], window_size=3) + self._extract_pattern_Y = PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]], center=[0,0], window_size=3) self.stageS = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) self.stageD = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) self.stageY = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) @@ -78,9 +78,9 @@ class SDYLutx2(nn.Module): super(SDYLutx2, self).__init__() self.scale = scale self.quantization_interval = quantization_interval - self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]]) - self._extract_pattern_D = PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]]) - self._extract_pattern_Y = PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]]) + self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=3) + self._extract_pattern_D = PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]], center=[0,0], window_size=3) + self._extract_pattern_Y = PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]], center=[0,0], window_size=3) self.stageS_1 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) self.stageD_1 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) self.stageY_1 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32)) diff --git a/src/scripts/train.py b/src/train.py similarity index 98% rename from src/scripts/train.py rename to src/train.py index 4d1dbec..9e468ed 100644 --- a/src/scripts/train.py +++ b/src/train.py @@ -1,5 +1,5 @@ import sys -sys.path.insert(0, "../") # run under the project directory + from pickle import dump import logging import math @@ -35,8 +35,8 @@ class TrainOptions: parser.add_argument('--hidden_dim', type=int, default=64, help="number of filters of convolutional layers") parser.add_argument('--crop_size', type=int, default=48, help='input LR training patch size') parser.add_argument('--batch_size', type=int, default=16, help="Batch size for training") - parser.add_argument('--models_dir', type=str, default='../../models/', help="experiment folder") - parser.add_argument('--datasets_dir', type=str, default="../../data/") + parser.add_argument('--models_dir', type=str, default='../models/', help="experiment folder") + parser.add_argument('--datasets_dir', type=str, default="../data/") parser.add_argument('--start_iter', type=int, default=0, help='Set 0 for from scratch, else will load saved params and trains further') parser.add_argument('--total_iter', type=int, default=200000, help='Total number of training iterations') parser.add_argument('--display_step', type=int, default=100, help='display info every N iteration') diff --git a/src/scripts/transfer_to_lut.py b/src/transfer_to_lut.py similarity index 94% rename from src/scripts/transfer_to_lut.py rename to src/transfer_to_lut.py index b06475d..d098780 100644 --- a/src/scripts/transfer_to_lut.py +++ b/src/transfer_to_lut.py @@ -1,5 +1,5 @@ import sys -sys.path.insert(0, "../") # run under the project directory + import logging import math import os @@ -19,7 +19,7 @@ import models class TransferToLutOptions(): def __init__(self): self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) - self.parser.add_argument('--model_path', '-m', type=str, default='../../models/last_trained_net.pth', help="model path folder") + self.parser.add_argument('--model_path', '-m', type=str, default='../models/last_trained_net.pth', help="model path folder") self.parser.add_argument('--quantization_bits', '-q', type=int, default=4, help="Number of 4DLUT buckets defined as 2**bits. Value is in range [1, 8].") self.parser.add_argument('--batch_size', '-b', type=int, default=2**10, help="Size of the batch for the input domain values.") diff --git a/src/scripts/validate.py b/src/validate.py similarity index 95% rename from src/scripts/validate.py rename to src/validate.py index 396aab0..47587da 100644 --- a/src/scripts/validate.py +++ b/src/validate.py @@ -1,5 +1,5 @@ import sys -sys.path.insert(0, "../") # run under the project directory + import logging import math import os @@ -24,7 +24,7 @@ class ValOptions(): def __init__(self): self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) self.parser.add_argument('--model_path', type=str, help="Model path.") - self.parser.add_argument('--datasets_dir', type=str, default="../../data/", help="Path to datasets.") + self.parser.add_argument('--datasets_dir', type=str, default="../data/", help="Path to datasets.") self.parser.add_argument('--val_datasets', type=str, default='Set5,Set14', help="Names of validation datasets.") self.parser.add_argument('--save_predictions', action='store_true', default=True, help='Save model predictions to exp_dir/val/dataset_name')