ru char to en

main
Vladimir 2 weeks ago
parent e6c390e273
commit bd2e602814

@ -26,7 +26,7 @@ from torch.distributed.fsdp.wrap import (
) )
import functools import functools
def save_checkpoint(credit_dataset, encoder, model, optimizer, epoch, loss, rocauc, сheсkpoints_dir): def save_checkpoint(credit_dataset, encoder, model, optimizer, epoch, loss, rocauc, checkpoints_dir):
checkpoint = { checkpoint = {
'encoder': { 'encoder': {
'state_dict': encoder.state_dict(), 'state_dict': encoder.state_dict(),
@ -45,7 +45,7 @@ def save_checkpoint(credit_dataset, encoder, model, optimizer, epoch, loss, roca
'train_uniq_client_ids_path': credit_dataset.train_uniq_client_ids_path, 'train_uniq_client_ids_path': credit_dataset.train_uniq_client_ids_path,
'test_uniq_client_ids_path': credit_dataset.test_uniq_client_ids_path 'test_uniq_client_ids_path': credit_dataset.test_uniq_client_ids_path
} }
path = сheсkpoints_dir + f"epoch_{epoch}_{rocauc:.4f}.pth" path = checkpoints_dir + f"epoch_{epoch}_{rocauc:.4f}.pth"
# if torch.distributed.get_rank() == 0: # if torch.distributed.get_rank() == 0:
torch.save(checkpoint, path) torch.save(checkpoint, path)
print(f"\nCheckpoint saved to {path}") print(f"\nCheckpoint saved to {path}")
@ -330,11 +330,11 @@ if __name__ == "__main__":
comment = sys.argv[1] comment = sys.argv[1]
logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
writer = SummaryWriter(logs_dir) writer = SummaryWriter(logs_dir)
сheсkpoints_dir = f'checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' checkpoints_dir = f'checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).name) script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).name)
Path(сheсkpoints_dir).mkdir(parents=True, exist_ok=True) Path(checkpoints_dir).mkdir(parents=True, exist_ok=True)
print("Logs dir:", logs_dir) print("Logs dir:", logs_dir)
print("Chekpoints dir:", сheсkpoints_dir) print("Chekpoints dir:", checkpoints_dir)
script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
script_snapshot_path.chmod(0o400) # with read-only permission script_snapshot_path.chmod(0o400) # with read-only permission
@ -432,7 +432,7 @@ if __name__ == "__main__":
save_checkpoint( save_checkpoint(
credit_dataset=credit_train_dataset, credit_dataset=credit_train_dataset,
encoder = model.module.encoder, model=model.module.classifier, optimizer=optimizer, epoch=epoch, encoder = model.module.encoder, model=model.module.classifier, optimizer=optimizer, epoch=epoch,
loss=ddp_loss[0], rocauc=rocauc, сheсkpoints_dir=сheсkpoints_dir) loss=ddp_loss[0], rocauc=rocauc, checkpoints_dir=checkpoints_dir)
torch.distributed.barrier() torch.distributed.barrier()
except KeyboardInterrupt: except KeyboardInterrupt:
print() print()
@ -444,7 +444,7 @@ if __name__ == "__main__":
save_checkpoint( save_checkpoint(
credit_dataset=credit_train_dataset, credit_dataset=credit_train_dataset,
encoder = model.module.encoder, model=model.module.classifier, optimizer=optimizer, epoch=epoch, encoder = model.module.encoder, model=model.module.classifier, optimizer=optimizer, epoch=epoch,
loss=ddp_loss[0], rocauc=rocauc, сheсkpoints_dir=сheсkpoints_dir) loss=ddp_loss[0], rocauc=rocauc, checkpoints_dir=checkpoints_dir)
if rank == 0: if rank == 0:
writer.close() writer.close()
torch.distributed.destroy_process_group() torch.distributed.destroy_process_group()

Loading…
Cancel
Save