diff --git a/src/models/__init__.py b/src/models/__init__.py index 26e5516..b1ba782 100644 --- a/src/models/__init__.py +++ b/src/models/__init__.py @@ -8,22 +8,25 @@ import numpy as np from pathlib import Path AVAILABLE_MODELS = { - 'SRNet': srnet.SRNet, 'SRLut': srlut.SRLut, - 'SRNetDense': srnet.SRNetDense, - 'SRNetDenseRot90': srnet.SRNetDenseRot90, 'SRLutRot90': srlut.SRLutRot90, - 'SRNetDenseRot90Y': srnet.SRNetDenseRot90Y, 'SRLutRot90Y': srlut.SRLutRot90Y, - 'RCNetCentered_3x3': rcnet.RCNetCentered_3x3, 'RCLutCentered_3x3': rclut.RCLutCentered_3x3, - 'RCNetCentered_7x7': rcnet.RCNetCentered_7x7, 'RCLutCentered_7x7': rclut.RCLutCentered_7x7, - 'RCNetRot90_3x3': rcnet.RCNetRot90_3x3, 'RCLutRot90_3x3': rclut.RCLutRot90_3x3, - 'RCNetRot90_7x7': rcnet.RCNetRot90_7x7, 'RCLutRot90_7x7': rclut.RCLutRot90_7x7, - 'RCNetx1': rcnet.RCNetx1, 'RCLutx1': rclut.RCLutx1, - 'RCNetx2': rcnet.RCNetx2, 'RCLutx2': rclut.RCLutx2, - 'RCNetx2Centered': rcnet.RCNetx2Centered, 'RCLutx2Centered': rclut.RCLutx2Centered, - 'SDYNetx1': sdynet.SDYNetx1, 'SDYLutx1': sdylut.SDYLutx1, - 'SDYNetCenteredx1': sdynet.SDYNetCenteredx1, 'SDYLutCenteredx1': sdylut.SDYLutCenteredx1, - 'SDYNetx2': sdynet.SDYNetx2, 'SDYLutx2': sdylut.SDYLutx2, - 'RCNetx2Unlutable': rcnet.RCNetx2Unlutable, - 'RCNetx2CenteredUnlutable': rcnet.RCNetx2CenteredUnlutable, + 'SRNet': srnet.SRNet, + 'SRLut': srlut.SRLut, + 'SRNetR90': srnet.SRNetR90, + 'SRLutR90': srlut.SRLutR90, + 'SRNetR90Y': srnet.SRNetR90Y, + 'SRLutR90Y': srlut.SRLutR90Y, + 'SDYNetx1': sdynet.SDYNetx1, + 'SDYLutx1': sdylut.SDYLutx1, + 'SDYNetx2': sdynet.SDYNetx2, + 'SDYLutx2': sdylut.SDYLutx2, + # 'RCNetCentered_3x3': rcnet.RCNetCentered_3x3, 'RCLutCentered_3x3': rclut.RCLutCentered_3x3, + # 'RCNetCentered_7x7': rcnet.RCNetCentered_7x7, 'RCLutCentered_7x7': rclut.RCLutCentered_7x7, + # 'RCNetRot90_3x3': rcnet.RCNetRot90_3x3, 'RCLutRot90_3x3': rclut.RCLutRot90_3x3, + # 'RCNetRot90_7x7': rcnet.RCNetRot90_7x7, 'RCLutRot90_7x7': rclut.RCLutRot90_7x7, + # 'RCNetx1': rcnet.RCNetx1, 'RCLutx1': rclut.RCLutx1, + # 'RCNetx2': rcnet.RCNetx2, 'RCLutx2': rclut.RCLutx2, + # 'RCNetx2Centered': rcnet.RCNetx2Centered, 'RCLutx2Centered': rclut.RCLutx2Centered, + # 'RCNetx2Unlutable': rcnet.RCNetx2Unlutable, + # 'RCNetx2CenteredUnlutable': rcnet.RCNetx2CenteredUnlutable, } def SaveCheckpoint(model, path):