You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

39 lines
1.6 KiB
Python

from . import rcnet
from . import rclut
from . import srnet
from . import srlut
from . import sdynet
import torch
import numpy as np
from pathlib import Path
AVAILABLE_MODELS = {
'SRNet': srnet.SRNet, 'SRLut': srlut.SRLut,
'SRNetRot90': srnet.SRNetRot90, 'SRLutRot90': srlut.SRLutRot90,
'RCNetCentered_3x3': rcnet.RCNetCentered_3x3, 'RCLutCentered_3x3': rclut.RCLutCentered_3x3,
'RCNetCentered_7x7': rcnet.RCNetCentered_7x7, 'RCLutCentered_7x7': rclut.RCLutCentered_7x7,
'RCNetRot90_3x3': rcnet.RCNetRot90_3x3, 'RCLutRot90_3x3': rclut.RCLutRot90_3x3,
'RCNetRot90_7x7': rcnet.RCNetRot90_7x7, 'RCLutRot90_7x7': rclut.RCLutRot90_7x7,
'RCNetx1': rcnet.RCNetx1, 'RCLutx1': rclut.RCLutx1,
'RCNetx2': rcnet.RCNetx2, 'RCLutx2': rclut.RCLutx2,
'RCNetx2Centered': rcnet.RCNetx2Centered, 'RCLutx2Centered': rclut.RCLutx2Centered,
'SDYNetx1': sdynet.SDYNetx1
}
def SaveCheckpoint(model, path):
model_container = {
'model': model.__class__.__name__,
'state_dict': model.state_dict(),
**{k:v for k,v in model.__dict__.items() if k[0] != '_' and k != 'training'}
}
torch.save(model_container, path)
def LoadCheckpoint(model_path):
model_path = Path(model_path).absolute()
if model_path.exists():
model_container = torch.load(model_path)
model = AVAILABLE_MODELS[model_container['model']](**{k:v for k,v in model_container.items() if k != "model" and k != "state_dict"})
model.load_state_dict(model_container['state_dict'], strict=True)
return model
else:
raise Exception(f"Path {model_path} does not exist.")