diff --git a/runs/.gitignore b/logs/.gitignore similarity index 100% rename from runs/.gitignore rename to logs/.gitignore diff --git a/readme.md b/readme.md index f84828a..8fee9b7 100644 --- a/readme.md +++ b/readme.md @@ -11,6 +11,6 @@ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python src/bert_training_dp.py fold3_18l_dy CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --standalone --nproc-per-node=8 src/bert_training_ddp.py fold3_18l_dyt_04_04_3750 ``` -Логирование ведётся в tensorboard в папку `./runs/`. В папку с логами при запуске копируется текущая версия скрипта. Чекпоинты моделей сохраняются в папку `./checkpoints/`. +Логирование ведётся в tensorboard в папку `./logs/`. В папку с логами при запуске копируется текущая версия скрипта. Чекпоинты моделей сохраняются в папку `./checkpoints/`. Разбиение на обучающую и тестовую выборки осуществляется скриптом `train_test_split.py`. diff --git a/src/bert_training.py b/src/bert_training.py index b23308a..48c9b4b 100644 --- a/src/bert_training.py +++ b/src/bert_training.py @@ -2,6 +2,7 @@ import os import sys # DEVICE_IDX = sys.argv[1] # os.environ['CUDA_VISIBLE_DEVICES'] = f"{DEVICE_IDX}" +comment = sys.argv[1] from torch import nn import torch @@ -14,9 +15,10 @@ from torch.nn.utils.rnn import pad_sequence from torch.utils.data import Dataset, DataLoader from pathlib import Path import schedulefree +import einops from einops import rearrange, repeat import torch.nn.functional as F - +from typing import Optional import functools def save_checkpoint(credit_dataset, encoder, model, optimizer, epoch, loss, rocauc, checkpoints_dir): @@ -48,8 +50,7 @@ def save_checkpoint(credit_dataset, encoder, model, optimizer, epoch, loss, roca class CreditProductsDataset: def __init__(self, features_path, targets_path, train_test_split_ratio=0.9, - train_uniq_client_ids_path=None, test_uniq_client_ids_path=None, - dropout_rate=0.0 + train_uniq_client_ids_path=None, test_uniq_client_ids_path=None ): self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) if Path(self.train_uniq_client_ids_path).exists(): @@ -97,6 +98,8 @@ class CreditProductsDataset: self.cat_features = pad_sequence(torch.split(self.cat_features, self.user_seq_lengths.tolist()), batch_first=True) # implicit max seq self.num_features = torch.tensor(self.features_df[self.num_columns].values, dtype=torch.float32) self.num_features = pad_sequence(torch.split(self.num_features, self.user_seq_lengths.tolist()), batch_first=True) + self.padding_mask = torch.ones(len(self.features_df), dtype=torch.bool) + self.padding_mask = pad_sequence(torch.split(self.padding_mask, self.user_seq_lengths.tolist()), batch_first=True) self.targets_df = self.targets_df.set_index('id') self.targets_df = self.targets_df.sort_index() @@ -104,20 +107,22 @@ class CreditProductsDataset: def get_train_batch(self, batch_size=4): sampled_ids = np.random.choice(self.train_uniq_client_ids, batch_size, replace=False) # think about replace=True - cat_features_batch = self.cat_features[sampled_ids] - num_features_batch = self.num_features[sampled_ids] - cat_features_batch *= torch.empty_like(cat_features_batch).bernoulli_(1-self.dropout_rate) # arg is keep_probability - num_features_batch *= torch.empty_like(num_features_batch).bernoulli_(1-self.dropout_rate) - targets_batch = self.targets[sampled_ids] - return cat_features_batch, num_features_batch, targets_batch + return ( + self.cat_features[sampled_ids], + self.num_features[sampled_ids], + self.padding_mask[sampled_ids], + self.targets[sampled_ids] + ) def get_test_batch_iterator(self, batch_size=4): for i in range(0, len(self.test_uniq_client_ids), batch_size): - ids = self.test_uniq_client_ids[i:i+batch_size] - cat_features_batch = self.cat_features[ids] - num_features_batch = self.num_features[ids] - targets_batch = self.targets[ids] - yield cat_features_batch, num_features_batch, targets_batch + sampled_ids = self.test_uniq_client_ids[i:i+batch_size] + yield ( + self.cat_features[sampled_ids], + self.num_features[sampled_ids], + self.padding_mask[sampled_ids], + self.targets[sampled_ids] + ) # for parallel data selection class WrapperDataset(Dataset): @@ -132,15 +137,15 @@ class WrapperDataset(Dataset): return self.num_batches def __getitem__(self, idx): - cat_inputs, num_inputs, targets = self.credit_dataset.get_train_batch(batch_size=self.batch_size) - return cat_inputs, num_inputs, targets + cat_inputs, num_inputs, padding_mask, targets = self.credit_dataset.get_train_batch(batch_size=self.batch_size) + return cat_inputs, num_inputs, padding_mask, targets ##################################### Model ########################################################################################### class Encoder(nn.Module): - def __init__(self, cat_columns, num_columns, cat_features_max_id, category_feature_dim=4, out_dim=64): + def __init__(self, cat_columns, num_columns, cat_features_max_id, category_feature_dim=4, out_dim=64, features_dropout_rate=0.0): super().__init__() - self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) # all args are added as object variables self.total_h_dim = len(self.cat_columns) * category_feature_dim + len(self.num_columns) self.cat_embeds = nn.Embedding(cat_features_max_id + 1, self.category_feature_dim, padding_idx=0) self.num_scales = nn.Parameter(torch.randn(1, len(self.num_columns))) @@ -148,10 +153,11 @@ class Encoder(nn.Module): self.proj = nn.Linear(self.total_h_dim, self.out_dim, bias=False) def forward(self, cat_features_batch, num_features_batch): - cat_embed_tensor = self.cat_embeds(cat_features_batch.type(torch.int32)) - cat_embed_tensor = cat_embed_tensor.reshape(cat_features_batch.shape[0], cat_features_batch.shape[1], -1) - num_embed_tensor = self.num_scales * num_features_batch + self.num_shifts - embed_tensor = torch.concat([cat_embed_tensor, num_embed_tensor], dim=-1) + cat_embed_tensor = self.cat_embeds(cat_features_batch.data.type(torch.int32)) + cat_embed_tensor = cat_embed_tensor.reshape(cat_features_batch.data.shape[0], cat_features_batch.data.shape[1], -1) + num_embed_tensor = self.num_scales * num_features_batch.data + self.num_shifts + embed_tensor = torch.concat([cat_embed_tensor.data, num_embed_tensor.data], dim=-1) + embed_tensor = F.dropout(embed_tensor, self.features_dropout_rate) inputs = self.proj(embed_tensor) return inputs @@ -204,24 +210,32 @@ class TransformerLayer(nn.Module): self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) def split_to_heads(self, x, B, T, H): - return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) if self.num_heads > 1 else x + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) def gather_heads(self, x, B, T, H): - return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) if self.num_heads > 1 else x + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) - def attention(self, x): + def attention(self, x, padding_mask): + padding_mask = padding_mask.unsqueeze(-1).expand(*padding_mask.shape+(self.num_heads,)) + padding_mask = self.split_to_heads(padding_mask, *padding_mask.shape) q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) v = self.split_to_heads(self.v_proj(x), *x.shape) scores = (q @ k.transpose(1, 2)) * (self.h_dim ** -0.5) + scores = scores.masked_fill(~padding_mask, -1e9) attention = nn.functional.softmax(scores, dim=2) return self.o_proj(self.gather_heads(attention @ v, *x.shape)) - def forward(self, x): - x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + def forward(self, x, padding_mask): + x = x + F.dropout1d(self.attention(self.ln1(x), padding_mask), p=self.dropout_rate) x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) return x +def prepend(element, tensor): + return torch.cat([element.expand([tensor.shape[0], element.shape[1], tensor.shape[2]]), tensor], dim=1) + class BertClassifier(nn.Module): def __init__(self, layers_num=1, h_dim=64, class_num=2, max_seq_len=128, num_heads=4, dropout_rate = 0.1): super().__init__() @@ -232,11 +246,12 @@ class BertClassifier(nn.Module): self.classifier_head = nn.Sequential(nn.Linear(h_dim, h_dim), nn.GELU(), nn.Linear(h_dim, class_num)) self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) - def forward(self, x): - x = torch.concat([self.cls_token.expand([x.shape[0], self.cls_token.shape[1], self.cls_token.shape[2]]), x], dim=1) + def forward(self, x, padding_mask): + x = prepend(self.cls_token, x) + padding_mask = torch.cat([torch.ones(x.shape[0], 1, dtype=torch.bool, device=x.device), padding_mask], dim=1) x = x + self.pos_embeds[:, :x.shape[1], :] for l in self.layers: - x = l(x) + x = l(x, padding_mask) x = self.classifier_head(x[:,0,:]) return x[:,:] if self.class_num > 1 else x[:,0] @@ -246,20 +261,21 @@ class Model(nn.Module): self.encoder = encoder self.classifier = classifier - def forward(self, cat_inputs, num_inputs): + def forward(self, cat_inputs, num_inputs, padding_mask): inputs = self.encoder(cat_inputs, num_inputs) - return self.classifier(inputs) + return self.classifier(inputs, padding_mask) def test(start_time, epoch, batches_per_epoch, batch_size, model, optimizer, credit_dataset, test_auroc, writer): model.eval() optimizer.eval() with torch.no_grad(): test_iterator = credit_dataset.get_test_batch_iterator(batch_size=batch_size) - for test_batch_id, (test_cat_inputs, test_num_inputs, test_targets) in enumerate(test_iterator): + for test_batch_id, (test_cat_inputs, test_num_inputs, test_padding_mask, test_targets) in enumerate(test_iterator): test_cat_inputs = test_cat_inputs.to("cuda", non_blocking=True) test_num_inputs = test_num_inputs.to("cuda", non_blocking=True) + test_padding_mask = test_padding_mask.to("cuda", non_blocking=True) test_targets = test_targets.to("cuda", non_blocking=True) - outputs = model(test_cat_inputs, test_num_inputs) + outputs = model(test_cat_inputs, test_num_inputs, test_padding_mask) test_auroc.update(outputs, test_targets.long()) print(f"\r {test_batch_id}/{len(credit_dataset.test_uniq_client_ids)//batch_size} {test_auroc.compute().item():.5f}", end = " "*20) if not writer is None: @@ -281,8 +297,8 @@ batch_size = 30000 datasets_per_epoch = 1 num_workers = 10 -comment = sys.argv[1] -logs_dir = f'runs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' + +logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' writer = SummaryWriter(logs_dir) checkpoints_dir = f'checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).name) @@ -297,8 +313,7 @@ credit_train_dataset = CreditProductsDataset( features_path="/wd/data/train_data/", targets_path="/wd/data/train_target.csv", train_uniq_client_ids_path=f"/wd/fold3_train_ids.csv", - test_uniq_client_ids_path=f"/wd/fold3_test_ids.csv", - dropout_rate=features_dropout_rate + test_uniq_client_ids_path=f"/wd/fold3_test_ids.csv" ) print(f"Dataset preparation time: {datetime.now() - start_prep_time}") @@ -307,7 +322,8 @@ encoder = Encoder( num_columns=credit_train_dataset.num_columns, cat_features_max_id=credit_train_dataset.cat_features.max(), category_feature_dim=category_feature_dim, - out_dim=h_dim + out_dim=h_dim, + features_dropout_rate=features_dropout_rate ) classifier = BertClassifier( @@ -358,13 +374,14 @@ try: test_auroc=test_auroc, writer=writer ) - for batch_id, (cat_inputs, num_inputs, targets) in enumerate(dataloader): + for batch_id, (cat_inputs, num_inputs, padding_mask, targets) in enumerate(dataloader): model.train() optimizer.train() optimizer.zero_grad() outputs = model( cat_inputs[0].to("cuda"), - num_inputs[0].to("cuda") + num_inputs[0].to("cuda"), + padding_mask[0].to("cuda") ) loss = criterion(outputs, targets[0].to("cuda")) loss.backward()