|
|
@ -26,7 +26,7 @@ from torch.distributed.fsdp.wrap import (
|
|
|
|
)
|
|
|
|
)
|
|
|
|
import functools
|
|
|
|
import functools
|
|
|
|
|
|
|
|
|
|
|
|
def save_checkpoint(credit_dataset, encoder, model, optimizer, epoch, loss, rocauc, сheсkpoints_dir):
|
|
|
|
def save_checkpoint(credit_dataset, encoder, model, optimizer, epoch, loss, rocauc, checkpoints_dir):
|
|
|
|
checkpoint = {
|
|
|
|
checkpoint = {
|
|
|
|
'encoder': {
|
|
|
|
'encoder': {
|
|
|
|
'state_dict': encoder.state_dict(),
|
|
|
|
'state_dict': encoder.state_dict(),
|
|
|
@ -45,7 +45,7 @@ def save_checkpoint(credit_dataset, encoder, model, optimizer, epoch, loss, roca
|
|
|
|
'train_uniq_client_ids_path': credit_dataset.train_uniq_client_ids_path,
|
|
|
|
'train_uniq_client_ids_path': credit_dataset.train_uniq_client_ids_path,
|
|
|
|
'test_uniq_client_ids_path': credit_dataset.test_uniq_client_ids_path
|
|
|
|
'test_uniq_client_ids_path': credit_dataset.test_uniq_client_ids_path
|
|
|
|
}
|
|
|
|
}
|
|
|
|
path = сheсkpoints_dir + f"epoch_{epoch}_{rocauc:.4f}.pth"
|
|
|
|
path = checkpoints_dir + f"epoch_{epoch}_{rocauc:.4f}.pth"
|
|
|
|
# if torch.distributed.get_rank() == 0:
|
|
|
|
# if torch.distributed.get_rank() == 0:
|
|
|
|
torch.save(checkpoint, path)
|
|
|
|
torch.save(checkpoint, path)
|
|
|
|
print(f"\nCheckpoint saved to {path}")
|
|
|
|
print(f"\nCheckpoint saved to {path}")
|
|
|
@ -330,11 +330,11 @@ if __name__ == "__main__":
|
|
|
|
comment = sys.argv[1]
|
|
|
|
comment = sys.argv[1]
|
|
|
|
logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
|
|
|
|
logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
|
|
|
|
writer = SummaryWriter(logs_dir)
|
|
|
|
writer = SummaryWriter(logs_dir)
|
|
|
|
сheсkpoints_dir = f'checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
|
|
|
|
checkpoints_dir = f'checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
|
|
|
|
script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).name)
|
|
|
|
script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).name)
|
|
|
|
Path(сheсkpoints_dir).mkdir(parents=True, exist_ok=True)
|
|
|
|
Path(checkpoints_dir).mkdir(parents=True, exist_ok=True)
|
|
|
|
print("Logs dir:", logs_dir)
|
|
|
|
print("Logs dir:", logs_dir)
|
|
|
|
print("Chekpoints dir:", сheсkpoints_dir)
|
|
|
|
print("Chekpoints dir:", checkpoints_dir)
|
|
|
|
script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
|
|
|
|
script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
|
|
|
|
script_snapshot_path.chmod(0o400) # with read-only permission
|
|
|
|
script_snapshot_path.chmod(0o400) # with read-only permission
|
|
|
|
|
|
|
|
|
|
|
@ -432,7 +432,7 @@ if __name__ == "__main__":
|
|
|
|
save_checkpoint(
|
|
|
|
save_checkpoint(
|
|
|
|
credit_dataset=credit_train_dataset,
|
|
|
|
credit_dataset=credit_train_dataset,
|
|
|
|
encoder = model.module.encoder, model=model.module.classifier, optimizer=optimizer, epoch=epoch,
|
|
|
|
encoder = model.module.encoder, model=model.module.classifier, optimizer=optimizer, epoch=epoch,
|
|
|
|
loss=ddp_loss[0], rocauc=rocauc, сheсkpoints_dir=сheсkpoints_dir)
|
|
|
|
loss=ddp_loss[0], rocauc=rocauc, checkpoints_dir=checkpoints_dir)
|
|
|
|
torch.distributed.barrier()
|
|
|
|
torch.distributed.barrier()
|
|
|
|
except KeyboardInterrupt:
|
|
|
|
except KeyboardInterrupt:
|
|
|
|
print()
|
|
|
|
print()
|
|
|
@ -444,7 +444,7 @@ if __name__ == "__main__":
|
|
|
|
save_checkpoint(
|
|
|
|
save_checkpoint(
|
|
|
|
credit_dataset=credit_train_dataset,
|
|
|
|
credit_dataset=credit_train_dataset,
|
|
|
|
encoder = model.module.encoder, model=model.module.classifier, optimizer=optimizer, epoch=epoch,
|
|
|
|
encoder = model.module.encoder, model=model.module.classifier, optimizer=optimizer, epoch=epoch,
|
|
|
|
loss=ddp_loss[0], rocauc=rocauc, сheсkpoints_dir=сheсkpoints_dir)
|
|
|
|
loss=ddp_loss[0], rocauc=rocauc, checkpoints_dir=checkpoints_dir)
|
|
|
|
if rank == 0:
|
|
|
|
if rank == 0:
|
|
|
|
writer.close()
|
|
|
|
writer.close()
|
|
|
|
torch.distributed.destroy_process_group()
|
|
|
|
torch.distributed.destroy_process_group()
|
|
|
|