allow train lut from scratch

main
vlpr 7 months ago
parent d33b65c235
commit b15b238b53

@ -46,6 +46,8 @@ class TrainOptions:
parser.add_argument('--prefetch_factor', '-p', type=int, default=16, help="Prefetch factor of dataloader workers.") parser.add_argument('--prefetch_factor', '-p', type=int, default=16, help="Prefetch factor of dataloader workers.")
parser.add_argument('--save_predictions', action='store_true', default=False, help='Save model predictions to exp_dir/val/dataset_name') parser.add_argument('--save_predictions', action='store_true', default=False, help='Save model predictions to exp_dir/val/dataset_name')
parser.add_argument('--device', default='cuda', help='Device of the model') parser.add_argument('--device', default='cuda', help='Device of the model')
parser.add_argument('--quantization_bits', '-q', type=int, default=4, help="Used when model is LUT. Number of 4DLUT buckets defined as 2**bits. Value is in range [1, 8].")
self.parser = parser self.parser = parser
def parse_args(self): def parse_args(self):
@ -101,7 +103,10 @@ if __name__ == "__main__":
model = LoadCheckpoint(config.model_path) model = LoadCheckpoint(config.model_path)
config.model = model.__class__.__name__ config.model = model.__class__.__name__
else: else:
model = AVAILABLE_MODELS[config.model]( hidden_dim = config.hidden_dim, scale = config.scale) if 'net' in config.model.lower():
model = AVAILABLE_MODELS[config.model]( hidden_dim = config.hidden_dim, scale = config.scale)
if 'lut' in config.model.lower():
model = AVAILABLE_MODELS[config.model]( quantization_interval = 2**(8-config.quantization_bits), scale = config.scale)
model = model.to(torch.device(config.device)) model = model.to(torch.device(config.device))
optimizer = AdamWScheduleFree(model.parameters()) optimizer = AdamWScheduleFree(model.parameters())
print(optimizer) print(optimizer)

Loading…
Cancel
Save