| 
						
						
							
								
							
						
						
					 | 
				
			
			 | 
			 | 
			
				@ -233,9 +233,6 @@ class TransformerLayer(nn.Module):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        return x
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				def prepend(element, tensor):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    return torch.cat([element.expand([tensor.shape[0], element.shape[1], tensor.shape[2]]), tensor], dim=1)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				class BertClassifier(nn.Module):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    def __init__(self, layers_num=1, h_dim=64, class_num=2, max_seq_len=128, num_heads=4, dropout_rate = 0.1):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        super().__init__()
 | 
			
		
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
			
			 | 
			 | 
			
				@ -247,8 +244,8 @@ class BertClassifier(nn.Module):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    def forward(self, x, padding_mask):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        x = prepend(self.cls_token, x)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        padding_mask = torch.cat([torch.ones(x.shape[0], 1, dtype=torch.bool, device=x.device), padding_mask], dim=1)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        x = torch.cat([self.cls_token.expand([x.shape[0], self.cls_token.shape[1], x.shape[2]]), x], dim=1) # prepend
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        padding_mask = torch.cat([torch.ones(x.shape[0], 1, dtype=torch.bool, device=x.device), padding_mask], dim=1) # prepend
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        x = x + self.pos_embeds[:, :x.shape[1], :]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        for l in self.layers:
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            x = l(x, padding_mask)
 | 
			
		
		
	
	
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
				
			
			 | 
			 | 
			
				
 
 |