fix transfer

main
vlpr 6 months ago
parent 80d57261e7
commit 4377e655e2

@ -68,8 +68,8 @@ def transfer_2x2_input_SxS_output(block, quantization_interval=16, batch_size=2*
for idx, (ix1s, ix2s, ix3s, ix4s, batch) in enumerate(domain_values_loader): for idx, (ix1s, ix2s, ix3s, ix4s, batch) in enumerate(domain_values_loader):
inputs = batch.type(torch.float32).cuda() inputs = batch.type(torch.float32).cuda()
with torch.no_grad(): with torch.no_grad():
outputs = block(inputs) outputs = block(inputs)[:,:,:]
lut[ix1s, ix2s, ix3s, ix4s, ...] = outputs.cpu().numpy().astype(np.uint8) #[:,:,:scale,:scale] # TODO first layer automatically pad image lut[ix1s, ix2s, ix3s, ix4s, ...] = outputs.cpu().numpy().astype(np.uint8)[:,:,:scale,:scale]
counter += inputs.shape[0] counter += inputs.shape[0]
print(f"\r {block.__class__.__name__} {counter}/{len(domain_values)}", end=" ") print(f"\r {block.__class__.__name__} {counter}/{len(domain_values)}", end=" ")
print() print()

@ -30,7 +30,7 @@ class SRNet(nn.Module):
return x return x
def get_lut_model(self, quantization_interval=16, batch_size=2**10): def get_lut_model(self, quantization_interval=16, batch_size=2**10):
stage_lut = lut.transfer_2x2_input_SxS_output(self.stage, quantization_interval=quantization_interval, batch_size=batch_size) stage_lut = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = srlut.SRLut.init_from_lut(stage_lut) lut_model = srlut.SRLut.init_from_lut(stage_lut)
return lut_model return lut_model
@ -61,8 +61,8 @@ class SRNetR90(nn.Module):
return output return output
def get_lut_model(self, quantization_interval=16, batch_size=2**10): def get_lut_model(self, quantization_interval=16, batch_size=2**10):
stage_lut = lut.transfer_2x2_input_SxS_output(self.stage, quantization_interval=quantization_interval, batch_size=batch_size) stage_lut = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = srlut.SRLutRot90.init_from_lut(stage_lut) lut_model = srlut.SRLutR90.init_from_lut(stage_lut)
return lut_model return lut_model
class SRNetR90Y(nn.Module): class SRNetR90Y(nn.Module):
@ -100,6 +100,6 @@ class SRNetR90Y(nn.Module):
return output return output
def get_lut_model(self, quantization_interval=16, batch_size=2**10): def get_lut_model(self, quantization_interval=16, batch_size=2**10):
stage_lut = lut.transfer_2x2_input_SxS_output(self.stage, quantization_interval=quantization_interval, batch_size=batch_size) stage_lut = lut.transfer_2x2_input_SxS_output(self.stage1_S, quantization_interval=quantization_interval, batch_size=batch_size)
lut_model = srlut.SRLutRot90Y.init_from_lut(stage_lut) lut_model = srlut.SRLutR90Y.init_from_lut(stage_lut)
return lut_model return lut_model

@ -87,7 +87,7 @@ if __name__ == "__main__":
results = valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Model {config.model_name}") results = valid_steps(model=model, datasets=test_datasets, config=config, log_prefix=f"Model {config.model_name}")
results.to_csv(config.results_path) results.to_csv(config.results_path)
print(config.model_name) print(config.exp_dir.stem)
print(results) print(results)
print() print()
print(f"Results saved to {config.results_path}") print(f"Results saved to {config.results_path}")

Loading…
Cancel
Save