From 96d18b4873cb392527be2051e9f7b6c446b9e6cf Mon Sep 17 00:00:00 2001 From: Vladimir Date: Wed, 21 May 2025 11:34:51 +0400 Subject: [PATCH] maybe more readable code --- src/bert_training.py | 175 ++++++++++++++++--------------------------- 1 file changed, 65 insertions(+), 110 deletions(-) diff --git a/src/bert_training.py b/src/bert_training.py index 08ff8c3..dd1d453 100644 --- a/src/bert_training.py +++ b/src/bert_training.py @@ -1,7 +1,5 @@ import os import sys -# DEVICE_IDX = sys.argv[1] -# os.environ['CUDA_VISIBLE_DEVICES'] = f"{DEVICE_IDX}" comment = sys.argv[1] from torch import nn @@ -21,31 +19,7 @@ import torch.nn.functional as F from typing import Optional import functools -def save_checkpoint(credit_dataset, encoder, model, optimizer, epoch, loss, rocauc, checkpoints_dir): - checkpoint = { - 'encoder': { - 'state_dict': encoder.state_dict(), - **{k:v for k,v in encoder.__dict__.items() if k[0] != '_' and k != 'training'} - }, - 'model': { - 'state_dict': model.state_dict(), - **{k:v for k,v in model.__dict__.items() if k[0] != '_' and k != 'training'} - }, - 'optimizer': { - 'state_dict': optimizer.state_dict(), - }, - 'epoch': epoch, - 'loss': loss, - 'rocauc': rocauc, - 'train_uniq_client_ids_path': credit_dataset.train_uniq_client_ids_path, - 'test_uniq_client_ids_path': credit_dataset.test_uniq_client_ids_path - } - path = checkpoints_dir + f"epoch_{epoch}_{rocauc:.4f}.pth" - # if torch.distributed.get_rank() == 0: - torch.save(checkpoint, path) - print(f"\nCheckpoint saved to {path}") - -######################################## Dataset ######################################################### +######################################## Dataset definition ######################################################### class CreditProductsDataset: def __init__(self, @@ -140,7 +114,7 @@ class WrapperDataset(Dataset): cat_inputs, num_inputs, padding_mask, targets = self.credit_dataset.get_train_batch(batch_size=self.batch_size) return cat_inputs, num_inputs, padding_mask, targets -##################################### Model ########################################################################################### +##################################### Model definition ########################################################################################### class Encoder(nn.Module): def __init__(self, cat_columns, num_columns, cat_features_max_id, category_feature_dim=4, out_dim=64, features_dropout_rate=0.0): @@ -230,7 +204,7 @@ class TransformerLayer(nn.Module): return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) def attention(self, x, padding_mask): - padding_mask = padding_mask.unsqueeze(-1).expand(*padding_mask.shape+(self.num_heads,)) + padding_mask = padding_mask.unsqueeze(-1).expand(*padding_mask.shape+(self.num_heads,)) # B, T -> B, T, num_heads padding_mask = self.split_to_heads(padding_mask, *padding_mask.shape) q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) @@ -251,7 +225,7 @@ class BertClassifier(nn.Module): self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) self.cls_token = nn.Parameter(torch.randn(1,1,h_dim)) self.max_seq_len = max_seq_len + self.cls_token.shape[1] - self.layers = nn.ModuleList([TransformerLayer(h_dim=h_dim, num_heads=num_heads, dropout_rate = dropout_rate, max_seq_len=self.max_seq_len) for _ in range(layers_num)]) + self.layers = nn.ModuleList([TransformerLayer(h_dim=h_dim, num_heads=num_heads, dropout_rate=dropout_rate, max_seq_len=self.max_seq_len) for _ in range(layers_num)]) self.classifier_head = nn.Sequential(nn.Linear(h_dim, h_dim), nn.GELU(), nn.Linear(h_dim, class_num)) self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) @@ -274,25 +248,7 @@ class Model(nn.Module): inputs = self.encoder(cat_inputs, num_inputs) return self.classifier(inputs, padding_mask) -def test(start_time, epoch, batches_per_epoch, batch_size, model, optimizer, credit_dataset, test_auroc, writer): - model.eval() - optimizer.eval() - with torch.no_grad(): - test_iterator = credit_dataset.get_test_batch_iterator(batch_size=batch_size) - for test_batch_id, (test_cat_inputs, test_num_inputs, test_padding_mask, test_targets) in enumerate(test_iterator): - test_cat_inputs = test_cat_inputs.to("cuda", non_blocking=True) - test_num_inputs = test_num_inputs.to("cuda", non_blocking=True) - test_padding_mask = test_padding_mask.to("cuda", non_blocking=True) - test_targets = test_targets.to("cuda", non_blocking=True) - outputs = model(test_cat_inputs, test_num_inputs, test_padding_mask) - test_auroc.update(outputs, test_targets.long()) - print(f"\r {test_batch_id}/{len(credit_dataset.test_uniq_client_ids)//batch_size} {test_auroc.compute().item():.5f}", end = " "*20) - if not writer is None: - writer.add_scalar('test_roc_auc', test_auroc.compute().item(), epoch * batches_per_epoch) - print(f"\r {datetime.now() - start_time} {epoch}/{epochs} Test rocauc: {test_auroc.compute().item():.5f}", end = " "*20) - print() - -######################################### Training ################################################################ +######################################### Training definition ################################################################ h_dim = 64 category_feature_dim = 8 @@ -306,7 +262,6 @@ batch_size = 30000 datasets_per_epoch = 1 num_workers = 10 - logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' writer = SummaryWriter(logs_dir) checkpoints_dir = f'checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' @@ -365,34 +320,74 @@ batches_per_epoch = len(training_data) print(f"Number of batches per epoch: {batches_per_epoch}, Number of datasets per epoch : {datasets_per_epoch}") test_auroc = AUROC(task='binary') +epoch = -1 +loss = torch.tensor([-1]) start_time = datetime.now() print("Started at:", start_time) last_display_time = start_time last_checkpoint_time = start_time + +def test(is_tensorboard_logging=False): + model.eval() + optimizer.eval() + with torch.no_grad(): + test_iterator = credit_train_dataset.get_test_batch_iterator(batch_size=batch_size) + for test_batch_id, (test_cat_inputs, test_num_inputs, test_padding_mask, test_targets) in enumerate(test_iterator): + test_cat_inputs = test_cat_inputs.to("cuda", non_blocking=True) + test_num_inputs = test_num_inputs.to("cuda", non_blocking=True) + test_padding_mask = test_padding_mask.to("cuda", non_blocking=True) + test_targets = test_targets.to("cuda", non_blocking=True) + outputs = model(test_cat_inputs, test_num_inputs, test_padding_mask) + test_auroc.update(outputs, test_targets.long()) + print(f"\r {test_batch_id}/{len(credit_train_dataset.test_uniq_client_ids)//batch_size} {test_auroc.compute().item():.5f}", end = " "*20) + if is_tensorboard_logging: + writer.add_scalar('test_roc_auc', test_auroc.compute().item(), epoch * batches_per_epoch) + print(f"\r {datetime.now() - start_time} {epoch}/{epochs} Test rocauc: {test_auroc.compute().item():.5f}", end = " "*20) + print() + +def save_checkpoint(): + test(is_tensorboard_logging=False) + checkpoint = { + 'encoder': { + 'state_dict': encoder.state_dict(), + **{k:v for k,v in encoder.__dict__.items() if k[0] != '_' and k != 'training'} + }, + 'model': { + 'state_dict': model.state_dict(), + **{k:v for k,v in model.__dict__.items() if k[0] != '_' and k != 'training'} + }, + 'optimizer': { + 'state_dict': optimizer.state_dict(), + }, + 'epoch': epoch, + 'loss': loss.item(), + 'rocauc': test_auroc.compute().item(), + 'train_uniq_client_ids_path': credit_train_dataset.train_uniq_client_ids_path, + 'test_uniq_client_ids_path': credit_train_dataset.test_uniq_client_ids_path + } + path = checkpoints_dir + f"epoch_{epoch}_{test_auroc.compute().item():.4f}.pth" + torch.save(checkpoint, path) + print(f"\nCheckpoint saved to {path}") + +###################################### Training loop ################################################################################ + try: for epoch in range(epochs): - test( - start_time=start_time, - epoch=epoch, - batches_per_epoch=batches_per_epoch, - batch_size=batch_size, - model=model, - optimizer=optimizer, - credit_dataset=credit_train_dataset, - test_auroc=test_auroc, - writer=writer - ) + test(is_tensorboard_logging=True) for batch_id, (cat_inputs, num_inputs, padding_mask, targets) in enumerate(dataloader): model.train() optimizer.train() optimizer.zero_grad() outputs = model( - cat_inputs[0].to("cuda"), - num_inputs[0].to("cuda"), - padding_mask[0].to("cuda") + cat_inputs=cat_inputs[0].to("cuda"), + num_inputs=num_inputs[0].to("cuda"), + padding_mask=padding_mask[0].to("cuda") + ) + loss = criterion( + input=outputs, + target=targets[0].to("cuda") ) - loss = criterion(outputs, targets[0].to("cuda")) loss.backward() optimizer.step() @@ -401,53 +396,13 @@ try: last_display_time = current_time writer.add_scalar(f'Loss', loss.item(), epoch*batches_per_epoch+batch_id) print(f"\r {current_time-start_time} {epoch+1}/{epochs} {batch_id}/{batches_per_epoch} loss: {loss.item():.6f} {comment}", end = " "*2) + if current_time - last_checkpoint_time > timedelta(hours=8): last_checkpoint_time = current_time - test( - start_time=start_time, - epoch=epoch, - batches_per_epoch=batches_per_epoch, - batch_size=batch_size, - model=model, - optimizer=optimizer, - credit_dataset=credit_train_dataset, - test_auroc=test_auroc, - writer=None - ) - rocauc = test_auroc.compute().item() - save_checkpoint( - credit_dataset=credit_train_dataset, - encoder = model.module.encoder, - model=model.module.classifier, - optimizer=optimizer, - epoch=epoch, - loss=loss.item(), - rocauc=rocauc, - checkpoints_dir=checkpoints_dir - ) + save_checkpoint() except KeyboardInterrupt: print() finally: - test( - start_time=start_time, - epoch=epoch+1, - batches_per_epoch=batches_per_epoch, - batch_size=batch_size, - model=model, - optimizer=optimizer, - credit_dataset=credit_train_dataset, - test_auroc=test_auroc, - writer=writer - ) - rocauc = test_auroc.compute().item() - save_checkpoint( - credit_dataset=credit_train_dataset, - encoder = model.encoder, - model=model.classifier, - optimizer=optimizer, - epoch=epoch, - loss=loss.item(), - rocauc=rocauc, - checkpoints_dir=checkpoints_dir - ) - writer.close() + epoch = epoch + 1 + save_checkpoint() + writer.close() \ No newline at end of file