You cannot select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
	
	
		
			115 lines
		
	
	
		
			4.9 KiB
		
	
	
	
		
			Python
		
	
			
		
		
	
	
			115 lines
		
	
	
		
			4.9 KiB
		
	
	
	
		
			Python
		
	
import torch
 | 
						|
import torch.nn as nn
 | 
						|
from torch.nn import functional as F
 | 
						|
from einops import rearrange
 | 
						|
import numpy as np
 | 
						|
from torch.utils.tensorboard import SummaryWriter
 | 
						|
from datetime import datetime
 | 
						|
import sys
 | 
						|
from pathlib import Path
 | 
						|
torch.manual_seed(1337)
 | 
						|
 | 
						|
#################################### Model #########################################
 | 
						|
 | 
						|
 | 
						|
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
 | 
						|
class RoPE(nn.Module):
 | 
						|
    def __init__(self, dim, max_seq_len=512):
 | 
						|
        super().__init__()
 | 
						|
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
 | 
						|
        t = torch.arange(max_seq_len).type_as(inv_freq)
 | 
						|
        freqs = torch.einsum('i,j->ij', t, inv_freq)
 | 
						|
        self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
 | 
						|
 | 
						|
    def rotate_half(self, x):
 | 
						|
        x1, x2 = x.chunk(2, dim=-1)
 | 
						|
        return torch.cat((-x2, x1), dim=-1)
 | 
						|
 | 
						|
    def forward(self, x, offset=0):
 | 
						|
        seq_len = x.size(1)
 | 
						|
        emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
 | 
						|
        cos = emb.cos()
 | 
						|
        sin = emb.sin()
 | 
						|
        return (x * cos) + (self.rotate_half(x) * sin)
 | 
						|
 | 
						|
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
 | 
						|
class DyT(nn.Module):
 | 
						|
    def __init__(self, num_features, alpha_init_value=0.5):
 | 
						|
        super().__init__()
 | 
						|
        self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
 | 
						|
        self.weight = nn.Parameter(torch.ones(num_features))
 | 
						|
        self.bias = nn.Parameter(torch.zeros(num_features))
 | 
						|
 | 
						|
    def forward(self, x):
 | 
						|
        x = torch.tanh(self.alpha * x)
 | 
						|
        return x * self.weight + self.bias
 | 
						|
 | 
						|
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
 | 
						|
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
 | 
						|
class TransformerLayer(nn.Module):
 | 
						|
    def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
 | 
						|
        super().__init__()
 | 
						|
        self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
 | 
						|
        self.q_proj = nn.Linear(h_dim, h_dim)
 | 
						|
        self.k_proj = nn.Linear(h_dim, h_dim)
 | 
						|
        self.v_proj = nn.Linear(h_dim, h_dim)
 | 
						|
        self.o_proj = nn.Linear(h_dim, h_dim)
 | 
						|
        self.ff1 = nn.Linear(h_dim, 4*h_dim)
 | 
						|
        self.ff2 = nn.Linear(4*h_dim, h_dim)
 | 
						|
        self.ln1 = DyT(h_dim)
 | 
						|
        self.ln2 = DyT(h_dim)
 | 
						|
        self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
 | 
						|
 | 
						|
    def split_to_heads(self, x, B, T, H):
 | 
						|
        if self.num_heads <= 1: return x
 | 
						|
        return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
 | 
						|
 | 
						|
    def gather_heads(self, x, B, T, H):
 | 
						|
        if self.num_heads <= 1: return x
 | 
						|
        return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
 | 
						|
 | 
						|
    def attention(self, x):
 | 
						|
        q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
 | 
						|
        k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
 | 
						|
        v = self.split_to_heads(self.v_proj(x), *x.shape)
 | 
						|
        scores = (q @ k.transpose(1, 2)) * (self.h_dim ** -0.5)
 | 
						|
        tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device)
 | 
						|
        scores = scores.masked_fill(tril == 0, float('-inf')) # encoder does not need this line
 | 
						|
        attention = nn.functional.softmax(scores, dim=2)
 | 
						|
        return self.o_proj(self.gather_heads(attention @ v, *x.shape))
 | 
						|
 | 
						|
    def forward(self, x):
 | 
						|
        x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
 | 
						|
        x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
 | 
						|
        return x
 | 
						|
 | 
						|
class GPT2(nn.Module):
 | 
						|
    def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=128, num_heads=4, dropout_rate = 0.1, pixel_size=None):
 | 
						|
        super().__init__()
 | 
						|
        self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
 | 
						|
        self.layers = nn.ModuleList([
 | 
						|
            TransformerLayer(h_dim=self.h_dim, num_heads=self.num_heads, dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) 
 | 
						|
            for _ in range(layers_num)])
 | 
						|
        self.tok_embeds = nn.Embedding(vocab_size, h_dim)
 | 
						|
        self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
 | 
						|
        self.lm_head = nn.Linear(h_dim, vocab_size)
 | 
						|
 | 
						|
    def forward(self, x, targets=None):
 | 
						|
        x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :]
 | 
						|
        for l in self.layers:
 | 
						|
            x = l(x)
 | 
						|
        logits = self.lm_head(x) # B,T,C
 | 
						|
        loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None
 | 
						|
        return logits, loss
 | 
						|
 | 
						|
    # what is the purpose? autoregressive inference!
 | 
						|
    def generate(self, start_idx, max_new_tokens):
 | 
						|
        idx = start_idx
 | 
						|
        for i in range(max_new_tokens):
 | 
						|
            idx_cond = idx[:,-self.max_seq_len:]
 | 
						|
            logits, loss = self(idx_cond)
 | 
						|
            logits = logits[:,-1,:] # B, C
 | 
						|
            probs = F.softmax(logits, dim=-1)
 | 
						|
            idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
 | 
						|
            idx = torch.cat([idx, idx_next], dim=1)
 | 
						|
        return idx |