diff --git a/src/bert_training_ddp.py b/src/bert_training_ddp.py index ca74cf3..7f72b55 100644 --- a/src/bert_training_ddp.py +++ b/src/bert_training_ddp.py @@ -26,7 +26,7 @@ from torch.distributed.fsdp.wrap import ( ) import functools -def save_checkpoint(credit_dataset, encoder, model, optimizer, epoch, loss, rocauc, сheсkpoints_dir): +def save_checkpoint(credit_dataset, encoder, model, optimizer, epoch, loss, rocauc, checkpoints_dir): checkpoint = { 'encoder': { 'state_dict': encoder.state_dict(), @@ -45,7 +45,7 @@ def save_checkpoint(credit_dataset, encoder, model, optimizer, epoch, loss, roca 'train_uniq_client_ids_path': credit_dataset.train_uniq_client_ids_path, 'test_uniq_client_ids_path': credit_dataset.test_uniq_client_ids_path } - path = сheсkpoints_dir + f"epoch_{epoch}_{rocauc:.4f}.pth" + path = checkpoints_dir + f"epoch_{epoch}_{rocauc:.4f}.pth" # if torch.distributed.get_rank() == 0: torch.save(checkpoint, path) print(f"\nCheckpoint saved to {path}") @@ -330,11 +330,11 @@ if __name__ == "__main__": comment = sys.argv[1] logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' writer = SummaryWriter(logs_dir) - сheсkpoints_dir = f'checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' + checkpoints_dir = f'checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).name) - Path(сheсkpoints_dir).mkdir(parents=True, exist_ok=True) + Path(checkpoints_dir).mkdir(parents=True, exist_ok=True) print("Logs dir:", logs_dir) - print("Chekpoints dir:", сheсkpoints_dir) + print("Chekpoints dir:", checkpoints_dir) script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script script_snapshot_path.chmod(0o400) # with read-only permission @@ -432,7 +432,7 @@ if __name__ == "__main__": save_checkpoint( credit_dataset=credit_train_dataset, encoder = model.module.encoder, model=model.module.classifier, optimizer=optimizer, epoch=epoch, - loss=ddp_loss[0], rocauc=rocauc, сheсkpoints_dir=сheсkpoints_dir) + loss=ddp_loss[0], rocauc=rocauc, checkpoints_dir=checkpoints_dir) torch.distributed.barrier() except KeyboardInterrupt: print() @@ -444,7 +444,7 @@ if __name__ == "__main__": save_checkpoint( credit_dataset=credit_train_dataset, encoder = model.module.encoder, model=model.module.classifier, optimizer=optimizer, epoch=epoch, - loss=ddp_loss[0], rocauc=rocauc, сheсkpoints_dir=сheсkpoints_dir) + loss=ddp_loss[0], rocauc=rocauc, checkpoints_dir=checkpoints_dir) if rank == 0: writer.close() torch.distributed.destroy_process_group()