From 4377e655e2547360e87cbe28c7840ba19bdcaf51 Mon Sep 17 00:00:00 2001 From: vlpr Date: Wed, 15 May 2024 15:48:51 +0000 Subject: [PATCH] fix transfer --- src/common/lut.py | 4 ++-- src/models/srnet.py | 10 +++++----- src/validate.py | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/common/lut.py b/src/common/lut.py index 9791197..3de99b5 100644 --- a/src/common/lut.py +++ b/src/common/lut.py @@ -68,8 +68,8 @@ def transfer_2x2_input_SxS_output(block, quantization_interval=16, batch_size=2* for idx, (ix1s, ix2s, ix3s, ix4s, batch) in enumerate(domain_values_loader): inputs = batch.type(torch.float32).cuda() with torch.no_grad(): - outputs = block(inputs) - lut[ix1s, ix2s, ix3s, ix4s, ...] = outputs.cpu().numpy().astype(np.uint8) #[:,:,:scale,:scale] # TODO first layer automatically pad image + outputs = block(inputs)[:,:,:] + lut[ix1s, ix2s, ix3s, ix4s, ...] = outputs.cpu().numpy().astype(np.uint8)[:,:,:scale,:scale] counter += inputs.shape[0] print(f"\r {block.__class__.__name__} {counter}/{len(domain_values)}", end=" ") print() diff --git a/src/models/srnet.py b/src/models/srnet.py index 6b8a9f7..c73030f 100644 --- a/src/models/srnet.py +++ b/src/models/srnet.py @@ -30,7 +30,7 @@ class SRNet(nn.Module): return x def get_lut_model(self, quantization_interval=16, batch_size=2**10): - stage_lut = lut.transfer_2x2_input_SxS_output(self.stage, quantization_interval=quantization_interval, batch_size=batch_size) + stage_lut = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size) lut_model = srlut.SRLut.init_from_lut(stage_lut) return lut_model @@ -61,8 +61,8 @@ class SRNetR90(nn.Module): return output def get_lut_model(self, quantization_interval=16, batch_size=2**10): - stage_lut = lut.transfer_2x2_input_SxS_output(self.stage, quantization_interval=quantization_interval, batch_size=batch_size) - lut_model = srlut.SRLutRot90.init_from_lut(stage_lut) + stage_lut = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size) + lut_model = srlut.SRLutR90.init_from_lut(stage_lut) return lut_model class SRNetR90Y(nn.Module): @@ -100,6 +100,6 @@ class SRNetR90Y(nn.Module): return output def get_lut_model(self, quantization_interval=16, batch_size=2**10): - stage_lut = lut.transfer_2x2_input_SxS_output(self.stage, quantization_interval=quantization_interval, batch_size=batch_size) - lut_model = srlut.SRLutRot90Y.init_from_lut(stage_lut) + stage_lut = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size) + lut_model = srlut.SRLutR90Y.init_from_lut(stage_lut) return lut_model \ No newline at end of file diff --git a/src/validate.py b/src/validate.py index 630e37c..ace279a 100644 --- a/src/validate.py +++ b/src/validate.py @@ -87,7 +87,7 @@ if __name__ == "__main__": results = valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Model {config.model_name}") results.to_csv(config.results_path) - print(config.model_name) + print(config.exp_dir.stem) print(results) print() print(f"Results saved to {config.results_path}")