diff --git a/src/train.py b/src/train.py index 8a97294..216a7e8 100644 --- a/src/train.py +++ b/src/train.py @@ -46,6 +46,8 @@ class TrainOptions: parser.add_argument('--prefetch_factor', '-p', type=int, default=16, help="Prefetch factor of dataloader workers.") parser.add_argument('--save_predictions', action='store_true', default=False, help='Save model predictions to exp_dir/val/dataset_name') parser.add_argument('--device', default='cuda', help='Device of the model') + parser.add_argument('--quantization_bits', '-q', type=int, default=4, help="Used when model is LUT. Number of 4DLUT buckets defined as 2**bits. Value is in range [1, 8].") + self.parser = parser def parse_args(self): @@ -101,7 +103,10 @@ if __name__ == "__main__": model = LoadCheckpoint(config.model_path) config.model = model.__class__.__name__ else: - model = AVAILABLE_MODELS[config.model]( hidden_dim = config.hidden_dim, scale = config.scale) + if 'net' in config.model.lower(): + model = AVAILABLE_MODELS[config.model]( hidden_dim = config.hidden_dim, scale = config.scale) + if 'lut' in config.model.lower(): + model = AVAILABLE_MODELS[config.model]( quantization_interval = 2**(8-config.quantization_bits), scale = config.scale) model = model.to(torch.device(config.device)) optimizer = AdamWScheduleFree(model.parameters()) print(optimizer)