From 38a19a32bdd72d514d62cef14106712d66d78796 Mon Sep 17 00:00:00 2001 From: Vladimir Date: Tue, 20 May 2025 20:47:06 +0400 Subject: [PATCH] prepend function is unnecessary when it is used once --- src/bert_training.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/bert_training.py b/src/bert_training.py index 48c9b4b..49636dd 100644 --- a/src/bert_training.py +++ b/src/bert_training.py @@ -233,9 +233,6 @@ class TransformerLayer(nn.Module): x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) return x -def prepend(element, tensor): - return torch.cat([element.expand([tensor.shape[0], element.shape[1], tensor.shape[2]]), tensor], dim=1) - class BertClassifier(nn.Module): def __init__(self, layers_num=1, h_dim=64, class_num=2, max_seq_len=128, num_heads=4, dropout_rate = 0.1): super().__init__() @@ -247,8 +244,8 @@ class BertClassifier(nn.Module): self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) def forward(self, x, padding_mask): - x = prepend(self.cls_token, x) - padding_mask = torch.cat([torch.ones(x.shape[0], 1, dtype=torch.bool, device=x.device), padding_mask], dim=1) + x = torch.cat([self.cls_token.expand([x.shape[0], self.cls_token.shape[1], x.shape[2]]), x], dim=1) # prepend + padding_mask = torch.cat([torch.ones(x.shape[0], 1, dtype=torch.bool, device=x.device), padding_mask], dim=1) # prepend x = x + self.pos_embeds[:, :x.shape[1], :] for l in self.layers: x = l(x, padding_mask)