added padding mask

main
Vladimir 2 weeks ago
parent 3986a4d5c7
commit 6d7e1b0095

@ -11,6 +11,6 @@ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python src/bert_training_dp.py fold3_18l_dy
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --standalone --nproc-per-node=8 src/bert_training_ddp.py fold3_18l_dyt_04_04_3750
```
Логирование ведётся в tensorboard в папку `./runs/`. В папку с логами при запуске копируется текущая версия скрипта. Чекпоинты моделей сохраняются в папку `./checkpoints/`.
Логирование ведётся в tensorboard в папку `./logs/`. В папку с логами при запуске копируется текущая версия скрипта. Чекпоинты моделей сохраняются в папку `./checkpoints/`.
Разбиение на обучающую и тестовую выборки осуществляется скриптом `train_test_split.py`.

@ -2,6 +2,7 @@ import os
import sys
# DEVICE_IDX = sys.argv[1]
# os.environ['CUDA_VISIBLE_DEVICES'] = f"{DEVICE_IDX}"
comment = sys.argv[1]
from torch import nn
import torch
@ -14,9 +15,10 @@ from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import schedulefree
import einops
from einops import rearrange, repeat
import torch.nn.functional as F
from typing import Optional
import functools
def save_checkpoint(credit_dataset, encoder, model, optimizer, epoch, loss, rocauc, checkpoints_dir):
@ -48,8 +50,7 @@ def save_checkpoint(credit_dataset, encoder, model, optimizer, epoch, loss, roca
class CreditProductsDataset:
def __init__(self,
features_path, targets_path, train_test_split_ratio=0.9,
train_uniq_client_ids_path=None, test_uniq_client_ids_path=None,
dropout_rate=0.0
train_uniq_client_ids_path=None, test_uniq_client_ids_path=None
):
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
if Path(self.train_uniq_client_ids_path).exists():
@ -97,6 +98,8 @@ class CreditProductsDataset:
self.cat_features = pad_sequence(torch.split(self.cat_features, self.user_seq_lengths.tolist()), batch_first=True) # implicit max seq
self.num_features = torch.tensor(self.features_df[self.num_columns].values, dtype=torch.float32)
self.num_features = pad_sequence(torch.split(self.num_features, self.user_seq_lengths.tolist()), batch_first=True)
self.padding_mask = torch.ones(len(self.features_df), dtype=torch.bool)
self.padding_mask = pad_sequence(torch.split(self.padding_mask, self.user_seq_lengths.tolist()), batch_first=True)
self.targets_df = self.targets_df.set_index('id')
self.targets_df = self.targets_df.sort_index()
@ -104,20 +107,22 @@ class CreditProductsDataset:
def get_train_batch(self, batch_size=4):
sampled_ids = np.random.choice(self.train_uniq_client_ids, batch_size, replace=False) # think about replace=True
cat_features_batch = self.cat_features[sampled_ids]
num_features_batch = self.num_features[sampled_ids]
cat_features_batch *= torch.empty_like(cat_features_batch).bernoulli_(1-self.dropout_rate) # arg is keep_probability
num_features_batch *= torch.empty_like(num_features_batch).bernoulli_(1-self.dropout_rate)
targets_batch = self.targets[sampled_ids]
return cat_features_batch, num_features_batch, targets_batch
return (
self.cat_features[sampled_ids],
self.num_features[sampled_ids],
self.padding_mask[sampled_ids],
self.targets[sampled_ids]
)
def get_test_batch_iterator(self, batch_size=4):
for i in range(0, len(self.test_uniq_client_ids), batch_size):
ids = self.test_uniq_client_ids[i:i+batch_size]
cat_features_batch = self.cat_features[ids]
num_features_batch = self.num_features[ids]
targets_batch = self.targets[ids]
yield cat_features_batch, num_features_batch, targets_batch
sampled_ids = self.test_uniq_client_ids[i:i+batch_size]
yield (
self.cat_features[sampled_ids],
self.num_features[sampled_ids],
self.padding_mask[sampled_ids],
self.targets[sampled_ids]
)
# for parallel data selection
class WrapperDataset(Dataset):
@ -132,15 +137,15 @@ class WrapperDataset(Dataset):
return self.num_batches
def __getitem__(self, idx):
cat_inputs, num_inputs, targets = self.credit_dataset.get_train_batch(batch_size=self.batch_size)
return cat_inputs, num_inputs, targets
cat_inputs, num_inputs, padding_mask, targets = self.credit_dataset.get_train_batch(batch_size=self.batch_size)
return cat_inputs, num_inputs, padding_mask, targets
##################################### Model ###########################################################################################
class Encoder(nn.Module):
def __init__(self, cat_columns, num_columns, cat_features_max_id, category_feature_dim=4, out_dim=64):
def __init__(self, cat_columns, num_columns, cat_features_max_id, category_feature_dim=4, out_dim=64, features_dropout_rate=0.0):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) # all args are added as object variables
self.total_h_dim = len(self.cat_columns) * category_feature_dim + len(self.num_columns)
self.cat_embeds = nn.Embedding(cat_features_max_id + 1, self.category_feature_dim, padding_idx=0)
self.num_scales = nn.Parameter(torch.randn(1, len(self.num_columns)))
@ -148,10 +153,11 @@ class Encoder(nn.Module):
self.proj = nn.Linear(self.total_h_dim, self.out_dim, bias=False)
def forward(self, cat_features_batch, num_features_batch):
cat_embed_tensor = self.cat_embeds(cat_features_batch.type(torch.int32))
cat_embed_tensor = cat_embed_tensor.reshape(cat_features_batch.shape[0], cat_features_batch.shape[1], -1)
num_embed_tensor = self.num_scales * num_features_batch + self.num_shifts
embed_tensor = torch.concat([cat_embed_tensor, num_embed_tensor], dim=-1)
cat_embed_tensor = self.cat_embeds(cat_features_batch.data.type(torch.int32))
cat_embed_tensor = cat_embed_tensor.reshape(cat_features_batch.data.shape[0], cat_features_batch.data.shape[1], -1)
num_embed_tensor = self.num_scales * num_features_batch.data + self.num_shifts
embed_tensor = torch.concat([cat_embed_tensor.data, num_embed_tensor.data], dim=-1)
embed_tensor = F.dropout(embed_tensor, self.features_dropout_rate)
inputs = self.proj(embed_tensor)
return inputs
@ -204,24 +210,32 @@ class TransformerLayer(nn.Module):
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
def split_to_heads(self, x, B, T, H):
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) if self.num_heads > 1 else x
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) if self.num_heads > 1 else x
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
def attention(self, x, padding_mask):
padding_mask = padding_mask.unsqueeze(-1).expand(*padding_mask.shape+(self.num_heads,))
padding_mask = self.split_to_heads(padding_mask, *padding_mask.shape)
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = (q @ k.transpose(1, 2)) * (self.h_dim ** -0.5)
scores = scores.masked_fill(~padding_mask, -1e9)
attention = nn.functional.softmax(scores, dim=2)
return self.o_proj(self.gather_heads(attention @ v, *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
def forward(self, x, padding_mask):
x = x + F.dropout1d(self.attention(self.ln1(x), padding_mask), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
def prepend(element, tensor):
return torch.cat([element.expand([tensor.shape[0], element.shape[1], tensor.shape[2]]), tensor], dim=1)
class BertClassifier(nn.Module):
def __init__(self, layers_num=1, h_dim=64, class_num=2, max_seq_len=128, num_heads=4, dropout_rate = 0.1):
super().__init__()
@ -232,11 +246,12 @@ class BertClassifier(nn.Module):
self.classifier_head = nn.Sequential(nn.Linear(h_dim, h_dim), nn.GELU(), nn.Linear(h_dim, class_num))
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
def forward(self, x):
x = torch.concat([self.cls_token.expand([x.shape[0], self.cls_token.shape[1], self.cls_token.shape[2]]), x], dim=1)
def forward(self, x, padding_mask):
x = prepend(self.cls_token, x)
padding_mask = torch.cat([torch.ones(x.shape[0], 1, dtype=torch.bool, device=x.device), padding_mask], dim=1)
x = x + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
x = l(x, padding_mask)
x = self.classifier_head(x[:,0,:])
return x[:,:] if self.class_num > 1 else x[:,0]
@ -246,20 +261,21 @@ class Model(nn.Module):
self.encoder = encoder
self.classifier = classifier
def forward(self, cat_inputs, num_inputs):
def forward(self, cat_inputs, num_inputs, padding_mask):
inputs = self.encoder(cat_inputs, num_inputs)
return self.classifier(inputs)
return self.classifier(inputs, padding_mask)
def test(start_time, epoch, batches_per_epoch, batch_size, model, optimizer, credit_dataset, test_auroc, writer):
model.eval()
optimizer.eval()
with torch.no_grad():
test_iterator = credit_dataset.get_test_batch_iterator(batch_size=batch_size)
for test_batch_id, (test_cat_inputs, test_num_inputs, test_targets) in enumerate(test_iterator):
for test_batch_id, (test_cat_inputs, test_num_inputs, test_padding_mask, test_targets) in enumerate(test_iterator):
test_cat_inputs = test_cat_inputs.to("cuda", non_blocking=True)
test_num_inputs = test_num_inputs.to("cuda", non_blocking=True)
test_padding_mask = test_padding_mask.to("cuda", non_blocking=True)
test_targets = test_targets.to("cuda", non_blocking=True)
outputs = model(test_cat_inputs, test_num_inputs)
outputs = model(test_cat_inputs, test_num_inputs, test_padding_mask)
test_auroc.update(outputs, test_targets.long())
print(f"\r {test_batch_id}/{len(credit_dataset.test_uniq_client_ids)//batch_size} {test_auroc.compute().item():.5f}", end = " "*20)
if not writer is None:
@ -281,8 +297,8 @@ batch_size = 30000
datasets_per_epoch = 1
num_workers = 10
comment = sys.argv[1]
logs_dir = f'runs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
writer = SummaryWriter(logs_dir)
checkpoints_dir = f'checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).name)
@ -297,8 +313,7 @@ credit_train_dataset = CreditProductsDataset(
features_path="/wd/data/train_data/",
targets_path="/wd/data/train_target.csv",
train_uniq_client_ids_path=f"/wd/fold3_train_ids.csv",
test_uniq_client_ids_path=f"/wd/fold3_test_ids.csv",
dropout_rate=features_dropout_rate
test_uniq_client_ids_path=f"/wd/fold3_test_ids.csv"
)
print(f"Dataset preparation time: {datetime.now() - start_prep_time}")
@ -307,7 +322,8 @@ encoder = Encoder(
num_columns=credit_train_dataset.num_columns,
cat_features_max_id=credit_train_dataset.cat_features.max(),
category_feature_dim=category_feature_dim,
out_dim=h_dim
out_dim=h_dim,
features_dropout_rate=features_dropout_rate
)
classifier = BertClassifier(
@ -358,13 +374,14 @@ try:
test_auroc=test_auroc,
writer=writer
)
for batch_id, (cat_inputs, num_inputs, targets) in enumerate(dataloader):
for batch_id, (cat_inputs, num_inputs, padding_mask, targets) in enumerate(dataloader):
model.train()
optimizer.train()
optimizer.zero_grad()
outputs = model(
cat_inputs[0].to("cuda"),
num_inputs[0].to("cuda")
num_inputs[0].to("cuda"),
padding_mask[0].to("cuda")
)
loss = criterion(outputs, targets[0].to("cuda"))
loss.backward()

Loading…
Cancel
Save