diff --git a/src/models/__init__.py b/src/models/__init__.py index b61bbf5..875d7f7 100644 --- a/src/models/__init__.py +++ b/src/models/__init__.py @@ -34,7 +34,7 @@ AVAILABLE_MODELS = { 'HDBLNet': hdbnet.HDBLNet, 'HDBHNet': hdbnet.HDBHNet, 'SRMsbLsbNet': srnet.SRMsbLsbNet, - 'SRMsbLsbShift2Net': srnet.SRMsbLsbShift2Net, + 'SRMsbLsbShiftNet': srnet.SRMsbLsbShiftNet, 'SRMsbLsbR90Net': srnet.SRMsbLsbR90Net, 'SRMsbLsb4R90Net': srnet.SRMsbLsb4R90Net, # 'RCNetCentered_3x3': rcnet.RCNetCentered_3x3, 'RCLutCentered_3x3': rclut.RCLutCentered_3x3, diff --git a/src/models/srnet.py b/src/models/srnet.py index 15e5c4b..fbe2acc 100644 --- a/src/models/srnet.py +++ b/src/models/srnet.py @@ -201,9 +201,9 @@ class SRMsbLsbNet(SRNetBase): raise NotImplementedError -class SRMsbLsbShift2Net(SRNetBase): +class SRMsbLsbShiftNet(SRNetBase): def __init__(self, hidden_dim = 64, layers_count = 4, scale = 4): - super(SRMsbLsbShift2Net, self).__init__() + super(SRMsbLsbShiftNet, self).__init__() self.scale = scale self.hidden_dim = hidden_dim self.layers_count = layers_count