|
|
|
@ -233,9 +233,6 @@ class TransformerLayer(nn.Module):
|
|
|
|
|
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
def prepend(element, tensor):
|
|
|
|
|
return torch.cat([element.expand([tensor.shape[0], element.shape[1], tensor.shape[2]]), tensor], dim=1)
|
|
|
|
|
|
|
|
|
|
class BertClassifier(nn.Module):
|
|
|
|
|
def __init__(self, layers_num=1, h_dim=64, class_num=2, max_seq_len=128, num_heads=4, dropout_rate = 0.1):
|
|
|
|
|
super().__init__()
|
|
|
|
@ -247,8 +244,8 @@ class BertClassifier(nn.Module):
|
|
|
|
|
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
|
|
|
|
|
|
|
|
|
|
def forward(self, x, padding_mask):
|
|
|
|
|
x = prepend(self.cls_token, x)
|
|
|
|
|
padding_mask = torch.cat([torch.ones(x.shape[0], 1, dtype=torch.bool, device=x.device), padding_mask], dim=1)
|
|
|
|
|
x = torch.cat([self.cls_token.expand([x.shape[0], self.cls_token.shape[1], x.shape[2]]), x], dim=1) # prepend
|
|
|
|
|
padding_mask = torch.cat([torch.ones(x.shape[0], 1, dtype=torch.bool, device=x.device), padding_mask], dim=1) # prepend
|
|
|
|
|
x = x + self.pos_embeds[:, :x.shape[1], :]
|
|
|
|
|
for l in self.layers:
|
|
|
|
|
x = l(x, padding_mask)
|
|
|
|
|