main
vlpr 6 months ago
parent f077add65b
commit 96d27bfd05

@ -6,7 +6,7 @@ from common.utils import round_func
from pathlib import Path from pathlib import Path
from common import lut from common import lut
from . import rclut from . import rclut
from common.layers import DenseConvUpscaleBlock from common import layers
class ReconstructedConvCentered(nn.Module): class ReconstructedConvCentered(nn.Module):
def __init__(self, hidden_dim, window_size=7): def __init__(self, hidden_dim, window_size=7):
@ -40,7 +40,7 @@ class RCBlockCentered(nn.Module):
super(RCBlockCentered, self).__init__() super(RCBlockCentered, self).__init__()
self.window_size = window_size self.window_size = window_size
self.rc_conv = ReconstructedConvCentered(hidden_dim=hidden_dim, window_size=window_size) self.rc_conv = ReconstructedConvCentered(hidden_dim=hidden_dim, window_size=window_size)
self.dense_conv_block = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor) self.dense_conv_block = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor)
def forward(self, x): def forward(self, x):
b,c,hs,ws = x.shape b,c,hs,ws = x.shape
@ -127,7 +127,7 @@ class RCBlockRot90(nn.Module):
super(RCBlockRot90, self).__init__() super(RCBlockRot90, self).__init__()
self.window_size = window_size self.window_size = window_size
self.rc_conv = ReconstructedConvRot90(hidden_dim=hidden_dim, window_size=window_size) self.rc_conv = ReconstructedConvRot90(hidden_dim=hidden_dim, window_size=window_size)
self.dense_conv_block = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor) self.dense_conv_block = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor)
def forward(self, x): def forward(self, x):
b,c,hs,ws = x.shape b,c,hs,ws = x.shape
@ -390,7 +390,7 @@ class RCBlockRot90Unlutable(nn.Module):
super(RCBlockRot90Unlutable, self).__init__() super(RCBlockRot90Unlutable, self).__init__()
self.window_size = window_size self.window_size = window_size
self.rc_conv = ReconstructedConvRot90Unlutable(hidden_dim=hidden_dim, window_size=window_size) self.rc_conv = ReconstructedConvRot90Unlutable(hidden_dim=hidden_dim, window_size=window_size)
self.dense_conv_block = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor) self.dense_conv_block = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor)
def forward(self, x): def forward(self, x):
b,c,hs,ws = x.shape b,c,hs,ws = x.shape
@ -495,7 +495,7 @@ class RCBlockCenteredUnlutable(nn.Module):
super(RCBlockRot90Unlutable, self).__init__() super(RCBlockRot90Unlutable, self).__init__()
self.window_size = window_size self.window_size = window_size
self.rc_conv = ReconstructedConvCenteredUnlutable(hidden_dim=hidden_dim, window_size=window_size) self.rc_conv = ReconstructedConvCenteredUnlutable(hidden_dim=hidden_dim, window_size=window_size)
self.dense_conv_block = DenseConvUpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor) self.dense_conv_block = layers.UpscaleBlock(hidden_dim=hidden_dim, layers_count=dense_conv_layer_count, upscale_factor=upscale_factor)
def forward(self, x): def forward(self, x):
b,c,hs,ws = x.shape b,c,hs,ws = x.shape

Loading…
Cancel
Save