from . import rcnet from . import rclut from . import srnet from . import srlut from . import sdynet import torch import numpy as np from pathlib import Path AVAILABLE_MODELS = { 'SRNet': srnet.SRNet, 'SRLut': srlut.SRLut, 'SRNetRot90': srnet.SRNetRot90, 'SRLutRot90': srlut.SRLutRot90, 'RCNetCentered_3x3': rcnet.RCNetCentered_3x3, 'RCLutCentered_3x3': rclut.RCLutCentered_3x3, 'RCNetCentered_7x7': rcnet.RCNetCentered_7x7, 'RCLutCentered_7x7': rclut.RCLutCentered_7x7, 'RCNetRot90_3x3': rcnet.RCNetRot90_3x3, 'RCLutRot90_3x3': rclut.RCLutRot90_3x3, 'RCNetRot90_7x7': rcnet.RCNetRot90_7x7, 'RCLutRot90_7x7': rclut.RCLutRot90_7x7, 'RCNetx1': rcnet.RCNetx1, 'RCLutx1': rclut.RCLutx1, 'RCNetx2': rcnet.RCNetx2, 'RCLutx2': rclut.RCLutx2, 'RCNetx2Centered': rcnet.RCNetx2Centered, 'RCLutx2Centered': rclut.RCLutx2Centered, 'SDYNetx1': sdynet.SDYNetx1 } def SaveCheckpoint(model, path): model_container = { 'model': model.__class__.__name__, 'state_dict': model.state_dict(), **{k:v for k,v in model.__dict__.items() if k[0] != '_' and k != 'training'} } torch.save(model_container, path) def LoadCheckpoint(model_path): model_path = Path(model_path).absolute() if model_path.exists(): model_container = torch.load(model_path) model = AVAILABLE_MODELS[model_container['model']](**{k:v for k,v in model_container.items() if k != "model" and k != "state_dict"}) model.load_state_dict(model_container['state_dict'], strict=True) return model else: raise Exception(f"Path {model_path} does not exist.")