|
|
|
@ -19,7 +19,7 @@ import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
|
import functools
|
|
|
|
|
|
|
|
|
|
def save_checkpoint(credit_dataset, encoder, model, optimizer, epoch, loss, rocauc, сheсkpoints_dir):
|
|
|
|
|
def save_checkpoint(credit_dataset, encoder, model, optimizer, epoch, loss, rocauc, checkpoints_dir):
|
|
|
|
|
checkpoint = {
|
|
|
|
|
'encoder': {
|
|
|
|
|
'state_dict': encoder.state_dict(),
|
|
|
|
@ -38,7 +38,7 @@ def save_checkpoint(credit_dataset, encoder, model, optimizer, epoch, loss, roca
|
|
|
|
|
'train_uniq_client_ids_path': credit_dataset.train_uniq_client_ids_path,
|
|
|
|
|
'test_uniq_client_ids_path': credit_dataset.test_uniq_client_ids_path
|
|
|
|
|
}
|
|
|
|
|
path = сheсkpoints_dir + f"epoch_{epoch}_{rocauc:.4f}.pth"
|
|
|
|
|
path = checkpoints_dir + f"epoch_{epoch}_{rocauc:.4f}.pth"
|
|
|
|
|
# if torch.distributed.get_rank() == 0:
|
|
|
|
|
torch.save(checkpoint, path)
|
|
|
|
|
print(f"\nCheckpoint saved to {path}")
|
|
|
|
@ -284,11 +284,11 @@ num_workers = 10
|
|
|
|
|
comment = sys.argv[1]
|
|
|
|
|
logs_dir = f'runs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
|
|
|
|
|
writer = SummaryWriter(logs_dir)
|
|
|
|
|
сheсkpoints_dir = f'checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
|
|
|
|
|
checkpoints_dir = f'checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
|
|
|
|
|
script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).name)
|
|
|
|
|
Path(сheсkpoints_dir).mkdir(parents=True, exist_ok=True)
|
|
|
|
|
Path(checkpoints_dir).mkdir(parents=True, exist_ok=True)
|
|
|
|
|
print("Logs dir:", logs_dir)
|
|
|
|
|
print("Chekpoints dir:", сheсkpoints_dir)
|
|
|
|
|
print("Chekpoints dir:", checkpoints_dir)
|
|
|
|
|
script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
|
|
|
|
|
script_snapshot_path.chmod(0o400) # with read-only permission
|
|
|
|
|
|
|
|
|
@ -397,7 +397,7 @@ try:
|
|
|
|
|
epoch=epoch,
|
|
|
|
|
loss=loss.item(),
|
|
|
|
|
rocauc=rocauc,
|
|
|
|
|
сheсkpoints_dir=сheсkpoints_dir
|
|
|
|
|
checkpoints_dir=checkpoints_dir
|
|
|
|
|
)
|
|
|
|
|
except KeyboardInterrupt:
|
|
|
|
|
print()
|
|
|
|
@ -422,6 +422,6 @@ finally:
|
|
|
|
|
epoch=epoch,
|
|
|
|
|
loss=loss.item(),
|
|
|
|
|
rocauc=rocauc,
|
|
|
|
|
сheсkpoints_dir=сheсkpoints_dir
|
|
|
|
|
checkpoints_dir=checkpoints_dir
|
|
|
|
|
)
|
|
|
|
|
writer.close()
|
|
|
|
|