protsenkovi 6 months ago
parent 6d1ffb551b
commit 64674aab60

@ -21,4 +21,7 @@ python image_demo.py --help
```
Requirements:
- [shedulefree](https://github.com/facebookresearch/schedule_free)
- [shedulefree](https://github.com/facebookresearch/schedule_free)
- tensorboard
- opencv-python-headless
- scipy

@ -19,8 +19,12 @@ class PercievePattern():
def __call__(self, x):
b,c,h,w = x.shape
x = F.pad(x, pad=[self.center[0], self.window_size-self.center[0]-1,
self.center[1], self.window_size-self.center[1]-1], mode='replicate')
x = F.pad(
x,
pad=[self.center[0], self.window_size-self.center[0]-1,
self.center[1], self.window_size-self.center[1]-1],
mode='replicate'
)
x = F.unfold(input=x, kernel_size=self.window_size)
x = torch.stack([
x[:,self.receptive_field_idxes[0],:],
@ -34,13 +38,15 @@ class PercievePattern():
# Huang G. et al. Densely connected convolutional networks //Proceedings of the IEEE conference on computer vision and pattern recognition. 2017. С. 4700-4708.
# https://ar5iv.labs.arxiv.org/html/1608.06993
# https://github.com/andreasveit/densenet-pytorch/blob/63152f4a40644b62717749536ed2e011c6e4d9ab/densenet.py#L40
# refactoring to linear give slight speed up, but require total rewrite to be consistent
class DenseConvUpscaleBlock(nn.Module):
def __init__(self, hidden_dim = 32, layers_count=5, upscale_factor=1):
super(DenseConvUpscaleBlock, self).__init__()
assert layers_count > 0
self.upscale_factor = upscale_factor
self.hidden_dim = hidden_dim
self.percieve = nn.Conv2d(1, hidden_dim, kernel_size=(2, 2), padding='valid', stride=1, dilation=1, bias=True)
self.embed = nn.Conv2d(1, hidden_dim, kernel_size=(2, 2), padding='valid', stride=1, dilation=1, bias=True)
self.convs = []
for i in range(layers_count):
self.convs.append(nn.Conv2d(in_channels = (i+1)*hidden_dim, out_channels = hidden_dim, kernel_size = 1, stride=1, padding=0, dilation=1, bias=True))
@ -55,7 +61,7 @@ class DenseConvUpscaleBlock(nn.Module):
def forward(self, x):
x = (x-127.5)/127.5
x = torch.relu(self.percieve(x))
x = torch.relu(self.embed(x))
for conv in self.convs:
x = torch.cat([x, torch.relu(conv(x))], dim=1)
x = self.shuffle(self.project_channels(x))

@ -1,6 +1,5 @@
from pathlib import Path
import sys
sys.path.insert(0, str(Path("../").resolve()) + "/")
from models import LoadCheckpoint
import torch
@ -12,11 +11,11 @@ import argparse
class ImageDemoOptions():
def __init__(self):
self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
self.parser.add_argument('--net_model_path', '-n', type=str, default="../../models/last_transfered_net.pth", help="Net model path folder")
self.parser.add_argument('--lut_model_path', '-l', type=str, default="../../models/last_transfered_lut.pth", help="Lut model path folder")
self.parser.add_argument('--hr_image_path', '-a', type=str, default="../../data/Set14/HR/monarch.png", help="HR image path")
self.parser.add_argument('--lr_image_path', '-b', type=str, default="../../data/Set14/LR/X4/monarch.png", help="LR image path")
self.parser.add_argument('--project_path', type=str, default="../../", help="Project path.")
self.parser.add_argument('--net_model_path', '-n', type=str, default="../models/last_transfered_net.pth", help="Net model path folder")
self.parser.add_argument('--lut_model_path', '-l', type=str, default="../models/last_transfered_lut.pth", help="Lut model path folder")
self.parser.add_argument('--hr_image_path', '-a', type=str, default="../data/Set14/HR/monarch.png", help="HR image path")
self.parser.add_argument('--lr_image_path', '-b', type=str, default="../data/Set14/LR/X4/monarch.png", help="LR image path")
self.parser.add_argument('--project_path', type=str, default="../", help="Project path.")
self.parser.add_argument('--batch_size', type=int, default=2**10, help="Size of the batch for the input domain values.")
def parse_args(self):

@ -15,9 +15,9 @@ class SDYLutx1(nn.Module):
super(SDYLutx1, self).__init__()
self.scale = scale
self.quantization_interval = quantization_interval
self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]])
self._extract_pattern_D = PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]])
self._extract_pattern_Y = PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]])
self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=3)
self._extract_pattern_D = PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]], center=[0,0], window_size=3)
self._extract_pattern_Y = PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]], center=[0,0], window_size=3)
self.stageS = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.stageD = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.stageY = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
@ -78,9 +78,9 @@ class SDYLutx2(nn.Module):
super(SDYLutx2, self).__init__()
self.scale = scale
self.quantization_interval = quantization_interval
self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]])
self._extract_pattern_D = PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]])
self._extract_pattern_Y = PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]])
self._extract_pattern_S = PercievePattern(receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]], center=[0,0], window_size=3)
self._extract_pattern_D = PercievePattern(receptive_field_idxes=[[0,0],[2,0],[0,2],[2,2]], center=[0,0], window_size=3)
self._extract_pattern_Y = PercievePattern(receptive_field_idxes=[[0,0],[1,1],[1,2],[2,1]], center=[0,0], window_size=3)
self.stageS_1 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.stageD_1 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))
self.stageY_1 = nn.Parameter(torch.randint(0, 255, size=(256//quantization_interval+1,)*4 + (scale,scale)).type(torch.float32))

@ -1,5 +1,5 @@
import sys
sys.path.insert(0, "../") # run under the project directory
from pickle import dump
import logging
import math
@ -35,8 +35,8 @@ class TrainOptions:
parser.add_argument('--hidden_dim', type=int, default=64, help="number of filters of convolutional layers")
parser.add_argument('--crop_size', type=int, default=48, help='input LR training patch size')
parser.add_argument('--batch_size', type=int, default=16, help="Batch size for training")
parser.add_argument('--models_dir', type=str, default='../../models/', help="experiment folder")
parser.add_argument('--datasets_dir', type=str, default="../../data/")
parser.add_argument('--models_dir', type=str, default='../models/', help="experiment folder")
parser.add_argument('--datasets_dir', type=str, default="../data/")
parser.add_argument('--start_iter', type=int, default=0, help='Set 0 for from scratch, else will load saved params and trains further')
parser.add_argument('--total_iter', type=int, default=200000, help='Total number of training iterations')
parser.add_argument('--display_step', type=int, default=100, help='display info every N iteration')

@ -1,5 +1,5 @@
import sys
sys.path.insert(0, "../") # run under the project directory
import logging
import math
import os
@ -19,7 +19,7 @@ import models
class TransferToLutOptions():
def __init__(self):
self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
self.parser.add_argument('--model_path', '-m', type=str, default='../../models/last_trained_net.pth', help="model path folder")
self.parser.add_argument('--model_path', '-m', type=str, default='../models/last_trained_net.pth', help="model path folder")
self.parser.add_argument('--quantization_bits', '-q', type=int, default=4, help="Number of 4DLUT buckets defined as 2**bits. Value is in range [1, 8].")
self.parser.add_argument('--batch_size', '-b', type=int, default=2**10, help="Size of the batch for the input domain values.")

@ -1,5 +1,5 @@
import sys
sys.path.insert(0, "../") # run under the project directory
import logging
import math
import os
@ -24,7 +24,7 @@ class ValOptions():
def __init__(self):
self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
self.parser.add_argument('--model_path', type=str, help="Model path.")
self.parser.add_argument('--datasets_dir', type=str, default="../../data/", help="Path to datasets.")
self.parser.add_argument('--datasets_dir', type=str, default="../data/", help="Path to datasets.")
self.parser.add_argument('--val_datasets', type=str, default='Set5,Set14', help="Names of validation datasets.")
self.parser.add_argument('--save_predictions', action='store_true', default=True, help='Save model predictions to exp_dir/val/dataset_name')
Loading…
Cancel
Save