From d9aa740746083e59ac77605b2f2ed8fd6d7d008b Mon Sep 17 00:00:00 2001 From: Vladimir Protsenko Date: Wed, 10 Dec 2025 05:24:29 +0000 Subject: [PATCH 1/9] trainable cylind lens --- src/char_gpt2_ff.py | 115 +++++++++ src/main.py | 21 +- src/optical_matrix_multiplication/config.py | 5 +- .../optical_mul.py | 20 +- .../propagator.py | 41 +++- src/optics_char_gpt2_ff.py | 218 ++++++++++++++++++ src/optics_char_gpt2_traindiag.py | 179 ++++++++++++++ 7 files changed, 577 insertions(+), 22 deletions(-) create mode 100644 src/char_gpt2_ff.py create mode 100644 src/optics_char_gpt2_ff.py create mode 100644 src/optics_char_gpt2_traindiag.py diff --git a/src/char_gpt2_ff.py b/src/char_gpt2_ff.py new file mode 100644 index 0000000..c1cd2cb --- /dev/null +++ b/src/char_gpt2_ff.py @@ -0,0 +1,115 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +from einops import rearrange +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime +import sys +from pathlib import Path +torch.manual_seed(1337) + +#################################### Model ######################################### + + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = nn.Linear(h_dim, 4*h_dim) + self.ff2 = nn.Linear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = (q @ k.transpose(1, 2)) * (self.h_dim ** -0.5) + tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) + scores = scores.masked_fill(tril == 0, float('-inf')) # encoder does not need this line + attention = nn.functional.softmax(scores, dim=2) + return self.o_proj(self.gather_heads(attention @ v, *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +class GPT2(nn.Module): + def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, pixel_size=None): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.layers = nn.ModuleList([ + TransformerLayer(h_dim=self.h_dim, num_heads=self.num_heads, dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) + for _ in range(layers_num)]) + self.tok_embeds = nn.Embedding(vocab_size, h_dim) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + self.lm_head = nn.Linear(h_dim, vocab_size) + + def forward(self, x, targets=None): + x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + logits = self.lm_head(x) # B,T,C + loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None + return logits, loss + + # what is the purpose? autoregressive inference! + def generate(self, start_idx, max_new_tokens): + idx = start_idx + for i in range(max_new_tokens): + idx_cond = idx[:,-self.max_seq_len:] + logits, loss = self(idx_cond) + logits = logits[:,-1,:] # B, C + probs = F.softmax(logits, dim=-1) + idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) + idx = torch.cat([idx, idx_next], dim=1) + return idx \ No newline at end of file diff --git a/src/main.py b/src/main.py index bf5c414..24406f8 100644 --- a/src/main.py +++ b/src/main.py @@ -9,23 +9,30 @@ import sys from pathlib import Path from char_gpt2 import GPT2 from optics_char_gpt2 import OpticGPT2 +from optics_char_gpt2_traindiag import OpticGPT2TrainDiag +from optics_char_gpt2_ff import OpticGPT2FF seed = 1337 torch.manual_seed(seed) -models = {'gpt2': GPT2, 'optic_gpt2': OpticGPT2} +models = {'gpt2': GPT2, 'optic_gpt2': OpticGPT2, 'optic_gpt2_ff': OpticGPT2FF, 'optic_gpt2_traindiag':OpticGPT2TrainDiag} -batch_size = 50 -max_iters = 40000 +batch_size = 25 +max_iters = 40000*2 eval_interval = 300 learning_rate = 1e-3 device = 'cuda' if torch.cuda.is_available() else 'cpu' eval_iters = 200 -layers_num = 22 +layers_num = 2 h_dim = 64 -max_seq_len = 256 -num_heads = 4 +max_seq_len = 64 +num_heads = 1 dropout_rate = 0.1 pixel_size = 3.6e-6 - +# CUDA_VISIBLE_DEVICES=1 python src/main.py optic_gpt2_ff ./data/wiki.train.tokens ./data/wiki.valid.tokens ./data/wiki.test.tokens seq_128_hdim_64 +# CUDA_VISIBLE_DEVICES=2 python src/main.py optic_gpt2_ff ./data/wiki.train.tokens ./data/wiki.valid.tokens ./data/wiki.test.tokens seq_128_hdim_128 +# CUDA_VISIBLE_DEVICES=3 python src/main.py optic_gpt2_ff ./data/wiki.train.tokens ./data/wiki.valid.tokens ./data/wiki.test.tokens seq_128_hdim_256 +# CUDA_VISIBLE_DEVICES=4 python src/main.py gpt2 ./data/wiki.train.tokens ./data/wiki.valid.tokens ./data/wiki.test.tokens seq_64_hdim_64 +# CUDA_VISIBLE_DEVICES=5 python src/main.py gpt2 ./data/wiki.train.tokens ./data/wiki.valid.tokens ./data/wiki.test.tokens seq_64_hdim_128 +# CUDA_VISIBLE_DEVICES=6 python src/main.py gpt2 ./data/wiki.train.tokens ./data/wiki.valid.tokens ./data/wiki.test.tokens seq_64_hdim_256 # CUDA_VISIBLE_DEVICES=1 python .src/main.py gpt2|optic_gpt2 ./data/wiki.train.tokens ./data/wiki.valid.tokens ./data/wiki.test.tokens comment MODEL_CLASS = models[sys.argv[1]] train_data_path = Path(sys.argv[2]) diff --git a/src/optical_matrix_multiplication/config.py b/src/optical_matrix_multiplication/config.py index 26c8589..7b84265 100644 --- a/src/optical_matrix_multiplication/config.py +++ b/src/optical_matrix_multiplication/config.py @@ -274,7 +274,8 @@ class Config(ConfigOpticBase, ConfigModelBase): wavelength: float = 532e-9, distance: float = 0.03, lens_pixel_size: float = 1.8e-6, - lens_size: int = 8192): + lens_size: int = 8192, + trainable_cylind_lens = False): """ Конструктор класса. @@ -294,6 +295,7 @@ class Config(ConfigOpticBase, ConfigModelBase): distance: дистанция в метрах распространения светового поля между плоскостями. lens_pixel_size: размер пикселя в метрах скрещенных линз в оптической системе (нужен исключительно для моделирования). lens_size: размер скрещенных линз в метрах в оптической системе (нужен исключительно для моделирования). + trainable_cylind_lens: обучаемые диагональные матрицы, линза перед фурье плоскостью """ ConfigOpticBase.__init__(self, wavelength, distance) @@ -320,6 +322,7 @@ class Config(ConfigOpticBase, ConfigModelBase): self._input_vector_split_x: int = left_matrix_split_x self._input_vector_split_y: int = left_matrix_split_y self._result_vector_split: int = result_matrix_split + self._trainable_cylind_lens = trainable_cylind_lens @property def matrix_split_x(self) -> int: diff --git a/src/optical_matrix_multiplication/optical_mul.py b/src/optical_matrix_multiplication/optical_mul.py index 4985adc..cb3e398 100644 --- a/src/optical_matrix_multiplication/optical_mul.py +++ b/src/optical_matrix_multiplication/optical_mul.py @@ -19,14 +19,29 @@ class OpticalMul(_nn.Module): prop_one = _PropSinc(config.input_vector_plane, config.first_lens_plane, config) prop_two = _PropCrossLens(config.first_lens_plane, config) prop_three = _PropSinc(config.first_lens_plane, config.matrix_plane, config) - prop_four = _PropСylindLens(config.matrix_plane, config) + prop_four = _PropСylindLens(config.matrix_plane, config, trainable=config._trainable_cylind_lens) prop_five = _PropSinc(config.matrix_plane, config.second_lens_plane, config) prop_six = _PropCrossLens(config.second_lens_plane, config).T prop_seven = _PropSinc(config.second_lens_plane, config.output_vector_plane, config) - self._propagator_one: _Prop = prop_one + prop_two + prop_three + prop_four + # print(prop_one) + # print(prop_two) + # print(prop_three) + # print(prop_four) + # print(prop_five) + # print((prop_one + prop_two + prop_three)) + # print((prop_one + prop_two + prop_three + prop_four)) + + self._propagator_one: _Prop = prop_one + prop_two + prop_three + self._propagator_between = prop_four self._propagator_two: _Prop = prop_five + prop_six + prop_seven + # print(self._propagator_one) + # print(self._propagator_between) + # print(self._propagator_between.operator_X) + # print(self._propagator_between.operator_Y) + # print(self._propagator_two) + kron_vec_utils = _torch.ones((config.input_vector_split_y, config.input_vector_split_x)) kron_mat_utils = _torch.ones((config.matrix_split_x, config.matrix_split_y)) self.register_buffer('_kron_vec_utils', kron_vec_utils, persistent=True) @@ -111,6 +126,7 @@ class OpticalMul(_nn.Module): mat_field = self.prepare_matrix(other) vec_field = self._propagator_one(vec_field) + vec_field = self._propagator_between(vec_field) vec_field = self._propagator_two(vec_field * mat_field) return self.prepare_out(vec_field) \ No newline at end of file diff --git a/src/optical_matrix_multiplication/propagator.py b/src/optical_matrix_multiplication/propagator.py index c667bd6..68ee2c4 100644 --- a/src/optical_matrix_multiplication/propagator.py +++ b/src/optical_matrix_multiplication/propagator.py @@ -16,12 +16,20 @@ class Propagator(_ABC, _nn.Module): operator_X: оператор отображающий распроcтранение светового поля вдоль оси абсцисс operator_Y: оператор отображающий распроcтранение светового поля вдоль оси ординат """ - def __init__(self, operator_X: _torch.Tensor, operator_Y: _torch.Tensor): + def __init__(self, operator_X: _torch.Tensor, operator_Y: _torch.Tensor, trainable = False, diagonal = False): super(Propagator, self).__init__() operator_X: _torch.Tensor = _torch.view_as_real(operator_X) operator_Y: _torch.Tensor = _torch.view_as_real(operator_Y) - self.register_buffer('_operator_X', operator_X, persistent=True) - self.register_buffer('_operator_Y', operator_Y, persistent=True) + if trainable: + self._operator_X = _nn.Parameter(operator_X) + self._operator_Y = _nn.Parameter(operator_Y) + self._trainable = trainable + self._diagonal = diagonal + else: + self.register_buffer('_operator_X', operator_X, persistent=True) + self.register_buffer('_operator_Y', operator_Y, persistent=True) + self._trainable = trainable + self._diagonal = diagonal @property def operator_X(self) -> _torch.Tensor: @@ -103,7 +111,13 @@ class Propagator(_ABC, _nn.Module): Распределение комплексной амплитуды светового поля, после распространения. """ - return self.operator_Y @ field @ self.operator_X + if self._diagonal: + return _torch.diag_embed(self.operator_Y) @ field @ _torch.diag_embed(self.operator_X) + else: + return self.operator_Y @ field @ self.operator_X + + def __repr__(self): + return f"Diag: {self._diagonal} Trainable: {self._trainable} Y shape: {self.operator_Y.shape}, X shape: {self.operator_X.shape}" class PropagatorLens(Propagator): """ @@ -133,7 +147,7 @@ class PropagatorCrossLens(PropagatorLens): представленной тонким оптическим элементом. """ def __init__(self, plane: _ConfigDesignPlane, - config: _ConfigOpticBase): + config: _ConfigOpticBase, trainable = False): """ Конструктор класса скрещенной линзы. @@ -144,7 +158,8 @@ class PropagatorCrossLens(PropagatorLens): operator_X = _torch.exp(-1j * config.K / config.distance * plane.linspace_by_x**2) operator_Y = _torch.exp(-1j * config.K / 2 / config.distance * plane.linspace_by_y**2) super(PropagatorCrossLens, self).__init__(_torch.diag_embed(operator_X), - _torch.diag_embed(operator_Y)) + _torch.diag_embed(operator_Y), + trainable) class PropagatorСylindLens(PropagatorLens): """ @@ -152,7 +167,7 @@ class PropagatorСylindLens(PropagatorLens): представленной тонким оптическим элементом. """ def __init__(self, plane: _ConfigDesignPlane, - config: _ConfigOpticBase): + config: _ConfigOpticBase, trainable = False): """ Конструктор класса цилиндрической линзы. @@ -162,8 +177,10 @@ class PropagatorСylindLens(PropagatorLens): """ operator_X = _torch.exp(-1j * config.K / config.distance * plane.linspace_by_x**2) operator_Y = _torch.ones_like(plane.linspace_by_y, dtype=_torch.cfloat) - super(PropagatorСylindLens, self).__init__(_torch.diag_embed(operator_X), - _torch.diag_embed(operator_Y)) + super(PropagatorСylindLens, self).__init__(operator_X, + operator_Y, + trainable, + diagonal=True) class PropagatorSinc(Propagator): """ @@ -172,7 +189,7 @@ class PropagatorSinc(Propagator): """ def __init__(self, first_plane: _ConfigDesignPlane, second_plane: _ConfigDesignPlane, - config: _ConfigOpticBase): + config: _ConfigOpticBase, trainable = False): """ Конструктор класса распространения в свободном пространстве. @@ -184,7 +201,7 @@ class PropagatorSinc(Propagator): operator_X, operator_Y = self.__get_operators(first_plane, second_plane, config) - super(PropagatorSinc, self).__init__(operator_X, operator_Y) + super(PropagatorSinc, self).__init__(operator_X, operator_Y, trainable) def __get_operator_for_dim(self, pixel_size_in: float, @@ -217,4 +234,4 @@ class PropagatorSinc(Propagator): second_plane.pixel_size_by_y, difference_y, config) - return operator_X, operator_Y + return operator_X, operator_Y \ No newline at end of file diff --git a/src/optics_char_gpt2_ff.py b/src/optics_char_gpt2_ff.py new file mode 100644 index 0000000..2347bf9 --- /dev/null +++ b/src/optics_char_gpt2_ff.py @@ -0,0 +1,218 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F, init +from einops import rearrange +import optical_matrix_multiplication as omm +from optical_matrix_multiplication import propagator +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime +from pathlib import Path +import sys +import math +torch.manual_seed(1337) + +#################################### Model ######################################### + +def norm(matrix: torch.Tensor, max_val: float = 1) -> torch.Tensor: + return matrix / (max_val + 1e-10) + +def optics_matmul_shift(sim, tensor_1, tensor_2): + # print(tensor_1.shape, tensor_2.shape) + + tensor_1 = tensor_1[None,:,:,:] + tensor_2 = tensor_2[None,None,:,:] + # print(tensor_1.shape, tensor_2.shape) + # raise RuntimeError + + if torch.min(tensor_1) >= 0 and torch.min(tensor_2) >= 0: + max_abs = abs(max(torch.max(tensor_1), torch.max(tensor_2))) + a, b = norm(tensor_1, max_abs), norm(tensor_2, max_abs) + return sim(a, b)[0,:,:,:] * max_abs **2 + + min_abs = abs(min(torch.min(tensor_1), torch.min(tensor_2))) + max_abs = abs(max(torch.max(tensor_1), torch.max(tensor_2))) + min_abs + + shift_a = min_abs * torch.ones(tensor_1.shape).to(tensor_1.device) + shift_b = min_abs * torch.ones(tensor_2.shape).to(tensor_1.device) + a_a_sh = tensor_1 + shift_a + b_b_sh = tensor_2 + shift_b + + a_a_sh_norm, b_b_sh_norm = norm(a_a_sh, max_abs), norm(b_b_sh, max_abs) + shift_a_norm, shift_b_norm = norm(shift_a, max_abs), norm(shift_b, max_abs) + a_a_sh_b_b_sh = sim(a_a_sh_norm, b_b_sh_norm) + a_a_sh_b_sh = sim(a_a_sh_norm, shift_b_norm) + a_sh_b_b_sh = sim(shift_a_norm, b_b_sh_norm) + a_sh_b_sh = sim(shift_a_norm, shift_b_norm) + + return (a_a_sh_b_b_sh - a_a_sh_b_sh - a_sh_b_b_sh + a_sh_b_sh)[0,:,:,:] * max_abs ** 2 + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +class OpticLinear(nn.Module): + def __init__( + self, + in_features, + out_features, + bias = True, + device = None, + dtype = None, + pixel_size = 3.6e-6 + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = nn.Parameter( + torch.empty((in_features, out_features), **factory_kwargs) + ) + # print(self.weight.shape) + if bias: + self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs)) + else: + self.register_parameter("bias", None) + self.k = nn.Parameter(torch.randn(1)) + self.sim = omm.OpticalMul( + omm.Config( + right_matrix_count_columns = out_features , + right_matrix_count_rows = in_features, + right_matrix_width = pixel_size * out_features , + right_matrix_height = pixel_size * in_features, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01) + ) + self.reset_parameters() + + def forward(self, input): + """ + Runs the forward pass. + """ + return self.k * optics_matmul_shift(self.sim, input, self.weight) + self.bias + + def reset_parameters(self) -> None: + """ + Resets parameters based on their initialization used in ``__init__``. + """ + # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with + # uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see + # https://github.com/pytorch/pytorch/issues/57109 + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + init.uniform_(self.bias, -bound, bound) + + def extra_repr(self) -> str: + """ + Return the extra representation of the module. + """ + return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}" + + + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = OpticLinear(h_dim, 4*h_dim) + self.ff2 = OpticLinear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = (q @ k.transpose(1, 2)) * (self.h_dim ** -0.5) + tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) + scores = scores.masked_fill(tril == 0, float('-inf')) # encoder does not need this line + attention = nn.functional.softmax(scores, dim=2) + return self.o_proj(self.gather_heads(attention @ v, *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +class OpticGPT2FF(nn.Module): + def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, + pixel_size = 3.6e-6): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.layers = nn.ModuleList([ + TransformerLayer(h_dim=self.h_dim, num_heads=self.num_heads, + dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) + for _ in range(layers_num)]) + self.tok_embeds = nn.Embedding(vocab_size, h_dim) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + self.lm_head = nn.Linear(h_dim, vocab_size) + + def forward(self, x, targets=None): + x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + logits = self.lm_head(x) # B,T,C + loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None + return logits, loss + + # what is the purpose? autoregressive inference! + def generate(self, start_idx, max_new_tokens): + idx = start_idx + for i in range(max_new_tokens): + idx_cond = idx[:,-self.max_seq_len:] + logits, loss = self(idx_cond) + logits = logits[:,-1,:] # B, C + probs = F.softmax(logits, dim=-1) + idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) + idx = torch.cat([idx, idx_next], dim=1) + return idx \ No newline at end of file diff --git a/src/optics_char_gpt2_traindiag.py b/src/optics_char_gpt2_traindiag.py new file mode 100644 index 0000000..2cdfc81 --- /dev/null +++ b/src/optics_char_gpt2_traindiag.py @@ -0,0 +1,179 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +from einops import rearrange +import optical_matrix_multiplication as omm +from optical_matrix_multiplication import propagator +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime +from pathlib import Path +import sys +torch.manual_seed(1337) + +#################################### Model ######################################### + +def norm(matrix: torch.Tensor, max_val: float = 1) -> torch.Tensor: + return matrix / (max_val + 1e-10) + +def optics_matmul_shift(sim, tensor_1, tensor_2): + tensor_1 = tensor_1[None,:,:,:] + tensor_2 = tensor_2[None,:,:,:] + + if torch.min(tensor_1) >= 0 and torch.min(tensor_2) >= 0: + max_abs = abs(max(torch.max(tensor_1), torch.max(tensor_2))) + a, b = norm(tensor_1, max_abs), norm(tensor_2, max_abs) + return sim(a, b)[0,:,:,:] * max_abs **2 + + min_abs = abs(min(torch.min(tensor_1), torch.min(tensor_2))) + max_abs = abs(max(torch.max(tensor_1), torch.max(tensor_2))) + min_abs + + shift_a = min_abs * torch.ones(tensor_1.shape).to(tensor_1.device) + shift_b = min_abs * torch.ones(tensor_2.shape).to(tensor_1.device) + a_a_sh = tensor_1 + shift_a + b_b_sh = tensor_2 + shift_b + + a_a_sh_norm, b_b_sh_norm = norm(a_a_sh, max_abs), norm(b_b_sh, max_abs) + shift_a_norm, shift_b_norm = norm(shift_a, max_abs), norm(shift_b, max_abs) + a_a_sh_b_b_sh = sim(a_a_sh_norm, b_b_sh_norm) + a_a_sh_b_sh = sim(a_a_sh_norm, shift_b_norm) + a_sh_b_b_sh = sim(shift_a_norm, b_b_sh_norm) + a_sh_b_sh = sim(shift_a_norm, shift_b_norm) + + return (a_a_sh_b_b_sh - a_a_sh_b_sh - a_sh_b_b_sh + a_sh_b_sh)[0,:,:,:] * max_abs ** 2 + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = nn.Linear(h_dim, 4*h_dim) + self.ff2 = nn.Linear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + self.k1 = nn.Parameter(torch.randn(1)) + self.k2 = nn.Parameter(torch.randn(1)) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = self.k1 * optics_matmul_shift(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5) + tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) + scores = scores.masked_fill(tril == 0, float('-inf')) + attention = nn.functional.softmax(scores, dim=2) + output = self.k2 * optics_matmul_shift(self.sim_output, attention, v) + return self.o_proj(self.gather_heads(output, *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +class OpticGPT2TrainDiag(nn.Module): + def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, + pixel_size = 3.6e-6): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=True) + ) + + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=True) + ) + self.layers = nn.ModuleList([ + TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads, + dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) + for _ in range(layers_num)]) + self.tok_embeds = nn.Embedding(vocab_size, h_dim) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + self.lm_head = nn.Linear(h_dim, vocab_size) + + def forward(self, x, targets=None): + x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + logits = self.lm_head(x) # B,T,C + loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None + return logits, loss + + # what is the purpose? autoregressive inference! + def generate(self, start_idx, max_new_tokens): + idx = start_idx + for i in range(max_new_tokens): + idx_cond = idx[:,-self.max_seq_len:] + logits, loss = self(idx_cond) + logits = logits[:,-1,:] # B, C + probs = F.softmax(logits, dim=-1) + idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) + idx = torch.cat([idx, idx_next], dim=1) + return idx \ No newline at end of file From a042b64d7e5daa3df3a402f526892486ba28fbcd Mon Sep 17 00:00:00 2001 From: Vladimir Protsenko Date: Sat, 17 Jan 2026 12:21:33 +0000 Subject: [PATCH 2/9] new experiments. code refactoring. --- ...r_gpt2_ff.py => char_gpt2_scaledmatmul.py} | 8 +- src/main.py | 73 ++++-- src/optical_matrix_multiplication/__init__.py | 3 +- .../optical_mul.py | 31 +-- src/optical_matrix_multiplication/parallel.py | 69 +++++- src/optics_char_gpt2.py | 44 +++- src/optics_char_gpt2_ff.py | 1 - src/optics_char_gpt2_new_formula.py | 227 ++++++++++++++++++ src/optics_char_gpt2_nokoef.py | 209 ++++++++++++++++ src/optics_char_gpt2_nokoef_newf.py | 220 +++++++++++++++++ src/optics_char_gpt2_traindiag.py | 95 +++++--- 11 files changed, 894 insertions(+), 86 deletions(-) rename src/{char_gpt2_ff.py => char_gpt2_scaledmatmul.py} (94%) create mode 100644 src/optics_char_gpt2_new_formula.py create mode 100644 src/optics_char_gpt2_nokoef.py create mode 100644 src/optics_char_gpt2_nokoef_newf.py diff --git a/src/char_gpt2_ff.py b/src/char_gpt2_scaledmatmul.py similarity index 94% rename from src/char_gpt2_ff.py rename to src/char_gpt2_scaledmatmul.py index c1cd2cb..25bb992 100644 --- a/src/char_gpt2_ff.py +++ b/src/char_gpt2_scaledmatmul.py @@ -59,6 +59,8 @@ class TransformerLayer(nn.Module): self.ln1 = DyT(h_dim) self.ln2 = DyT(h_dim) self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + self.k1 = nn.Parameter(torch.randn(1)) + self.k2 = nn.Parameter(torch.randn(1)) def split_to_heads(self, x, B, T, H): if self.num_heads <= 1: return x @@ -72,18 +74,18 @@ class TransformerLayer(nn.Module): q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) v = self.split_to_heads(self.v_proj(x), *x.shape) - scores = (q @ k.transpose(1, 2)) * (self.h_dim ** -0.5) + scores = self.k1 * (q @ k.transpose(1, 2)) * (self.h_dim ** -0.5) tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) scores = scores.masked_fill(tril == 0, float('-inf')) # encoder does not need this line attention = nn.functional.softmax(scores, dim=2) - return self.o_proj(self.gather_heads(attention @ v, *x.shape)) + return self.o_proj(self.gather_heads(self.k2 * (attention @ v), *x.shape)) def forward(self, x): x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) return x -class GPT2(nn.Module): +class GPT2ScaledMM(nn.Module): def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, pixel_size=None): super().__init__() self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) diff --git a/src/main.py b/src/main.py index 24406f8..52fd104 100644 --- a/src/main.py +++ b/src/main.py @@ -11,41 +11,54 @@ from char_gpt2 import GPT2 from optics_char_gpt2 import OpticGPT2 from optics_char_gpt2_traindiag import OpticGPT2TrainDiag from optics_char_gpt2_ff import OpticGPT2FF +from optics_char_gpt2_new_formula import OpticGPT2NewFormula +from char_gpt2_scaledmatmul import GPT2ScaledMM +from optics_char_gpt2_nokoef import OpticGPT2NOKoef +from optics_char_gpt2_nokoef_newf import OpticGPT2NOKoefNewF +import shutil seed = 1337 torch.manual_seed(seed) -models = {'gpt2': GPT2, 'optic_gpt2': OpticGPT2, 'optic_gpt2_ff': OpticGPT2FF, 'optic_gpt2_traindiag':OpticGPT2TrainDiag} - -batch_size = 25 -max_iters = 40000*2 +models = { + 'gpt2': GPT2, + 'optic_gpt2': OpticGPT2, + 'optic_gpt2_ff': OpticGPT2FF, + 'optic_gpt2_traindiag': OpticGPT2TrainDiag, + 'optic_gpt2_newformula': OpticGPT2NewFormula, + 'optic_gpt2_nokoef': OpticGPT2NOKoef, + 'optic_gpt2_nokoef_newformula': OpticGPT2NOKoefNewF, + 'gpt2_scaledmm': GPT2ScaledMM +} + +batch_size = 50 +gradient_accumulation_steps = 2 # check this impl for correctness https://unsloth.ai/blog/gradient +max_iters = 40000 eval_interval = 300 learning_rate = 1e-3 device = 'cuda' if torch.cuda.is_available() else 'cpu' eval_iters = 200 layers_num = 2 h_dim = 64 -max_seq_len = 64 +max_seq_len = 256 num_heads = 1 dropout_rate = 0.1 pixel_size = 3.6e-6 +assert batch_size % gradient_accumulation_steps == 0 # CUDA_VISIBLE_DEVICES=1 python src/main.py optic_gpt2_ff ./data/wiki.train.tokens ./data/wiki.valid.tokens ./data/wiki.test.tokens seq_128_hdim_64 -# CUDA_VISIBLE_DEVICES=2 python src/main.py optic_gpt2_ff ./data/wiki.train.tokens ./data/wiki.valid.tokens ./data/wiki.test.tokens seq_128_hdim_128 -# CUDA_VISIBLE_DEVICES=3 python src/main.py optic_gpt2_ff ./data/wiki.train.tokens ./data/wiki.valid.tokens ./data/wiki.test.tokens seq_128_hdim_256 -# CUDA_VISIBLE_DEVICES=4 python src/main.py gpt2 ./data/wiki.train.tokens ./data/wiki.valid.tokens ./data/wiki.test.tokens seq_64_hdim_64 -# CUDA_VISIBLE_DEVICES=5 python src/main.py gpt2 ./data/wiki.train.tokens ./data/wiki.valid.tokens ./data/wiki.test.tokens seq_64_hdim_128 -# CUDA_VISIBLE_DEVICES=6 python src/main.py gpt2 ./data/wiki.train.tokens ./data/wiki.valid.tokens ./data/wiki.test.tokens seq_64_hdim_256 -# CUDA_VISIBLE_DEVICES=1 python .src/main.py gpt2|optic_gpt2 ./data/wiki.train.tokens ./data/wiki.valid.tokens ./data/wiki.test.tokens comment -MODEL_CLASS = models[sys.argv[1]] + +MODEL_CLASS = models[sys.argv[1]] train_data_path = Path(sys.argv[2]) val_data_path = Path(sys.argv[3]) test_data_path = Path(sys.argv[4]) -comment = f"{sys.argv[1]}_{train_data_path.name}_{sys.argv[5]}_{seed}" +comment = f"{sys.argv[1]}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[5] if len(sys.argv) >= 6 else ''}" logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' writer = SummaryWriter(logs_dir) -script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).name) +script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name) print("Logs dir:", logs_dir) -script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script -script_snapshot_path.chmod(0o400) # with read-only permission +# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script +shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository +script_snapshot_path.chmod(0o500) # with read-only permission + #################################### Dataset ######################################### # wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt @@ -106,33 +119,45 @@ m = MODEL_CLASS( layers_num=layers_num ) m = m.to(device) +writer.add_text('model', str(m), 0) + #################################### Train ######################################### optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) +m.eval() completion = complete(m, encode("\n"*max_seq_len), 2*max_seq_len) print(completion) writer.add_text('completions', completion, 0) for i in range(max_iters): - xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size) - logits, loss = m(xb, yb) + m.train() optimizer.zero_grad(set_to_none=True) - loss.backward() + accumulated_loss = 0.0 + for j in range(gradient_accumulation_steps): + xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps) + logits, loss = m(xb, yb) + loss = loss / gradient_accumulation_steps + loss.backward() + accumulated_loss += loss.item() + if i % 100 == 0: + writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i) optimizer.step() - writer.add_scalar('loss', loss.item(), i) - print(f"\r{i}/{max_iters} {loss.item()}", end="") + writer.add_scalar('loss', accumulated_loss, i) + print(f"\r{i}/{max_iters} {accumulated_loss}", end="") if i % 5000 == 0: + m.eval() ppl = perplexity(model=m, data=val_data) writer.add_scalar('val_perplexity', ppl.item(), i) print(f"\rPerplexity at {i}: {ppl}") writer.add_text('completions', complete(m, encode("\n"*max_seq_len), 2*max_seq_len), i) +m.eval() ppl = perplexity(model=m, data=val_data) -print(f"\r{i+1}/{max_iters} {loss.item()}") -print(f"\rPerplexity at {i}: {ppl}") +print(f"\r{i+1}/{max_iters} {accumulated_loss}") +print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") writer.add_scalar('val_perplexity', ppl.item(), i+1) -writer.add_scalar('loss', loss.item(), i) +writer.add_scalar('loss', accumulated_loss, i) ppl = perplexity(model=m, data=test_data) writer.add_scalar('test_perplexity', ppl.item(), i+1) diff --git a/src/optical_matrix_multiplication/__init__.py b/src/optical_matrix_multiplication/__init__.py index 9974821..9a1844c 100644 --- a/src/optical_matrix_multiplication/__init__.py +++ b/src/optical_matrix_multiplication/__init__.py @@ -6,4 +6,5 @@ __version__ = "3.0.0" from .config import Config from . import propagator from .optical_mul import OpticalMul -from .parallel import DataParallel \ No newline at end of file +from .parallel import DataParallel +from .parallel import ScatterDataParallel \ No newline at end of file diff --git a/src/optical_matrix_multiplication/optical_mul.py b/src/optical_matrix_multiplication/optical_mul.py index cb3e398..473cd74 100644 --- a/src/optical_matrix_multiplication/optical_mul.py +++ b/src/optical_matrix_multiplication/optical_mul.py @@ -15,33 +15,23 @@ class OpticalMul(_nn.Module): config: конфигурация расчётной системы. """ super(OpticalMul, self).__init__() + self.trainable_cylind_lens = config._trainable_cylind_lens prop_one = _PropSinc(config.input_vector_plane, config.first_lens_plane, config) prop_two = _PropCrossLens(config.first_lens_plane, config) prop_three = _PropSinc(config.first_lens_plane, config.matrix_plane, config) - prop_four = _PropСylindLens(config.matrix_plane, config, trainable=config._trainable_cylind_lens) + prop_four = _PropСylindLens(config.matrix_plane, config, trainable=self.trainable_cylind_lens) prop_five = _PropSinc(config.matrix_plane, config.second_lens_plane, config) prop_six = _PropCrossLens(config.second_lens_plane, config).T prop_seven = _PropSinc(config.second_lens_plane, config.output_vector_plane, config) - # print(prop_one) - # print(prop_two) - # print(prop_three) - # print(prop_four) - # print(prop_five) - # print((prop_one + prop_two + prop_three)) - # print((prop_one + prop_two + prop_three + prop_four)) - - self._propagator_one: _Prop = prop_one + prop_two + prop_three - self._propagator_between = prop_four + if self.trainable_cylind_lens: + self._propagator_one: _Prop = prop_one + prop_two + prop_three + self._propagator_between = prop_four + else: + self._propagator_one: _Prop = prop_one + prop_two + prop_three + prop_four self._propagator_two: _Prop = prop_five + prop_six + prop_seven - # print(self._propagator_one) - # print(self._propagator_between) - # print(self._propagator_between.operator_X) - # print(self._propagator_between.operator_Y) - # print(self._propagator_two) - kron_vec_utils = _torch.ones((config.input_vector_split_y, config.input_vector_split_x)) kron_mat_utils = _torch.ones((config.matrix_split_x, config.matrix_split_y)) self.register_buffer('_kron_vec_utils', kron_vec_utils, persistent=True) @@ -125,8 +115,11 @@ class OpticalMul(_nn.Module): vec_field = self.prepare_vector(input) mat_field = self.prepare_matrix(other) - vec_field = self._propagator_one(vec_field) - vec_field = self._propagator_between(vec_field) + if self.trainable_cylind_lens: + vec_field = self._propagator_one(vec_field) + vec_field = self._propagator_between(vec_field) + else: + vec_field = self._propagator_one(vec_field) vec_field = self._propagator_two(vec_field * mat_field) return self.prepare_out(vec_field) \ No newline at end of file diff --git a/src/optical_matrix_multiplication/parallel.py b/src/optical_matrix_multiplication/parallel.py index 631a1cd..b707415 100644 --- a/src/optical_matrix_multiplication/parallel.py +++ b/src/optical_matrix_multiplication/parallel.py @@ -94,4 +94,71 @@ class DataParallel(_nn.Module): outputs = _nn.parallel.parallel_apply(replicas, stacked_input) - return _nn.parallel.gather(outputs, self.output_device, dim) \ No newline at end of file + return _nn.parallel.gather(outputs, self.output_device, dim) + +class ScatterDataParallel(_nn.Module): + """ + Оптимизированный DataParallel для работы с attention матрицами разных размеров. + Эквивалентно DataParallel от PyTorch? + """ + def __init__(self, module: _nn.Module, devices: Union[None, List[Union[int, _torch.device]]] = None, + output_device: Union[int, _torch.device] = None) -> None: + super(ScatterDataParallel, self).__init__() + + if not _torch.cuda.is_available(): + raise EnvironmentError("cuda is not available.") + + if not devices: + devices = [_torch.device(f'cuda:{x}') for x in range(_torch.cuda.device_count())] + + if not output_device: + output_device = devices[0] + + self.module = module + self.devices = devices + self.output_device = output_device + + def buffers(self, *inputs) -> Iterator[_torch.Tensor]: + return self.module.buffers(*inputs) + + def parameters(self, *inputs) -> Iterator[_nn.parameter.Parameter]: + return self.module.parameters(*inputs) + + def forward(self, input: _torch.Tensor, other: _torch.Tensor, **kwargs: Any) -> _torch.Tensor: + ''' + Оптимизированный forward для attention матриц. + + Особенности: + - Scatter по batch dimension (0) вместо произвольного dim + - Оба тензора scatter'ятся для согласованности размерностей + - Поддержка многомерных attention тензоров [batch, heads, seq, dim] + ''' + + # Определяем dimension для scatter на основе структуры тензоров + if input.dim() >= 3 and other.dim() >= 3: + # Для attention матриц scatter по batch dimension + scatter_dim = 0 + else: + # Для обычных 2D матриц используем dim из kwargs или по умолчанию 2 + scatter_dim = kwargs.get('dim', 2) + + # Подготовка модуля и данных + self.module = self.module.to(self.devices[0]) + + # Scatter ОБОИХ тензоров для согласованности размерностей + scattered_input = _nn.parallel.scatter(input, self.devices, scatter_dim) + scattered_other = _nn.parallel.scatter(other, self.devices, scatter_dim) + + # Создаем реплики модуля + replicas = _nn.parallel.replicate(self.module, self.devices) + + # Формируем входные данные для каждого устройства + # Убедимся, что все списки одинаковой длины + min_len = min(len(scattered_input), len(scattered_other), len(replicas)) + stacked_input = [(scattered_input[i], scattered_other[i]) for i in range(min_len)] + + # Параллельное вычисление + outputs = _nn.parallel.parallel_apply(replicas[:min_len], stacked_input) + + # Сбор результатов + return _nn.parallel.gather(outputs, self.output_device, scatter_dim) \ No newline at end of file diff --git a/src/optics_char_gpt2.py b/src/optics_char_gpt2.py index 1c68c7a..3989ed6 100644 --- a/src/optics_char_gpt2.py +++ b/src/optics_char_gpt2.py @@ -121,7 +121,37 @@ class OpticGPT2(nn.Module): pixel_size = 3.6e-6): super().__init__() self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) - self.sim_scores = omm.OpticalMul( + if max_seq_len < 512: + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=False) + ) + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=False) + ) + if max_seq_len >= 512: + self.sim_scores = omm.OpticalMul( omm.Config(right_matrix_count_columns = max_seq_len, right_matrix_count_rows = h_dim // num_heads, right_matrix_width = pixel_size * max_seq_len, @@ -132,10 +162,11 @@ class OpticGPT2(nn.Module): left_matrix_split_x = 2, left_matrix_split_y = 2, result_matrix_split = 2, - distance = 0.01) + distance = 0.15, + lens_size = 8192 * 2, + trainable_cylind_lens=False) ) - - self.sim_output = omm.OpticalMul( + self.sim_output = omm.OpticalMul( omm.Config(right_matrix_count_columns = h_dim // num_heads, right_matrix_count_rows = max_seq_len, right_matrix_width = pixel_size * (h_dim // num_heads), @@ -146,8 +177,11 @@ class OpticGPT2(nn.Module): left_matrix_split_x = 2, left_matrix_split_y = 2, result_matrix_split = 2, - distance = 0.01) + distance = 0.15, + lens_size = 8192 * 2, + trainable_cylind_lens=False) ) + self.layers = nn.ModuleList([ TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads, dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) diff --git a/src/optics_char_gpt2_ff.py b/src/optics_char_gpt2_ff.py index 2347bf9..59bbd93 100644 --- a/src/optics_char_gpt2_ff.py +++ b/src/optics_char_gpt2_ff.py @@ -96,7 +96,6 @@ class OpticLinear(nn.Module): self.weight = nn.Parameter( torch.empty((in_features, out_features), **factory_kwargs) ) - # print(self.weight.shape) if bias: self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs)) else: diff --git a/src/optics_char_gpt2_new_formula.py b/src/optics_char_gpt2_new_formula.py new file mode 100644 index 0000000..6b59105 --- /dev/null +++ b/src/optics_char_gpt2_new_formula.py @@ -0,0 +1,227 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +from einops import rearrange +import optical_matrix_multiplication as omm +from optical_matrix_multiplication import propagator +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime +from pathlib import Path +import sys +torch.manual_seed(1337) + +#################################### Model ######################################### + +# def norm(matrix: torch.Tensor, max_val: float = 1) -> torch.Tensor: +# return matrix / (max_val + 1e-10) + +def new_formula(sim, tensor_1, tensor_2): + tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1 + tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2 + device = tensor_1.device + + A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0) + A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0) + B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0) + B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0) + + max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений + max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений + max_B_pos = torch.max(B_pos) + max_B_neg = torch.max(B_neg) + + zero_template = torch.zeros_like( + torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3])) + + if max_A_pos > 0 and max_B_pos > 0: + t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos + else: + t1 = zero_template.clone().to(device) + + if max_A_pos > 0 and max_B_neg > 0: + t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg + else: + t2 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_pos > 0: + t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos + else: + t3 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_neg > 0: + t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg + else: + t4 = zero_template.clone().to(device) + + return (t1 - t2 - t3 + t4)[0,:,:,:] + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = nn.Linear(h_dim, 4*h_dim) + self.ff2 = nn.Linear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + self.k1 = nn.Parameter(torch.randn(1)) + self.k2 = nn.Parameter(torch.randn(1)) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = self.k1 * new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5) + tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) + scores = scores.masked_fill(tril == 0, float('-inf')) + attention = nn.functional.softmax(scores, dim=2) + output = self.k2 * new_formula(self.sim_output, attention, v) + return self.o_proj(self.gather_heads(output, *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +class OpticGPT2NewFormula(nn.Module): + def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, + pixel_size = 3.6e-6): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + if max_seq_len != 512: + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=False) + ) + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=False) + ) + if max_seq_len == 512: + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2, + trainable_cylind_lens=False) + ) + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2, + trainable_cylind_lens=False) + ) + self.sim_scores = omm.DataParallel(self.sim_scores) + self.sim_output = omm.DataParallel(self.sim_output) + + self.layers = nn.ModuleList([ + TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads, + dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) + for _ in range(layers_num)]) + self.tok_embeds = nn.Embedding(vocab_size, h_dim) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + self.lm_head = nn.Linear(h_dim, vocab_size) + + def forward(self, x, targets=None): + x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + logits = self.lm_head(x) # B,T,C + loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None + return logits, loss + + # what is the purpose? autoregressive inference! + def generate(self, start_idx, max_new_tokens): + idx = start_idx + for i in range(max_new_tokens): + idx_cond = idx[:,-self.max_seq_len:] + logits, loss = self(idx_cond) + logits = logits[:,-1,:] # B, C + probs = F.softmax(logits, dim=-1) + idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) + idx = torch.cat([idx, idx_next], dim=1) + return idx \ No newline at end of file diff --git a/src/optics_char_gpt2_nokoef.py b/src/optics_char_gpt2_nokoef.py new file mode 100644 index 0000000..af1e518 --- /dev/null +++ b/src/optics_char_gpt2_nokoef.py @@ -0,0 +1,209 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +from einops import rearrange +import optical_matrix_multiplication as omm +from optical_matrix_multiplication import propagator +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime +from pathlib import Path +import sys +torch.manual_seed(1337) + +#################################### Model ######################################### + +def norm(matrix: torch.Tensor, max_val: float = 1) -> torch.Tensor: + return matrix / (max_val + 1e-10) + +def optics_matmul_shift(sim, tensor_1, tensor_2): + tensor_1 = tensor_1[None,:,:,:] + tensor_2 = tensor_2[None,:,:,:] + + if torch.min(tensor_1) >= 0 and torch.min(tensor_2) >= 0: + max_abs = abs(max(torch.max(tensor_1), torch.max(tensor_2))) + a, b = norm(tensor_1, max_abs), norm(tensor_2, max_abs) + return sim(a, b)[0,:,:,:] * max_abs **2 + + min_abs = abs(min(torch.min(tensor_1), torch.min(tensor_2))) + max_abs = abs(max(torch.max(tensor_1), torch.max(tensor_2))) + min_abs + + shift_a = min_abs * torch.ones(tensor_1.shape).to(tensor_1.device) + shift_b = min_abs * torch.ones(tensor_2.shape).to(tensor_1.device) + a_a_sh = tensor_1 + shift_a + b_b_sh = tensor_2 + shift_b + + a_a_sh_norm, b_b_sh_norm = norm(a_a_sh, max_abs), norm(b_b_sh, max_abs) + shift_a_norm, shift_b_norm = norm(shift_a, max_abs), norm(shift_b, max_abs) + a_a_sh_b_b_sh = sim(a_a_sh_norm, b_b_sh_norm) + a_a_sh_b_sh = sim(a_a_sh_norm, shift_b_norm) + a_sh_b_b_sh = sim(shift_a_norm, b_b_sh_norm) + a_sh_b_sh = sim(shift_a_norm, shift_b_norm) + + return (a_a_sh_b_b_sh - a_a_sh_b_sh - a_sh_b_b_sh + a_sh_b_sh)[0,:,:,:] * max_abs ** 2 + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = nn.Linear(h_dim, 4*h_dim) + self.ff2 = nn.Linear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = optics_matmul_shift(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5) + tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) + scores = scores.masked_fill(tril == 0, float('-inf')) + attention = nn.functional.softmax(scores, dim=2) + output = optics_matmul_shift(self.sim_output, attention, v) + return self.o_proj(self.gather_heads(output, *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +class OpticGPT2NOKoef(nn.Module): + def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, + pixel_size = 3.6e-6): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + if max_seq_len < 512: + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=False) + ) + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=False) + ) + if max_seq_len >= 512: + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2, + trainable_cylind_lens=False) + ) + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2, + trainable_cylind_lens=False) + ) + + self.layers = nn.ModuleList([ + TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads, + dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) + for _ in range(layers_num)]) + self.tok_embeds = nn.Embedding(vocab_size, h_dim) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + self.lm_head = nn.Linear(h_dim, vocab_size) + + def forward(self, x, targets=None): + x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + logits = self.lm_head(x) # B,T,C + loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None + return logits, loss + + # what is the purpose? autoregressive inference! + def generate(self, start_idx, max_new_tokens): + idx = start_idx + for i in range(max_new_tokens): + idx_cond = idx[:,-self.max_seq_len:] + logits, loss = self(idx_cond) + logits = logits[:,-1,:] # B, C + probs = F.softmax(logits, dim=-1) + idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) + idx = torch.cat([idx, idx_next], dim=1) + return idx \ No newline at end of file diff --git a/src/optics_char_gpt2_nokoef_newf.py b/src/optics_char_gpt2_nokoef_newf.py new file mode 100644 index 0000000..f865c18 --- /dev/null +++ b/src/optics_char_gpt2_nokoef_newf.py @@ -0,0 +1,220 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +from einops import rearrange +import optical_matrix_multiplication as omm +from optical_matrix_multiplication import propagator +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime +from pathlib import Path +import sys +torch.manual_seed(1337) + +#################################### Model ######################################### + +def new_formula(sim, tensor_1, tensor_2): + tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1 + tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2 + device = tensor_1.device + + A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0) + A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0) + B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0) + B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0) + + max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений + max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений + max_B_pos = torch.max(B_pos) + max_B_neg = torch.max(B_neg) + + zero_template = torch.zeros_like( + torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3])) + + if max_A_pos > 0 and max_B_pos > 0: + t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos + else: + t1 = zero_template.clone().to(device) + + if max_A_pos > 0 and max_B_neg > 0: + t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg + else: + t2 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_pos > 0: + t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos + else: + t3 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_neg > 0: + t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg + else: + t4 = zero_template.clone().to(device) + + return (t1 - t2 - t3 + t4)[0,:,:,:] + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = nn.Linear(h_dim, 4*h_dim) + self.ff2 = nn.Linear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5) + tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) + scores = scores.masked_fill(tril == 0, float('-inf')) + attention = nn.functional.softmax(scores, dim=2) + output = new_formula(self.sim_output, attention, v) + return self.o_proj(self.gather_heads(output, *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +class OpticGPT2NOKoefNewF(nn.Module): + def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, + pixel_size = 3.6e-6): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + if max_seq_len < 512: + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=False) + ) + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=False) + ) + if max_seq_len >= 512: + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2, + trainable_cylind_lens=False) + ) + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2, + trainable_cylind_lens=False) + ) + + self.layers = nn.ModuleList([ + TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads, + dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) + for _ in range(layers_num)]) + self.tok_embeds = nn.Embedding(vocab_size, h_dim) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + self.lm_head = nn.Linear(h_dim, vocab_size) + + def forward(self, x, targets=None): + x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + logits = self.lm_head(x) # B,T,C + loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None + return logits, loss + + # what is the purpose? autoregressive inference! + def generate(self, start_idx, max_new_tokens): + idx = start_idx + for i in range(max_new_tokens): + idx_cond = idx[:,-self.max_seq_len:] + logits, loss = self(idx_cond) + logits = logits[:,-1,:] # B, C + probs = F.softmax(logits, dim=-1) + idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) + idx = torch.cat([idx, idx_next], dim=1) + return idx \ No newline at end of file diff --git a/src/optics_char_gpt2_traindiag.py b/src/optics_char_gpt2_traindiag.py index 2cdfc81..26de9d1 100644 --- a/src/optics_char_gpt2_traindiag.py +++ b/src/optics_char_gpt2_traindiag.py @@ -77,7 +77,7 @@ class DyT(nn.Module): # Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 # NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 class TransformerLayer(nn.Module): - def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128, pixel_size=3.6e-6): super().__init__() self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) self.q_proj = nn.Linear(h_dim, h_dim) @@ -91,6 +91,66 @@ class TransformerLayer(nn.Module): self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) self.k1 = nn.Parameter(torch.randn(1)) self.k2 = nn.Parameter(torch.randn(1)) + if max_seq_len < 512: + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=True) + ) + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=True) + ) + if max_seq_len >= 512: + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2) + ) + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2) + ) + self.sim_scores = omm.DataParallel(self.sim_scores) + self.sim_output = omm.DataParallel(self.sim_output) def split_to_heads(self, x, B, T, H): if self.num_heads <= 1: return x @@ -121,38 +181,9 @@ class OpticGPT2TrainDiag(nn.Module): pixel_size = 3.6e-6): super().__init__() self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) - self.sim_scores = omm.OpticalMul( - omm.Config(right_matrix_count_columns = max_seq_len, - right_matrix_count_rows = h_dim // num_heads, - right_matrix_width = pixel_size * max_seq_len, - right_matrix_height = pixel_size * (h_dim // num_heads), - min_height_gap = pixel_size, - right_matrix_split_x = 2, - right_matrix_split_y = 2, - left_matrix_split_x = 2, - left_matrix_split_y = 2, - result_matrix_split = 2, - distance = 0.01, - trainable_cylind_lens=True) - ) - - self.sim_output = omm.OpticalMul( - omm.Config(right_matrix_count_columns = h_dim // num_heads, - right_matrix_count_rows = max_seq_len, - right_matrix_width = pixel_size * (h_dim // num_heads), - right_matrix_height = pixel_size * max_seq_len, - min_height_gap = pixel_size, - right_matrix_split_x = 2, - right_matrix_split_y = 2, - left_matrix_split_x = 2, - left_matrix_split_y = 2, - result_matrix_split = 2, - distance = 0.01, - trainable_cylind_lens=True) - ) self.layers = nn.ModuleList([ - TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads, - dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) + TransformerLayer(h_dim=self.h_dim, num_heads=self.num_heads, + dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len, pixel_size=pixel_size) for _ in range(layers_num)]) self.tok_embeds = nn.Embedding(vocab_size, h_dim) self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) From 51147b36b301a9293e792bcca155693915cfadf2 Mon Sep 17 00:00:00 2001 From: Vladimir Protsenko Date: Sun, 18 Jan 2026 19:22:09 +0000 Subject: [PATCH 3/9] parallel dim is not 0, but 1. diagonal nontrainable lens bug fixed. --- src/main.py | 6 ++--- .../optical_mul.py | 2 +- src/optical_matrix_multiplication/parallel.py | 15 +++++------- .../propagator.py | 23 +++++++++++-------- src/optics_char_gpt2_nokoef_newf.py | 2 ++ 5 files changed, 25 insertions(+), 23 deletions(-) diff --git a/src/main.py b/src/main.py index 52fd104..d6a06c4 100644 --- a/src/main.py +++ b/src/main.py @@ -30,7 +30,7 @@ models = { } batch_size = 50 -gradient_accumulation_steps = 2 # check this impl for correctness https://unsloth.ai/blog/gradient +gradient_accumulation_steps = 5 # check this impl for correctness https://unsloth.ai/blog/gradient max_iters = 40000 eval_interval = 300 learning_rate = 1e-3 @@ -38,7 +38,7 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu' eval_iters = 200 layers_num = 2 h_dim = 64 -max_seq_len = 256 +max_seq_len = 512 num_heads = 1 dropout_rate = 0.1 pixel_size = 3.6e-6 @@ -149,7 +149,7 @@ for i in range(max_iters): m.eval() ppl = perplexity(model=m, data=val_data) writer.add_scalar('val_perplexity', ppl.item(), i) - print(f"\rPerplexity at {i}: {ppl}") + print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") writer.add_text('completions', complete(m, encode("\n"*max_seq_len), 2*max_seq_len), i) m.eval() diff --git a/src/optical_matrix_multiplication/optical_mul.py b/src/optical_matrix_multiplication/optical_mul.py index 473cd74..11f52dc 100644 --- a/src/optical_matrix_multiplication/optical_mul.py +++ b/src/optical_matrix_multiplication/optical_mul.py @@ -114,12 +114,12 @@ class OpticalMul(_nn.Module): """ vec_field = self.prepare_vector(input) mat_field = self.prepare_matrix(other) - if self.trainable_cylind_lens: vec_field = self._propagator_one(vec_field) vec_field = self._propagator_between(vec_field) else: vec_field = self._propagator_one(vec_field) + vec_field = self._propagator_two(vec_field * mat_field) return self.prepare_out(vec_field) \ No newline at end of file diff --git a/src/optical_matrix_multiplication/parallel.py b/src/optical_matrix_multiplication/parallel.py index b707415..7621c6f 100644 --- a/src/optical_matrix_multiplication/parallel.py +++ b/src/optical_matrix_multiplication/parallel.py @@ -129,34 +129,31 @@ class ScatterDataParallel(_nn.Module): Оптимизированный forward для attention матриц. Особенности: - - Scatter по batch dimension (0) вместо произвольного dim + - Scatter по batch dimension (1) вместо произвольного dim - Оба тензора scatter'ятся для согласованности размерностей - - Поддержка многомерных attention тензоров [batch, heads, seq, dim] + - Поддержка многомерных attention тензоров [batch, heads, seq, dim] ?? ''' # Определяем dimension для scatter на основе структуры тензоров if input.dim() >= 3 and other.dim() >= 3: # Для attention матриц scatter по batch dimension - scatter_dim = 0 + scatter_dim = 1 else: # Для обычных 2D матриц используем dim из kwargs или по умолчанию 2 scatter_dim = kwargs.get('dim', 2) - + # Подготовка модуля и данных self.module = self.module.to(self.devices[0]) + replicas = _nn.parallel.replicate(self.module, self.devices) # Scatter ОБОИХ тензоров для согласованности размерностей scattered_input = _nn.parallel.scatter(input, self.devices, scatter_dim) scattered_other = _nn.parallel.scatter(other, self.devices, scatter_dim) - - # Создаем реплики модуля - replicas = _nn.parallel.replicate(self.module, self.devices) - + # Формируем входные данные для каждого устройства # Убедимся, что все списки одинаковой длины min_len = min(len(scattered_input), len(scattered_other), len(replicas)) stacked_input = [(scattered_input[i], scattered_other[i]) for i in range(min_len)] - # Параллельное вычисление outputs = _nn.parallel.parallel_apply(replicas[:min_len], stacked_input) diff --git a/src/optical_matrix_multiplication/propagator.py b/src/optical_matrix_multiplication/propagator.py index 68ee2c4..eaa061f 100644 --- a/src/optical_matrix_multiplication/propagator.py +++ b/src/optical_matrix_multiplication/propagator.py @@ -16,7 +16,7 @@ class Propagator(_ABC, _nn.Module): operator_X: оператор отображающий распроcтранение светового поля вдоль оси абсцисс operator_Y: оператор отображающий распроcтранение светового поля вдоль оси ординат """ - def __init__(self, operator_X: _torch.Tensor, operator_Y: _torch.Tensor, trainable = False, diagonal = False): + def __init__(self, operator_X: _torch.Tensor, operator_Y: _torch.Tensor, trainable = False): super(Propagator, self).__init__() operator_X: _torch.Tensor = _torch.view_as_real(operator_X) operator_Y: _torch.Tensor = _torch.view_as_real(operator_Y) @@ -24,12 +24,10 @@ class Propagator(_ABC, _nn.Module): self._operator_X = _nn.Parameter(operator_X) self._operator_Y = _nn.Parameter(operator_Y) self._trainable = trainable - self._diagonal = diagonal else: self.register_buffer('_operator_X', operator_X, persistent=True) self.register_buffer('_operator_Y', operator_Y, persistent=True) self._trainable = trainable - self._diagonal = diagonal @property def operator_X(self) -> _torch.Tensor: @@ -111,13 +109,13 @@ class Propagator(_ABC, _nn.Module): Распределение комплексной амплитуды светового поля, после распространения. """ - if self._diagonal: + if self._trainable: return _torch.diag_embed(self.operator_Y) @ field @ _torch.diag_embed(self.operator_X) else: return self.operator_Y @ field @ self.operator_X def __repr__(self): - return f"Diag: {self._diagonal} Trainable: {self._trainable} Y shape: {self.operator_Y.shape}, X shape: {self.operator_X.shape}" + return f"Trainable: {self._trainable} Y shape: {self.operator_Y.shape}, X shape: {self.operator_X.shape}" class PropagatorLens(Propagator): """ @@ -167,7 +165,8 @@ class PropagatorСylindLens(PropagatorLens): представленной тонким оптическим элементом. """ def __init__(self, plane: _ConfigDesignPlane, - config: _ConfigOpticBase, trainable = False): + config: _ConfigOpticBase, + trainable = False): """ Конструктор класса цилиндрической линзы. @@ -177,10 +176,14 @@ class PropagatorСylindLens(PropagatorLens): """ operator_X = _torch.exp(-1j * config.K / config.distance * plane.linspace_by_x**2) operator_Y = _torch.ones_like(plane.linspace_by_y, dtype=_torch.cfloat) - super(PropagatorСylindLens, self).__init__(operator_X, - operator_Y, - trainable, - diagonal=True) + if trainable: + super(PropagatorСylindLens, self).__init__(operator_X, + operator_Y, + trainable) + else: + super(PropagatorСylindLens, self).__init__(_torch.diag_embed(operator_X), + _torch.diag_embed(operator_Y), + trainable) class PropagatorSinc(Propagator): """ diff --git a/src/optics_char_gpt2_nokoef_newf.py b/src/optics_char_gpt2_nokoef_newf.py index f865c18..d99f149 100644 --- a/src/optics_char_gpt2_nokoef_newf.py +++ b/src/optics_char_gpt2_nokoef_newf.py @@ -190,6 +190,8 @@ class OpticGPT2NOKoefNewF(nn.Module): lens_size = 8192 * 2, trainable_cylind_lens=False) ) + self.sim_scores = omm.ScatterDataParallel(self.sim_scores) + self.sim_output = omm.ScatterDataParallel(self.sim_output) self.layers = nn.ModuleList([ TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads, From 58b3271cc8920a6571d5cc59f55b6ca73dd8c8ab Mon Sep 17 00:00:00 2001 From: Vladimir Protsenko Date: Sun, 8 Feb 2026 20:59:51 +0000 Subject: [PATCH 4/9] Trainable lens experiments. Refactor perplexity. Bert experiments. --- .gitignore | 2 + src/bert_optica_koef.py | 475 +++++++++++++++ src/bert_optica_koef_newf.py | 483 +++++++++++++++ src/bert_optica_nokoef.py | 473 +++++++++++++++ src/bert_optica_nokoef_newf.py | 483 +++++++++++++++ src/optical_matrix_multiplication/__init__.py | 8 +- .../optical_mul.py | 559 +++++++++++++++++- .../propagator.py | 206 +++++-- src/optics_char_gpt2_nokoef.py | 2 + src/train_gpt2.py | 341 +++++++++++ ...ain_optics_trainable_focal_dist_lens_64.py | 399 +++++++++++++ src/train_optics_trainable_lens_128.py | 464 +++++++++++++++ src/train_optics_trainable_lens_256.py | 464 +++++++++++++++ src/train_optics_trainable_lens_512.py | 464 +++++++++++++++ src/train_optics_trainable_lens_64.py | 464 +++++++++++++++ 15 files changed, 5237 insertions(+), 50 deletions(-) create mode 100644 src/bert_optica_koef.py create mode 100644 src/bert_optica_koef_newf.py create mode 100644 src/bert_optica_nokoef.py create mode 100644 src/bert_optica_nokoef_newf.py create mode 100644 src/train_gpt2.py create mode 100644 src/train_optics_trainable_focal_dist_lens_64.py create mode 100644 src/train_optics_trainable_lens_128.py create mode 100644 src/train_optics_trainable_lens_256.py create mode 100644 src/train_optics_trainable_lens_512.py create mode 100644 src/train_optics_trainable_lens_64.py diff --git a/.gitignore b/.gitignore index e15106e..d2133bf 100644 --- a/.gitignore +++ b/.gitignore @@ -214,3 +214,5 @@ __marimo__/ # Streamlit .streamlit/secrets.toml + +checkpoints/ \ No newline at end of file diff --git a/src/bert_optica_koef.py b/src/bert_optica_koef.py new file mode 100644 index 0000000..634c671 --- /dev/null +++ b/src/bert_optica_koef.py @@ -0,0 +1,475 @@ +import os +import sys +#os.environ['CUDA_VISIBLE_DEVICES'] = f"[0,1,2,3,4,5,6,7]" # f"{sys.argv[1]}" + +from torch import nn +import torch +import pandas as pd +import numpy as np +from matplotlib import pyplot as plt +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime, timedelta +from torchmetrics import AUROC +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import Dataset, DataLoader +from pathlib import Path +import schedulefree +from einops import rearrange, repeat +import torch.nn.functional as F +from torch import autograd +import optical_matrix_multiplication as omm + +step = 1 +pixel_size: float = 3.6e-6 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +print('Current device - ', device) + +def norm(matrix: torch.Tensor, max_val: float = 1) -> torch.Tensor: + return matrix / (max_val + 1e-10) + +def optics_matmul_shift(sim, tensor_1, tensor_2): + tensor_1 = tensor_1[None,:,:,:] + tensor_2 = tensor_2[None,:,:,:] + + if torch.min(tensor_1) >= 0 and torch.min(tensor_2) >= 0: + max_abs = abs(max(torch.max(tensor_1), torch.max(tensor_2))) + a, b = norm(tensor_1, max_abs), norm(tensor_2, max_abs) + return sim(a, b)[0,:,:,:] * max_abs **2 + + min_abs = abs(min(torch.min(tensor_1), torch.min(tensor_2))) + max_abs = abs(max(torch.max(tensor_1), torch.max(tensor_2))) + min_abs + + shift_a = min_abs * torch.ones(tensor_1.shape).to(device) + shift_b = min_abs * torch.ones(tensor_2.shape).to(device) + a_a_sh = tensor_1 + shift_a + b_b_sh = tensor_2 + shift_b + # (a_shifted - shift_a)(b_shifted - shift_b) = + # a_shifted*b_shifted - a_shifted*shift_b - b_shifted*shift_a + shift_a*shift_b + + a_a_sh_norm, b_b_sh_norm = norm(a_a_sh, max_abs), norm(b_b_sh, max_abs) + shift_a_norm, shift_b_norm = norm(shift_a, max_abs), norm(shift_b, max_abs) + + a_a_sh_b_b_sh = sim(a_a_sh_norm, b_b_sh_norm) + a_a_sh_b_sh = sim(a_a_sh_norm, shift_b_norm) + a_sh_b_b_sh = sim(shift_a_norm, b_b_sh_norm) + a_sh_b_sh = sim(shift_a_norm, shift_b_norm) + + return (a_a_sh_b_b_sh - a_a_sh_b_sh - a_sh_b_b_sh + a_sh_b_sh)[0,:,:,:] * max_abs ** 2 + +comment = Path(__file__).stem # sys.argv[2] +checkpoint_file = None +logs_dir = f'/wd/finbert_results/runs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +сhekpoints_dir = f'/wd/finbert_results/сhekpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +Path(сhekpoints_dir).mkdir(parents=True, exist_ok=True) +print("Logs dir:", logs_dir) +print("Chekpoints dir:", сhekpoints_dir) +writer = SummaryWriter(logs_dir) +Path(logs_dir + "bert_training.py").write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script + +def save_checkpoint(credit_dataset, encoder, model, optimizer, epoch, loss, rocauc, сhekpoints_dir="checkpoints/"): + checkpoint = { + 'encoder': { + 'state_dict': encoder.state_dict(), + **{k:v for k,v in encoder.__dict__.items() if k[0] != '_' and k != 'training'} + }, + 'model': { + 'state_dict': model.state_dict(), + **{k:v for k,v in model.__dict__.items() if k[0] != '_' and k != 'training'} + }, + 'epoch': epoch, + 'optimizer': { + 'state_dict': optimizer.state_dict(), + }, + 'loss': loss, + 'rocauc': rocauc, + 'train_uniq_client_ids_path': credit_dataset.train_uniq_client_ids_path, + 'test_uniq_client_ids_path': credit_dataset.test_uniq_client_ids_path + } + path = сhekpoints_dir + f"epoch_{epoch}_{loss:.5f}.pth" + torch.save(checkpoint, path) + print(f"\nCheckpoint saved to {path}") + +def load_checkpoint(checkpoint_file): + if os.path.exists(checkpoint_file): + checkpoint = torch.load(checkpoint_file) + + #optimizer.load_state_dict(checkpoint['optimizer']) + encoder.load_state_dict(checkpoint['encoder']['state_dict']) + model.load_state_dict(checkpoint['model']['state_dict']) + +class CreditProductsDataset: + def __init__(self, + features_path, targets_path, train_test_split_ratio=0.9, + train_uniq_client_ids_path=None, test_uniq_client_ids_path=None + ): + self.train_uniq_client_ids_path = train_uniq_client_ids_path + self.test_uniq_client_ids_path = test_uniq_client_ids_path + if Path(self.train_uniq_client_ids_path).exists(): + self.train_uniq_client_ids = pd.read_csv(self.train_uniq_client_ids_path).iloc[:,0].values + print("Loaded", self.train_uniq_client_ids_path) + else: + raise Exception(f"No {self.train_uniq_client_ids_path}") + if Path(self.test_uniq_client_ids_path).exists(): + self.test_uniq_client_ids = pd.read_csv(self.test_uniq_client_ids_path).iloc[:,0].values + print("Loaded", self.test_uniq_client_ids_path) + else: + raise Exception(f"No {self.test_uniq_client_ids_path}") + self.features_df = pd.read_parquet(features_path) + self.targets_df = pd.read_csv(targets_path) + self.uniq_client_ids = self.features_df.id.unique() + self.max_user_history = self.features_df.rn.max() + self.id_columns = ['id', 'rn'] + self.cat_columns = [ + 'pre_since_opened', 'pre_since_confirmed', 'pre_pterm', 'pre_fterm', 'pre_till_pclose', 'pre_till_fclose', + 'pre_loans_credit_limit', 'pre_loans_next_pay_summ', 'pre_loans_outstanding', 'pre_loans_total_overdue', + 'pre_loans_max_overdue_sum', 'pre_loans_credit_cost_rate', 'is_zero_loans5', 'is_zero_loans530', 'is_zero_loans3060', + 'is_zero_loans6090', 'is_zero_loans90', 'pre_util', 'pre_over2limit', 'pre_maxover2limit', 'is_zero_util', 'is_zero_over2limit', + 'is_zero_maxover2limit', 'enc_paym_0', 'enc_paym_1', 'enc_paym_2', 'enc_paym_3', 'enc_paym_4', 'enc_paym_5', 'enc_paym_6', 'enc_paym_7', + 'enc_paym_8', 'enc_paym_9', 'enc_paym_10', 'enc_paym_11', 'enc_paym_12', 'enc_paym_13', 'enc_paym_14', 'enc_paym_15', 'enc_paym_16', + 'enc_paym_17', 'enc_paym_18', 'enc_paym_19', 'enc_paym_20', 'enc_paym_21', 'enc_paym_22', 'enc_paym_23', 'enc_paym_24', + 'enc_loans_account_holder_type', 'enc_loans_credit_status', 'enc_loans_credit_type', 'enc_loans_account_cur', 'pclose_flag', + 'fclose_flag' + ] + self.num_columns = list(set(self.features_df.columns).difference(set(self.id_columns)).difference(self.cat_columns)) + # make unified category index for embeddings for all columns. zero index embedding for padding will be zeroed during training + self.cat_cardinalities = self.features_df.max(axis=0)[self.cat_columns] + 1 + self.cat_cardinalities_integral = self.cat_cardinalities.cumsum() + self.features_df[self.cat_columns[1:]] = self.features_df[self.cat_columns[1:]] + self.cat_cardinalities_integral[1:] + self.features_df[self.cat_columns] = self.features_df[self.cat_columns] + 1 # zero embedding is for padding + + self.features_df = self.features_df.sort_values(self.id_columns, ascending=[True, True]) + self.features_df = self.features_df.set_index('id') + self.targets_df = self.targets_df.set_index('id') + self.targets_df = self.targets_df.sort_index() + + self.user_seq_lengths = self.features_df.index.value_counts().sort_index() + + self.cat_features = torch.tensor(self.features_df[self.cat_columns].values, dtype=torch.int16) + self.cat_features = pad_sequence(torch.split(self.cat_features, self.user_seq_lengths.tolist()), batch_first=True) # implicit max seq + self.num_features = torch.tensor(self.features_df[self.num_columns].values, dtype=torch.float32) + self.num_features = pad_sequence(torch.split(self.num_features, self.user_seq_lengths.tolist()), batch_first=True) + self.targets = torch.tensor(self.targets_df.flag.values).type(torch.float32) + + def get_batch(self, batch_size=4): + sampled_ids = np.random.choice(self.train_uniq_client_ids, batch_size, replace=False) # think about replace=True + cat_features_batch = self.cat_features[sampled_ids] * torch.empty_like(self.cat_features[sampled_ids]).bernoulli_(0.9) # dropout features with prob 0.1 + num_features_batch = self.num_features[sampled_ids] * torch.empty_like(self.num_features[sampled_ids]).bernoulli_(0.9) # dropout features with prob 0.1 + targets_batch = self.targets[sampled_ids] + return cat_features_batch, num_features_batch, targets_batch + + def get_test_batch_iterator(self, batch_size=4): + for i in range(0, len(self.test_uniq_client_ids), batch_size): + ids = self.test_uniq_client_ids[i:i+batch_size] + cat_features_batch = self.cat_features[ids] + num_features_batch = self.num_features[ids] + targets_batch = self.targets[ids] + yield cat_features_batch, num_features_batch, targets_batch + +class Encoder(nn.Module): + def __init__(self, cat_columns, num_columns, cat_features_max_id, category_feature_dim=4, out_dim=64): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.total_h_dim = len(self.cat_columns) * category_feature_dim + len(self.num_columns) + self.cat_embeds = nn.Embedding(cat_features_max_id + 1, self.category_feature_dim, padding_idx=0) + self.num_scales = nn.Parameter(torch.randn(1, len(self.num_columns))) + self.num_shifts = nn.Parameter(torch.randn(1, len(self.num_columns))) + self.proj = nn.Linear(self.total_h_dim, self.out_dim, bias=False) + + @property + def device(self): + return next(self.parameters()).device + + def forward(self, cat_features_batch, num_features_batch, targets_batch): + cat_features_batch = cat_features_batch.to(self.device) + num_features_batch = num_features_batch.to(self.device) + + cat_embed_tensor = self.cat_embeds(cat_features_batch.type(torch.int32)) + cat_embed_tensor = cat_embed_tensor.reshape(cat_features_batch.shape[0], cat_features_batch.shape[1], -1) + num_embed_tensor = self.num_scales * num_features_batch + self.num_shifts + embed_tensor = torch.concat([cat_embed_tensor, num_embed_tensor], dim=-1) + + inputs = self.proj(embed_tensor) + targets = targets_batch.to(self.device) + + return inputs, targets + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = nn.Linear(h_dim, 4*h_dim) + self.ff2 = nn.Linear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + self.k1 = nn.Parameter(torch.randn(1)) + self.k2 = nn.Parameter(torch.randn(1)) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = self.k1 * optics_matmul_shift(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5) + attention = nn.functional.softmax(scores, dim=2) + output = self.k2 * optics_matmul_shift(self.sim_output, attention, v) + return self.o_proj(self.gather_heads(output, *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +# Vision Transformers Need Registers https://arxiv.org/html/2309.16588v2 +class BertClassifier(nn.Module): + def __init__(self, layers_num=1, h_dim=64, class_num=2, max_seq_len=128, num_heads=4, num_reg=0, dropout_rate = 0.1): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.cls_token = nn.Parameter(torch.randn(1,1+num_reg,h_dim)) # reg tokens can be added by second dim >1 + self.max_seq_len = max_seq_len + self.cls_token.shape[1] + print(h_dim, self.max_seq_len) + if max_seq_len < 512: + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = self.max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * self.max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=False) + ) + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = self.max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * self.max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=False) + ) + if max_seq_len >= 512: + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = self.max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * self.max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2, + trainable_cylind_lens=False) + ) + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = self.max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * self.max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2, + trainable_cylind_lens=False) + ) + self.layers = nn.ModuleList([TransformerLayer(h_dim=h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=num_heads, dropout_rate = dropout_rate, max_seq_len=self.max_seq_len) for _ in range(layers_num)]) + self.classifier_head = nn.Sequential(nn.Linear(h_dim, h_dim), nn.GELU(), nn.Linear(h_dim, class_num)) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + + def forward(self, x): + x = torch.concat([self.cls_token.expand([x.shape[0], self.cls_token.shape[1], self.cls_token.shape[2]]), x], dim=1) + x = x + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + x = self.classifier_head(x[:,0,:]) + return x[:,:] if self.class_num > 1 else x[:,0] + +start_prep_time = datetime.now() +credit_train_dataset = CreditProductsDataset( + features_path="/wd/finbert_data/train_data", + targets_path="/wd/finbert_data/train_target .csv", + train_uniq_client_ids_path="/wd/finbert_data/train_uniq_client_ids.csv", + test_uniq_client_ids_path="/wd/finbert_data/test_uniq_client_ids.csv", +) +print(f"Dataset preparation time: {datetime.now() - start_prep_time}") + +h_dim = 64 +category_feature_dim = 8 +layers_num = 2 +num_heads = 1 +class_num = 1 +dropout_rate = 0.1 + +encoder = Encoder( + cat_columns=credit_train_dataset.cat_columns, + num_columns=credit_train_dataset.num_columns, + cat_features_max_id=credit_train_dataset.cat_features.max(), + category_feature_dim=category_feature_dim, + out_dim=h_dim, +).to(device) + +model = BertClassifier( + layers_num=layers_num, + num_heads=num_heads, + h_dim=h_dim, + class_num=class_num, + max_seq_len=credit_train_dataset.max_user_history, + dropout_rate = dropout_rate +).to(device) + +print(f'Parameters model - {sum(p.numel() for p in model.parameters())}, parameters encoder - {sum(p.numel() for p in encoder.parameters())}') +optimizer = schedulefree.AdamWScheduleFree(list(model.parameters()) + list(encoder.parameters())) + +if checkpoint_file is not None: + load_checkpoint(checkpoint_file) + +positive_counts = credit_train_dataset.targets_df.loc[credit_train_dataset.train_uniq_client_ids].values.sum() +negative_counts = len(credit_train_dataset.targets_df.loc[credit_train_dataset.train_uniq_client_ids]) - positive_counts +pos_weight = negative_counts / (positive_counts + 1e-15) +print(f"Class imbalance: {negative_counts} {positive_counts}. Pos weight: {pos_weight}") +criterion = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight)) + +epochs = 50 +batch_size = 300 +batches_per_epoch = len(credit_train_dataset.uniq_client_ids) // batch_size + +# for parallel data selection +class WrapperDataset(Dataset): + def __init__(self, credit_dataset, encoder, batch_size): + self.credit_dataset = credit_dataset + self.encoder = encoder + self.batch_size = batch_size + + def __len__(self): + return len(self.credit_dataset.uniq_client_ids) // self.batch_size + + def __getitem__(self, idx): + cat_inputs, num_inputs, targets = credit_train_dataset.get_batch(batch_size=self.batch_size) + return cat_inputs, num_inputs, targets + +training_data = WrapperDataset(credit_train_dataset, encoder, batch_size=batch_size) +dataloader = DataLoader(training_data, batch_size=1, shuffle=False, num_workers=8, pin_memory=True) + +val_auroc = AUROC(task='binary') +test_auroc = AUROC(task='binary') + +def test(epoch): + model.eval() + encoder.eval() + optimizer.eval() + with torch.no_grad(): + test_iterator = credit_train_dataset.get_test_batch_iterator(batch_size=batch_size) + for test_batch_id, (test_cat_inputs, test_num_inputs, test_targets) in enumerate(test_iterator): + inputs, targets = encoder(test_cat_inputs, test_num_inputs, test_targets) + outputs = model(inputs) + test_auroc.update(outputs, targets.long()) + print(f"\r {test_batch_id}/{len(credit_train_dataset.test_uniq_client_ids)//batch_size} {test_auroc.compute().item():.4f}", end = " "*40) + + writer.add_scalar('test_roc_auc', test_auroc.compute().item(), epoch * batches_per_epoch) + print(f"\r {datetime.now() - start_time} {epoch}/{epochs} Test rocauc: {test_auroc.compute().item():.4f}", end = " "*40) + print() + +start_time = datetime.now() +print("Started at:", start_time) +last_display_time = start_time +last_checkpoint_time = start_time +for epoch in range(epochs): + test(epoch) + for batch_id, (cat_inputs, num_inputs, targets) in enumerate(dataloader): + # with autograd.detect_anomaly(True): + model.train() + encoder.train() + optimizer.train() + inputs, targets = encoder(cat_inputs[0], num_inputs[0], targets[0]) + outputs = model(inputs) + loss = criterion(outputs, targets) + loss.backward() + current_time = datetime.now() + if current_time - last_display_time > timedelta(seconds=1): + writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1e6).item(), epoch*batches_per_epoch+batch_id) + optimizer.step() + optimizer.zero_grad() + if current_time - last_display_time > timedelta(seconds=1): + model.eval() + encoder.eval() + optimizer.eval() + last_display_time = current_time + writer.add_scalar('Loss', loss.item(), epoch*batches_per_epoch+batch_id) + print(f"\r {current_time-start_time} {epoch+1}/{epochs} {batch_id}/{batches_per_epoch} loss: {loss.item()}", end = " "*40) + if current_time - last_checkpoint_time > timedelta(hours=4): + last_checkpoint_time = current_time + save_checkpoint( + credit_dataset=credit_train_dataset, + encoder = encoder, model=model, optimizer=optimizer, epoch=epoch, + loss=loss.item(), rocauc=test_auroc.compute().item(), сhekpoints_dir=сhekpoints_dir) + +test(epochs) +print() + +save_checkpoint( + credit_dataset=credit_train_dataset, + encoder = encoder, model=model, optimizer=optimizer, epoch=epochs, + loss=loss.item(), rocauc=test_auroc.compute().item(), сhekpoints_dir=сhekpoints_dir) + +writer.close() \ No newline at end of file diff --git a/src/bert_optica_koef_newf.py b/src/bert_optica_koef_newf.py new file mode 100644 index 0000000..9e5fb1d --- /dev/null +++ b/src/bert_optica_koef_newf.py @@ -0,0 +1,483 @@ +import os +import sys +#os.environ['CUDA_VISIBLE_DEVICES'] = f"[0,1,2,3,4,5,6,7]" # f"{sys.argv[1]}" + +from torch import nn +import torch +import pandas as pd +import numpy as np +from matplotlib import pyplot as plt +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime, timedelta +from torchmetrics import AUROC +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import Dataset, DataLoader +from pathlib import Path +import schedulefree +from einops import rearrange, repeat +import torch.nn.functional as F +from torch import autograd +import optical_matrix_multiplication as omm + +step = 1 +pixel_size: float = 3.6e-6 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +print('Current device - ', device) + +def new_formula(sim, tensor_1, tensor_2): + tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1 + tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2 + device = tensor_1.device + + A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0) + A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0) + B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0) + B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0) + + max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений + max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений + max_B_pos = torch.max(B_pos) + max_B_neg = torch.max(B_neg) + + zero_template = torch.zeros_like( + torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3])) + + if max_A_pos > 0 and max_B_pos > 0: + t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos + else: + t1 = zero_template.clone().to(device) + + if max_A_pos > 0 and max_B_neg > 0: + t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg + else: + t2 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_pos > 0: + t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos + else: + t3 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_neg > 0: + t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg + else: + t4 = zero_template.clone().to(device) + + return (t1 - t2 - t3 + t4)[0,:,:,:] + +comment = Path(__file__).stem # sys.argv[2] +checkpoint_file = None +logs_dir = f'/wd/finbert_results/runs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +сhekpoints_dir = f'/wd/finbert_results/сhekpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +Path(сhekpoints_dir).mkdir(parents=True, exist_ok=True) +print("Logs dir:", logs_dir) +print("Chekpoints dir:", сhekpoints_dir) +writer = SummaryWriter(logs_dir) +Path(logs_dir + "bert_training.py").write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script + +def save_checkpoint(credit_dataset, encoder, model, optimizer, epoch, loss, rocauc, сhekpoints_dir="checkpoints/"): + checkpoint = { + 'encoder': { + 'state_dict': encoder.state_dict(), + **{k:v for k,v in encoder.__dict__.items() if k[0] != '_' and k != 'training'} + }, + 'model': { + 'state_dict': model.state_dict(), + **{k:v for k,v in model.__dict__.items() if k[0] != '_' and k != 'training'} + }, + 'epoch': epoch, + 'optimizer': { + 'state_dict': optimizer.state_dict(), + }, + 'loss': loss, + 'rocauc': rocauc, + 'train_uniq_client_ids_path': credit_dataset.train_uniq_client_ids_path, + 'test_uniq_client_ids_path': credit_dataset.test_uniq_client_ids_path + } + path = сhekpoints_dir + f"epoch_{epoch}_{loss:.5f}.pth" + torch.save(checkpoint, path) + print(f"\nCheckpoint saved to {path}") + +def load_checkpoint(checkpoint_file): + if os.path.exists(checkpoint_file): + checkpoint = torch.load(checkpoint_file) + + #optimizer.load_state_dict(checkpoint['optimizer']) + encoder.load_state_dict(checkpoint['encoder']['state_dict']) + model.load_state_dict(checkpoint['model']['state_dict']) + +class CreditProductsDataset: + def __init__(self, + features_path, targets_path, train_test_split_ratio=0.9, + train_uniq_client_ids_path=None, test_uniq_client_ids_path=None + ): + self.train_uniq_client_ids_path = train_uniq_client_ids_path + self.test_uniq_client_ids_path = test_uniq_client_ids_path + if Path(self.train_uniq_client_ids_path).exists(): + self.train_uniq_client_ids = pd.read_csv(self.train_uniq_client_ids_path).iloc[:,0].values + print("Loaded", self.train_uniq_client_ids_path) + else: + raise Exception(f"No {self.train_uniq_client_ids_path}") + if Path(self.test_uniq_client_ids_path).exists(): + self.test_uniq_client_ids = pd.read_csv(self.test_uniq_client_ids_path).iloc[:,0].values + print("Loaded", self.test_uniq_client_ids_path) + else: + raise Exception(f"No {self.test_uniq_client_ids_path}") + self.features_df = pd.read_parquet(features_path) + self.targets_df = pd.read_csv(targets_path) + self.uniq_client_ids = self.features_df.id.unique() + self.max_user_history = self.features_df.rn.max() + self.id_columns = ['id', 'rn'] + self.cat_columns = [ + 'pre_since_opened', 'pre_since_confirmed', 'pre_pterm', 'pre_fterm', 'pre_till_pclose', 'pre_till_fclose', + 'pre_loans_credit_limit', 'pre_loans_next_pay_summ', 'pre_loans_outstanding', 'pre_loans_total_overdue', + 'pre_loans_max_overdue_sum', 'pre_loans_credit_cost_rate', 'is_zero_loans5', 'is_zero_loans530', 'is_zero_loans3060', + 'is_zero_loans6090', 'is_zero_loans90', 'pre_util', 'pre_over2limit', 'pre_maxover2limit', 'is_zero_util', 'is_zero_over2limit', + 'is_zero_maxover2limit', 'enc_paym_0', 'enc_paym_1', 'enc_paym_2', 'enc_paym_3', 'enc_paym_4', 'enc_paym_5', 'enc_paym_6', 'enc_paym_7', + 'enc_paym_8', 'enc_paym_9', 'enc_paym_10', 'enc_paym_11', 'enc_paym_12', 'enc_paym_13', 'enc_paym_14', 'enc_paym_15', 'enc_paym_16', + 'enc_paym_17', 'enc_paym_18', 'enc_paym_19', 'enc_paym_20', 'enc_paym_21', 'enc_paym_22', 'enc_paym_23', 'enc_paym_24', + 'enc_loans_account_holder_type', 'enc_loans_credit_status', 'enc_loans_credit_type', 'enc_loans_account_cur', 'pclose_flag', + 'fclose_flag' + ] + self.num_columns = list(set(self.features_df.columns).difference(set(self.id_columns)).difference(self.cat_columns)) + # make unified category index for embeddings for all columns. zero index embedding for padding will be zeroed during training + self.cat_cardinalities = self.features_df.max(axis=0)[self.cat_columns] + 1 + self.cat_cardinalities_integral = self.cat_cardinalities.cumsum() + self.features_df[self.cat_columns[1:]] = self.features_df[self.cat_columns[1:]] + self.cat_cardinalities_integral[1:] + self.features_df[self.cat_columns] = self.features_df[self.cat_columns] + 1 # zero embedding is for padding + + self.features_df = self.features_df.sort_values(self.id_columns, ascending=[True, True]) + self.features_df = self.features_df.set_index('id') + self.targets_df = self.targets_df.set_index('id') + self.targets_df = self.targets_df.sort_index() + + self.user_seq_lengths = self.features_df.index.value_counts().sort_index() + + self.cat_features = torch.tensor(self.features_df[self.cat_columns].values, dtype=torch.int16) + self.cat_features = pad_sequence(torch.split(self.cat_features, self.user_seq_lengths.tolist()), batch_first=True) # implicit max seq + self.num_features = torch.tensor(self.features_df[self.num_columns].values, dtype=torch.float32) + self.num_features = pad_sequence(torch.split(self.num_features, self.user_seq_lengths.tolist()), batch_first=True) + self.targets = torch.tensor(self.targets_df.flag.values).type(torch.float32) + + def get_batch(self, batch_size=4): + sampled_ids = np.random.choice(self.train_uniq_client_ids, batch_size, replace=False) # think about replace=True + cat_features_batch = self.cat_features[sampled_ids] * torch.empty_like(self.cat_features[sampled_ids]).bernoulli_(0.9) # dropout features with prob 0.1 + num_features_batch = self.num_features[sampled_ids] * torch.empty_like(self.num_features[sampled_ids]).bernoulli_(0.9) # dropout features with prob 0.1 + targets_batch = self.targets[sampled_ids] + return cat_features_batch, num_features_batch, targets_batch + + def get_test_batch_iterator(self, batch_size=4): + for i in range(0, len(self.test_uniq_client_ids), batch_size): + ids = self.test_uniq_client_ids[i:i+batch_size] + cat_features_batch = self.cat_features[ids] + num_features_batch = self.num_features[ids] + targets_batch = self.targets[ids] + yield cat_features_batch, num_features_batch, targets_batch + +class Encoder(nn.Module): + def __init__(self, cat_columns, num_columns, cat_features_max_id, category_feature_dim=4, out_dim=64): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.total_h_dim = len(self.cat_columns) * category_feature_dim + len(self.num_columns) + self.cat_embeds = nn.Embedding(cat_features_max_id + 1, self.category_feature_dim, padding_idx=0) + self.num_scales = nn.Parameter(torch.randn(1, len(self.num_columns))) + self.num_shifts = nn.Parameter(torch.randn(1, len(self.num_columns))) + self.proj = nn.Linear(self.total_h_dim, self.out_dim, bias=False) + + @property + def device(self): + return next(self.parameters()).device + + def forward(self, cat_features_batch, num_features_batch, targets_batch): + cat_features_batch = cat_features_batch.to(self.device) + num_features_batch = num_features_batch.to(self.device) + + cat_embed_tensor = self.cat_embeds(cat_features_batch.type(torch.int32)) + cat_embed_tensor = cat_embed_tensor.reshape(cat_features_batch.shape[0], cat_features_batch.shape[1], -1) + num_embed_tensor = self.num_scales * num_features_batch + self.num_shifts + embed_tensor = torch.concat([cat_embed_tensor, num_embed_tensor], dim=-1) + + inputs = self.proj(embed_tensor) + targets = targets_batch.to(self.device) + + return inputs, targets + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = nn.Linear(h_dim, 4*h_dim) + self.ff2 = nn.Linear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + self.k1 = nn.Parameter(torch.randn(1)) + self.k2 = nn.Parameter(torch.randn(1)) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = self.k1 * new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5) + attention = nn.functional.softmax(scores, dim=2) + output = self.k2 * new_formula(self.sim_output, attention, v) + return self.o_proj(self.gather_heads(output, *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +# Vision Transformers Need Registers https://arxiv.org/html/2309.16588v2 +class BertClassifier(nn.Module): + def __init__(self, layers_num=1, h_dim=64, class_num=2, max_seq_len=128, num_heads=4, num_reg=0, dropout_rate = 0.1): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.cls_token = nn.Parameter(torch.randn(1,1+num_reg,h_dim)) # reg tokens can be added by second dim >1 + self.max_seq_len = max_seq_len + self.cls_token.shape[1] + print(h_dim, self.max_seq_len) + if max_seq_len < 512: + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = self.max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * self.max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=False) + ) + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = self.max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * self.max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=False) + ) + if max_seq_len >= 512: + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = self.max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * self.max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2, + trainable_cylind_lens=False) + ) + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = self.max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * self.max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2, + trainable_cylind_lens=False) + ) + self.layers = nn.ModuleList([TransformerLayer(h_dim=h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=num_heads, dropout_rate = dropout_rate, max_seq_len=self.max_seq_len) for _ in range(layers_num)]) + self.classifier_head = nn.Sequential(nn.Linear(h_dim, h_dim), nn.GELU(), nn.Linear(h_dim, class_num)) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + + def forward(self, x): + x = torch.concat([self.cls_token.expand([x.shape[0], self.cls_token.shape[1], self.cls_token.shape[2]]), x], dim=1) + x = x + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + x = self.classifier_head(x[:,0,:]) + return x[:,:] if self.class_num > 1 else x[:,0] + +start_prep_time = datetime.now() +credit_train_dataset = CreditProductsDataset( + features_path="/wd/finbert_data/train_data", + targets_path="/wd/finbert_data/train_target .csv", + train_uniq_client_ids_path="/wd/finbert_data/train_uniq_client_ids.csv", + test_uniq_client_ids_path="/wd/finbert_data/test_uniq_client_ids.csv", +) +print(f"Dataset preparation time: {datetime.now() - start_prep_time}") + +h_dim = 64 +category_feature_dim = 8 +layers_num = 2 +num_heads = 1 +class_num = 1 +dropout_rate = 0.1 + +encoder = Encoder( + cat_columns=credit_train_dataset.cat_columns, + num_columns=credit_train_dataset.num_columns, + cat_features_max_id=credit_train_dataset.cat_features.max(), + category_feature_dim=category_feature_dim, + out_dim=h_dim, +).to(device) + +model = BertClassifier( + layers_num=layers_num, + num_heads=num_heads, + h_dim=h_dim, + class_num=class_num, + max_seq_len=credit_train_dataset.max_user_history, + dropout_rate = dropout_rate +).to(device) + +print(f'Parameters model - {sum(p.numel() for p in model.parameters())}, parameters encoder - {sum(p.numel() for p in encoder.parameters())}') +optimizer = schedulefree.AdamWScheduleFree(list(model.parameters()) + list(encoder.parameters())) + +if checkpoint_file is not None: + load_checkpoint(checkpoint_file) + +positive_counts = credit_train_dataset.targets_df.loc[credit_train_dataset.train_uniq_client_ids].values.sum() +negative_counts = len(credit_train_dataset.targets_df.loc[credit_train_dataset.train_uniq_client_ids]) - positive_counts +pos_weight = negative_counts / (positive_counts + 1e-15) +print(f"Class imbalance: {negative_counts} {positive_counts}. Pos weight: {pos_weight}") +criterion = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight)) + +epochs = 50 +batch_size = 300 +batches_per_epoch = len(credit_train_dataset.uniq_client_ids) // batch_size + +# for parallel data selection +class WrapperDataset(Dataset): + def __init__(self, credit_dataset, encoder, batch_size): + self.credit_dataset = credit_dataset + self.encoder = encoder + self.batch_size = batch_size + + def __len__(self): + return len(self.credit_dataset.uniq_client_ids) // self.batch_size + + def __getitem__(self, idx): + cat_inputs, num_inputs, targets = credit_train_dataset.get_batch(batch_size=self.batch_size) + return cat_inputs, num_inputs, targets + +training_data = WrapperDataset(credit_train_dataset, encoder, batch_size=batch_size) +dataloader = DataLoader(training_data, batch_size=1, shuffle=False, num_workers=8, pin_memory=True) + +val_auroc = AUROC(task='binary') +test_auroc = AUROC(task='binary') + +def test(epoch): + model.eval() + encoder.eval() + optimizer.eval() + with torch.no_grad(): + test_iterator = credit_train_dataset.get_test_batch_iterator(batch_size=batch_size) + for test_batch_id, (test_cat_inputs, test_num_inputs, test_targets) in enumerate(test_iterator): + inputs, targets = encoder(test_cat_inputs, test_num_inputs, test_targets) + outputs = model(inputs) + test_auroc.update(outputs, targets.long()) + print(f"\r {test_batch_id}/{len(credit_train_dataset.test_uniq_client_ids)//batch_size} {test_auroc.compute().item():.4f}", end = " "*40) + + writer.add_scalar('test_roc_auc', test_auroc.compute().item(), epoch * batches_per_epoch) + print(f"\r {datetime.now() - start_time} {epoch}/{epochs} Test rocauc: {test_auroc.compute().item():.4f}", end = " "*40) + print() + +start_time = datetime.now() +print("Started at:", start_time) +last_display_time = start_time +last_checkpoint_time = start_time +for epoch in range(epochs): + test(epoch) + for batch_id, (cat_inputs, num_inputs, targets) in enumerate(dataloader): + # with autograd.detect_anomaly(True): + model.train() + encoder.train() + optimizer.train() + inputs, targets = encoder(cat_inputs[0], num_inputs[0], targets[0]) + outputs = model(inputs) + loss = criterion(outputs, targets) + loss.backward() + current_time = datetime.now() + if current_time - last_display_time > timedelta(seconds=1): + writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1e6).item(), epoch*batches_per_epoch+batch_id) + optimizer.step() + optimizer.zero_grad() + if current_time - last_display_time > timedelta(seconds=1): + model.eval() + encoder.eval() + optimizer.eval() + last_display_time = current_time + writer.add_scalar('Loss', loss.item(), epoch*batches_per_epoch+batch_id) + print(f"\r {current_time-start_time} {epoch+1}/{epochs} {batch_id}/{batches_per_epoch} loss: {loss.item()}", end = " "*40) + if current_time - last_checkpoint_time > timedelta(hours=4): + last_checkpoint_time = current_time + save_checkpoint( + credit_dataset=credit_train_dataset, + encoder = encoder, model=model, optimizer=optimizer, epoch=epoch, + loss=loss.item(), rocauc=test_auroc.compute().item(), сhekpoints_dir=сhekpoints_dir) + +test(epochs) +print() + +save_checkpoint( + credit_dataset=credit_train_dataset, + encoder = encoder, model=model, optimizer=optimizer, epoch=epochs, + loss=loss.item(), rocauc=test_auroc.compute().item(), сhekpoints_dir=сhekpoints_dir) + +writer.close() \ No newline at end of file diff --git a/src/bert_optica_nokoef.py b/src/bert_optica_nokoef.py new file mode 100644 index 0000000..2c5b249 --- /dev/null +++ b/src/bert_optica_nokoef.py @@ -0,0 +1,473 @@ +import os +import sys +#os.environ['CUDA_VISIBLE_DEVICES'] = f"[0,1,2,3,4,5,6,7]" # f"{sys.argv[1]}" + +from torch import nn +import torch +import pandas as pd +import numpy as np +from matplotlib import pyplot as plt +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime, timedelta +from torchmetrics import AUROC +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import Dataset, DataLoader +from pathlib import Path +import schedulefree +from einops import rearrange, repeat +import torch.nn.functional as F +from torch import autograd +import optical_matrix_multiplication as omm + +step = 1 +pixel_size: float = 3.6e-6 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +print('Current device - ', device) + +def norm(matrix: torch.Tensor, max_val: float = 1) -> torch.Tensor: + return matrix / (max_val + 1e-10) + +def optics_matmul_shift(sim, tensor_1, tensor_2): + tensor_1 = tensor_1[None,:,:,:] + tensor_2 = tensor_2[None,:,:,:] + + if torch.min(tensor_1) >= 0 and torch.min(tensor_2) >= 0: + max_abs = abs(max(torch.max(tensor_1), torch.max(tensor_2))) + a, b = norm(tensor_1, max_abs), norm(tensor_2, max_abs) + return sim(a, b)[0,:,:,:] * max_abs **2 + + min_abs = abs(min(torch.min(tensor_1), torch.min(tensor_2))) + max_abs = abs(max(torch.max(tensor_1), torch.max(tensor_2))) + min_abs + + shift_a = min_abs * torch.ones(tensor_1.shape).to(device) + shift_b = min_abs * torch.ones(tensor_2.shape).to(device) + a_a_sh = tensor_1 + shift_a + b_b_sh = tensor_2 + shift_b + + a_a_sh_norm, b_b_sh_norm = norm(a_a_sh, max_abs), norm(b_b_sh, max_abs) + shift_a_norm, shift_b_norm = norm(shift_a, max_abs), norm(shift_b, max_abs) + + a_a_sh_b_b_sh = sim(a_a_sh_norm, b_b_sh_norm) + a_a_sh_b_sh = sim(a_a_sh_norm, shift_b_norm) + a_sh_b_b_sh = sim(shift_a_norm, b_b_sh_norm) + a_sh_b_sh = sim(shift_a_norm, shift_b_norm) + + return (a_a_sh_b_b_sh - a_a_sh_b_sh - a_sh_b_b_sh + a_sh_b_sh)[0,:,:,:] * max_abs ** 2 + +comment = Path(__file__).stem # sys.argv[2] +checkpoint_file = None +logs_dir = f'/wd/finbert_results/runs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +сhekpoints_dir = f'/wd/finbert_results/сhekpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +Path(сhekpoints_dir).mkdir(parents=True, exist_ok=True) +print("Logs dir:", logs_dir) +print("Chekpoints dir:", сhekpoints_dir) +writer = SummaryWriter(logs_dir) +Path(logs_dir + "bert_training.py").write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script + +def save_checkpoint(credit_dataset, encoder, model, optimizer, epoch, loss, rocauc, сhekpoints_dir="checkpoints/"): + checkpoint = { + 'encoder': { + 'state_dict': encoder.state_dict(), + **{k:v for k,v in encoder.__dict__.items() if k[0] != '_' and k != 'training'} + }, + 'model': { + 'state_dict': model.state_dict(), + **{k:v for k,v in model.__dict__.items() if k[0] != '_' and k != 'training'} + }, + 'epoch': epoch, + 'optimizer': { + 'state_dict': optimizer.state_dict(), + }, + 'loss': loss, + 'rocauc': rocauc, + 'train_uniq_client_ids_path': credit_dataset.train_uniq_client_ids_path, + 'test_uniq_client_ids_path': credit_dataset.test_uniq_client_ids_path + } + path = сhekpoints_dir + f"epoch_{epoch}_{loss:.5f}.pth" + torch.save(checkpoint, path) + print(f"\nCheckpoint saved to {path}") + +def load_checkpoint(checkpoint_file): + if os.path.exists(checkpoint_file): + checkpoint = torch.load(checkpoint_file) + + #optimizer.load_state_dict(checkpoint['optimizer']) + encoder.load_state_dict(checkpoint['encoder']['state_dict']) + model.load_state_dict(checkpoint['model']['state_dict']) + +class CreditProductsDataset: + def __init__(self, + features_path, targets_path, train_test_split_ratio=0.9, + train_uniq_client_ids_path=None, test_uniq_client_ids_path=None + ): + self.train_uniq_client_ids_path = train_uniq_client_ids_path + self.test_uniq_client_ids_path = test_uniq_client_ids_path + if Path(self.train_uniq_client_ids_path).exists(): + self.train_uniq_client_ids = pd.read_csv(self.train_uniq_client_ids_path).iloc[:,0].values + print("Loaded", self.train_uniq_client_ids_path) + else: + raise Exception(f"No {self.train_uniq_client_ids_path}") + if Path(self.test_uniq_client_ids_path).exists(): + self.test_uniq_client_ids = pd.read_csv(self.test_uniq_client_ids_path).iloc[:,0].values + print("Loaded", self.test_uniq_client_ids_path) + else: + raise Exception(f"No {self.test_uniq_client_ids_path}") + self.features_df = pd.read_parquet(features_path) + self.targets_df = pd.read_csv(targets_path) + self.uniq_client_ids = self.features_df.id.unique() + self.max_user_history = self.features_df.rn.max() + self.id_columns = ['id', 'rn'] + self.cat_columns = [ + 'pre_since_opened', 'pre_since_confirmed', 'pre_pterm', 'pre_fterm', 'pre_till_pclose', 'pre_till_fclose', + 'pre_loans_credit_limit', 'pre_loans_next_pay_summ', 'pre_loans_outstanding', 'pre_loans_total_overdue', + 'pre_loans_max_overdue_sum', 'pre_loans_credit_cost_rate', 'is_zero_loans5', 'is_zero_loans530', 'is_zero_loans3060', + 'is_zero_loans6090', 'is_zero_loans90', 'pre_util', 'pre_over2limit', 'pre_maxover2limit', 'is_zero_util', 'is_zero_over2limit', + 'is_zero_maxover2limit', 'enc_paym_0', 'enc_paym_1', 'enc_paym_2', 'enc_paym_3', 'enc_paym_4', 'enc_paym_5', 'enc_paym_6', 'enc_paym_7', + 'enc_paym_8', 'enc_paym_9', 'enc_paym_10', 'enc_paym_11', 'enc_paym_12', 'enc_paym_13', 'enc_paym_14', 'enc_paym_15', 'enc_paym_16', + 'enc_paym_17', 'enc_paym_18', 'enc_paym_19', 'enc_paym_20', 'enc_paym_21', 'enc_paym_22', 'enc_paym_23', 'enc_paym_24', + 'enc_loans_account_holder_type', 'enc_loans_credit_status', 'enc_loans_credit_type', 'enc_loans_account_cur', 'pclose_flag', + 'fclose_flag' + ] + self.num_columns = list(set(self.features_df.columns).difference(set(self.id_columns)).difference(self.cat_columns)) + # make unified category index for embeddings for all columns. zero index embedding for padding will be zeroed during training + self.cat_cardinalities = self.features_df.max(axis=0)[self.cat_columns] + 1 + self.cat_cardinalities_integral = self.cat_cardinalities.cumsum() + self.features_df[self.cat_columns[1:]] = self.features_df[self.cat_columns[1:]] + self.cat_cardinalities_integral[1:] + self.features_df[self.cat_columns] = self.features_df[self.cat_columns] + 1 # zero embedding is for padding + + self.features_df = self.features_df.sort_values(self.id_columns, ascending=[True, True]) + self.features_df = self.features_df.set_index('id') + self.targets_df = self.targets_df.set_index('id') + self.targets_df = self.targets_df.sort_index() + + self.user_seq_lengths = self.features_df.index.value_counts().sort_index() + + self.cat_features = torch.tensor(self.features_df[self.cat_columns].values, dtype=torch.int16) + self.cat_features = pad_sequence(torch.split(self.cat_features, self.user_seq_lengths.tolist()), batch_first=True) # implicit max seq + self.num_features = torch.tensor(self.features_df[self.num_columns].values, dtype=torch.float32) + self.num_features = pad_sequence(torch.split(self.num_features, self.user_seq_lengths.tolist()), batch_first=True) + self.targets = torch.tensor(self.targets_df.flag.values).type(torch.float32) + + def get_batch(self, batch_size=4): + sampled_ids = np.random.choice(self.train_uniq_client_ids, batch_size, replace=False) # think about replace=True + cat_features_batch = self.cat_features[sampled_ids] * torch.empty_like(self.cat_features[sampled_ids]).bernoulli_(0.9) # dropout features with prob 0.1 + num_features_batch = self.num_features[sampled_ids] * torch.empty_like(self.num_features[sampled_ids]).bernoulli_(0.9) # dropout features with prob 0.1 + targets_batch = self.targets[sampled_ids] + return cat_features_batch, num_features_batch, targets_batch + + def get_test_batch_iterator(self, batch_size=4): + for i in range(0, len(self.test_uniq_client_ids), batch_size): + ids = self.test_uniq_client_ids[i:i+batch_size] + cat_features_batch = self.cat_features[ids] + num_features_batch = self.num_features[ids] + targets_batch = self.targets[ids] + yield cat_features_batch, num_features_batch, targets_batch + +class Encoder(nn.Module): + def __init__(self, cat_columns, num_columns, cat_features_max_id, category_feature_dim=4, out_dim=64): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.total_h_dim = len(self.cat_columns) * category_feature_dim + len(self.num_columns) + self.cat_embeds = nn.Embedding(cat_features_max_id + 1, self.category_feature_dim, padding_idx=0) + self.num_scales = nn.Parameter(torch.randn(1, len(self.num_columns))) + self.num_shifts = nn.Parameter(torch.randn(1, len(self.num_columns))) + self.proj = nn.Linear(self.total_h_dim, self.out_dim, bias=False) + + @property + def device(self): + return next(self.parameters()).device + + def forward(self, cat_features_batch, num_features_batch, targets_batch): + cat_features_batch = cat_features_batch.to(self.device) + num_features_batch = num_features_batch.to(self.device) + + cat_embed_tensor = self.cat_embeds(cat_features_batch.type(torch.int32)) + cat_embed_tensor = cat_embed_tensor.reshape(cat_features_batch.shape[0], cat_features_batch.shape[1], -1) + num_embed_tensor = self.num_scales * num_features_batch + self.num_shifts + embed_tensor = torch.concat([cat_embed_tensor, num_embed_tensor], dim=-1) + + inputs = self.proj(embed_tensor) + targets = targets_batch.to(self.device) + + return inputs, targets + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = nn.Linear(h_dim, 4*h_dim) + self.ff2 = nn.Linear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + # self.k1 = nn.Parameter(torch.randn(1)) + # self.k2 = nn.Parameter(torch.randn(1)) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = optics_matmul_shift(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5) + attention = nn.functional.softmax(scores, dim=2) + output = optics_matmul_shift(self.sim_output, attention, v) + return self.o_proj(self.gather_heads(output, *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +# Vision Transformers Need Registers https://arxiv.org/html/2309.16588v2 +class BertClassifier(nn.Module): + def __init__(self, layers_num=1, h_dim=64, class_num=2, max_seq_len=128, num_heads=4, num_reg=0, dropout_rate = 0.1): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.cls_token = nn.Parameter(torch.randn(1,1+num_reg,h_dim)) # reg tokens can be added by second dim >1 + self.max_seq_len = max_seq_len + self.cls_token.shape[1] + print(h_dim, self.max_seq_len) + if max_seq_len < 512: + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = self.max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * self.max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=False) + ) + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = self.max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * self.max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=False) + ) + if max_seq_len >= 512: + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = self.max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * self.max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2, + trainable_cylind_lens=False) + ) + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = self.max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * self.max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2, + trainable_cylind_lens=False) + ) + self.layers = nn.ModuleList([TransformerLayer(h_dim=h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=num_heads, dropout_rate = dropout_rate, max_seq_len=self.max_seq_len) for _ in range(layers_num)]) + self.classifier_head = nn.Sequential(nn.Linear(h_dim, h_dim), nn.GELU(), nn.Linear(h_dim, class_num)) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + + def forward(self, x): + x = torch.concat([self.cls_token.expand([x.shape[0], self.cls_token.shape[1], self.cls_token.shape[2]]), x], dim=1) + x = x + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + x = self.classifier_head(x[:,0,:]) + return x[:,:] if self.class_num > 1 else x[:,0] + +start_prep_time = datetime.now() +credit_train_dataset = CreditProductsDataset( + features_path="/wd/finbert_data/train_data", + targets_path="/wd/finbert_data/train_target .csv", + train_uniq_client_ids_path="/wd/finbert_data/train_uniq_client_ids.csv", + test_uniq_client_ids_path="/wd/finbert_data/test_uniq_client_ids.csv", +) +print(f"Dataset preparation time: {datetime.now() - start_prep_time}") + +h_dim = 64 +category_feature_dim = 8 +layers_num = 2 +num_heads = 1 +class_num = 1 +dropout_rate = 0.1 + +encoder = Encoder( + cat_columns=credit_train_dataset.cat_columns, + num_columns=credit_train_dataset.num_columns, + cat_features_max_id=credit_train_dataset.cat_features.max(), + category_feature_dim=category_feature_dim, + out_dim=h_dim, +).to(device) + +model = BertClassifier( + layers_num=layers_num, + num_heads=num_heads, + h_dim=h_dim, + class_num=class_num, + max_seq_len=credit_train_dataset.max_user_history, + dropout_rate = dropout_rate +).to(device) + +print(f'Parameters model - {sum(p.numel() for p in model.parameters())}, parameters encoder - {sum(p.numel() for p in encoder.parameters())}') +optimizer = schedulefree.AdamWScheduleFree(list(model.parameters()) + list(encoder.parameters())) + +if checkpoint_file is not None: + load_checkpoint(checkpoint_file) + +positive_counts = credit_train_dataset.targets_df.loc[credit_train_dataset.train_uniq_client_ids].values.sum() +negative_counts = len(credit_train_dataset.targets_df.loc[credit_train_dataset.train_uniq_client_ids]) - positive_counts +pos_weight = negative_counts / (positive_counts + 1e-15) +print(f"Class imbalance: {negative_counts} {positive_counts}. Pos weight: {pos_weight}") +criterion = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight)) + +epochs = 50 +batch_size = 300 +batches_per_epoch = len(credit_train_dataset.uniq_client_ids) // batch_size + +# for parallel data selection +class WrapperDataset(Dataset): + def __init__(self, credit_dataset, encoder, batch_size): + self.credit_dataset = credit_dataset + self.encoder = encoder + self.batch_size = batch_size + + def __len__(self): + return len(self.credit_dataset.uniq_client_ids) // self.batch_size + + def __getitem__(self, idx): + cat_inputs, num_inputs, targets = credit_train_dataset.get_batch(batch_size=self.batch_size) + return cat_inputs, num_inputs, targets + +training_data = WrapperDataset(credit_train_dataset, encoder, batch_size=batch_size) +dataloader = DataLoader(training_data, batch_size=1, shuffle=False, num_workers=8, pin_memory=True) + +val_auroc = AUROC(task='binary') +test_auroc = AUROC(task='binary') + +def test(epoch): + model.eval() + encoder.eval() + optimizer.eval() + with torch.no_grad(): + test_iterator = credit_train_dataset.get_test_batch_iterator(batch_size=batch_size) + for test_batch_id, (test_cat_inputs, test_num_inputs, test_targets) in enumerate(test_iterator): + inputs, targets = encoder(test_cat_inputs, test_num_inputs, test_targets) + outputs = model(inputs) + test_auroc.update(outputs, targets.long()) + print(f"\r {test_batch_id}/{len(credit_train_dataset.test_uniq_client_ids)//batch_size} {test_auroc.compute().item():.4f}", end = " "*40) + + writer.add_scalar('test_roc_auc', test_auroc.compute().item(), epoch * batches_per_epoch) + print(f"\r {datetime.now() - start_time} {epoch}/{epochs} Test rocauc: {test_auroc.compute().item():.4f}", end = " "*40) + print() + +start_time = datetime.now() +print("Started at:", start_time) +last_display_time = start_time +last_checkpoint_time = start_time +for epoch in range(epochs): + test(epoch) + for batch_id, (cat_inputs, num_inputs, targets) in enumerate(dataloader): + # with autograd.detect_anomaly(True): + model.train() + encoder.train() + optimizer.train() + inputs, targets = encoder(cat_inputs[0], num_inputs[0], targets[0]) + outputs = model(inputs) + loss = criterion(outputs, targets) + loss.backward() + current_time = datetime.now() + if current_time - last_display_time > timedelta(seconds=1): + writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1e6).item(), epoch*batches_per_epoch+batch_id) + optimizer.step() + optimizer.zero_grad() + if current_time - last_display_time > timedelta(seconds=1): + model.eval() + encoder.eval() + optimizer.eval() + last_display_time = current_time + writer.add_scalar('Loss', loss.item(), epoch*batches_per_epoch+batch_id) + print(f"\r {current_time-start_time} {epoch+1}/{epochs} {batch_id}/{batches_per_epoch} loss: {loss.item()}", end = " "*40) + if current_time - last_checkpoint_time > timedelta(hours=4): + last_checkpoint_time = current_time + save_checkpoint( + credit_dataset=credit_train_dataset, + encoder = encoder, model=model, optimizer=optimizer, epoch=epoch, + loss=loss.item(), rocauc=test_auroc.compute().item(), сhekpoints_dir=сhekpoints_dir) + +test(epochs) +print() + +save_checkpoint( + credit_dataset=credit_train_dataset, + encoder = encoder, model=model, optimizer=optimizer, epoch=epochs, + loss=loss.item(), rocauc=test_auroc.compute().item(), сhekpoints_dir=сhekpoints_dir) + +writer.close() \ No newline at end of file diff --git a/src/bert_optica_nokoef_newf.py b/src/bert_optica_nokoef_newf.py new file mode 100644 index 0000000..dcfede5 --- /dev/null +++ b/src/bert_optica_nokoef_newf.py @@ -0,0 +1,483 @@ +import os +import sys +#os.environ['CUDA_VISIBLE_DEVICES'] = f"[0,1,2,3,4,5,6,7]" # f"{sys.argv[1]}" + +from torch import nn +import torch +import pandas as pd +import numpy as np +from matplotlib import pyplot as plt +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime, timedelta +from torchmetrics import AUROC +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import Dataset, DataLoader +from pathlib import Path +import schedulefree +from einops import rearrange, repeat +import torch.nn.functional as F +from torch import autograd +import optical_matrix_multiplication as omm + +step = 1 +pixel_size: float = 3.6e-6 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +print('Current device - ', device) + +def new_formula(sim, tensor_1, tensor_2): + tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1 + tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2 + device = tensor_1.device + + A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0) + A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0) + B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0) + B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0) + + max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений + max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений + max_B_pos = torch.max(B_pos) + max_B_neg = torch.max(B_neg) + + zero_template = torch.zeros_like( + torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3])) + + if max_A_pos > 0 and max_B_pos > 0: + t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos + else: + t1 = zero_template.clone().to(device) + + if max_A_pos > 0 and max_B_neg > 0: + t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg + else: + t2 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_pos > 0: + t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos + else: + t3 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_neg > 0: + t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg + else: + t4 = zero_template.clone().to(device) + + return (t1 - t2 - t3 + t4)[0,:,:,:] + +comment = Path(__file__).stem # sys.argv[2] +checkpoint_file = None +logs_dir = f'/wd/finbert_results/runs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +сhekpoints_dir = f'/wd/finbert_results/сhekpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +Path(сhekpoints_dir).mkdir(parents=True, exist_ok=True) +print("Logs dir:", logs_dir) +print("Chekpoints dir:", сhekpoints_dir) +writer = SummaryWriter(logs_dir) +Path(logs_dir + "bert_training.py").write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script + +def save_checkpoint(credit_dataset, encoder, model, optimizer, epoch, loss, rocauc, сhekpoints_dir="checkpoints/"): + checkpoint = { + 'encoder': { + 'state_dict': encoder.state_dict(), + **{k:v for k,v in encoder.__dict__.items() if k[0] != '_' and k != 'training'} + }, + 'model': { + 'state_dict': model.state_dict(), + **{k:v for k,v in model.__dict__.items() if k[0] != '_' and k != 'training'} + }, + 'epoch': epoch, + 'optimizer': { + 'state_dict': optimizer.state_dict(), + }, + 'loss': loss, + 'rocauc': rocauc, + 'train_uniq_client_ids_path': credit_dataset.train_uniq_client_ids_path, + 'test_uniq_client_ids_path': credit_dataset.test_uniq_client_ids_path + } + path = сhekpoints_dir + f"epoch_{epoch}_{loss:.5f}.pth" + torch.save(checkpoint, path) + print(f"\nCheckpoint saved to {path}") + +def load_checkpoint(checkpoint_file): + if os.path.exists(checkpoint_file): + checkpoint = torch.load(checkpoint_file) + + #optimizer.load_state_dict(checkpoint['optimizer']) + encoder.load_state_dict(checkpoint['encoder']['state_dict']) + model.load_state_dict(checkpoint['model']['state_dict']) + +class CreditProductsDataset: + def __init__(self, + features_path, targets_path, train_test_split_ratio=0.9, + train_uniq_client_ids_path=None, test_uniq_client_ids_path=None + ): + self.train_uniq_client_ids_path = train_uniq_client_ids_path + self.test_uniq_client_ids_path = test_uniq_client_ids_path + if Path(self.train_uniq_client_ids_path).exists(): + self.train_uniq_client_ids = pd.read_csv(self.train_uniq_client_ids_path).iloc[:,0].values + print("Loaded", self.train_uniq_client_ids_path) + else: + raise Exception(f"No {self.train_uniq_client_ids_path}") + if Path(self.test_uniq_client_ids_path).exists(): + self.test_uniq_client_ids = pd.read_csv(self.test_uniq_client_ids_path).iloc[:,0].values + print("Loaded", self.test_uniq_client_ids_path) + else: + raise Exception(f"No {self.test_uniq_client_ids_path}") + self.features_df = pd.read_parquet(features_path) + self.targets_df = pd.read_csv(targets_path) + self.uniq_client_ids = self.features_df.id.unique() + self.max_user_history = self.features_df.rn.max() + self.id_columns = ['id', 'rn'] + self.cat_columns = [ + 'pre_since_opened', 'pre_since_confirmed', 'pre_pterm', 'pre_fterm', 'pre_till_pclose', 'pre_till_fclose', + 'pre_loans_credit_limit', 'pre_loans_next_pay_summ', 'pre_loans_outstanding', 'pre_loans_total_overdue', + 'pre_loans_max_overdue_sum', 'pre_loans_credit_cost_rate', 'is_zero_loans5', 'is_zero_loans530', 'is_zero_loans3060', + 'is_zero_loans6090', 'is_zero_loans90', 'pre_util', 'pre_over2limit', 'pre_maxover2limit', 'is_zero_util', 'is_zero_over2limit', + 'is_zero_maxover2limit', 'enc_paym_0', 'enc_paym_1', 'enc_paym_2', 'enc_paym_3', 'enc_paym_4', 'enc_paym_5', 'enc_paym_6', 'enc_paym_7', + 'enc_paym_8', 'enc_paym_9', 'enc_paym_10', 'enc_paym_11', 'enc_paym_12', 'enc_paym_13', 'enc_paym_14', 'enc_paym_15', 'enc_paym_16', + 'enc_paym_17', 'enc_paym_18', 'enc_paym_19', 'enc_paym_20', 'enc_paym_21', 'enc_paym_22', 'enc_paym_23', 'enc_paym_24', + 'enc_loans_account_holder_type', 'enc_loans_credit_status', 'enc_loans_credit_type', 'enc_loans_account_cur', 'pclose_flag', + 'fclose_flag' + ] + self.num_columns = list(set(self.features_df.columns).difference(set(self.id_columns)).difference(self.cat_columns)) + # make unified category index for embeddings for all columns. zero index embedding for padding will be zeroed during training + self.cat_cardinalities = self.features_df.max(axis=0)[self.cat_columns] + 1 + self.cat_cardinalities_integral = self.cat_cardinalities.cumsum() + self.features_df[self.cat_columns[1:]] = self.features_df[self.cat_columns[1:]] + self.cat_cardinalities_integral[1:] + self.features_df[self.cat_columns] = self.features_df[self.cat_columns] + 1 # zero embedding is for padding + + self.features_df = self.features_df.sort_values(self.id_columns, ascending=[True, True]) + self.features_df = self.features_df.set_index('id') + self.targets_df = self.targets_df.set_index('id') + self.targets_df = self.targets_df.sort_index() + + self.user_seq_lengths = self.features_df.index.value_counts().sort_index() + + self.cat_features = torch.tensor(self.features_df[self.cat_columns].values, dtype=torch.int16) + self.cat_features = pad_sequence(torch.split(self.cat_features, self.user_seq_lengths.tolist()), batch_first=True) # implicit max seq + self.num_features = torch.tensor(self.features_df[self.num_columns].values, dtype=torch.float32) + self.num_features = pad_sequence(torch.split(self.num_features, self.user_seq_lengths.tolist()), batch_first=True) + self.targets = torch.tensor(self.targets_df.flag.values).type(torch.float32) + + def get_batch(self, batch_size=4): + sampled_ids = np.random.choice(self.train_uniq_client_ids, batch_size, replace=False) # think about replace=True + cat_features_batch = self.cat_features[sampled_ids] * torch.empty_like(self.cat_features[sampled_ids]).bernoulli_(0.9) # dropout features with prob 0.1 + num_features_batch = self.num_features[sampled_ids] * torch.empty_like(self.num_features[sampled_ids]).bernoulli_(0.9) # dropout features with prob 0.1 + targets_batch = self.targets[sampled_ids] + return cat_features_batch, num_features_batch, targets_batch + + def get_test_batch_iterator(self, batch_size=4): + for i in range(0, len(self.test_uniq_client_ids), batch_size): + ids = self.test_uniq_client_ids[i:i+batch_size] + cat_features_batch = self.cat_features[ids] + num_features_batch = self.num_features[ids] + targets_batch = self.targets[ids] + yield cat_features_batch, num_features_batch, targets_batch + +class Encoder(nn.Module): + def __init__(self, cat_columns, num_columns, cat_features_max_id, category_feature_dim=4, out_dim=64): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.total_h_dim = len(self.cat_columns) * category_feature_dim + len(self.num_columns) + self.cat_embeds = nn.Embedding(cat_features_max_id + 1, self.category_feature_dim, padding_idx=0) + self.num_scales = nn.Parameter(torch.randn(1, len(self.num_columns))) + self.num_shifts = nn.Parameter(torch.randn(1, len(self.num_columns))) + self.proj = nn.Linear(self.total_h_dim, self.out_dim, bias=False) + + @property + def device(self): + return next(self.parameters()).device + + def forward(self, cat_features_batch, num_features_batch, targets_batch): + cat_features_batch = cat_features_batch.to(self.device) + num_features_batch = num_features_batch.to(self.device) + + cat_embed_tensor = self.cat_embeds(cat_features_batch.type(torch.int32)) + cat_embed_tensor = cat_embed_tensor.reshape(cat_features_batch.shape[0], cat_features_batch.shape[1], -1) + num_embed_tensor = self.num_scales * num_features_batch + self.num_shifts + embed_tensor = torch.concat([cat_embed_tensor, num_embed_tensor], dim=-1) + + inputs = self.proj(embed_tensor) + targets = targets_batch.to(self.device) + + return inputs, targets + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = nn.Linear(h_dim, 4*h_dim) + self.ff2 = nn.Linear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + # self.k1 = nn.Parameter(torch.randn(1)) + # self.k2 = nn.Parameter(torch.randn(1)) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5) + attention = nn.functional.softmax(scores, dim=2) + output = new_formula(self.sim_output, attention, v) + return self.o_proj(self.gather_heads(output, *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +# Vision Transformers Need Registers https://arxiv.org/html/2309.16588v2 +class BertClassifier(nn.Module): + def __init__(self, layers_num=1, h_dim=64, class_num=2, max_seq_len=128, num_heads=4, num_reg=0, dropout_rate = 0.1): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.cls_token = nn.Parameter(torch.randn(1,1+num_reg,h_dim)) # reg tokens can be added by second dim >1 + self.max_seq_len = max_seq_len + self.cls_token.shape[1] + print(h_dim, self.max_seq_len) + if max_seq_len < 512: + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = self.max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * self.max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=False) + ) + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = self.max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * self.max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=False) + ) + if max_seq_len >= 512: + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = self.max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * self.max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2, + trainable_cylind_lens=False) + ) + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = self.max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * self.max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2, + trainable_cylind_lens=False) + ) + self.layers = nn.ModuleList([TransformerLayer(h_dim=h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=num_heads, dropout_rate = dropout_rate, max_seq_len=self.max_seq_len) for _ in range(layers_num)]) + self.classifier_head = nn.Sequential(nn.Linear(h_dim, h_dim), nn.GELU(), nn.Linear(h_dim, class_num)) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + + def forward(self, x): + x = torch.concat([self.cls_token.expand([x.shape[0], self.cls_token.shape[1], self.cls_token.shape[2]]), x], dim=1) + x = x + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + x = self.classifier_head(x[:,0,:]) + return x[:,:] if self.class_num > 1 else x[:,0] + +start_prep_time = datetime.now() +credit_train_dataset = CreditProductsDataset( + features_path="/wd/finbert_data/train_data", + targets_path="/wd/finbert_data/train_target .csv", + train_uniq_client_ids_path="/wd/finbert_data/train_uniq_client_ids.csv", + test_uniq_client_ids_path="/wd/finbert_data/test_uniq_client_ids.csv", +) +print(f"Dataset preparation time: {datetime.now() - start_prep_time}") + +h_dim = 64 +category_feature_dim = 8 +layers_num = 2 +num_heads = 1 +class_num = 1 +dropout_rate = 0.1 + +encoder = Encoder( + cat_columns=credit_train_dataset.cat_columns, + num_columns=credit_train_dataset.num_columns, + cat_features_max_id=credit_train_dataset.cat_features.max(), + category_feature_dim=category_feature_dim, + out_dim=h_dim, +).to(device) + +model = BertClassifier( + layers_num=layers_num, + num_heads=num_heads, + h_dim=h_dim, + class_num=class_num, + max_seq_len=credit_train_dataset.max_user_history, + dropout_rate = dropout_rate +).to(device) + +print(f'Parameters model - {sum(p.numel() for p in model.parameters())}, parameters encoder - {sum(p.numel() for p in encoder.parameters())}') +optimizer = schedulefree.AdamWScheduleFree(list(model.parameters()) + list(encoder.parameters())) + +if checkpoint_file is not None: + load_checkpoint(checkpoint_file) + +positive_counts = credit_train_dataset.targets_df.loc[credit_train_dataset.train_uniq_client_ids].values.sum() +negative_counts = len(credit_train_dataset.targets_df.loc[credit_train_dataset.train_uniq_client_ids]) - positive_counts +pos_weight = negative_counts / (positive_counts + 1e-15) +print(f"Class imbalance: {negative_counts} {positive_counts}. Pos weight: {pos_weight}") +criterion = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight)) + +epochs = 50 +batch_size = 300 +batches_per_epoch = len(credit_train_dataset.uniq_client_ids) // batch_size + +# for parallel data selection +class WrapperDataset(Dataset): + def __init__(self, credit_dataset, encoder, batch_size): + self.credit_dataset = credit_dataset + self.encoder = encoder + self.batch_size = batch_size + + def __len__(self): + return len(self.credit_dataset.uniq_client_ids) // self.batch_size + + def __getitem__(self, idx): + cat_inputs, num_inputs, targets = credit_train_dataset.get_batch(batch_size=self.batch_size) + return cat_inputs, num_inputs, targets + +training_data = WrapperDataset(credit_train_dataset, encoder, batch_size=batch_size) +dataloader = DataLoader(training_data, batch_size=1, shuffle=False, num_workers=8, pin_memory=True) + +val_auroc = AUROC(task='binary') +test_auroc = AUROC(task='binary') + +def test(epoch): + model.eval() + encoder.eval() + optimizer.eval() + with torch.no_grad(): + test_iterator = credit_train_dataset.get_test_batch_iterator(batch_size=batch_size) + for test_batch_id, (test_cat_inputs, test_num_inputs, test_targets) in enumerate(test_iterator): + inputs, targets = encoder(test_cat_inputs, test_num_inputs, test_targets) + outputs = model(inputs) + test_auroc.update(outputs, targets.long()) + print(f"\r {test_batch_id}/{len(credit_train_dataset.test_uniq_client_ids)//batch_size} {test_auroc.compute().item():.4f}", end = " "*40) + + writer.add_scalar('test_roc_auc', test_auroc.compute().item(), epoch * batches_per_epoch) + print(f"\r {datetime.now() - start_time} {epoch}/{epochs} Test rocauc: {test_auroc.compute().item():.4f}", end = " "*40) + print() + +start_time = datetime.now() +print("Started at:", start_time) +last_display_time = start_time +last_checkpoint_time = start_time +for epoch in range(epochs): + test(epoch) + for batch_id, (cat_inputs, num_inputs, targets) in enumerate(dataloader): + # with autograd.detect_anomaly(True): + model.train() + encoder.train() + optimizer.train() + inputs, targets = encoder(cat_inputs[0], num_inputs[0], targets[0]) + outputs = model(inputs) + loss = criterion(outputs, targets) + loss.backward() + current_time = datetime.now() + if current_time - last_display_time > timedelta(seconds=1): + writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1e6).item(), epoch*batches_per_epoch+batch_id) + optimizer.step() + optimizer.zero_grad() + if current_time - last_display_time > timedelta(seconds=1): + model.eval() + encoder.eval() + optimizer.eval() + last_display_time = current_time + writer.add_scalar('Loss', loss.item(), epoch*batches_per_epoch+batch_id) + print(f"\r {current_time-start_time} {epoch+1}/{epochs} {batch_id}/{batches_per_epoch} loss: {loss.item()}", end = " "*40) + if current_time - last_checkpoint_time > timedelta(hours=4): + last_checkpoint_time = current_time + save_checkpoint( + credit_dataset=credit_train_dataset, + encoder = encoder, model=model, optimizer=optimizer, epoch=epoch, + loss=loss.item(), rocauc=test_auroc.compute().item(), сhekpoints_dir=сhekpoints_dir) + +test(epochs) +print() + +save_checkpoint( + credit_dataset=credit_train_dataset, + encoder = encoder, model=model, optimizer=optimizer, epoch=epochs, + loss=loss.item(), rocauc=test_auroc.compute().item(), сhekpoints_dir=сhekpoints_dir) + +writer.close() \ No newline at end of file diff --git a/src/optical_matrix_multiplication/__init__.py b/src/optical_matrix_multiplication/__init__.py index 9a1844c..d5f8779 100644 --- a/src/optical_matrix_multiplication/__init__.py +++ b/src/optical_matrix_multiplication/__init__.py @@ -5,6 +5,12 @@ __version__ = "3.0.0" from .config import Config from . import propagator -from .optical_mul import OpticalMul +from .optical_mul import ( + OpticalMul, + TrainableLensOpticalMul, + TrainableScalarOpticalMul, + TrainableScalarAndLensOpticalMul, + TrainableFocalDistLensOpticalMul +) from .parallel import DataParallel from .parallel import ScatterDataParallel \ No newline at end of file diff --git a/src/optical_matrix_multiplication/optical_mul.py b/src/optical_matrix_multiplication/optical_mul.py index 11f52dc..cbe5f57 100644 --- a/src/optical_matrix_multiplication/optical_mul.py +++ b/src/optical_matrix_multiplication/optical_mul.py @@ -1,7 +1,15 @@ import torch as _torch import torch.nn as _nn from .config import Config as _Config -from .propagator import PropagatorCrossLens as _PropCrossLens, PropagatorСylindLens as _PropСylindLens, PropagatorSinc as _PropSinc, Propagator as _Prop +from .propagator import PropagatorCrossLens as _PropCrossLens, PropagatorCylindLens as _PropCylindLens, PropagatorSinc as _PropSinc, Propagator as _Prop +from .propagator import ( + PropagatorTrainableCylindLens as _PropagatorTrainableCylindLens, + PropagatorTrainableFocalDistCylindLens as _PropagatorTrainableFocalDistCylindLens +) +from torch.utils.tensorboard import SummaryWriter +from typing import Optional +import matplotlib.pyplot as plt + class OpticalMul(_nn.Module): """ @@ -14,22 +22,129 @@ class OpticalMul(_nn.Module): Args: config: конфигурация расчётной системы. """ - super(OpticalMul, self).__init__() - self.trainable_cylind_lens = config._trainable_cylind_lens + super().__init__() + + prop_one = _PropSinc(config.input_vector_plane, config.first_lens_plane, config) + prop_two = _PropCrossLens(config.first_lens_plane, config) + prop_three = _PropSinc(config.first_lens_plane, config.matrix_plane, config) + prop_four = _PropCylindLens(config.matrix_plane, config) + prop_five = _PropSinc(config.matrix_plane, config.second_lens_plane, config) + prop_six = _PropCrossLens(config.second_lens_plane, config).T + prop_seven = _PropSinc(config.second_lens_plane, config.output_vector_plane, config) + + self._propagator_one: _Prop = prop_one + prop_two + prop_three + prop_four + self._propagator_two: _Prop = prop_five + prop_six + prop_seven + + kron_vec_utils = _torch.ones((config.input_vector_split_y, config.input_vector_split_x)) + kron_mat_utils = _torch.ones((config.matrix_split_x, config.matrix_split_y)) + self.register_buffer('_kron_vec_utils', kron_vec_utils, persistent=True) + self.register_buffer('_kron_mat_utils', kron_mat_utils, persistent=True) + + self._avg_pool = _nn.AvgPool2d((1, config.result_vector_split)) + + def prepare_vector(self, data: _torch.Tensor) -> _torch.Tensor: + """ + Метод подготовки матрицы левой матрицы, как набора векторов столбцов, к подаче на вход системы. + + Args: + data: матрица комплексной амплитуды распределений световых полей. + + Returns: + Матрицы содержащие вектора левой матрицы. + """ + data = data.cfloat().flip(-1) + data = data.unsqueeze(-2) + data = _torch.kron(data.contiguous(), self._kron_vec_utils) + return data + + def prepare_matrix(self, data: _torch.Tensor) -> _torch.Tensor: + """ + Метод подготовки правой матрицы к подаче на вход системы. + + Args: + data: матрица комплексной амплитуды распределения светового поля. + + Returns: + Матрица - оптический элемент в центре модели. + """ + if (data.dim() > 4) and data.size(-1) == 2: + data = _torch.view_as_complex(data) + + data = data.cfloat().transpose(-2, -1) + data = data.unsqueeze(-3) + data = _torch.kron(data.contiguous(), self._kron_mat_utils) + return data + + def prepare_out(self, field: _torch.Tensor) -> _torch.Tensor: + """ + Метод получения результата матричного умножения. + + Args: + data: матрицы выходого распределения светового поля системы. + + Returns: + Вектор столбец (амплитудное распределение). + """ + ### Закоментированная часть кода - более физически корректный вариант работы модели, + ### однако, данный вариант кода будет требовать большое кол-во памяти во время обучения + field = field.abs().squeeze(-1) #**2 + field = self._avg_pool(field) + return field.flip(-1) #**0.5 + + def forward(self, + input: _torch.Tensor, + other: _torch.Tensor) -> _torch.Tensor: + """ + Метод выполения матричного умножения. + + Args: + input: матрица (B, C, H, W). + other: матрица (B, C, W, K). + + Returns: + Рензультат матричного умножения (B, C, H, K). + + Example: + >>> mul = OpticalMul(...) + >>> A = torch.rand((1, 1, 256, 256)) > 0.5 + >>> B = torch.rand((1, 1, 256, 256)) > 0.5 + >>> mul(A, B).shape + torch.Size([1, 1, 256, 256]) + >>> A = torch.rand((1, 1, 64, 256)) > 0.5 + >>> B = torch.rand((1, 1, 256, 128)) > 0.5 + >>> mul(A, B).shape + torch.Size([1, 1, 64, 128]) + """ + vec_field = self.prepare_vector(input) + mat_field = self.prepare_matrix(other) + + vec_field = self._propagator_one(vec_field, mat_field.shape[-2:]) + vec_field = self._propagator_two(vec_field * mat_field, (mat_field.size(-2), 1)) + + return self.prepare_out(vec_field) + +class TrainableScalarOpticalMul(_nn.Module): + """ + Класс системы, выполняющей оптически операцию умножения матрицы на матрицу. + """ + def __init__(self, config: _Config): + """ + Конструктор класса. + + Args: + config: конфигурация расчётной системы. + """ + super().__init__() prop_one = _PropSinc(config.input_vector_plane, config.first_lens_plane, config) prop_two = _PropCrossLens(config.first_lens_plane, config) prop_three = _PropSinc(config.first_lens_plane, config.matrix_plane, config) - prop_four = _PropСylindLens(config.matrix_plane, config, trainable=self.trainable_cylind_lens) + prop_four = _PropCylindLens(config.matrix_plane, config) prop_five = _PropSinc(config.matrix_plane, config.second_lens_plane, config) prop_six = _PropCrossLens(config.second_lens_plane, config).T prop_seven = _PropSinc(config.second_lens_plane, config.output_vector_plane, config) - if self.trainable_cylind_lens: - self._propagator_one: _Prop = prop_one + prop_two + prop_three - self._propagator_between = prop_four - else: - self._propagator_one: _Prop = prop_one + prop_two + prop_three + prop_four + self._propagator_one: _Prop = prop_one + prop_two + prop_three + prop_four self._propagator_two: _Prop = prop_five + prop_six + prop_seven kron_vec_utils = _torch.ones((config.input_vector_split_y, config.input_vector_split_x)) @@ -38,6 +153,421 @@ class OpticalMul(_nn.Module): self.register_buffer('_kron_mat_utils', kron_mat_utils, persistent=True) self._avg_pool = _nn.AvgPool2d((1, config.result_vector_split)) + self.k = nn.Parameter(_torch.tensor(1)) + + def prepare_vector(self, data: _torch.Tensor) -> _torch.Tensor: + """ + Метод подготовки матрицы левой матрицы, как набора векторов столбцов, к подаче на вход системы. + + Args: + data: матрица комплексной амплитуды распределений световых полей. + + Returns: + Матрицы содержащие вектора левой матрицы. + """ + data = data.cfloat().flip(-1) + data = data.unsqueeze(-2) + data = _torch.kron(data.contiguous(), self._kron_vec_utils) + return data + + def prepare_matrix(self, data: _torch.Tensor) -> _torch.Tensor: + """ + Метод подготовки правой матрицы к подаче на вход системы. + + Args: + data: матрица комплексной амплитуды распределения светового поля. + + Returns: + Матрица - оптический элемент в центре модели. + """ + if (data.dim() > 4) and data.size(-1) == 2: + data = _torch.view_as_complex(data) + + data = data.cfloat().transpose(-2, -1) + data = data.unsqueeze(-3) + data = _torch.kron(data.contiguous(), self._kron_mat_utils) + return data + + def prepare_out(self, field: _torch.Tensor) -> _torch.Tensor: + """ + Метод получения результата матричного умножения. + + Args: + data: матрицы выходого распределения светового поля системы. + + Returns: + Вектор столбец (амплитудное распределение). + """ + ### Закоментированная часть кода - более физически корректный вариант работы модели, + ### однако, данный вариант кода будет требовать большое кол-во памяти во время обучения + field = field.abs().squeeze(-1) #**2 + field = self._avg_pool(field) + return field.flip(-1) #**0.5 + + def forward(self, + input: _torch.Tensor, + other: _torch.Tensor) -> _torch.Tensor: + """ + Метод выполения матричного умножения. + + Args: + input: матрица (B, C, H, W). + other: матрица (B, C, W, K). + + Returns: + Рензультат матричного умножения (B, C, H, K). + + Example: + >>> mul = OpticalMul(...) + >>> A = torch.rand((1, 1, 256, 256)) > 0.5 + >>> B = torch.rand((1, 1, 256, 256)) > 0.5 + >>> mul(A, B).shape + torch.Size([1, 1, 256, 256]) + >>> A = torch.rand((1, 1, 64, 256)) > 0.5 + >>> B = torch.rand((1, 1, 256, 128)) > 0.5 + >>> mul(A, B).shape + torch.Size([1, 1, 64, 128]) + """ + vec_field = self.prepare_vector(input) + mat_field = self.prepare_matrix(other) + + vec_field = self._propagator_one(vec_field, mat_field.shape[-2:]) + vec_field = self._propagator_two(vec_field * mat_field, (mat_field.size(-2), 1)) + + return self.k * self.prepare_out(vec_field) + +class TrainableLensOpticalMul(_nn.Module): + """ + Класс системы, выполняющей оптически операцию умножения матрицы на матрицу. + """ + def __init__(self, config: _Config): + """ + Конструктор класса. + + Args: + config: конфигурация расчётной системы. + """ + super().__init__() + + prop_one = _PropSinc(config.input_vector_plane, config.first_lens_plane, config) + prop_two = _PropCrossLens(config.first_lens_plane, config) + prop_three = _PropSinc(config.first_lens_plane, config.matrix_plane, config) + prop_four = _PropagatorTrainableCylindLens(config.matrix_plane, config) + prop_five = _PropSinc(config.matrix_plane, config.second_lens_plane, config) + prop_six = _PropCrossLens(config.second_lens_plane, config).T + prop_seven = _PropSinc(config.second_lens_plane, config.output_vector_plane, config) + + self._propagator_one: _Prop = prop_one + prop_two + prop_three + self._propagator_cylind_lens: _Prop = prop_four + self._propagator_three: _Prop = prop_five + prop_six + prop_seven + + kron_vec_utils = _torch.ones((config.input_vector_split_y, config.input_vector_split_x)) + kron_mat_utils = _torch.ones((config.matrix_split_x, config.matrix_split_y)) + self.register_buffer('_kron_vec_utils', kron_vec_utils, persistent=True) + self.register_buffer('_kron_mat_utils', kron_mat_utils, persistent=True) + + self._avg_pool = _nn.AvgPool2d((1, config.result_vector_split)) + + def prepare_vector(self, data: _torch.Tensor) -> _torch.Tensor: + """ + Метод подготовки матрицы левой матрицы, как набора векторов столбцов, к подаче на вход системы. + + Args: + data: матрица комплексной амплитуды распределений световых полей. + + Returns: + Матрицы содержащие вектора левой матрицы. + """ + data = data.cfloat().flip(-1) + data = data.unsqueeze(-2) + data = _torch.kron(data.contiguous(), self._kron_vec_utils) + return data + + def prepare_matrix(self, data: _torch.Tensor) -> _torch.Tensor: + """ + Метод подготовки правой матрицы к подаче на вход системы. + + Args: + data: матрица комплексной амплитуды распределения светового поля. + + Returns: + Матрица - оптический элемент в центре модели. + """ + if (data.dim() > 4) and data.size(-1) == 2: + data = _torch.view_as_complex(data) + + data = data.cfloat().transpose(-2, -1) + data = data.unsqueeze(-3) + # TODO data should be at least two seq length. For one we get + # untimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead. + data = _torch.kron(data.contiguous(), self._kron_mat_utils) + return data + + def prepare_out(self, field: _torch.Tensor) -> _torch.Tensor: + """ + Метод получения результата матричного умножения. + + Args: + data: матрицы выходого распределения светового поля системы. + + Returns: + Вектор столбец (амплитудное распределение). + """ + ### Закоментированная часть кода - более физически корректный вариант работы модели, + ### однако, данный вариант кода будет требовать большое кол-во памяти во время обучения + field = field.abs().squeeze(-1) #**2 + field = self._avg_pool(field) + return field.flip(-1) #**0.5 + + def forward(self, + input: _torch.Tensor, + other: _torch.Tensor) -> _torch.Tensor: + """ + Метод выполения матричного умножения. + + Args: + input: матрица (B, C, H, W). + other: матрица (B, C, W, K). + + Returns: + Рензультат матричного умножения (B, C, H, K). + + Example: + >>> mul = OpticalMul(...) + >>> A = torch.rand((1, 1, 256, 256)) > 0.5 + >>> B = torch.rand((1, 1, 256, 256)) > 0.5 + >>> mul(A, B).shape + torch.Size([1, 1, 256, 256]) + >>> A = torch.rand((1, 1, 64, 256)) > 0.5 + >>> B = torch.rand((1, 1, 256, 128)) > 0.5 + >>> mul(A, B).shape + torch.Size([1, 1, 64, 128]) + """ + vec_field = self.prepare_vector(input) + mat_field = self.prepare_matrix(other) + vec_field = self._propagator_one(vec_field, mat_field.shape[-2:]) + vec_field = self._propagator_cylind_lens(vec_field, mat_field.shape[-2:]) + vec_field = self._propagator_three(vec_field * mat_field, (mat_field.size(-2), 1)) + + return self.prepare_out(vec_field) + + @_torch.no_grad() + def log_cylind_lens_operator_x( + self, + writer: SummaryWriter, + tag: str, + global_step: Optional[int] = None, + ): + # 1. Apply exp to get the wrapped phase as it would be physically seen + # This ensures values outside [-pi, pi] wrap correctly + complex_op = _torch.exp(-1j * self._propagator_cylind_lens._operator_X_phi) + wrapped_phase = _torch.angle(complex_op).float() # Range: [-π, π] + + # 2. Normalize for Image Visualization [0, 1] + phase_normalized = (wrapped_phase + _torch.pi) / (2 * _torch.pi) + + # 3. Log as a 1-pixel high row + # Shape: [1, 1, Width] + phase_row = phase_normalized.unsqueeze(0).unsqueeze(0) + writer.add_image(f"{tag}/phase_row", phase_row, global_step, dataformats='CHW') + + fig, ax = plt.subplots(figsize=(6, 4)) + ax.plot(wrapped_phase.detach().cpu().numpy(), label=f'Step {global_step}') + ax.set_title(f"Cylindrical Lens Phase Profile (x) {tag}") + ax.set_xlabel("Pixel Index") + ax.set_ylabel("Phase (rad)") + ax.set_ylim([-_torch.pi - 0.5, _torch.pi + 0.5]) + ax.grid(True, linestyle='--', alpha=0.6) + + + # Send the figure to the "Plots" or "Images" tab in TensorBoard + writer.add_figure(f"{tag}/phase_profile", fig, global_step) + plt.close(fig) # Important: prevent memory leaks + + +class TrainableFocalDistLensOpticalMul(_nn.Module): + """ + Класс системы, выполняющей оптически операцию умножения матрицы на матрицу. + """ + def __init__(self, config: _Config): + """ + Конструктор класса. + + Args: + config: конфигурация расчётной системы. + """ + super().__init__() + + prop_one = _PropSinc(config.input_vector_plane, config.first_lens_plane, config) + prop_two = _PropCrossLens(config.first_lens_plane, config) + prop_three = _PropSinc(config.first_lens_plane, config.matrix_plane, config) + prop_four = _PropagatorTrainableFocalDistCylindLens(config.matrix_plane, config) + prop_five = _PropSinc(config.matrix_plane, config.second_lens_plane, config) + prop_six = _PropCrossLens(config.second_lens_plane, config).T + prop_seven = _PropSinc(config.second_lens_plane, config.output_vector_plane, config) + + self._propagator_one: _Prop = prop_one + prop_two + prop_three + self._propagator_cylind_lens: _Prop = prop_four + self._propagator_three: _Prop = prop_five + prop_six + prop_seven + + kron_vec_utils = _torch.ones((config.input_vector_split_y, config.input_vector_split_x)) + kron_mat_utils = _torch.ones((config.matrix_split_x, config.matrix_split_y)) + self.register_buffer('_kron_vec_utils', kron_vec_utils, persistent=True) + self.register_buffer('_kron_mat_utils', kron_mat_utils, persistent=True) + + self._avg_pool = _nn.AvgPool2d((1, config.result_vector_split)) + + def prepare_vector(self, data: _torch.Tensor) -> _torch.Tensor: + """ + Метод подготовки матрицы левой матрицы, как набора векторов столбцов, к подаче на вход системы. + + Args: + data: матрица комплексной амплитуды распределений световых полей. + + Returns: + Матрицы содержащие вектора левой матрицы. + """ + data = data.cfloat().flip(-1) + data = data.unsqueeze(-2) + data = _torch.kron(data.contiguous(), self._kron_vec_utils) + return data + + def prepare_matrix(self, data: _torch.Tensor) -> _torch.Tensor: + """ + Метод подготовки правой матрицы к подаче на вход системы. + + Args: + data: матрица комплексной амплитуды распределения светового поля. + + Returns: + Матрица - оптический элемент в центре модели. + """ + if (data.dim() > 4) and data.size(-1) == 2: + data = _torch.view_as_complex(data) + + data = data.cfloat().transpose(-2, -1) + data = data.unsqueeze(-3) + # TODO data should be at least two seq length. For one we get + # untimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead. + data = _torch.kron(data.contiguous(), self._kron_mat_utils) + return data + + def prepare_out(self, field: _torch.Tensor) -> _torch.Tensor: + """ + Метод получения результата матричного умножения. + + Args: + data: матрицы выходого распределения светового поля системы. + + Returns: + Вектор столбец (амплитудное распределение). + """ + ### Закоментированная часть кода - более физически корректный вариант работы модели, + ### однако, данный вариант кода будет требовать большое кол-во памяти во время обучения + field = field.abs().squeeze(-1) #**2 + field = self._avg_pool(field) + return field.flip(-1) #**0.5 + + def forward(self, + input: _torch.Tensor, + other: _torch.Tensor) -> _torch.Tensor: + """ + Метод выполения матричного умножения. + + Args: + input: матрица (B, C, H, W). + other: матрица (B, C, W, K). + + Returns: + Рензультат матричного умножения (B, C, H, K). + + Example: + >>> mul = OpticalMul(...) + >>> A = torch.rand((1, 1, 256, 256)) > 0.5 + >>> B = torch.rand((1, 1, 256, 256)) > 0.5 + >>> mul(A, B).shape + torch.Size([1, 1, 256, 256]) + >>> A = torch.rand((1, 1, 64, 256)) > 0.5 + >>> B = torch.rand((1, 1, 256, 128)) > 0.5 + >>> mul(A, B).shape + torch.Size([1, 1, 64, 128]) + """ + vec_field = self.prepare_vector(input) + mat_field = self.prepare_matrix(other) + vec_field = self._propagator_one(vec_field, mat_field.shape[-2:]) + vec_field = self._propagator_cylind_lens(vec_field, mat_field.shape[-2:]) + vec_field = self._propagator_three(vec_field * mat_field, (mat_field.size(-2), 1)) + + return self.prepare_out(vec_field) + + @_torch.no_grad() + def log_cylind_lens_operator_x( + self, + writer: SummaryWriter, + tag: str, + global_step: Optional[int] = None, + ): + # 1. Apply exp to get the wrapped phase as it would be physically seen + # This ensures values outside [-pi, pi] wrap correctly + lens = self._propagator_cylind_lens + writer.add_scalar(f"{tag}/focal_distance", lens._distance.detach().cpu().numpy(), global_step) + + complex_op = _torch.exp(-1j * lens._K / lens._distance * lens._linspace_by_x**2) + wrapped_phase = _torch.angle(complex_op).float() # Range: [-π, π] + + # 2. Normalize for Image Visualization [0, 1] + phase_normalized = (wrapped_phase + _torch.pi) / (2 * _torch.pi) + + # 3. Log as a 1-pixel high row + # Shape: [1, 1, Width] + phase_row = phase_normalized.unsqueeze(0).unsqueeze(0) + writer.add_image(f"{tag}/phase_row", phase_row, global_step, dataformats='CHW') + + fig, ax = plt.subplots(figsize=(6, 4)) + ax.plot(wrapped_phase.detach().cpu().numpy(), label=f'Step {global_step}') + ax.set_title(f"Cylindrical Lens Phase Profile (x) {tag}\nFocal distance: {lens._distance.detach().cpu().numpy()}") + ax.set_xlabel("Pixel Index") + ax.set_ylabel("Phase (rad)") + ax.set_ylim([-_torch.pi - 0.5, _torch.pi + 0.5]) + ax.grid(True, linestyle='--', alpha=0.6) + + # Send the figure to the "Plots" or "Images" tab in TensorBoard + writer.add_figure(f"{tag}/phase_profile", fig, global_step) + plt.close(fig) # Important: prevent memory leaks + + +class TrainableScalarAndLensOpticalMul(_nn.Module): + """ + Класс системы, выполняющей оптически операцию умножения матрицы на матрицу. + """ + def __init__(self, config: _Config): + """ + Конструктор класса. + + Args: + config: конфигурация расчётной системы. + """ + super().__init__() + + prop_one = _PropSinc(config.input_vector_plane, config.first_lens_plane, config) + prop_two = _PropCrossLens(config.first_lens_plane, config) + prop_three = _PropSinc(config.first_lens_plane, config.matrix_plane, config) + prop_four = _PropagatorTrainableCylindLens(config.matrix_plane, config) + prop_five = _PropSinc(config.matrix_plane, config.second_lens_plane, config) + prop_six = _PropCrossLens(config.second_lens_plane, config).T + prop_seven = _PropSinc(config.second_lens_plane, config.output_vector_plane, config) + + self._propagator_one: _Prop = prop_one + prop_two + prop_three + self._propagator_two: _Prop = prop_four + self._propagator_three: _Prop = prop_five + prop_six + prop_seven + + kron_vec_utils = _torch.ones((config.input_vector_split_y, config.input_vector_split_x)) + kron_mat_utils = _torch.ones((config.matrix_split_x, config.matrix_split_y)) + self.register_buffer('_kron_vec_utils', kron_vec_utils, persistent=True) + self.register_buffer('_kron_mat_utils', kron_mat_utils, persistent=True) + + self._avg_pool = _nn.AvgPool2d((1, config.result_vector_split)) + self.k = nn.Parameter(torch.tensor(1)) def prepare_vector(self, data: _torch.Tensor) -> _torch.Tensor: """ @@ -114,12 +644,9 @@ class OpticalMul(_nn.Module): """ vec_field = self.prepare_vector(input) mat_field = self.prepare_matrix(other) - if self.trainable_cylind_lens: - vec_field = self._propagator_one(vec_field) - vec_field = self._propagator_between(vec_field) - else: - vec_field = self._propagator_one(vec_field) - vec_field = self._propagator_two(vec_field * mat_field) + vec_field = self._propagator_one(vec_field, mat_field.shape[-2:]) + vec_field = self._propagator_two(vec_field, mat_field.shape[-2:]) + vec_field = self._propagator_three(vec_field * mat_field, (mat_field.size(-2), 1)) - return self.prepare_out(vec_field) \ No newline at end of file + return self.k * self.prepare_out(vec_field) \ No newline at end of file diff --git a/src/optical_matrix_multiplication/propagator.py b/src/optical_matrix_multiplication/propagator.py index eaa061f..3079d13 100644 --- a/src/optical_matrix_multiplication/propagator.py +++ b/src/optical_matrix_multiplication/propagator.py @@ -7,6 +7,7 @@ from typing import Tuple as _Tuple, Sequence as _Sequence from abc import ABC as _ABC import collections as _collections +import copy as _copy class Propagator(_ABC, _nn.Module): """ @@ -16,18 +17,12 @@ class Propagator(_ABC, _nn.Module): operator_X: оператор отображающий распроcтранение светового поля вдоль оси абсцисс operator_Y: оператор отображающий распроcтранение светового поля вдоль оси ординат """ - def __init__(self, operator_X: _torch.Tensor, operator_Y: _torch.Tensor, trainable = False): + def __init__(self, operator_X: _torch.Tensor, operator_Y: _torch.Tensor): super(Propagator, self).__init__() operator_X: _torch.Tensor = _torch.view_as_real(operator_X) operator_Y: _torch.Tensor = _torch.view_as_real(operator_Y) - if trainable: - self._operator_X = _nn.Parameter(operator_X) - self._operator_Y = _nn.Parameter(operator_Y) - self._trainable = trainable - else: - self.register_buffer('_operator_X', operator_X, persistent=True) - self.register_buffer('_operator_Y', operator_Y, persistent=True) - self._trainable = trainable + self.register_buffer('_operator_X', operator_X, persistent=True) + self.register_buffer('_operator_Y', operator_Y, persistent=True) @property def operator_X(self) -> _torch.Tensor: @@ -98,7 +93,14 @@ class Propagator(_ABC, _nn.Module): """ return self.cat(propagator) - def forward(self, field: _torch.Tensor) -> _torch.Tensor: + @staticmethod + def __slice_calculation(total_rows: int, num_to_take: int) -> slice: + start = (total_rows - num_to_take) // 2 + end = start + num_to_take + return slice(start, end) + + def forward(self, + field: _torch.Tensor, resul_shape: None | _Tuple[int, int] | _torch.Size) -> _torch.Tensor: """ Метод распространения светового поля в среде. @@ -109,13 +111,23 @@ class Propagator(_ABC, _nn.Module): Распределение комплексной амплитуды светового поля, после распространения. """ - if self._trainable: - return _torch.diag_embed(self.operator_Y) @ field @ _torch.diag_embed(self.operator_X) - else: - return self.operator_Y @ field @ self.operator_X + + if (resul_shape is not None): + field_shape = field.shape[-2:] + operator_Y_shape = self.operator_Y.shape[-2:] + operator_X_shape = self.operator_X.shape[-2:] + + slice_one = Propagator.__slice_calculation(operator_Y_shape[0], resul_shape[0]) + slice_two = Propagator.__slice_calculation(operator_Y_shape[1], field_shape[0]) + slice_three = Propagator.__slice_calculation(operator_X_shape[0], field_shape[1]) + slice_four = Propagator.__slice_calculation(operator_X_shape[1], resul_shape[1]) + + return self.operator_Y[..., slice_one, slice_two] @ field @ self.operator_X[..., slice_three, slice_four] + + return self.operator_Y @ field @ self.operator_X def __repr__(self): - return f"Trainable: {self._trainable} Y shape: {self.operator_Y.shape}, X shape: {self.operator_X.shape}" + return f"Y shape: {self.operator_Y.shape}, X shape: {self.operator_X.shape}" class PropagatorLens(Propagator): """ @@ -145,7 +157,7 @@ class PropagatorCrossLens(PropagatorLens): представленной тонким оптическим элементом. """ def __init__(self, plane: _ConfigDesignPlane, - config: _ConfigOpticBase, trainable = False): + config: _ConfigOpticBase): """ Конструктор класса скрещенной линзы. @@ -155,18 +167,17 @@ class PropagatorCrossLens(PropagatorLens): """ operator_X = _torch.exp(-1j * config.K / config.distance * plane.linspace_by_x**2) operator_Y = _torch.exp(-1j * config.K / 2 / config.distance * plane.linspace_by_y**2) - super(PropagatorCrossLens, self).__init__(_torch.diag_embed(operator_X), - _torch.diag_embed(operator_Y), - trainable) + super(PropagatorCrossLens, self).__init__( + _torch.diag_embed(operator_X), + _torch.diag_embed(operator_Y)) -class PropagatorСylindLens(PropagatorLens): +class PropagatorCylindLens(PropagatorLens): """ Класс распространения света в цилиндрической линзе, представленной тонким оптическим элементом. """ def __init__(self, plane: _ConfigDesignPlane, - config: _ConfigOpticBase, - trainable = False): + config: _ConfigOpticBase): """ Конструктор класса цилиндрической линзы. @@ -176,14 +187,10 @@ class PropagatorСylindLens(PropagatorLens): """ operator_X = _torch.exp(-1j * config.K / config.distance * plane.linspace_by_x**2) operator_Y = _torch.ones_like(plane.linspace_by_y, dtype=_torch.cfloat) - if trainable: - super(PropagatorСylindLens, self).__init__(operator_X, - operator_Y, - trainable) - else: - super(PropagatorСylindLens, self).__init__(_torch.diag_embed(operator_X), - _torch.diag_embed(operator_Y), - trainable) + super(PropagatorCylindLens, self).__init__( + _torch.diag_embed(operator_X), + _torch.diag_embed(operator_Y)) + class PropagatorSinc(Propagator): """ @@ -192,7 +199,7 @@ class PropagatorSinc(Propagator): """ def __init__(self, first_plane: _ConfigDesignPlane, second_plane: _ConfigDesignPlane, - config: _ConfigOpticBase, trainable = False): + config: _ConfigOpticBase): """ Конструктор класса распространения в свободном пространстве. @@ -204,7 +211,7 @@ class PropagatorSinc(Propagator): operator_X, operator_Y = self.__get_operators(first_plane, second_plane, config) - super(PropagatorSinc, self).__init__(operator_X, operator_Y, trainable) + super(PropagatorSinc, self).__init__(operator_X, operator_Y) def __get_operator_for_dim(self, pixel_size_in: float, @@ -237,4 +244,137 @@ class PropagatorSinc(Propagator): second_plane.pixel_size_by_y, difference_y, config) - return operator_X, operator_Y \ No newline at end of file + return operator_X, operator_Y + + +####################################################################################################################### + +class PropagatorTrainableCylindLens(_ABC, _nn.Module): + """ + Класс распространения света в обучаемой цилиндрической линзе, + представленной тонким прозрачным оптическим элементом. + """ + def __init__(self, + plane: _ConfigDesignPlane, + config: _ConfigOpticBase + ): + super().__init__() + # non smooth profile after training. better to train only focal length? + self._operator_X_phi = _nn.Parameter(config.K / config.distance * plane.linspace_by_x**2) + operator_Y = _torch.diag_embed(_torch.ones_like(plane.linspace_by_y, dtype=_torch.cfloat)) + operator_Y = _torch.view_as_real(operator_Y) + self.register_buffer('_operator_Y', operator_Y, persistent=True) + + @property + def operator_X(self) -> _torch.Tensor: + """ + Returns: + оператор отображающий распроcтранение светового поля вдоль оси абсцисс + """ + return _torch.diag_embed(_torch.exp(-1j * self._operator_X_phi)) + + @property + def operator_Y(self) -> _torch.Tensor: + """ + Returns: + оператор отображающий распроcтранение светового поля вдоль оси ординат + """ + return _torch.view_as_complex(self._operator_Y) + + @staticmethod + def __slice_calculation(total_rows: int, num_to_take: int) -> slice: + start = (total_rows - num_to_take) // 2 + end = start + num_to_take + return slice(start, end) + + def forward(self, + field: _torch.Tensor, resul_shape: None | _Tuple[int, int] | _torch.Size) -> _torch.Tensor: + """ + Метод распространения светового поля в среде. + + Args: + field: распределение комплексной амплитуды светового поля. + + Returns: + Распределение комплексной амплитуды светового поля, + после распространения. + """ + + if (resul_shape is not None): + field_shape = field.shape[-2:] + operator_Y_shape = self.operator_Y.shape[-2:] + operator_X_shape = self.operator_X.shape[-2:] + + slice_one = PropagatorTrainableCylindLens.__slice_calculation(operator_Y_shape[0], resul_shape[0]) + slice_two = PropagatorTrainableCylindLens.__slice_calculation(operator_Y_shape[1], field_shape[0]) + slice_three = PropagatorTrainableCylindLens.__slice_calculation(operator_X_shape[0], field_shape[1]) + slice_four = PropagatorTrainableCylindLens.__slice_calculation(operator_X_shape[1], resul_shape[1]) + return self.operator_Y[..., slice_one, slice_two] @ field @ self.operator_X[..., slice_three, slice_four] + + return self.operator_Y @ field @ self.operator_X + + +class PropagatorTrainableFocalDistCylindLens(_ABC, _nn.Module): + """ + Класс распространения света в обучаемой цилиндрической линзе, + представленной тонким прозрачным оптическим элементом. + """ + def __init__(self, + plane: _ConfigDesignPlane, + config: _ConfigOpticBase + ): + super().__init__() + self._distance = _nn.Parameter(_torch.tensor(config.distance)) + self.register_buffer('_K', _torch.tensor(config.K), persistent=True) + self.register_buffer('_linspace_by_x', plane.linspace_by_x.detach().clone(), persistent=True) + operator_Y = _torch.diag_embed(_torch.ones_like(plane.linspace_by_y, dtype=_torch.cfloat)) + operator_Y = _torch.view_as_real(operator_Y) + self.register_buffer('_operator_Y', operator_Y, persistent=True) + + @property + def operator_X(self) -> _torch.Tensor: + """ + Returns: + оператор отображающий распроcтранение светового поля вдоль оси абсцисс + """ + return _torch.diag_embed(_torch.exp(-1j * self._K / self._distance * self._linspace_by_x**2)) + + @property + def operator_Y(self) -> _torch.Tensor: + """ + Returns: + оператор отображающий распроcтранение светового поля вдоль оси ординат + """ + return _torch.view_as_complex(self._operator_Y) + + @staticmethod + def __slice_calculation(total_rows: int, num_to_take: int) -> slice: + start = (total_rows - num_to_take) // 2 + end = start + num_to_take + return slice(start, end) + + def forward(self, + field: _torch.Tensor, resul_shape: None | _Tuple[int, int] | _torch.Size) -> _torch.Tensor: + """ + Метод распространения светового поля в среде. + + Args: + field: распределение комплексной амплитуды светового поля. + + Returns: + Распределение комплексной амплитуды светового поля, + после распространения. + """ + + if (resul_shape is not None): + field_shape = field.shape[-2:] + operator_Y_shape = self.operator_Y.shape[-2:] + operator_X_shape = self.operator_X.shape[-2:] + + slice_one = PropagatorTrainableFocalDistCylindLens.__slice_calculation(operator_Y_shape[0], resul_shape[0]) + slice_two = PropagatorTrainableFocalDistCylindLens.__slice_calculation(operator_Y_shape[1], field_shape[0]) + slice_three = PropagatorTrainableFocalDistCylindLens.__slice_calculation(operator_X_shape[0], field_shape[1]) + slice_four = PropagatorTrainableFocalDistCylindLens.__slice_calculation(operator_X_shape[1], resul_shape[1]) + return self.operator_Y[..., slice_one, slice_two] @ field @ self.operator_X[..., slice_three, slice_four] + + return self.operator_Y @ field @ self.operator_X \ No newline at end of file diff --git a/src/optics_char_gpt2_nokoef.py b/src/optics_char_gpt2_nokoef.py index af1e518..0cd65d3 100644 --- a/src/optics_char_gpt2_nokoef.py +++ b/src/optics_char_gpt2_nokoef.py @@ -179,6 +179,8 @@ class OpticGPT2NOKoef(nn.Module): lens_size = 8192 * 2, trainable_cylind_lens=False) ) + self.sim_scores = omm.ScatterDataParallel(self.sim_scores) + self.sim_output = omm.ScatterDataParallel(self.sim_output) self.layers = nn.ModuleList([ TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads, diff --git a/src/train_gpt2.py b/src/train_gpt2.py new file mode 100644 index 0000000..edee343 --- /dev/null +++ b/src/train_gpt2.py @@ -0,0 +1,341 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +from einops import rearrange +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime +import sys +from pathlib import Path +import shutil +seed = 1337 +torch.manual_seed(seed) + +############################### MODEL ############################################################# + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = nn.Linear(h_dim, 4*h_dim) + self.ff2 = nn.Linear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = (q @ k.transpose(1, 2)) * (self.h_dim ** -0.5) + tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) + scores = scores.masked_fill(tril == 0, float('-inf')) # encoder does not need this line + attention = nn.functional.softmax(scores, dim=2) + return self.o_proj(self.gather_heads(attention @ v, *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +class GPT2(nn.Module): + def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, pixel_size=None): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.layers = nn.ModuleList([ + TransformerLayer(h_dim=self.h_dim, num_heads=self.num_heads, dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) + for _ in range(layers_num)]) + self.tok_embeds = nn.Embedding(vocab_size, h_dim) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + self.lm_head = nn.Linear(h_dim, vocab_size) + + def forward(self, x, targets=None): + x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + logits = self.lm_head(x) # B,T,C + loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None + return logits, loss + + # what is the purpose? autoregressive inference! + def generate(self, start_idx, max_new_tokens): + idx = start_idx + for i in range(max_new_tokens): + idx_cond = idx[:,-self.max_seq_len:] + logits, loss = self(idx_cond) + logits = logits[:,-1,:] # B, C + probs = F.softmax(logits, dim=-1) + idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) + idx = torch.cat([idx, idx_next], dim=1) + return idx + +################################################################################################### + + +batch_size = 50 +gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient +max_iters = int(4e4) #40000 +eval_interval = 300 +learning_rate = 1e-3 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +eval_iters = 200 +layers_num = 2 +h_dim = 64 +max_seq_len = 64 +num_heads = 1 +dropout_rate = 0.1 +pixel_size = 3.6e-6 +assert batch_size % gradient_accumulation_steps == 0 +# CUDA_VISIBLE_DEVICES=1 python src/main.py + +MODEL_CLASS = GPT2 +train_data_path = Path("./data/wiki.train.tokens") +val_data_path = Path("./data/wiki.valid.tokens") +test_data_path = Path("./data/wiki.test.tokens") +comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}" + +logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +writer = SummaryWriter(logs_dir) +script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name) +print("Logs dir:", logs_dir) +# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script +shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository +script_snapshot_path.chmod(0o500) # with read-only permission + +# Create standalone checkpoints directory with your specified format +checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +Path(checkpoints_dir).mkdir(parents=True, exist_ok=True) +print("Checkpoints dir:", checkpoints_dir) + +#################################### Dataset ######################################### + +# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt +with open(train_data_path, encoding='utf-8') as f: + train_text = f.read() + +with open(val_data_path, encoding='utf-8') as f: + val_text = f.read() + +with open(test_data_path, encoding='utf-8') as f: + test_text = f.read() + +text = train_text + val_text + test_text + +chars = sorted(set(text)) +vocab_size = len(chars) + +wtoi = {w:i for i,w in enumerate(chars)} +itow = {i:w for i,w in enumerate(chars)} + +encode = lambda s: [wtoi[w] for w in s] +decode = lambda idx: ''.join([itow[i] for i in idx]) + +def get_batch(data, seq_len, batch_size): + ix = torch.randint(len(data)-seq_len, (batch_size, )) + x = torch.stack([data[i:i+seq_len] for i in ix]).to(device) + y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device) + return x, y + +train_data = torch.tensor(encode(train_text), dtype=torch.long) +val_data = torch.tensor(encode(val_text), dtype=torch.long) +test_data = torch.tensor(encode(test_text), dtype=torch.long) + +@torch.no_grad() +def perplexity(model, data): + model.eval() + stride = max(1, len(data) // 10000) + total_loss_sum = 0.0 + total_tokens_count = 0 + losses = [] + for i in range(0, len(data)-max_seq_len-1, stride): + x = data[i:(i+max_seq_len)].to(device) + y = data[(i+1):(i+max_seq_len+1)].to(device) + logits, mean_loss = model(x[None,...], y[None,...]) + num_tokens = y.numel() + total_loss_sum += mean_loss.item() * num_tokens + total_tokens_count += num_tokens + print(f"\rppl {i}/{len(data)-max_seq_len-1}", end="") + return np.exp(total_loss_sum / total_tokens_count) + +#################################### Model #########################################mo +def complete(m, start_idxs=[0], max_new_tokens=100): + start_idx = torch.tensor([start_idxs]).to(device) + generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens) + return decode(generated_tokens[0].tolist()) + +m = MODEL_CLASS( + vocab_size=vocab_size, + h_dim=h_dim, + max_seq_len=max_seq_len, + num_heads=num_heads, + pixel_size=pixel_size, + layers_num=layers_num +) +m = m.to(device) +writer.add_text('model', str(m), 0) + +optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) + +#################################### Checkpoint Function ######################################### + +def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir): + """Save model checkpoint with complete training state""" + checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt' + torch.save({ + 'step': step, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss, + 'config': config, + 'wtoi': wtoi, + 'itow': itow, + }, checkpoint_path) + +# Training config for checkpointing +training_config = { + 'vocab_size': vocab_size, + 'layers_num': layers_num, + 'h_dim': h_dim, + 'max_seq_len': max_seq_len, + 'num_heads': num_heads, + 'dropout_rate': dropout_rate, + 'batch_size': batch_size, + 'learning_rate': learning_rate, + 'gradient_accumulation_steps': gradient_accumulation_steps, + 'pixel_size': pixel_size, + 'max_iters': max_iters, +} +#################################### Train ######################################### + +start_time = datetime.now() +print("Started at:", start_time) + +m.eval() +task_prompts = [ + "1 2 3 4 5", + "The capital of France is", + "The chemical symbol of gold is", + "If yesterday was Friday, then tomorrow will be", + "The opposite of hot is", + "The planets of the solar system are:", + "My favorite color is", + "If 5*x + 3 = 13, then x is", +] +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, 0) + +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, 0) + +for i in range(max_iters): + m.train() + optimizer.zero_grad(set_to_none=True) + accumulated_loss = 0.0 + for j in range(gradient_accumulation_steps): + xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps) + logits, loss = m(xb, yb) + loss = loss / gradient_accumulation_steps + loss.backward() + accumulated_loss += loss.item() + if i % 100 == 0: + writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i) + optimizer.step() + writer.add_scalar('loss', accumulated_loss, i) + print(f"\r{i}/{max_iters} {accumulated_loss}", end="") + if i % 5000 == 0: + m.eval() + ppl = perplexity(model=m, data=val_data) + writer.add_scalar('val_perplexity', ppl.item(), i) + print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}") + writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) + task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) + writer.add_text('completions/task', task_results, i) + save_checkpoint( + model=m, + optimizer=optimizer, + step=i, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir + ) + +m.eval() +ppl = perplexity(model=m, data=val_data) +print(f"\r{i+1}/{max_iters} {accumulated_loss}") +print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") +writer.add_scalar('val_perplexity', ppl.item(), i+1) +writer.add_scalar('loss', accumulated_loss, i) + +ppl = perplexity(model=m, data=test_data) +writer.add_scalar('test_perplexity', ppl.item(), i+1) +print(f"\rTest Perplexity at {i}: {ppl}") + +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, i+1) +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, i+1) + +# Save final checkpoint +save_checkpoint( + model=m, + optimizer=optimizer, + step=max_iters, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir +) +print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}") \ No newline at end of file diff --git a/src/train_optics_trainable_focal_dist_lens_64.py b/src/train_optics_trainable_focal_dist_lens_64.py new file mode 100644 index 0000000..a0aac71 --- /dev/null +++ b/src/train_optics_trainable_focal_dist_lens_64.py @@ -0,0 +1,399 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +from einops import rearrange +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime +import sys +from pathlib import Path +import optical_matrix_multiplication as omm +import shutil +seed = 1337 +torch.manual_seed(seed) + +############################### MODEL ############################################################# + +def new_formula(sim, tensor_1, tensor_2): + tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1 + tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2 + device = tensor_1.device + + A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0) + A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0) + B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0) + B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0) + + max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений + max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений + max_B_pos = torch.max(B_pos) + max_B_neg = torch.max(B_neg) + + zero_template = torch.zeros_like( + torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3])) + + if max_A_pos > 0 and max_B_pos > 0: + t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos + else: + t1 = zero_template.clone().to(device) + + if max_A_pos > 0 and max_B_neg > 0: + t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg + else: + t2 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_pos > 0: + t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos + else: + t3 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_neg > 0: + t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg + else: + t4 = zero_template.clone().to(device) + + return (t1 - t2 - t3 + t4)[0,:,:,:] + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = nn.Linear(h_dim, 4*h_dim) + self.ff2 = nn.Linear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + # now trainable parameters are in optical mat mul class, one scalar per simulator + # here we use only TrainableLensOpticalMul and scalars for each layer + self.k1 = nn.Parameter(torch.randn(1)) + self.k2 = nn.Parameter(torch.randn(1)) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = self.k1 * new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5) + tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) + scores = scores.masked_fill(tril == 0, float('-inf')) + attention = nn.functional.softmax(scores, dim=2) + output = self.k2 * new_formula(self.sim_output, attention, v) + return self.o_proj(self.gather_heads(output, *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +class OpticGPT2TrainableScalarAndFocalDistLens(nn.Module): + def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, + pixel_size = 3.6e-6): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + if max_seq_len != 512: + self.sim_scores = omm.TrainableFocalDistLensOpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01) + ) + self.sim_output = omm.TrainableFocalDistLensOpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01) + ) + if max_seq_len == 512: + self.sim_scores = omm.TrainableFocalDistLensOpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2) + ) + self.sim_output = omm.TrainableFocalDistLensOpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2) + ) + self.sim_scores = omm.ScatterDataParallel(self.sim_scores) + self.sim_output = omm.ScatterDataParallel(self.sim_output) + + self.layers = nn.ModuleList([ + TransformerLayer( + h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads, + dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) + for _ in range(layers_num)]) + self.tok_embeds = nn.Embedding(vocab_size, h_dim) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + self.lm_head = nn.Linear(h_dim, vocab_size) + + def forward(self, x, targets=None): + x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + logits = self.lm_head(x) # B,T,C + loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None + return logits, loss + + # what is the purpose? autoregressive inference! + def generate(self, start_idx, max_new_tokens): + idx = start_idx + for i in range(max_new_tokens): + idx_cond = idx[:,-self.max_seq_len:] + logits, loss = self(idx_cond) + logits = logits[:,-1,:] # B, C + probs = F.softmax(logits, dim=-1) + idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) + idx = torch.cat([idx, idx_next], dim=1) + return idx + + def log_trainable_optic_params( + self, + writer: SummaryWriter, + global_step, + ): + self.sim_scores.module.log_cylind_lens_operator_x(writer, "sim_scores", global_step) + self.sim_output.module.log_cylind_lens_operator_x(writer, "sim_output", global_step) + for i, layer in enumerate(self.layers): + # Using f-string tags to group them nicely in TensorBoard (e.g., Layer_0/k1) + writer.add_scalar(f"optic_scalars/layer_{i}/k1", layer.k1.item(), global_step) + writer.add_scalar(f"optic_scalars/layer_{i}/k2", layer.k2.item(), global_step) + + +################################################################################################### + +batch_size = 50 +gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient +max_iters = 40000 +eval_interval = 300 +learning_rate = 1e-3 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +eval_iters = 200 +layers_num = 2 +h_dim = 64 +max_seq_len = 64 +num_heads = 1 +dropout_rate = 0.1 +pixel_size = 3.6e-6 +assert batch_size % gradient_accumulation_steps == 0 +# CUDA_VISIBLE_DEVICES=1 python src/main.py optic_gpt2_ff seq_128_hdim_64 + +MODEL_CLASS = OpticGPT2TrainableScalarAndFocalDistLens +train_data_path = Path("./data/wiki.train.tokens") +val_data_path = Path("./data/wiki.valid.tokens") +test_data_path = Path("./data/wiki.test.tokens") +comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}" + +logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +writer = SummaryWriter(logs_dir) +script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name) +print("Logs dir:", logs_dir) +# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script +shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository +script_snapshot_path.chmod(0o500) # with read-only permission + + +#################################### Dataset ######################################### + +# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt +with open(train_data_path, encoding='utf-8') as f: + train_text = f.read() + +with open(val_data_path, encoding='utf-8') as f: + val_text = f.read() + +with open(test_data_path, encoding='utf-8') as f: + test_text = f.read() + +text = train_text + val_text + test_text + +chars = sorted(set(text)) +vocab_size = len(chars) + +wtoi = {w:i for i,w in enumerate(chars)} +itow = {i:w for i,w in enumerate(chars)} + +encode = lambda s: [wtoi[w] for w in s] +decode = lambda idx: ''.join([itow[i] for i in idx]) + +def get_batch(data, seq_len, batch_size): + ix = torch.randint(len(data)-seq_len, (batch_size, )) + x = torch.stack([data[i:i+seq_len] for i in ix]).to(device) + y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device) + return x, y + +train_data = torch.tensor(encode(train_text), dtype=torch.long) +val_data = torch.tensor(encode(val_text), dtype=torch.long) +test_data = torch.tensor(encode(test_text), dtype=torch.long) + +@torch.no_grad() +def perplexity(model, data): + stride = max(1, len(data) // 10000) + losses = [] + for i in range(0, len(data)-max_seq_len-1, stride): + x = data[i:(i+max_seq_len)].to(device) + y = data[(i+1):(i+max_seq_len+1)].to(device) + logits, loss = model(x[None,...], y[None,...]) + losses.append(loss.item()) + print(f"\rppl {i}/{len(data)-max_seq_len-1}", end="") + return np.exp(np.mean(losses)) + +#################################### Model #########################################mo +def complete(m, start_idxs=[0], max_new_tokens=100): + start_idx = torch.tensor([start_idxs]).to(device) + generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens) + return decode(generated_tokens[0].tolist()) + +m = MODEL_CLASS( + vocab_size=vocab_size, + h_dim=h_dim, + max_seq_len=max_seq_len, + num_heads=num_heads, + pixel_size=pixel_size, + layers_num=layers_num +) +m = m.to(device) +writer.add_text('model', str(m), 0) + +#################################### Train ######################################### + +optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) + +m.eval() + +task_prompts = [ + "1 2 3 4 5", + "The capital of France is", + "The chemical symbol of gold is", + "If yesterday was Friday, then tomorrow will be", + "The opposite of hot is", + "The planets of the solar system are:", + "My favorite color is", + "If 5*x + 3 = 13, then x is", +] +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, 0) + +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, 0) + +for i in range(max_iters): + m.train() + optimizer.zero_grad(set_to_none=True) + accumulated_loss = 0.0 + for j in range(gradient_accumulation_steps): + xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps) + logits, loss = m(xb, yb) + loss = loss / gradient_accumulation_steps + loss.backward() + accumulated_loss += loss.item() + if i % 100 == 0: + writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i) + optimizer.step() + writer.add_scalar('loss', accumulated_loss, i) + print(f"\r{i}/{max_iters} {accumulated_loss}", end="") + if i % 5000 == 0: + m.eval() + ppl = perplexity(model=m, data=val_data) + writer.add_scalar('val_perplexity', ppl.item(), i) + print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") + writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) + task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) + writer.add_text('completions/task', task_results, i) + m.log_trainable_optic_params(writer, i) + +m.eval() +ppl = perplexity(model=m, data=val_data) +print(f"\r{i+1}/{max_iters} {accumulated_loss}") +print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") +writer.add_scalar('val_perplexity', ppl.item(), i+1) +writer.add_scalar('loss', accumulated_loss, i) + +ppl = perplexity(model=m, data=test_data) +writer.add_scalar('test_perplexity', ppl.item(), i+1) +print(f"\rTest Perplexity at {i}: {ppl}") + +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, i+1) +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, i+1) \ No newline at end of file diff --git a/src/train_optics_trainable_lens_128.py b/src/train_optics_trainable_lens_128.py new file mode 100644 index 0000000..799f312 --- /dev/null +++ b/src/train_optics_trainable_lens_128.py @@ -0,0 +1,464 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +from einops import rearrange +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime +import sys +from pathlib import Path +import optical_matrix_multiplication as omm +import shutil +seed = 1337 +torch.manual_seed(seed) + +############################### MODEL ############################################################# + +def new_formula(sim, tensor_1, tensor_2): + tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1 + tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2 + device = tensor_1.device + + A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0) + A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0) + B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0) + B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0) + + max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений + max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений + max_B_pos = torch.max(B_pos) + max_B_neg = torch.max(B_neg) + + zero_template = torch.zeros_like( + torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3])) + + if max_A_pos > 0 and max_B_pos > 0: + t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos + else: + t1 = zero_template.clone().to(device) + + if max_A_pos > 0 and max_B_neg > 0: + t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg + else: + t2 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_pos > 0: + t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos + else: + t3 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_neg > 0: + t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg + else: + t4 = zero_template.clone().to(device) + + return (t1 - t2 - t3 + t4)[0,:,:,:] + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = nn.Linear(h_dim, 4*h_dim) + self.ff2 = nn.Linear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + # now trainable parameters are in optical mat mul class, one scalar per simulator + # here we use only TrainableLensOpticalMul and scalars for each layer + self.k1 = nn.Parameter(torch.randn(1)) + self.k2 = nn.Parameter(torch.randn(1)) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = self.k1 * new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5) + tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) + scores = scores.masked_fill(tril == 0, float('-inf')) + attention = nn.functional.softmax(scores, dim=2) + output = self.k2 * new_formula(self.sim_output, attention, v) + return self.o_proj(self.gather_heads(output, *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +class OpticGPT2TrainableScalarAndLens(nn.Module): + def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, + pixel_size = 3.6e-6): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + if max_seq_len != 512: + self.sim_scores = omm.TrainableLensOpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01) + ) + self.sim_output = omm.TrainableLensOpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01) + ) + if max_seq_len == 512: + self.sim_scores = omm.TrainableLensOpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2) + ) + self.sim_output = omm.TrainableLensOpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2) + ) + self.sim_scores = omm.ScatterDataParallel(self.sim_scores) + self.sim_output = omm.ScatterDataParallel(self.sim_output) + + self.layers = nn.ModuleList([ + TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads, + dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) + for _ in range(layers_num)]) + self.tok_embeds = nn.Embedding(vocab_size, h_dim) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + self.lm_head = nn.Linear(h_dim, vocab_size) + + def forward(self, x, targets=None): + x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + logits = self.lm_head(x) # B,T,C + loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None + return logits, loss + + # what is the purpose? autoregressive inference! + def generate(self, start_idx, max_new_tokens): + idx = start_idx + for i in range(max_new_tokens): + idx_cond = idx[:,-self.max_seq_len:] + logits, loss = self(idx_cond) + logits = logits[:,-1,:] # B, C + probs = F.softmax(logits, dim=-1) + idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) + idx = torch.cat([idx, idx_next], dim=1) + return idx + + def log_trainable_optic_params( + self, + writer: SummaryWriter, + global_step, + ): + self.sim_scores.module.log_cylind_lens_operator_x(writer, "sim_scores", global_step) + self.sim_output.module.log_cylind_lens_operator_x(writer, "sim_output", global_step) + for i, layer in enumerate(self.layers): + # Using f-string tags to group them nicely in TensorBoard (e.g., Layer_0/k1) + writer.add_scalar(f"optic_scalars/layer_{i}/k1", layer.k1.item(), global_step) + writer.add_scalar(f"optic_scalars/layer_{i}/k2", layer.k2.item(), global_step) + + +################################################################################################### + +batch_size = 50 +gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient +max_iters = int(4e4) #40000 +eval_interval = 300 +learning_rate = 1e-3 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +eval_iters = 200 +layers_num = 2 +h_dim = 64 +max_seq_len = 128 +num_heads = 1 +dropout_rate = 0.1 +pixel_size = 3.6e-6 +assert batch_size % gradient_accumulation_steps == 0 +# CUDA_VISIBLE_DEVICES=1 python src/main.py optic_gpt2_ff seq_128_hdim_64 + +MODEL_CLASS = OpticGPT2TrainableScalarAndLens +train_data_path = Path("./data/wiki.train.tokens") +val_data_path = Path("./data/wiki.valid.tokens") +test_data_path = Path("./data/wiki.test.tokens") +comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}" + +logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +writer = SummaryWriter(logs_dir) +script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name) +print("Logs dir:", logs_dir) +# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script +shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository +script_snapshot_path.chmod(0o500) # with read-only permission + +# Create standalone checkpoints directory with your specified format +checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +Path(checkpoints_dir).mkdir(parents=True, exist_ok=True) +print("Checkpoints dir:", checkpoints_dir) + +#################################### Dataset ######################################### + +# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt +with open(train_data_path, encoding='utf-8') as f: + train_text = f.read() + +with open(val_data_path, encoding='utf-8') as f: + val_text = f.read() + +with open(test_data_path, encoding='utf-8') as f: + test_text = f.read() + +text = train_text + val_text + test_text + +chars = sorted(set(text)) +vocab_size = len(chars) + +wtoi = {w:i for i,w in enumerate(chars)} +itow = {i:w for i,w in enumerate(chars)} + +encode = lambda s: [wtoi[w] for w in s] +decode = lambda idx: ''.join([itow[i] for i in idx]) + +def get_batch(data, seq_len, batch_size): + ix = torch.randint(len(data)-seq_len, (batch_size, )) + x = torch.stack([data[i:i+seq_len] for i in ix]).to(device) + y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device) + return x, y + +train_data = torch.tensor(encode(train_text), dtype=torch.long) +val_data = torch.tensor(encode(val_text), dtype=torch.long) +test_data = torch.tensor(encode(test_text), dtype=torch.long) + +@torch.no_grad() +def perplexity(model, data): + model.eval() + stride = max(1, len(data) // 10000) + total_loss_sum = 0.0 + total_tokens_count = 0 + losses = [] + for i in range(0, len(data)-max_seq_len-1, stride): + x = data[i:(i+max_seq_len)].to(device) + y = data[(i+1):(i+max_seq_len+1)].to(device) + logits, mean_loss = model(x[None,...], y[None,...]) + num_tokens = y.numel() + total_loss_sum += mean_loss.item() * num_tokens + total_tokens_count += num_tokens + print(f"\rppl {i}/{len(data)-max_seq_len-1}", end="") + return np.exp(total_loss_sum / total_tokens_count) + +#################################### Model #########################################mo +def complete(m, start_idxs=[0], max_new_tokens=100): + start_idx = torch.tensor([start_idxs]).to(device) + generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens) + return decode(generated_tokens[0].tolist()) + +m = MODEL_CLASS( + vocab_size=vocab_size, + h_dim=h_dim, + max_seq_len=max_seq_len, + num_heads=num_heads, + pixel_size=pixel_size, + layers_num=layers_num +) +m = m.to(device) +writer.add_text('model', str(m), 0) + +#################################### Train ######################################### + +optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) + +#################################### Checkpoint Function ######################################### + +def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir): + """Save model checkpoint with complete training state""" + checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt' + torch.save({ + 'step': step, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss, + 'config': config, + 'wtoi': wtoi, + 'itow': itow, + }, checkpoint_path) + +# Training config for checkpointing +training_config = { + 'vocab_size': vocab_size, + 'layers_num': layers_num, + 'h_dim': h_dim, + 'max_seq_len': max_seq_len, + 'num_heads': num_heads, + 'dropout_rate': dropout_rate, + 'batch_size': batch_size, + 'learning_rate': learning_rate, + 'gradient_accumulation_steps': gradient_accumulation_steps, + 'pixel_size': pixel_size, + 'max_iters': max_iters, +} +#################################### Train ######################################### + +start_time = datetime.now() +print("Started at:", start_time) + +m.eval() + +task_prompts = [ + "1 2 3 4 5", + "The capital of France is", + "The chemical symbol of gold is", + "If yesterday was Friday, then tomorrow will be", + "The opposite of hot is", + "The planets of the solar system are:", + "My favorite color is", + "If 5*x + 3 = 13, then x is", +] +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, 0) + +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, 0) + +for i in range(max_iters): + m.train() + optimizer.zero_grad(set_to_none=True) + accumulated_loss = 0.0 + for j in range(gradient_accumulation_steps): + xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps) + logits, loss = m(xb, yb) + loss = loss / gradient_accumulation_steps + loss.backward() + accumulated_loss += loss.item() + if i % 100 == 0: + writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i) + optimizer.step() + writer.add_scalar('loss', accumulated_loss, i) + print(f"\r{i}/{max_iters} {accumulated_loss}", end="") + if i % 5000 == 0: + m.eval() + ppl = perplexity(model=m, data=val_data) + writer.add_scalar('val_perplexity', ppl.item(), i) + print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}") + writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) + task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) + writer.add_text('completions/task', task_results, i) + m.log_trainable_optic_params(writer, i) + save_checkpoint( + model=m, + optimizer=optimizer, + step=i, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir + ) + +m.eval() +ppl = perplexity(model=m, data=val_data) +print(f"\r{i+1}/{max_iters} {accumulated_loss}") +print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") +writer.add_scalar('val_perplexity', ppl.item(), i+1) +writer.add_scalar('loss', accumulated_loss, i) + +ppl = perplexity(model=m, data=test_data) +writer.add_scalar('test_perplexity', ppl.item(), i+1) +print(f"\rTest Perplexity at {i}: {ppl}") + +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, i+1) +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, i+1) + +# Save final checkpoint +save_checkpoint( + model=m, + optimizer=optimizer, + step=max_iters, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir +) +print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}") \ No newline at end of file diff --git a/src/train_optics_trainable_lens_256.py b/src/train_optics_trainable_lens_256.py new file mode 100644 index 0000000..d7df7f1 --- /dev/null +++ b/src/train_optics_trainable_lens_256.py @@ -0,0 +1,464 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +from einops import rearrange +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime +import sys +from pathlib import Path +import optical_matrix_multiplication as omm +import shutil +seed = 1337 +torch.manual_seed(seed) + +############################### MODEL ############################################################# + +def new_formula(sim, tensor_1, tensor_2): + tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1 + tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2 + device = tensor_1.device + + A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0) + A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0) + B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0) + B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0) + + max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений + max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений + max_B_pos = torch.max(B_pos) + max_B_neg = torch.max(B_neg) + + zero_template = torch.zeros_like( + torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3])) + + if max_A_pos > 0 and max_B_pos > 0: + t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos + else: + t1 = zero_template.clone().to(device) + + if max_A_pos > 0 and max_B_neg > 0: + t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg + else: + t2 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_pos > 0: + t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos + else: + t3 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_neg > 0: + t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg + else: + t4 = zero_template.clone().to(device) + + return (t1 - t2 - t3 + t4)[0,:,:,:] + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = nn.Linear(h_dim, 4*h_dim) + self.ff2 = nn.Linear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + # now trainable parameters are in optical mat mul class, one scalar per simulator + # here we use only TrainableLensOpticalMul and scalars for each layer + self.k1 = nn.Parameter(torch.randn(1)) + self.k2 = nn.Parameter(torch.randn(1)) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = self.k1 * new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5) + tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) + scores = scores.masked_fill(tril == 0, float('-inf')) + attention = nn.functional.softmax(scores, dim=2) + output = self.k2 * new_formula(self.sim_output, attention, v) + return self.o_proj(self.gather_heads(output, *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +class OpticGPT2TrainableScalarAndLens(nn.Module): + def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, + pixel_size = 3.6e-6): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + if max_seq_len != 512: + self.sim_scores = omm.TrainableLensOpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01) + ) + self.sim_output = omm.TrainableLensOpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01) + ) + if max_seq_len == 512: + self.sim_scores = omm.TrainableLensOpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2) + ) + self.sim_output = omm.TrainableLensOpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2) + ) + self.sim_scores = omm.ScatterDataParallel(self.sim_scores) + self.sim_output = omm.ScatterDataParallel(self.sim_output) + + self.layers = nn.ModuleList([ + TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads, + dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) + for _ in range(layers_num)]) + self.tok_embeds = nn.Embedding(vocab_size, h_dim) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + self.lm_head = nn.Linear(h_dim, vocab_size) + + def forward(self, x, targets=None): + x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + logits = self.lm_head(x) # B,T,C + loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None + return logits, loss + + # what is the purpose? autoregressive inference! + def generate(self, start_idx, max_new_tokens): + idx = start_idx + for i in range(max_new_tokens): + idx_cond = idx[:,-self.max_seq_len:] + logits, loss = self(idx_cond) + logits = logits[:,-1,:] # B, C + probs = F.softmax(logits, dim=-1) + idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) + idx = torch.cat([idx, idx_next], dim=1) + return idx + + def log_trainable_optic_params( + self, + writer: SummaryWriter, + global_step, + ): + self.sim_scores.module.log_cylind_lens_operator_x(writer, "sim_scores", global_step) + self.sim_output.module.log_cylind_lens_operator_x(writer, "sim_output", global_step) + for i, layer in enumerate(self.layers): + # Using f-string tags to group them nicely in TensorBoard (e.g., Layer_0/k1) + writer.add_scalar(f"optic_scalars/layer_{i}/k1", layer.k1.item(), global_step) + writer.add_scalar(f"optic_scalars/layer_{i}/k2", layer.k2.item(), global_step) + + +################################################################################################### + +batch_size = 50 +gradient_accumulation_steps = 5 # check this impl for correctness https://unsloth.ai/blog/gradient +max_iters = int(4e4) #40000 +eval_interval = 300 +learning_rate = 1e-3 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +eval_iters = 200 +layers_num = 2 +h_dim = 64 +max_seq_len = 256 +num_heads = 1 +dropout_rate = 0.1 +pixel_size = 3.6e-6 +assert batch_size % gradient_accumulation_steps == 0 +# CUDA_VISIBLE_DEVICES=1 python src/main.py optic_gpt2_ff seq_128_hdim_64 + +MODEL_CLASS = OpticGPT2TrainableScalarAndLens +train_data_path = Path("./data/wiki.train.tokens") +val_data_path = Path("./data/wiki.valid.tokens") +test_data_path = Path("./data/wiki.test.tokens") +comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}" + +logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +writer = SummaryWriter(logs_dir) +script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name) +print("Logs dir:", logs_dir) +# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script +shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository +script_snapshot_path.chmod(0o500) # with read-only permission + +# Create standalone checkpoints directory with your specified format +checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +Path(checkpoints_dir).mkdir(parents=True, exist_ok=True) +print("Checkpoints dir:", checkpoints_dir) + +#################################### Dataset ######################################### + +# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt +with open(train_data_path, encoding='utf-8') as f: + train_text = f.read() + +with open(val_data_path, encoding='utf-8') as f: + val_text = f.read() + +with open(test_data_path, encoding='utf-8') as f: + test_text = f.read() + +text = train_text + val_text + test_text + +chars = sorted(set(text)) +vocab_size = len(chars) + +wtoi = {w:i for i,w in enumerate(chars)} +itow = {i:w for i,w in enumerate(chars)} + +encode = lambda s: [wtoi[w] for w in s] +decode = lambda idx: ''.join([itow[i] for i in idx]) + +def get_batch(data, seq_len, batch_size): + ix = torch.randint(len(data)-seq_len, (batch_size, )) + x = torch.stack([data[i:i+seq_len] for i in ix]).to(device) + y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device) + return x, y + +train_data = torch.tensor(encode(train_text), dtype=torch.long) +val_data = torch.tensor(encode(val_text), dtype=torch.long) +test_data = torch.tensor(encode(test_text), dtype=torch.long) + +@torch.no_grad() +def perplexity(model, data): + model.eval() + stride = max(1, len(data) // 10000) + total_loss_sum = 0.0 + total_tokens_count = 0 + losses = [] + for i in range(0, len(data)-max_seq_len-1, stride): + x = data[i:(i+max_seq_len)].to(device) + y = data[(i+1):(i+max_seq_len+1)].to(device) + logits, mean_loss = model(x[None,...], y[None,...]) + num_tokens = y.numel() + total_loss_sum += mean_loss.item() * num_tokens + total_tokens_count += num_tokens + print(f"\rppl {i}/{len(data)-max_seq_len-1}", end="") + return np.exp(total_loss_sum / total_tokens_count) + +#################################### Model #########################################mo +def complete(m, start_idxs=[0], max_new_tokens=100): + start_idx = torch.tensor([start_idxs]).to(device) + generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens) + return decode(generated_tokens[0].tolist()) + +m = MODEL_CLASS( + vocab_size=vocab_size, + h_dim=h_dim, + max_seq_len=max_seq_len, + num_heads=num_heads, + pixel_size=pixel_size, + layers_num=layers_num +) +m = m.to(device) +writer.add_text('model', str(m), 0) + +#################################### Train ######################################### + +optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) + +#################################### Checkpoint Function ######################################### + +def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir): + """Save model checkpoint with complete training state""" + checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt' + torch.save({ + 'step': step, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss, + 'config': config, + 'wtoi': wtoi, + 'itow': itow, + }, checkpoint_path) + +# Training config for checkpointing +training_config = { + 'vocab_size': vocab_size, + 'layers_num': layers_num, + 'h_dim': h_dim, + 'max_seq_len': max_seq_len, + 'num_heads': num_heads, + 'dropout_rate': dropout_rate, + 'batch_size': batch_size, + 'learning_rate': learning_rate, + 'gradient_accumulation_steps': gradient_accumulation_steps, + 'pixel_size': pixel_size, + 'max_iters': max_iters, +} +#################################### Train ######################################### + +start_time = datetime.now() +print("Started at:", start_time) + +m.eval() + +task_prompts = [ + "1 2 3 4 5", + "The capital of France is", + "The chemical symbol of gold is", + "If yesterday was Friday, then tomorrow will be", + "The opposite of hot is", + "The planets of the solar system are:", + "My favorite color is", + "If 5*x + 3 = 13, then x is", +] +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, 0) + +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, 0) + +for i in range(max_iters): + m.train() + optimizer.zero_grad(set_to_none=True) + accumulated_loss = 0.0 + for j in range(gradient_accumulation_steps): + xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps) + logits, loss = m(xb, yb) + loss = loss / gradient_accumulation_steps + loss.backward() + accumulated_loss += loss.item() + if i % 100 == 0: + writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i) + optimizer.step() + writer.add_scalar('loss', accumulated_loss, i) + print(f"\r{i}/{max_iters} {accumulated_loss}", end="") + if i % 5000 == 0: + m.eval() + ppl = perplexity(model=m, data=val_data) + writer.add_scalar('val_perplexity', ppl.item(), i) + print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}") + writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) + task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) + writer.add_text('completions/task', task_results, i) + m.log_trainable_optic_params(writer, i) + save_checkpoint( + model=m, + optimizer=optimizer, + step=i, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir + ) + +m.eval() +ppl = perplexity(model=m, data=val_data) +print(f"\r{i+1}/{max_iters} {accumulated_loss}") +print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") +writer.add_scalar('val_perplexity', ppl.item(), i+1) +writer.add_scalar('loss', accumulated_loss, i) + +ppl = perplexity(model=m, data=test_data) +writer.add_scalar('test_perplexity', ppl.item(), i+1) +print(f"\rTest Perplexity at {i}: {ppl}") + +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, i+1) +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, i+1) + +# Save final checkpoint +save_checkpoint( + model=m, + optimizer=optimizer, + step=max_iters, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir +) +print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}") \ No newline at end of file diff --git a/src/train_optics_trainable_lens_512.py b/src/train_optics_trainable_lens_512.py new file mode 100644 index 0000000..16bd46e --- /dev/null +++ b/src/train_optics_trainable_lens_512.py @@ -0,0 +1,464 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +from einops import rearrange +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime +import sys +from pathlib import Path +import optical_matrix_multiplication as omm +import shutil +seed = 1337 +torch.manual_seed(seed) + +############################### MODEL ############################################################# + +def new_formula(sim, tensor_1, tensor_2): + tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1 + tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2 + device = tensor_1.device + + A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0) + A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0) + B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0) + B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0) + + max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений + max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений + max_B_pos = torch.max(B_pos) + max_B_neg = torch.max(B_neg) + + zero_template = torch.zeros_like( + torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3])) + + if max_A_pos > 0 and max_B_pos > 0: + t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos + else: + t1 = zero_template.clone().to(device) + + if max_A_pos > 0 and max_B_neg > 0: + t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg + else: + t2 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_pos > 0: + t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos + else: + t3 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_neg > 0: + t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg + else: + t4 = zero_template.clone().to(device) + + return (t1 - t2 - t3 + t4)[0,:,:,:] + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = nn.Linear(h_dim, 4*h_dim) + self.ff2 = nn.Linear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + # now trainable parameters are in optical mat mul class, one scalar per simulator + # here we use only TrainableLensOpticalMul and scalars for each layer + self.k1 = nn.Parameter(torch.randn(1)) + self.k2 = nn.Parameter(torch.randn(1)) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = self.k1 * new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5) + tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) + scores = scores.masked_fill(tril == 0, float('-inf')) + attention = nn.functional.softmax(scores, dim=2) + output = self.k2 * new_formula(self.sim_output, attention, v) + return self.o_proj(self.gather_heads(output, *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +class OpticGPT2TrainableScalarAndLens(nn.Module): + def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, + pixel_size = 3.6e-6): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + if max_seq_len != 512: + self.sim_scores = omm.TrainableLensOpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01) + ) + self.sim_output = omm.TrainableLensOpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01) + ) + if max_seq_len == 512: + self.sim_scores = omm.TrainableLensOpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2) + ) + self.sim_output = omm.TrainableLensOpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2) + ) + self.sim_scores = omm.ScatterDataParallel(self.sim_scores) + self.sim_output = omm.ScatterDataParallel(self.sim_output) + + self.layers = nn.ModuleList([ + TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads, + dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) + for _ in range(layers_num)]) + self.tok_embeds = nn.Embedding(vocab_size, h_dim) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + self.lm_head = nn.Linear(h_dim, vocab_size) + + def forward(self, x, targets=None): + x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + logits = self.lm_head(x) # B,T,C + loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None + return logits, loss + + # what is the purpose? autoregressive inference! + def generate(self, start_idx, max_new_tokens): + idx = start_idx + for i in range(max_new_tokens): + idx_cond = idx[:,-self.max_seq_len:] + logits, loss = self(idx_cond) + logits = logits[:,-1,:] # B, C + probs = F.softmax(logits, dim=-1) + idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) + idx = torch.cat([idx, idx_next], dim=1) + return idx + + def log_trainable_optic_params( + self, + writer: SummaryWriter, + global_step, + ): + self.sim_scores.module.log_cylind_lens_operator_x(writer, "sim_scores", global_step) + self.sim_output.module.log_cylind_lens_operator_x(writer, "sim_output", global_step) + for i, layer in enumerate(self.layers): + # Using f-string tags to group them nicely in TensorBoard (e.g., Layer_0/k1) + writer.add_scalar(f"optic_scalars/layer_{i}/k1", layer.k1.item(), global_step) + writer.add_scalar(f"optic_scalars/layer_{i}/k2", layer.k2.item(), global_step) + + +################################################################################################### + +batch_size = 50 +gradient_accumulation_steps = 10 # check this impl for correctness https://unsloth.ai/blog/gradient +max_iters = int(4e4) #40000 +eval_interval = 300 +learning_rate = 1e-3 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +eval_iters = 200 +layers_num = 2 +h_dim = 64 +max_seq_len = 512 +num_heads = 1 +dropout_rate = 0.1 +pixel_size = 3.6e-6 +assert batch_size % gradient_accumulation_steps == 0 +# CUDA_VISIBLE_DEVICES=1 python src/main.py optic_gpt2_ff seq_128_hdim_64 + +MODEL_CLASS = OpticGPT2TrainableScalarAndLens +train_data_path = Path("./data/wiki.train.tokens") +val_data_path = Path("./data/wiki.valid.tokens") +test_data_path = Path("./data/wiki.test.tokens") +comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}" + +logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +writer = SummaryWriter(logs_dir) +script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name) +print("Logs dir:", logs_dir) +# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script +shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository +script_snapshot_path.chmod(0o500) # with read-only permission + +# Create standalone checkpoints directory with your specified format +checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +Path(checkpoints_dir).mkdir(parents=True, exist_ok=True) +print("Checkpoints dir:", checkpoints_dir) + +#################################### Dataset ######################################### + +# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt +with open(train_data_path, encoding='utf-8') as f: + train_text = f.read() + +with open(val_data_path, encoding='utf-8') as f: + val_text = f.read() + +with open(test_data_path, encoding='utf-8') as f: + test_text = f.read() + +text = train_text + val_text + test_text + +chars = sorted(set(text)) +vocab_size = len(chars) + +wtoi = {w:i for i,w in enumerate(chars)} +itow = {i:w for i,w in enumerate(chars)} + +encode = lambda s: [wtoi[w] for w in s] +decode = lambda idx: ''.join([itow[i] for i in idx]) + +def get_batch(data, seq_len, batch_size): + ix = torch.randint(len(data)-seq_len, (batch_size, )) + x = torch.stack([data[i:i+seq_len] for i in ix]).to(device) + y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device) + return x, y + +train_data = torch.tensor(encode(train_text), dtype=torch.long) +val_data = torch.tensor(encode(val_text), dtype=torch.long) +test_data = torch.tensor(encode(test_text), dtype=torch.long) + +@torch.no_grad() +def perplexity(model, data): + model.eval() + stride = max(1, len(data) // 10000) + total_loss_sum = 0.0 + total_tokens_count = 0 + losses = [] + for i in range(0, len(data)-max_seq_len-1, stride): + x = data[i:(i+max_seq_len)].to(device) + y = data[(i+1):(i+max_seq_len+1)].to(device) + logits, mean_loss = model(x[None,...], y[None,...]) + num_tokens = y.numel() + total_loss_sum += mean_loss.item() * num_tokens + total_tokens_count += num_tokens + print(f"\rppl {i}/{len(data)-max_seq_len-1}", end="") + return np.exp(total_loss_sum / total_tokens_count) + +#################################### Model #########################################mo +def complete(m, start_idxs=[0], max_new_tokens=100): + start_idx = torch.tensor([start_idxs]).to(device) + generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens) + return decode(generated_tokens[0].tolist()) + +m = MODEL_CLASS( + vocab_size=vocab_size, + h_dim=h_dim, + max_seq_len=max_seq_len, + num_heads=num_heads, + pixel_size=pixel_size, + layers_num=layers_num +) +m = m.to(device) +writer.add_text('model', str(m), 0) + +#################################### Train ######################################### + +optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) + +#################################### Checkpoint Function ######################################### + +def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir): + """Save model checkpoint with complete training state""" + checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt' + torch.save({ + 'step': step, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss, + 'config': config, + 'wtoi': wtoi, + 'itow': itow, + }, checkpoint_path) + +# Training config for checkpointing +training_config = { + 'vocab_size': vocab_size, + 'layers_num': layers_num, + 'h_dim': h_dim, + 'max_seq_len': max_seq_len, + 'num_heads': num_heads, + 'dropout_rate': dropout_rate, + 'batch_size': batch_size, + 'learning_rate': learning_rate, + 'gradient_accumulation_steps': gradient_accumulation_steps, + 'pixel_size': pixel_size, + 'max_iters': max_iters, +} +#################################### Train ######################################### + +start_time = datetime.now() +print("Started at:", start_time) + +m.eval() + +task_prompts = [ + "1 2 3 4 5", + "The capital of France is", + "The chemical symbol of gold is", + "If yesterday was Friday, then tomorrow will be", + "The opposite of hot is", + "The planets of the solar system are:", + "My favorite color is", + "If 5*x + 3 = 13, then x is", +] +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, 0) + +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, 0) + +for i in range(max_iters): + m.train() + optimizer.zero_grad(set_to_none=True) + accumulated_loss = 0.0 + for j in range(gradient_accumulation_steps): + xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps) + logits, loss = m(xb, yb) + loss = loss / gradient_accumulation_steps + loss.backward() + accumulated_loss += loss.item() + if i % 100 == 0: + writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i) + optimizer.step() + writer.add_scalar('loss', accumulated_loss, i) + print(f"\r{i}/{max_iters} {accumulated_loss}", end="") + if i % 5000 == 0: + m.eval() + ppl = perplexity(model=m, data=val_data) + writer.add_scalar('val_perplexity', ppl.item(), i) + print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}") + writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) + task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) + writer.add_text('completions/task', task_results, i) + m.log_trainable_optic_params(writer, i) + save_checkpoint( + model=m, + optimizer=optimizer, + step=i, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir + ) + +m.eval() +ppl = perplexity(model=m, data=val_data) +print(f"\r{i+1}/{max_iters} {accumulated_loss}") +print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") +writer.add_scalar('val_perplexity', ppl.item(), i+1) +writer.add_scalar('loss', accumulated_loss, i) + +ppl = perplexity(model=m, data=test_data) +writer.add_scalar('test_perplexity', ppl.item(), i+1) +print(f"\rTest Perplexity at {i}: {ppl}") + +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, i+1) +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, i+1) + +# Save final checkpoint +save_checkpoint( + model=m, + optimizer=optimizer, + step=max_iters, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir +) +print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}") \ No newline at end of file diff --git a/src/train_optics_trainable_lens_64.py b/src/train_optics_trainable_lens_64.py new file mode 100644 index 0000000..b3f4c3c --- /dev/null +++ b/src/train_optics_trainable_lens_64.py @@ -0,0 +1,464 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +from einops import rearrange +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime +import sys +from pathlib import Path +import optical_matrix_multiplication as omm +import shutil +seed = 1337 +torch.manual_seed(seed) + +############################### MODEL ############################################################# + +def new_formula(sim, tensor_1, tensor_2): + tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1 + tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2 + device = tensor_1.device + + A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0) + A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0) + B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0) + B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0) + + max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений + max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений + max_B_pos = torch.max(B_pos) + max_B_neg = torch.max(B_neg) + + zero_template = torch.zeros_like( + torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3])) + + if max_A_pos > 0 and max_B_pos > 0: + t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos + else: + t1 = zero_template.clone().to(device) + + if max_A_pos > 0 and max_B_neg > 0: + t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg + else: + t2 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_pos > 0: + t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos + else: + t3 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_neg > 0: + t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg + else: + t4 = zero_template.clone().to(device) + + return (t1 - t2 - t3 + t4)[0,:,:,:] + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = nn.Linear(h_dim, 4*h_dim) + self.ff2 = nn.Linear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + # now trainable parameters are in optical mat mul class, one scalar per simulator + # here we use only TrainableLensOpticalMul and scalars for each layer + self.k1 = nn.Parameter(torch.randn(1)) + self.k2 = nn.Parameter(torch.randn(1)) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = self.k1 * new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5) + tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) + scores = scores.masked_fill(tril == 0, float('-inf')) + attention = nn.functional.softmax(scores, dim=2) + output = self.k2 * new_formula(self.sim_output, attention, v) + return self.o_proj(self.gather_heads(output, *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +class OpticGPT2TrainableScalarAndLens(nn.Module): + def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, + pixel_size = 3.6e-6): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + if max_seq_len != 512: + self.sim_scores = omm.TrainableLensOpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01) + ) + self.sim_output = omm.TrainableLensOpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01) + ) + if max_seq_len == 512: + self.sim_scores = omm.TrainableLensOpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2) + ) + self.sim_output = omm.TrainableLensOpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2) + ) + self.sim_scores = omm.ScatterDataParallel(self.sim_scores) + self.sim_output = omm.ScatterDataParallel(self.sim_output) + + self.layers = nn.ModuleList([ + TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads, + dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) + for _ in range(layers_num)]) + self.tok_embeds = nn.Embedding(vocab_size, h_dim) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + self.lm_head = nn.Linear(h_dim, vocab_size) + + def forward(self, x, targets=None): + x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + logits = self.lm_head(x) # B,T,C + loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None + return logits, loss + + # what is the purpose? autoregressive inference! + def generate(self, start_idx, max_new_tokens): + idx = start_idx + for i in range(max_new_tokens): + idx_cond = idx[:,-self.max_seq_len:] + logits, loss = self(idx_cond) + logits = logits[:,-1,:] # B, C + probs = F.softmax(logits, dim=-1) + idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) + idx = torch.cat([idx, idx_next], dim=1) + return idx + + def log_trainable_optic_params( + self, + writer: SummaryWriter, + global_step, + ): + self.sim_scores.module.log_cylind_lens_operator_x(writer, "sim_scores", global_step) + self.sim_output.module.log_cylind_lens_operator_x(writer, "sim_output", global_step) + for i, layer in enumerate(self.layers): + # Using f-string tags to group them nicely in TensorBoard (e.g., Layer_0/k1) + writer.add_scalar(f"optic_scalars/layer_{i}/k1", layer.k1.item(), global_step) + writer.add_scalar(f"optic_scalars/layer_{i}/k2", layer.k2.item(), global_step) + + +################################################################################################### + +batch_size = 50 +gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient +max_iters = int(4e4) #40000 +eval_interval = 300 +learning_rate = 1e-3 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +eval_iters = 200 +layers_num = 2 +h_dim = 64 +max_seq_len = 64 +num_heads = 1 +dropout_rate = 0.1 +pixel_size = 3.6e-6 +assert batch_size % gradient_accumulation_steps == 0 +# CUDA_VISIBLE_DEVICES=1 python src/main.py optic_gpt2_ff seq_128_hdim_64 + +MODEL_CLASS = OpticGPT2TrainableScalarAndLens +train_data_path = Path("./data/wiki.train.tokens") +val_data_path = Path("./data/wiki.valid.tokens") +test_data_path = Path("./data/wiki.test.tokens") +comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}" + +logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +writer = SummaryWriter(logs_dir) +script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name) +print("Logs dir:", logs_dir) +# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script +shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository +script_snapshot_path.chmod(0o500) # with read-only permission + +# Create standalone checkpoints directory with your specified format +checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +Path(checkpoints_dir).mkdir(parents=True, exist_ok=True) +print("Checkpoints dir:", checkpoints_dir) + +#################################### Dataset ######################################### + +# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt +with open(train_data_path, encoding='utf-8') as f: + train_text = f.read() + +with open(val_data_path, encoding='utf-8') as f: + val_text = f.read() + +with open(test_data_path, encoding='utf-8') as f: + test_text = f.read() + +text = train_text + val_text + test_text + +chars = sorted(set(text)) +vocab_size = len(chars) + +wtoi = {w:i for i,w in enumerate(chars)} +itow = {i:w for i,w in enumerate(chars)} + +encode = lambda s: [wtoi[w] for w in s] +decode = lambda idx: ''.join([itow[i] for i in idx]) + +def get_batch(data, seq_len, batch_size): + ix = torch.randint(len(data)-seq_len, (batch_size, )) + x = torch.stack([data[i:i+seq_len] for i in ix]).to(device) + y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device) + return x, y + +train_data = torch.tensor(encode(train_text), dtype=torch.long) +val_data = torch.tensor(encode(val_text), dtype=torch.long) +test_data = torch.tensor(encode(test_text), dtype=torch.long) + +@torch.no_grad() +def perplexity(model, data): + model.eval() + stride = max(1, len(data) // 10000) + total_loss_sum = 0.0 + total_tokens_count = 0 + losses = [] + for i in range(0, len(data)-max_seq_len-1, stride): + x = data[i:(i+max_seq_len)].to(device) + y = data[(i+1):(i+max_seq_len+1)].to(device) + logits, mean_loss = model(x[None,...], y[None,...]) + num_tokens = y.numel() + total_loss_sum += mean_loss.item() * num_tokens + total_tokens_count += num_tokens + print(f"\rppl {i}/{len(data)-max_seq_len-1}", end="") + return np.exp(total_loss_sum / total_tokens_count) + +#################################### Model #########################################mo +def complete(m, start_idxs=[0], max_new_tokens=100): + start_idx = torch.tensor([start_idxs]).to(device) + generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens) + return decode(generated_tokens[0].tolist()) + +m = MODEL_CLASS( + vocab_size=vocab_size, + h_dim=h_dim, + max_seq_len=max_seq_len, + num_heads=num_heads, + pixel_size=pixel_size, + layers_num=layers_num +) +m = m.to(device) +writer.add_text('model', str(m), 0) + +#################################### Train ######################################### + +optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) + +#################################### Checkpoint Function ######################################### + +def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir): + """Save model checkpoint with complete training state""" + checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt' + torch.save({ + 'step': step, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss, + 'config': config, + 'wtoi': wtoi, + 'itow': itow, + }, checkpoint_path) + +# Training config for checkpointing +training_config = { + 'vocab_size': vocab_size, + 'layers_num': layers_num, + 'h_dim': h_dim, + 'max_seq_len': max_seq_len, + 'num_heads': num_heads, + 'dropout_rate': dropout_rate, + 'batch_size': batch_size, + 'learning_rate': learning_rate, + 'gradient_accumulation_steps': gradient_accumulation_steps, + 'pixel_size': pixel_size, + 'max_iters': max_iters, +} +#################################### Train ######################################### + +start_time = datetime.now() +print("Started at:", start_time) + +m.eval() + +task_prompts = [ + "1 2 3 4 5", + "The capital of France is", + "The chemical symbol of gold is", + "If yesterday was Friday, then tomorrow will be", + "The opposite of hot is", + "The planets of the solar system are:", + "My favorite color is", + "If 5*x + 3 = 13, then x is", +] +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, 0) + +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, 0) + +for i in range(max_iters): + m.train() + optimizer.zero_grad(set_to_none=True) + accumulated_loss = 0.0 + for j in range(gradient_accumulation_steps): + xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps) + logits, loss = m(xb, yb) + loss = loss / gradient_accumulation_steps + loss.backward() + accumulated_loss += loss.item() + if i % 100 == 0: + writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i) + optimizer.step() + writer.add_scalar('loss', accumulated_loss, i) + print(f"\r{i}/{max_iters} {accumulated_loss}", end="") + if i % 5000 == 0: + m.eval() + ppl = perplexity(model=m, data=val_data) + writer.add_scalar('val_perplexity', ppl.item(), i) + print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}") + writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) + task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) + writer.add_text('completions/task', task_results, i) + m.log_trainable_optic_params(writer, i) + save_checkpoint( + model=m, + optimizer=optimizer, + step=i, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir + ) + +m.eval() +ppl = perplexity(model=m, data=val_data) +print(f"\r{i+1}/{max_iters} {accumulated_loss}") +print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") +writer.add_scalar('val_perplexity', ppl.item(), i+1) +writer.add_scalar('loss', accumulated_loss, i) + +ppl = perplexity(model=m, data=test_data) +writer.add_scalar('test_perplexity', ppl.item(), i+1) +print(f"\rTest Perplexity at {i}: {ppl}") + +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, i+1) +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, i+1) + +# Save final checkpoint +save_checkpoint( + model=m, + optimizer=optimizer, + step=max_iters, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir +) +print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}") \ No newline at end of file From 064b9e14c8a03dd270fb40050cc4a920e4c6848a Mon Sep 17 00:00:00 2001 From: Vladimir Protsenko Date: Tue, 10 Feb 2026 14:35:31 +0000 Subject: [PATCH 5/9] added new logs --- src/bert_optica_koef.py | 2 ++ src/bert_optica_koef_newf.py | 2 ++ src/bert_optica_nokoef.py | 2 ++ src/bert_optica_nokoef_newf.py | 2 ++ src/optical_matrix_multiplication/config.py | 5 +---- src/optical_matrix_multiplication/optical_mul.py | 2 +- src/optical_matrix_multiplication/propagator.py | 1 - src/train_gpt2.py | 5 +++-- src/train_optics_trainable_focal_dist_lens_64.py | 3 ++- src/train_optics_trainable_lens_128.py | 3 +++ src/train_optics_trainable_lens_256.py | 6 +++++- src/train_optics_trainable_lens_512.py | 6 +++++- src/train_optics_trainable_lens_64.py | 4 +++- 13 files changed, 31 insertions(+), 12 deletions(-) diff --git a/src/bert_optica_koef.py b/src/bert_optica_koef.py index 634c671..9098b80 100644 --- a/src/bert_optica_koef.py +++ b/src/bert_optica_koef.py @@ -379,6 +379,8 @@ model = BertClassifier( ).to(device) print(f'Parameters model - {sum(p.numel() for p in model.parameters())}, parameters encoder - {sum(p.numel() for p in encoder.parameters())}') +model_description = str(model) + f'\nParameters count - {sum(p.numel() for p in model.parameters())}, parameters encoder - {sum(p.numel() for p in encoder.parameters())}' +writer.add_text('model', model_description, 0) optimizer = schedulefree.AdamWScheduleFree(list(model.parameters()) + list(encoder.parameters())) if checkpoint_file is not None: diff --git a/src/bert_optica_koef_newf.py b/src/bert_optica_koef_newf.py index 9e5fb1d..a46e5e3 100644 --- a/src/bert_optica_koef_newf.py +++ b/src/bert_optica_koef_newf.py @@ -387,6 +387,8 @@ model = BertClassifier( ).to(device) print(f'Parameters model - {sum(p.numel() for p in model.parameters())}, parameters encoder - {sum(p.numel() for p in encoder.parameters())}') +model_description = str(model) + f'\nParameters count - {sum(p.numel() for p in model.parameters())}, parameters encoder - {sum(p.numel() for p in encoder.parameters())}' +writer.add_text('model', model_description, 0) optimizer = schedulefree.AdamWScheduleFree(list(model.parameters()) + list(encoder.parameters())) if checkpoint_file is not None: diff --git a/src/bert_optica_nokoef.py b/src/bert_optica_nokoef.py index 2c5b249..4f83cc6 100644 --- a/src/bert_optica_nokoef.py +++ b/src/bert_optica_nokoef.py @@ -377,6 +377,8 @@ model = BertClassifier( ).to(device) print(f'Parameters model - {sum(p.numel() for p in model.parameters())}, parameters encoder - {sum(p.numel() for p in encoder.parameters())}') +model_description = str(model) + f'\nParameters count - {sum(p.numel() for p in model.parameters())}, parameters encoder - {sum(p.numel() for p in encoder.parameters())}' +writer.add_text('model', model_description, 0) optimizer = schedulefree.AdamWScheduleFree(list(model.parameters()) + list(encoder.parameters())) if checkpoint_file is not None: diff --git a/src/bert_optica_nokoef_newf.py b/src/bert_optica_nokoef_newf.py index dcfede5..7728aa0 100644 --- a/src/bert_optica_nokoef_newf.py +++ b/src/bert_optica_nokoef_newf.py @@ -387,6 +387,8 @@ model = BertClassifier( ).to(device) print(f'Parameters model - {sum(p.numel() for p in model.parameters())}, parameters encoder - {sum(p.numel() for p in encoder.parameters())}') +model_description = str(model) + f'\nParameters count - {sum(p.numel() for p in model.parameters())}, parameters encoder - {sum(p.numel() for p in encoder.parameters())}' +writer.add_text('model', model_description, 0) optimizer = schedulefree.AdamWScheduleFree(list(model.parameters()) + list(encoder.parameters())) if checkpoint_file is not None: diff --git a/src/optical_matrix_multiplication/config.py b/src/optical_matrix_multiplication/config.py index 7b84265..26c8589 100644 --- a/src/optical_matrix_multiplication/config.py +++ b/src/optical_matrix_multiplication/config.py @@ -274,8 +274,7 @@ class Config(ConfigOpticBase, ConfigModelBase): wavelength: float = 532e-9, distance: float = 0.03, lens_pixel_size: float = 1.8e-6, - lens_size: int = 8192, - trainable_cylind_lens = False): + lens_size: int = 8192): """ Конструктор класса. @@ -295,7 +294,6 @@ class Config(ConfigOpticBase, ConfigModelBase): distance: дистанция в метрах распространения светового поля между плоскостями. lens_pixel_size: размер пикселя в метрах скрещенных линз в оптической системе (нужен исключительно для моделирования). lens_size: размер скрещенных линз в метрах в оптической системе (нужен исключительно для моделирования). - trainable_cylind_lens: обучаемые диагональные матрицы, линза перед фурье плоскостью """ ConfigOpticBase.__init__(self, wavelength, distance) @@ -322,7 +320,6 @@ class Config(ConfigOpticBase, ConfigModelBase): self._input_vector_split_x: int = left_matrix_split_x self._input_vector_split_y: int = left_matrix_split_y self._result_vector_split: int = result_matrix_split - self._trainable_cylind_lens = trainable_cylind_lens @property def matrix_split_x(self) -> int: diff --git a/src/optical_matrix_multiplication/optical_mul.py b/src/optical_matrix_multiplication/optical_mul.py index cbe5f57..d7759d1 100644 --- a/src/optical_matrix_multiplication/optical_mul.py +++ b/src/optical_matrix_multiplication/optical_mul.py @@ -371,7 +371,7 @@ class TrainableLensOpticalMul(_nn.Module): phase_row = phase_normalized.unsqueeze(0).unsqueeze(0) writer.add_image(f"{tag}/phase_row", phase_row, global_step, dataformats='CHW') - fig, ax = plt.subplots(figsize=(6, 4)) + fig, ax = plt.subplots(figsize=(12, 4)) ax.plot(wrapped_phase.detach().cpu().numpy(), label=f'Step {global_step}') ax.set_title(f"Cylindrical Lens Phase Profile (x) {tag}") ax.set_xlabel("Pixel Index") diff --git a/src/optical_matrix_multiplication/propagator.py b/src/optical_matrix_multiplication/propagator.py index 3079d13..66111a6 100644 --- a/src/optical_matrix_multiplication/propagator.py +++ b/src/optical_matrix_multiplication/propagator.py @@ -111,7 +111,6 @@ class Propagator(_ABC, _nn.Module): Распределение комплексной амплитуды светового поля, после распространения. """ - if (resul_shape is not None): field_shape = field.shape[-2:] operator_Y_shape = self.operator_Y.shape[-2:] diff --git a/src/train_gpt2.py b/src/train_gpt2.py index edee343..a5aa3d7 100644 --- a/src/train_gpt2.py +++ b/src/train_gpt2.py @@ -218,8 +218,9 @@ m = MODEL_CLASS( layers_num=layers_num ) m = m.to(device) -writer.add_text('model', str(m), 0) - +model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}' +writer.add_text('model', model_description, 0) +# TODO for all experiments optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) #################################### Checkpoint Function ######################################### diff --git a/src/train_optics_trainable_focal_dist_lens_64.py b/src/train_optics_trainable_focal_dist_lens_64.py index a0aac71..87a9558 100644 --- a/src/train_optics_trainable_focal_dist_lens_64.py +++ b/src/train_optics_trainable_focal_dist_lens_64.py @@ -329,7 +329,8 @@ m = MODEL_CLASS( layers_num=layers_num ) m = m.to(device) -writer.add_text('model', str(m), 0) +model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}' +writer.add_text('model', model_description, 0) #################################### Train ######################################### diff --git a/src/train_optics_trainable_lens_128.py b/src/train_optics_trainable_lens_128.py index 799f312..8c3f2dc 100644 --- a/src/train_optics_trainable_lens_128.py +++ b/src/train_optics_trainable_lens_128.py @@ -338,6 +338,8 @@ m = MODEL_CLASS( ) m = m.to(device) writer.add_text('model', str(m), 0) +model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}' +writer.add_text('model', model_description, 0) #################################### Train ######################################### @@ -450,6 +452,7 @@ task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt i print(task_results) writer.add_text('completions/task', task_results, i+1) +m.log_trainable_optic_params(writer, max_iters) # Save final checkpoint save_checkpoint( model=m, diff --git a/src/train_optics_trainable_lens_256.py b/src/train_optics_trainable_lens_256.py index d7df7f1..59e5f31 100644 --- a/src/train_optics_trainable_lens_256.py +++ b/src/train_optics_trainable_lens_256.py @@ -238,7 +238,7 @@ class OpticGPT2TrainableScalarAndLens(nn.Module): ################################################################################################### batch_size = 50 -gradient_accumulation_steps = 5 # check this impl for correctness https://unsloth.ai/blog/gradient +gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient max_iters = int(4e4) #40000 eval_interval = 300 learning_rate = 1e-3 @@ -338,6 +338,8 @@ m = MODEL_CLASS( ) m = m.to(device) writer.add_text('model', str(m), 0) +model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}' +writer.add_text('model', model_description, 0) #################################### Train ######################################### @@ -450,6 +452,8 @@ task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt i print(task_results) writer.add_text('completions/task', task_results, i+1) +m.log_trainable_optic_params(writer, max_iters) + # Save final checkpoint save_checkpoint( model=m, diff --git a/src/train_optics_trainable_lens_512.py b/src/train_optics_trainable_lens_512.py index 16bd46e..9d081e2 100644 --- a/src/train_optics_trainable_lens_512.py +++ b/src/train_optics_trainable_lens_512.py @@ -238,7 +238,7 @@ class OpticGPT2TrainableScalarAndLens(nn.Module): ################################################################################################### batch_size = 50 -gradient_accumulation_steps = 10 # check this impl for correctness https://unsloth.ai/blog/gradient +gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient max_iters = int(4e4) #40000 eval_interval = 300 learning_rate = 1e-3 @@ -338,6 +338,8 @@ m = MODEL_CLASS( ) m = m.to(device) writer.add_text('model', str(m), 0) +model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}' +writer.add_text('model', model_description, 0) #################################### Train ######################################### @@ -450,6 +452,8 @@ task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt i print(task_results) writer.add_text('completions/task', task_results, i+1) +m.log_trainable_optic_params(writer, max_iters) + # Save final checkpoint save_checkpoint( model=m, diff --git a/src/train_optics_trainable_lens_64.py b/src/train_optics_trainable_lens_64.py index b3f4c3c..bbe5a73 100644 --- a/src/train_optics_trainable_lens_64.py +++ b/src/train_optics_trainable_lens_64.py @@ -337,7 +337,8 @@ m = MODEL_CLASS( layers_num=layers_num ) m = m.to(device) -writer.add_text('model', str(m), 0) +model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}' +writer.add_text('model', model_description, 0) #################################### Train ######################################### @@ -450,6 +451,7 @@ task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt i print(task_results) writer.add_text('completions/task', task_results, i+1) +m.log_trainable_optic_params(writer, max_iters) # Save final checkpoint save_checkpoint( model=m, From 85249dfbba9deb8f62e84f5ee9c42ade9a8e654d Mon Sep 17 00:00:00 2001 From: Vladimir Protsenko Date: Sat, 14 Feb 2026 14:48:15 +0000 Subject: [PATCH 6/9] checkpointing, batched perplexity. --- src/train_gpt2.py | 53 ++++++-- ...ain_optics_trainable_focal_dist_lens_64.py | 120 +++++++++++++++--- src/train_optics_trainable_lens_128.py | 77 +++++++---- src/train_optics_trainable_lens_256.py | 77 +++++++---- src/train_optics_trainable_lens_512.py | 77 +++++++---- src/train_optics_trainable_lens_64.py | 80 ++++++++---- 6 files changed, 347 insertions(+), 137 deletions(-) diff --git a/src/train_gpt2.py b/src/train_gpt2.py index a5aa3d7..75de26f 100644 --- a/src/train_gpt2.py +++ b/src/train_gpt2.py @@ -127,7 +127,7 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu' eval_iters = 200 layers_num = 2 h_dim = 64 -max_seq_len = 64 +max_seq_len = 512 num_heads = 1 dropout_rate = 0.1 pixel_size = 3.6e-6 @@ -187,20 +187,44 @@ val_data = torch.tensor(encode(val_text), dtype=torch.long) test_data = torch.tensor(encode(test_text), dtype=torch.long) @torch.no_grad() -def perplexity(model, data): +def perplexity(model, data, batch_size=32): model.eval() stride = max(1, len(data) // 10000) total_loss_sum = 0.0 total_tokens_count = 0 - losses = [] - for i in range(0, len(data)-max_seq_len-1, stride): - x = data[i:(i+max_seq_len)].to(device) - y = data[(i+1):(i+max_seq_len+1)].to(device) - logits, mean_loss = model(x[None,...], y[None,...]) - num_tokens = y.numel() + + # Precompute all valid start positions + start_positions = list(range(0, len(data) - max_seq_len - 1, stride)) + total_sequences = len(start_positions) + + # Process sequences in batches + for i in range(0, total_sequences, batch_size): + batch_starts = start_positions[i:min(i + batch_size, total_sequences)] + + # Efficiently stack sequences into batch tensors + x_batch = torch.stack([ + data[start:start + max_seq_len] + for start in batch_starts + ]).to(device) + + y_batch = torch.stack([ + data[start + 1:start + max_seq_len + 1] + for start in batch_starts + ]).to(device) + + # Forward pass (model should return mean loss averaged over all tokens in batch) + _, mean_loss = model(x_batch, y_batch) + + # Accumulate weighted loss (mean_loss is already averaged over tokens) + num_tokens = y_batch.numel() total_loss_sum += mean_loss.item() * num_tokens total_tokens_count += num_tokens - print(f"\rppl {i}/{len(data)-max_seq_len-1}", end="") + + # Progress update + processed = min(i + batch_size, total_sequences) + print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True) + + print() # Final newline return np.exp(total_loss_sum / total_tokens_count) #################################### Model #########################################mo @@ -220,7 +244,10 @@ m = MODEL_CLASS( m = m.to(device) model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}' writer.add_text('model', model_description, 0) -# TODO for all experiments + +print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.") +print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.") + optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) #################################### Checkpoint Function ######################################### @@ -293,7 +320,7 @@ for i in range(max_iters): print(f"\r{i}/{max_iters} {accumulated_loss}", end="") if i % 5000 == 0: m.eval() - ppl = perplexity(model=m, data=val_data) + ppl = perplexity(model=m, data=val_data, batch_size=batch_size) writer.add_scalar('val_perplexity', ppl.item(), i) print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}") writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) @@ -311,13 +338,13 @@ for i in range(max_iters): ) m.eval() -ppl = perplexity(model=m, data=val_data) +ppl = perplexity(model=m, data=val_data, batch_size=batch_size) print(f"\r{i+1}/{max_iters} {accumulated_loss}") print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") writer.add_scalar('val_perplexity', ppl.item(), i+1) writer.add_scalar('loss', accumulated_loss, i) -ppl = perplexity(model=m, data=test_data) +ppl = perplexity(model=m, data=test_data, batch_size=batch_size) writer.add_scalar('test_perplexity', ppl.item(), i+1) print(f"\rTest Perplexity at {i}: {ppl}") diff --git a/src/train_optics_trainable_focal_dist_lens_64.py b/src/train_optics_trainable_focal_dist_lens_64.py index 87a9558..883e705 100644 --- a/src/train_optics_trainable_focal_dist_lens_64.py +++ b/src/train_optics_trainable_focal_dist_lens_64.py @@ -268,6 +268,10 @@ print("Logs dir:", logs_dir) shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository script_snapshot_path.chmod(0o500) # with read-only permission +# Create standalone checkpoints directory with your specified format +checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +Path(checkpoints_dir).mkdir(parents=True, exist_ok=True) +print("Checkpoints dir:", checkpoints_dir) #################################### Dataset ######################################### @@ -303,18 +307,48 @@ val_data = torch.tensor(encode(val_text), dtype=torch.long) test_data = torch.tensor(encode(test_text), dtype=torch.long) @torch.no_grad() -def perplexity(model, data): +def perplexity(model, data, batch_size=32): + model.eval() stride = max(1, len(data) // 10000) - losses = [] - for i in range(0, len(data)-max_seq_len-1, stride): - x = data[i:(i+max_seq_len)].to(device) - y = data[(i+1):(i+max_seq_len+1)].to(device) - logits, loss = model(x[None,...], y[None,...]) - losses.append(loss.item()) - print(f"\rppl {i}/{len(data)-max_seq_len-1}", end="") - return np.exp(np.mean(losses)) - -#################################### Model #########################################mo + total_loss_sum = 0.0 + total_tokens_count = 0 + + # Precompute all valid start positions + start_positions = list(range(0, len(data) - max_seq_len - 1, stride)) + total_sequences = len(start_positions) + + # Process sequences in batches + for i in range(0, total_sequences, batch_size): + batch_starts = start_positions[i:min(i + batch_size, total_sequences)] + + # Efficiently stack sequences into batch tensors + x_batch = torch.stack([ + data[start:start + max_seq_len] + for start in batch_starts + ]).to(device) + + y_batch = torch.stack([ + data[start + 1:start + max_seq_len + 1] + for start in batch_starts + ]).to(device) + + # Forward pass (model should return mean loss averaged over all tokens in batch) + _, mean_loss = model(x_batch, y_batch) + + # Accumulate weighted loss (mean_loss is already averaged over tokens) + num_tokens = y_batch.numel() + total_loss_sum += mean_loss.item() * num_tokens + total_tokens_count += num_tokens + + # Progress update + processed = min(i + batch_size, total_sequences) + print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True) + + print() # Final newline + return np.exp(total_loss_sum / total_tokens_count) + +#################################### Model ######################################### + def complete(m, start_idxs=[0], max_new_tokens=100): start_idx = torch.tensor([start_idxs]).to(device) generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens) @@ -336,6 +370,37 @@ writer.add_text('model', model_description, 0) optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) +#################################### Checkpoint Function ######################################### + +def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir): + """Save model checkpoint with complete training state""" + checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt' + torch.save({ + 'step': step, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss, + 'config': config, + 'wtoi': wtoi, + 'itow': itow, + }, checkpoint_path) + +# Training config for checkpointing +training_config = { + 'vocab_size': vocab_size, + 'layers_num': layers_num, + 'h_dim': h_dim, + 'max_seq_len': max_seq_len, + 'num_heads': num_heads, + 'dropout_rate': dropout_rate, + 'batch_size': batch_size, + 'learning_rate': learning_rate, + 'gradient_accumulation_steps': gradient_accumulation_steps, + 'pixel_size': pixel_size, + 'max_iters': max_iters, +} + +#################################### Train ######################################### m.eval() task_prompts = [ @@ -373,22 +438,32 @@ for i in range(max_iters): print(f"\r{i}/{max_iters} {accumulated_loss}", end="") if i % 5000 == 0: m.eval() - ppl = perplexity(model=m, data=val_data) + ppl = perplexity(model=m, data=val_data, batch_size=batch_size) writer.add_scalar('val_perplexity', ppl.item(), i) print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) writer.add_text('completions/task', task_results, i) m.log_trainable_optic_params(writer, i) + save_checkpoint( + model=m, + optimizer=optimizer, + step=i, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir + ) m.eval() -ppl = perplexity(model=m, data=val_data) +ppl = perplexity(model=m, data=val_data, batch_size=batch_size) print(f"\r{i+1}/{max_iters} {accumulated_loss}") print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") writer.add_scalar('val_perplexity', ppl.item(), i+1) writer.add_scalar('loss', accumulated_loss, i) -ppl = perplexity(model=m, data=test_data) +ppl = perplexity(model=m, data=test_data, batch_size=batch_size) writer.add_scalar('test_perplexity', ppl.item(), i+1) print(f"\rTest Perplexity at {i}: {ppl}") @@ -397,4 +472,19 @@ print(completion) writer.add_text('completions', completion, i+1) task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) print(task_results) -writer.add_text('completions/task', task_results, i+1) \ No newline at end of file +writer.add_text('completions/task', task_results, i+1) + +m.log_trainable_optic_params(writer, max_iters) + +# Save final checkpoint +save_checkpoint( + model=m, + optimizer=optimizer, + step=max_iters, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir +) +print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}") \ No newline at end of file diff --git a/src/train_optics_trainable_lens_128.py b/src/train_optics_trainable_lens_128.py index 8c3f2dc..9a551d3 100644 --- a/src/train_optics_trainable_lens_128.py +++ b/src/train_optics_trainable_lens_128.py @@ -12,6 +12,21 @@ import shutil seed = 1337 torch.manual_seed(seed) +batch_size = 50 +gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient +max_iters = int(4e4) #40000 +eval_interval = 300 +learning_rate = 1e-3 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +eval_iters = 200 +layers_num = 2 +h_dim = 64 +max_seq_len = 128 +num_heads = 1 +dropout_rate = 0.1 +pixel_size = 3.6e-6 +assert batch_size % gradient_accumulation_steps == 0 + ############################### MODEL ############################################################# def new_formula(sim, tensor_1, tensor_2): @@ -237,22 +252,6 @@ class OpticGPT2TrainableScalarAndLens(nn.Module): ################################################################################################### -batch_size = 50 -gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient -max_iters = int(4e4) #40000 -eval_interval = 300 -learning_rate = 1e-3 -device = 'cuda' if torch.cuda.is_available() else 'cpu' -eval_iters = 200 -layers_num = 2 -h_dim = 64 -max_seq_len = 128 -num_heads = 1 -dropout_rate = 0.1 -pixel_size = 3.6e-6 -assert batch_size % gradient_accumulation_steps == 0 -# CUDA_VISIBLE_DEVICES=1 python src/main.py optic_gpt2_ff seq_128_hdim_64 - MODEL_CLASS = OpticGPT2TrainableScalarAndLens train_data_path = Path("./data/wiki.train.tokens") val_data_path = Path("./data/wiki.valid.tokens") @@ -306,20 +305,44 @@ val_data = torch.tensor(encode(val_text), dtype=torch.long) test_data = torch.tensor(encode(test_text), dtype=torch.long) @torch.no_grad() -def perplexity(model, data): +def perplexity(model, data, batch_size=32): model.eval() stride = max(1, len(data) // 10000) total_loss_sum = 0.0 total_tokens_count = 0 - losses = [] - for i in range(0, len(data)-max_seq_len-1, stride): - x = data[i:(i+max_seq_len)].to(device) - y = data[(i+1):(i+max_seq_len+1)].to(device) - logits, mean_loss = model(x[None,...], y[None,...]) - num_tokens = y.numel() + + # Precompute all valid start positions + start_positions = list(range(0, len(data) - max_seq_len - 1, stride)) + total_sequences = len(start_positions) + + # Process sequences in batches + for i in range(0, total_sequences, batch_size): + batch_starts = start_positions[i:min(i + batch_size, total_sequences)] + + # Efficiently stack sequences into batch tensors + x_batch = torch.stack([ + data[start:start + max_seq_len] + for start in batch_starts + ]).to(device) + + y_batch = torch.stack([ + data[start + 1:start + max_seq_len + 1] + for start in batch_starts + ]).to(device) + + # Forward pass (model should return mean loss averaged over all tokens in batch) + _, mean_loss = model(x_batch, y_batch) + + # Accumulate weighted loss (mean_loss is already averaged over tokens) + num_tokens = y_batch.numel() total_loss_sum += mean_loss.item() * num_tokens total_tokens_count += num_tokens - print(f"\rppl {i}/{len(data)-max_seq_len-1}", end="") + + # Progress update + processed = min(i + batch_size, total_sequences) + print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True) + + print() # Final newline return np.exp(total_loss_sum / total_tokens_count) #################################### Model #########################################mo @@ -416,7 +439,7 @@ for i in range(max_iters): print(f"\r{i}/{max_iters} {accumulated_loss}", end="") if i % 5000 == 0: m.eval() - ppl = perplexity(model=m, data=val_data) + ppl = perplexity(model=m, data=val_data, batch_size=batch_size) writer.add_scalar('val_perplexity', ppl.item(), i) print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}") writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) @@ -435,13 +458,13 @@ for i in range(max_iters): ) m.eval() -ppl = perplexity(model=m, data=val_data) +ppl = perplexity(model=m, data=val_data, batch_size=batch_size) print(f"\r{i+1}/{max_iters} {accumulated_loss}") print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") writer.add_scalar('val_perplexity', ppl.item(), i+1) writer.add_scalar('loss', accumulated_loss, i) -ppl = perplexity(model=m, data=test_data) +ppl = perplexity(model=m, data=test_data, batch_size=batch_size) writer.add_scalar('test_perplexity', ppl.item(), i+1) print(f"\rTest Perplexity at {i}: {ppl}") diff --git a/src/train_optics_trainable_lens_256.py b/src/train_optics_trainable_lens_256.py index 59e5f31..c802f4d 100644 --- a/src/train_optics_trainable_lens_256.py +++ b/src/train_optics_trainable_lens_256.py @@ -12,6 +12,21 @@ import shutil seed = 1337 torch.manual_seed(seed) +batch_size = 50 +gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient +max_iters = int(4e4) #40000 +eval_interval = 300 +learning_rate = 1e-3 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +eval_iters = 200 +layers_num = 2 +h_dim = 64 +max_seq_len = 256 +num_heads = 1 +dropout_rate = 0.1 +pixel_size = 3.6e-6 +assert batch_size % gradient_accumulation_steps == 0 + ############################### MODEL ############################################################# def new_formula(sim, tensor_1, tensor_2): @@ -237,22 +252,6 @@ class OpticGPT2TrainableScalarAndLens(nn.Module): ################################################################################################### -batch_size = 50 -gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient -max_iters = int(4e4) #40000 -eval_interval = 300 -learning_rate = 1e-3 -device = 'cuda' if torch.cuda.is_available() else 'cpu' -eval_iters = 200 -layers_num = 2 -h_dim = 64 -max_seq_len = 256 -num_heads = 1 -dropout_rate = 0.1 -pixel_size = 3.6e-6 -assert batch_size % gradient_accumulation_steps == 0 -# CUDA_VISIBLE_DEVICES=1 python src/main.py optic_gpt2_ff seq_128_hdim_64 - MODEL_CLASS = OpticGPT2TrainableScalarAndLens train_data_path = Path("./data/wiki.train.tokens") val_data_path = Path("./data/wiki.valid.tokens") @@ -306,20 +305,44 @@ val_data = torch.tensor(encode(val_text), dtype=torch.long) test_data = torch.tensor(encode(test_text), dtype=torch.long) @torch.no_grad() -def perplexity(model, data): +def perplexity(model, data, batch_size=32): model.eval() stride = max(1, len(data) // 10000) total_loss_sum = 0.0 total_tokens_count = 0 - losses = [] - for i in range(0, len(data)-max_seq_len-1, stride): - x = data[i:(i+max_seq_len)].to(device) - y = data[(i+1):(i+max_seq_len+1)].to(device) - logits, mean_loss = model(x[None,...], y[None,...]) - num_tokens = y.numel() + + # Precompute all valid start positions + start_positions = list(range(0, len(data) - max_seq_len - 1, stride)) + total_sequences = len(start_positions) + + # Process sequences in batches + for i in range(0, total_sequences, batch_size): + batch_starts = start_positions[i:min(i + batch_size, total_sequences)] + + # Efficiently stack sequences into batch tensors + x_batch = torch.stack([ + data[start:start + max_seq_len] + for start in batch_starts + ]).to(device) + + y_batch = torch.stack([ + data[start + 1:start + max_seq_len + 1] + for start in batch_starts + ]).to(device) + + # Forward pass (model should return mean loss averaged over all tokens in batch) + _, mean_loss = model(x_batch, y_batch) + + # Accumulate weighted loss (mean_loss is already averaged over tokens) + num_tokens = y_batch.numel() total_loss_sum += mean_loss.item() * num_tokens total_tokens_count += num_tokens - print(f"\rppl {i}/{len(data)-max_seq_len-1}", end="") + + # Progress update + processed = min(i + batch_size, total_sequences) + print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True) + + print() # Final newline return np.exp(total_loss_sum / total_tokens_count) #################################### Model #########################################mo @@ -416,7 +439,7 @@ for i in range(max_iters): print(f"\r{i}/{max_iters} {accumulated_loss}", end="") if i % 5000 == 0: m.eval() - ppl = perplexity(model=m, data=val_data) + ppl = perplexity(model=m, data=val_data, batch_size=batch_size) writer.add_scalar('val_perplexity', ppl.item(), i) print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}") writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) @@ -435,13 +458,13 @@ for i in range(max_iters): ) m.eval() -ppl = perplexity(model=m, data=val_data) +ppl = perplexity(model=m, data=val_data, batch_size=batch_size) print(f"\r{i+1}/{max_iters} {accumulated_loss}") print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") writer.add_scalar('val_perplexity', ppl.item(), i+1) writer.add_scalar('loss', accumulated_loss, i) -ppl = perplexity(model=m, data=test_data) +ppl = perplexity(model=m, data=test_data, batch_size=batch_size) writer.add_scalar('test_perplexity', ppl.item(), i+1) print(f"\rTest Perplexity at {i}: {ppl}") diff --git a/src/train_optics_trainable_lens_512.py b/src/train_optics_trainable_lens_512.py index 9d081e2..b090e0c 100644 --- a/src/train_optics_trainable_lens_512.py +++ b/src/train_optics_trainable_lens_512.py @@ -12,6 +12,21 @@ import shutil seed = 1337 torch.manual_seed(seed) +batch_size = 50 +gradient_accumulation_steps = 5 # check this impl for correctness https://unsloth.ai/blog/gradient +max_iters = int(4e4) #40000 +eval_interval = 300 +learning_rate = 1e-3 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +eval_iters = 200 +layers_num = 2 +h_dim = 64 +max_seq_len = 512 +num_heads = 1 +dropout_rate = 0.1 +pixel_size = 3.6e-6 +assert batch_size % gradient_accumulation_steps == 0 + ############################### MODEL ############################################################# def new_formula(sim, tensor_1, tensor_2): @@ -237,22 +252,6 @@ class OpticGPT2TrainableScalarAndLens(nn.Module): ################################################################################################### -batch_size = 50 -gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient -max_iters = int(4e4) #40000 -eval_interval = 300 -learning_rate = 1e-3 -device = 'cuda' if torch.cuda.is_available() else 'cpu' -eval_iters = 200 -layers_num = 2 -h_dim = 64 -max_seq_len = 512 -num_heads = 1 -dropout_rate = 0.1 -pixel_size = 3.6e-6 -assert batch_size % gradient_accumulation_steps == 0 -# CUDA_VISIBLE_DEVICES=1 python src/main.py optic_gpt2_ff seq_128_hdim_64 - MODEL_CLASS = OpticGPT2TrainableScalarAndLens train_data_path = Path("./data/wiki.train.tokens") val_data_path = Path("./data/wiki.valid.tokens") @@ -306,20 +305,44 @@ val_data = torch.tensor(encode(val_text), dtype=torch.long) test_data = torch.tensor(encode(test_text), dtype=torch.long) @torch.no_grad() -def perplexity(model, data): +def perplexity(model, data, batch_size=32): model.eval() stride = max(1, len(data) // 10000) total_loss_sum = 0.0 total_tokens_count = 0 - losses = [] - for i in range(0, len(data)-max_seq_len-1, stride): - x = data[i:(i+max_seq_len)].to(device) - y = data[(i+1):(i+max_seq_len+1)].to(device) - logits, mean_loss = model(x[None,...], y[None,...]) - num_tokens = y.numel() + + # Precompute all valid start positions + start_positions = list(range(0, len(data) - max_seq_len - 1, stride)) + total_sequences = len(start_positions) + + # Process sequences in batches + for i in range(0, total_sequences, batch_size): + batch_starts = start_positions[i:min(i + batch_size, total_sequences)] + + # Efficiently stack sequences into batch tensors + x_batch = torch.stack([ + data[start:start + max_seq_len] + for start in batch_starts + ]).to(device) + + y_batch = torch.stack([ + data[start + 1:start + max_seq_len + 1] + for start in batch_starts + ]).to(device) + + # Forward pass (model should return mean loss averaged over all tokens in batch) + _, mean_loss = model(x_batch, y_batch) + + # Accumulate weighted loss (mean_loss is already averaged over tokens) + num_tokens = y_batch.numel() total_loss_sum += mean_loss.item() * num_tokens total_tokens_count += num_tokens - print(f"\rppl {i}/{len(data)-max_seq_len-1}", end="") + + # Progress update + processed = min(i + batch_size, total_sequences) + print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True) + + print() # Final newline return np.exp(total_loss_sum / total_tokens_count) #################################### Model #########################################mo @@ -416,7 +439,7 @@ for i in range(max_iters): print(f"\r{i}/{max_iters} {accumulated_loss}", end="") if i % 5000 == 0: m.eval() - ppl = perplexity(model=m, data=val_data) + ppl = perplexity(model=m, data=val_data, batch_size=batch_size) writer.add_scalar('val_perplexity', ppl.item(), i) print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}") writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) @@ -435,13 +458,13 @@ for i in range(max_iters): ) m.eval() -ppl = perplexity(model=m, data=val_data) +ppl = perplexity(model=m, data=val_data, batch_size=batch_size) print(f"\r{i+1}/{max_iters} {accumulated_loss}") print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") writer.add_scalar('val_perplexity', ppl.item(), i+1) writer.add_scalar('loss', accumulated_loss, i) -ppl = perplexity(model=m, data=test_data) +ppl = perplexity(model=m, data=test_data, batch_size=batch_size) writer.add_scalar('test_perplexity', ppl.item(), i+1) print(f"\rTest Perplexity at {i}: {ppl}") diff --git a/src/train_optics_trainable_lens_64.py b/src/train_optics_trainable_lens_64.py index bbe5a73..751147b 100644 --- a/src/train_optics_trainable_lens_64.py +++ b/src/train_optics_trainable_lens_64.py @@ -12,6 +12,21 @@ import shutil seed = 1337 torch.manual_seed(seed) +batch_size = 50 +gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient +max_iters = int(4e4) #40000 +eval_interval = 300 +learning_rate = 1e-3 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +eval_iters = 200 +layers_num = 2 +h_dim = 64 +max_seq_len = 64 +num_heads = 1 +dropout_rate = 0.1 +pixel_size = 3.6e-6 +assert batch_size % gradient_accumulation_steps == 0 + ############################### MODEL ############################################################# def new_formula(sim, tensor_1, tensor_2): @@ -237,22 +252,6 @@ class OpticGPT2TrainableScalarAndLens(nn.Module): ################################################################################################### -batch_size = 50 -gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient -max_iters = int(4e4) #40000 -eval_interval = 300 -learning_rate = 1e-3 -device = 'cuda' if torch.cuda.is_available() else 'cpu' -eval_iters = 200 -layers_num = 2 -h_dim = 64 -max_seq_len = 64 -num_heads = 1 -dropout_rate = 0.1 -pixel_size = 3.6e-6 -assert batch_size % gradient_accumulation_steps == 0 -# CUDA_VISIBLE_DEVICES=1 python src/main.py optic_gpt2_ff seq_128_hdim_64 - MODEL_CLASS = OpticGPT2TrainableScalarAndLens train_data_path = Path("./data/wiki.train.tokens") val_data_path = Path("./data/wiki.valid.tokens") @@ -306,23 +305,48 @@ val_data = torch.tensor(encode(val_text), dtype=torch.long) test_data = torch.tensor(encode(test_text), dtype=torch.long) @torch.no_grad() -def perplexity(model, data): +def perplexity(model, data, batch_size=32): model.eval() stride = max(1, len(data) // 10000) total_loss_sum = 0.0 total_tokens_count = 0 - losses = [] - for i in range(0, len(data)-max_seq_len-1, stride): - x = data[i:(i+max_seq_len)].to(device) - y = data[(i+1):(i+max_seq_len+1)].to(device) - logits, mean_loss = model(x[None,...], y[None,...]) - num_tokens = y.numel() + + # Precompute all valid start positions + start_positions = list(range(0, len(data) - max_seq_len - 1, stride)) + total_sequences = len(start_positions) + + # Process sequences in batches + for i in range(0, total_sequences, batch_size): + batch_starts = start_positions[i:min(i + batch_size, total_sequences)] + + # Efficiently stack sequences into batch tensors + x_batch = torch.stack([ + data[start:start + max_seq_len] + for start in batch_starts + ]).to(device) + + y_batch = torch.stack([ + data[start + 1:start + max_seq_len + 1] + for start in batch_starts + ]).to(device) + + # Forward pass (model should return mean loss averaged over all tokens in batch) + _, mean_loss = model(x_batch, y_batch) + + # Accumulate weighted loss (mean_loss is already averaged over tokens) + num_tokens = y_batch.numel() total_loss_sum += mean_loss.item() * num_tokens total_tokens_count += num_tokens - print(f"\rppl {i}/{len(data)-max_seq_len-1}", end="") + + # Progress update + processed = min(i + batch_size, total_sequences) + print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True) + + print() # Final newline return np.exp(total_loss_sum / total_tokens_count) -#################################### Model #########################################mo +#################################### Model ######################################### + def complete(m, start_idxs=[0], max_new_tokens=100): start_idx = torch.tensor([start_idxs]).to(device) generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens) @@ -415,7 +439,7 @@ for i in range(max_iters): print(f"\r{i}/{max_iters} {accumulated_loss}", end="") if i % 5000 == 0: m.eval() - ppl = perplexity(model=m, data=val_data) + ppl = perplexity(model=m, data=val_data, batch_size=batch_size) writer.add_scalar('val_perplexity', ppl.item(), i) print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}") writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) @@ -434,13 +458,13 @@ for i in range(max_iters): ) m.eval() -ppl = perplexity(model=m, data=val_data) +ppl = perplexity(model=m, data=val_data, batch_size=batch_size) print(f"\r{i+1}/{max_iters} {accumulated_loss}") print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") writer.add_scalar('val_perplexity', ppl.item(), i+1) writer.add_scalar('loss', accumulated_loss, i) -ppl = perplexity(model=m, data=test_data) +ppl = perplexity(model=m, data=test_data, batch_size=batch_size) writer.add_scalar('test_perplexity', ppl.item(), i+1) print(f"\rTest Perplexity at {i}: {ppl}") From 1ace114d0c39f28c22222833f9fa6ccfca22c24c Mon Sep 17 00:00:00 2001 From: Vladimir Protsenko Date: Sat, 14 Feb 2026 14:48:59 +0000 Subject: [PATCH 7/9] optic mul error test --- src/basic_optic_mm_test.py | 93 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 93 insertions(+) create mode 100644 src/basic_optic_mm_test.py diff --git a/src/basic_optic_mm_test.py b/src/basic_optic_mm_test.py new file mode 100644 index 0000000..bc0a883 --- /dev/null +++ b/src/basic_optic_mm_test.py @@ -0,0 +1,93 @@ +import torch +import torch.nn as nn +import optical_matrix_multiplication as omm +import matplotlib.pyplot as plt +device = 'cpu' + +h_dim = 64 +pixel_size = 3.6e-6 +batch_size = 100 +test_lengths = [59, 64, 128, 256, 512] + +for max_seq_len in test_lengths: + if max_seq_len < 512: + sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * h_dim, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + lens_size = 8192) + ) + sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * h_dim, + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + lens_size = 8192) + ) + else: + sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * h_dim, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2) + ) + sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * h_dim, + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2) + ) + + def cko(x,y): + x = x**2 + y = y**2 + return (((x / x.mean() - y / y.mean())**2).mean())**0.5 * 100 + + sim_scores = sim_scores.to(device=device) + sim_output = sim_output.to(device=device) + + q = torch.rand((batch_size, 1, max_seq_len, h_dim)).to(device=device) + k = torch.rand((batch_size, 1, max_seq_len, h_dim)).to(device=device).transpose(-2, -1) + true_scores = q @ k + opt_scores = sim_scores(q, k) + CKO_scores = cko(true_scores, opt_scores).detach().cpu().numpy() + + scores = torch.rand((batch_size, 1, max_seq_len, max_seq_len)).to(device=device) + v = torch.rand((batch_size, 1, max_seq_len, h_dim)).to(device=device) + true_o = scores @ v + opt_o = sim_output(scores, v) + CKO_o = cko(true_o, opt_o).detach().cpu().numpy() + + print(f"CKO sim_scores[{h_dim},{max_seq_len}] [{q.shape[-2]}, {q.shape[-1]}]x[{k.shape[-2]}, {k.shape[-1]}]: {CKO_scores}") + print(f"CKO sim_output[{max_seq_len},{h_dim}] [{true_scores.shape[-2]}, {true_scores.shape[-1]}]x[{v.shape[-2]}, {v.shape[-1]}]: {CKO_o}") \ No newline at end of file From 1b8ddc31d4bae38fa596f3e0c65ebdacf1417790 Mon Sep 17 00:00:00 2001 From: Vladimir Protsenko Date: Sat, 14 Feb 2026 18:08:06 +0000 Subject: [PATCH 8/9] scripts version --- src/bpe_main.py | 168 ------- src/bpe_tokenizer.py | 116 ----- src/char_gpt2.py | 115 ----- src/main.py | 168 ------- src/optics_char_gpt2.py | 211 -------- src/optics_char_gpt2_ff.py | 217 -------- src/optics_char_gpt2_nokoef.py | 211 -------- src/optics_char_gpt2_traindiag.py | 210 -------- src/train_char_gpt2_128.py | 367 ++++++++++++++ src/train_char_gpt2_256.py | 367 ++++++++++++++ src/{train_gpt2.py => train_char_gpt2_512.py} | 32 +- src/train_char_gpt2_64.py | 367 ++++++++++++++ src/train_char_gpt2_koef_128.py | 369 ++++++++++++++ src/train_char_gpt2_koef_256.py | 369 ++++++++++++++ src/train_char_gpt2_koef_512.py | 369 ++++++++++++++ src/train_char_gpt2_koef_64.py | 369 ++++++++++++++ src/train_optics_char_gpt2_128.py | 473 ++++++++++++++++++ src/train_optics_char_gpt2_256.py | 473 ++++++++++++++++++ src/train_optics_char_gpt2_512.py | 473 ++++++++++++++++++ ...ormula.py => train_optics_char_gpt2_64.py} | 272 +++++++++- src/train_optics_char_gpt2_ff.py | 472 +++++++++++++++++ src/train_optics_char_gpt2_nokoef_128.py | 471 +++++++++++++++++ src/train_optics_char_gpt2_nokoef_256.py | 471 +++++++++++++++++ src/train_optics_char_gpt2_nokoef_512.py | 471 +++++++++++++++++ ...py => train_optics_char_gpt2_nokoef_64.py} | 267 +++++++++- 25 files changed, 6413 insertions(+), 1455 deletions(-) delete mode 100644 src/bpe_main.py delete mode 100644 src/bpe_tokenizer.py delete mode 100644 src/char_gpt2.py delete mode 100644 src/main.py delete mode 100644 src/optics_char_gpt2.py delete mode 100644 src/optics_char_gpt2_ff.py delete mode 100644 src/optics_char_gpt2_nokoef.py delete mode 100644 src/optics_char_gpt2_traindiag.py create mode 100644 src/train_char_gpt2_128.py create mode 100644 src/train_char_gpt2_256.py rename src/{train_gpt2.py => train_char_gpt2_512.py} (99%) create mode 100644 src/train_char_gpt2_64.py create mode 100644 src/train_char_gpt2_koef_128.py create mode 100644 src/train_char_gpt2_koef_256.py create mode 100644 src/train_char_gpt2_koef_512.py create mode 100644 src/train_char_gpt2_koef_64.py create mode 100644 src/train_optics_char_gpt2_128.py create mode 100644 src/train_optics_char_gpt2_256.py create mode 100644 src/train_optics_char_gpt2_512.py rename src/{optics_char_gpt2_new_formula.py => train_optics_char_gpt2_64.py} (51%) create mode 100644 src/train_optics_char_gpt2_ff.py create mode 100644 src/train_optics_char_gpt2_nokoef_128.py create mode 100644 src/train_optics_char_gpt2_nokoef_256.py create mode 100644 src/train_optics_char_gpt2_nokoef_512.py rename src/{optics_char_gpt2_nokoef_newf.py => train_optics_char_gpt2_nokoef_64.py} (51%) diff --git a/src/bpe_main.py b/src/bpe_main.py deleted file mode 100644 index 6bd30e5..0000000 --- a/src/bpe_main.py +++ /dev/null @@ -1,168 +0,0 @@ -import torch -import torch.nn as nn -from torch.nn import functional as F -from einops import rearrange -import numpy as np -from torch.utils.tensorboard import SummaryWriter -from datetime import datetime -import sys -from pathlib import Path -from char_gpt2 import GPT2 -from optics_char_gpt2 import OpticGPT2 -from bpe_tokenizer import byte_pair_init, byte_pair_encode, byte_pair_decode - -seed = 1337 -torch.manual_seed(seed) -models = {'gpt2': GPT2, 'optic_gpt2': OpticGPT2} - -batch_size = 50 -max_iters = 40000*10 -eval_interval = 300 -learning_rate = 1e-3 -device = 'cuda' if torch.cuda.is_available() else 'cpu' -eval_iters = 200 -layers_num = 22 -h_dim = 64 -max_seq_len = 256 -num_heads = 4 -dropout_rate = 0.1 -pixel_size = 3.6e-6 -merges_count = 20 - -# CUDA_VISIBLE_DEVICES=1 python .src/main.py gpt2|optic_gpt2 ./data/wiki.train.tokens ./data/wiki.valid.tokens ./data/wiki.test.tokens comment -MODEL_CLASS = models[sys.argv[1]] -train_data_path = Path(sys.argv[2]) -val_data_path = Path(sys.argv[3]) -test_data_path = Path(sys.argv[4]) -comment = f"bpe_{sys.argv[1]}_{train_data_path.name}_{sys.argv[5]}_{seed}" - -logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' -writer = SummaryWriter(logs_dir) -script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).name) -print("Logs dir:", logs_dir) -script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script -script_snapshot_path.chmod(0o400) # with read-only permission -#################################### Dataset ######################################### - -# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt -with open(train_data_path, encoding='utf-8') as f: - train_text = f.read() - -with open(val_data_path, encoding='utf-8') as f: - val_text = f.read() - -with open(test_data_path, encoding='utf-8') as f: - test_text = f.read() - -text = train_text + val_text + test_text - -chars = sorted(set(text)) -print(f"Len chars: {len(chars)}") - -wtoi = {w:i for i,w in enumerate(chars)} -itow = {i:w for i,w in enumerate(chars)} - -import pickle - -start_time = datetime.now() -if Path("./data/bpe_text.pkl").exists() and Path("./data/merges.pkl").exists(): - with open("./data/bpe_text.pkl", 'rb') as f: bpe_text = pickle.load(f) - with open("./data/merges.pkl", 'rb') as f: merges = pickle.load(f) -else: - bpe_text, merges = byte_pair_init([wtoi[w] for w in text], vocab_size=len(chars), merges_count=20) - with open("./data/bpe_text.pkl", 'wb') as f: pickle.dump(bpe_text, f) - with open("./data/merges.pkl", 'wb') as f: pickle.dump(merges, f) - -print(f"Compression ratio: {len(text)/len(bpe_text)}, init took {datetime.now()-start_time}") - -vocab_size = len(chars) + merges_count - -encode = lambda s: byte_pair_encode([wtoi[w] for w in s], merges) -decode = lambda idx: "".join([itow[i] for i in byte_pair_decode(idx, merges)]) - -def get_batch(data, seq_len, batch_size): - ix = torch.randint(len(data)-seq_len, (batch_size, )) - x = torch.stack([data[i:i+seq_len] for i in ix]).to(device) - y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device) - return x, y - -start_time = datetime.now() -train_bpe_encoded_path = Path("./data/train_bpe_encoded.pt") -val_bpe_encoded_path = Path("./data/val_bpe_encoded.pt") -test_bpe_encoded_path = Path("./data/test_bpe_encoded.pt") -if train_bpe_encoded_path.exists() and val_bpe_encoded_path.exists() and test_bpe_encoded_path.exists(): - train_data = torch.load(train_bpe_encoded_path) - val_data = torch.load(val_bpe_encoded_path) - test_data = torch.load(test_bpe_encoded_path) -else: - train_data = torch.tensor(encode(train_text), dtype=torch.long) - val_data = torch.tensor(encode(val_text), dtype=torch.long) - test_data = torch.tensor(encode(test_text), dtype=torch.long) - torch.save(train_data, train_bpe_encoded_path) - torch.save(val_data, val_bpe_encoded_path) - torch.save(test_data, test_bpe_encoded_path) -print(f"Encoded {datetime.now() - start_time}") - -@torch.no_grad() -def perplexity(model, data): - stride = max(1, len(data) // 10000) - losses = [] - for i in range(0, len(data)-max_seq_len-1, stride): - x = data[i:(i+max_seq_len)].to(device) - y = data[(i+1):(i+max_seq_len+1)].to(device) - logits, loss = model(x[None,...], y[None,...]) - losses.append(loss.item()) - print(f"\rppl {i}/{len(data)-max_seq_len-1}", end="") - return np.exp(np.mean(losses)) - -#################################### Model #########################################mo -def complete(m, start_idxs=[0], max_new_tokens=100): - start_idx = torch.tensor([start_idxs]).to(device) - generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens) - return decode(generated_tokens[0].tolist()) - -m = MODEL_CLASS( - vocab_size=vocab_size, - h_dim=h_dim, - max_seq_len=max_seq_len, - num_heads=num_heads, - pixel_size=pixel_size, - layers_num=layers_num -) -m = m.to(device) -print(m) -#################################### Train ######################################### - -optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) - -completion = complete(m, encode("\n"*max_seq_len), 2*max_seq_len) -print(completion) -writer.add_text('completions', completion, 0) - -for i in range(max_iters): - xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size) - logits, loss = m(xb, yb) - optimizer.zero_grad(set_to_none=True) - loss.backward() - optimizer.step() - writer.add_scalar('loss', loss.item(), i) - print(f"\r{i}/{max_iters} {loss.item()}", end="") - if i % 5000 == 0: - ppl = perplexity(model=m, data=val_data) - writer.add_scalar('val_perplexity', ppl.item(), i) - print(f"\rPerplexity at {i}: {ppl}") - writer.add_text('completions', complete(m, encode("\n"*max_seq_len), 2*max_seq_len), i) - -ppl = perplexity(model=m, data=val_data) -print(f"\r{i+1}/{max_iters} {loss.item()}") -print(f"\rPerplexity at {i}: {ppl}") -writer.add_scalar('val_perplexity', ppl.item(), i+1) -writer.add_scalar('loss', loss.item(), i) - -ppl = perplexity(model=m, data=test_data) -writer.add_scalar('test_perplexity', ppl.item(), i+1) -print(f"\rTest Perplexity at {i}: {ppl}") - -completion = complete(m, encode("\n"*max_seq_len), 2*max_seq_len) -print(completion) -writer.add_text('completions', completion, i+1) \ No newline at end of file diff --git a/src/bpe_tokenizer.py b/src/bpe_tokenizer.py deleted file mode 100644 index d66748a..0000000 --- a/src/bpe_tokenizer.py +++ /dev/null @@ -1,116 +0,0 @@ -import numpy as np -import pandas as pd - -def get_top_pair(tokens): - hist = pd.DataFrame(np.vstack([tokens[:-1], tokens[1:]]).T).value_counts().reset_index().astype(np.int32) - return list(hist.nlargest(1, columns='count').iloc[0, [0,1]]) - -def merge(tokens, pair, new_idx): - new_tokens = [] - skip = False - for a,b in zip(tokens[:-1], tokens[1:]): - if skip: - skip = False - continue - if a == pair[0] and b == pair[1]: - new_tokens.append(new_idx) - skip = True - else: - new_tokens.append(a) - if not skip: - new_tokens.append(b) - return np.array(new_tokens) - -def unmerge(tokens, pair_idx, pair): - new_tokens = [] - for idx in tokens: - if idx == pair_idx: - new_tokens.append(pair[0]) - new_tokens.append(pair[1]) - else: - new_tokens.append(idx) - return new_tokens - -def byte_pair_init(char_ids, vocab_size, merges_count=20): - byte_text = np.array(char_ids) - merges = [] - for i in range(merges_count): - top_pair = get_top_pair(byte_text) - new_idx = vocab_size + i - merges.append([top_pair, new_idx]) - print(f"{top_pair} {new_idx}") - byte_text = merge(byte_text, top_pair, new_idx) - return np.array(byte_text), merges - -def byte_pair_encode(char_ids, merges): - tokens = np.array(char_ids) - for pair, pair_idx in merges: - tokens = merge(tokens, pair, pair_idx) - return tokens - -def byte_pair_decode(tokens, merges): - for pair, pair_idx in reversed(merges): - tokens = unmerge(tokens, pair_idx, pair) - return tokens - - -# def get_top_pair(tokens): -# hist = pd.DataFrame(np.vstack([tokens[:-1], tokens[1:]]).T).value_counts().reset_index().astype(np.uint16) -# return np.array(hist.nlargest(1, columns='count').iloc[0, [0,1]]) - -# def merge(tokens, pair, new_idx): -# if len(tokens) % 2 != 0: -# tokens = np.append(tokens, np.array([0], dtype=np.uint16)) -# # print("not even") -# a = np.frombuffer(bytes(tokens), dtype=np.uint32).copy() -# b = np.frombuffer(bytes(pair), dtype=np.uint32) -# c = np.frombuffer(bytes(np.array([2**16-1, new_idx], dtype=np.uint16)), dtype=np.uint32) -# a[a==b] = c -# d = np.frombuffer(bytes(a), dtype=np.uint16) -# indices = np.where(d == 2**16-1) -# e = np.delete(d, indices) -# e = e[:-1] -# else: -# # print("even") -# a = np.frombuffer(bytes(tokens), dtype=np.uint32).copy() -# b = np.frombuffer(bytes(pair), dtype=np.uint32) -# c = np.frombuffer(bytes(np.array([2**16-1, new_idx], dtype=np.uint16)), dtype=np.uint32) -# a[a==b] = c -# d = np.frombuffer(bytes(a), dtype=np.uint16) -# indices = np.where(d == 2**16-1) -# e = np.delete(d, indices) -# return e - -# def unmerge(tokens, pair_idx, pair): -# new_tokens = [] -# for idx in tokens: -# if idx == pair_idx: -# new_tokens.append(pair[0]) -# new_tokens.append(pair[1]) -# else: -# new_tokens.append(idx) -# return new_tokens - -# def byte_pair_init(char_ids, vocab_size, merges_count=20): -# assert vocab_size < 2**16 -# byte_text = np.array(char_ids, dtype=np.uint16) -# merges = [] -# for i in range(merges_count): -# top_pair = get_top_pair(byte_text) -# new_idx = vocab_size + i -# print([top_pair, new_idx]) -# merges.append([top_pair, new_idx]) -# byte_text = merge(byte_text, top_pair, new_idx) -# byte_text = np.roll(merge(np.roll(byte_text, 1), top_pair, new_idx), -1) -# return byte_text, merges - -# def byte_pair_encode(char_ids, merges): -# tokens = np.array(char_ids, dtype=np.uint16) -# for pair, pair_idx in merges: -# tokens = merge(tokens, pair, pair_idx) -# return tokens - -# def byte_pair_decode(tokens, merges): -# for pair, pair_idx in reversed(merges): -# tokens = unmerge(tokens, pair_idx, pair) -# return tokens \ No newline at end of file diff --git a/src/char_gpt2.py b/src/char_gpt2.py deleted file mode 100644 index c1cd2cb..0000000 --- a/src/char_gpt2.py +++ /dev/null @@ -1,115 +0,0 @@ -import torch -import torch.nn as nn -from torch.nn import functional as F -from einops import rearrange -import numpy as np -from torch.utils.tensorboard import SummaryWriter -from datetime import datetime -import sys -from pathlib import Path -torch.manual_seed(1337) - -#################################### Model ######################################### - - -# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 -class RoPE(nn.Module): - def __init__(self, dim, max_seq_len=512): - super().__init__() - inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) - t = torch.arange(max_seq_len).type_as(inv_freq) - freqs = torch.einsum('i,j->ij', t, inv_freq) - self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) - - def rotate_half(self, x): - x1, x2 = x.chunk(2, dim=-1) - return torch.cat((-x2, x1), dim=-1) - - def forward(self, x, offset=0): - seq_len = x.size(1) - emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) - cos = emb.cos() - sin = emb.sin() - return (x * cos) + (self.rotate_half(x) * sin) - -# Transformers without Normalization https://jiachenzhu.github.io/DyT/ -class DyT(nn.Module): - def __init__(self, num_features, alpha_init_value=0.5): - super().__init__() - self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) - self.weight = nn.Parameter(torch.ones(num_features)) - self.bias = nn.Parameter(torch.zeros(num_features)) - - def forward(self, x): - x = torch.tanh(self.alpha * x) - return x * self.weight + self.bias - -# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 -# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 -class TransformerLayer(nn.Module): - def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): - super().__init__() - self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) - self.q_proj = nn.Linear(h_dim, h_dim) - self.k_proj = nn.Linear(h_dim, h_dim) - self.v_proj = nn.Linear(h_dim, h_dim) - self.o_proj = nn.Linear(h_dim, h_dim) - self.ff1 = nn.Linear(h_dim, 4*h_dim) - self.ff2 = nn.Linear(4*h_dim, h_dim) - self.ln1 = DyT(h_dim) - self.ln2 = DyT(h_dim) - self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) - - def split_to_heads(self, x, B, T, H): - if self.num_heads <= 1: return x - return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) - - def gather_heads(self, x, B, T, H): - if self.num_heads <= 1: return x - return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) - - def attention(self, x): - q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) - k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) - v = self.split_to_heads(self.v_proj(x), *x.shape) - scores = (q @ k.transpose(1, 2)) * (self.h_dim ** -0.5) - tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) - scores = scores.masked_fill(tril == 0, float('-inf')) # encoder does not need this line - attention = nn.functional.softmax(scores, dim=2) - return self.o_proj(self.gather_heads(attention @ v, *x.shape)) - - def forward(self, x): - x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) - x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) - return x - -class GPT2(nn.Module): - def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, pixel_size=None): - super().__init__() - self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) - self.layers = nn.ModuleList([ - TransformerLayer(h_dim=self.h_dim, num_heads=self.num_heads, dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) - for _ in range(layers_num)]) - self.tok_embeds = nn.Embedding(vocab_size, h_dim) - self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) - self.lm_head = nn.Linear(h_dim, vocab_size) - - def forward(self, x, targets=None): - x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] - for l in self.layers: - x = l(x) - logits = self.lm_head(x) # B,T,C - loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None - return logits, loss - - # what is the purpose? autoregressive inference! - def generate(self, start_idx, max_new_tokens): - idx = start_idx - for i in range(max_new_tokens): - idx_cond = idx[:,-self.max_seq_len:] - logits, loss = self(idx_cond) - logits = logits[:,-1,:] # B, C - probs = F.softmax(logits, dim=-1) - idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) - idx = torch.cat([idx, idx_next], dim=1) - return idx \ No newline at end of file diff --git a/src/main.py b/src/main.py deleted file mode 100644 index d6a06c4..0000000 --- a/src/main.py +++ /dev/null @@ -1,168 +0,0 @@ -import torch -import torch.nn as nn -from torch.nn import functional as F -from einops import rearrange -import numpy as np -from torch.utils.tensorboard import SummaryWriter -from datetime import datetime -import sys -from pathlib import Path -from char_gpt2 import GPT2 -from optics_char_gpt2 import OpticGPT2 -from optics_char_gpt2_traindiag import OpticGPT2TrainDiag -from optics_char_gpt2_ff import OpticGPT2FF -from optics_char_gpt2_new_formula import OpticGPT2NewFormula -from char_gpt2_scaledmatmul import GPT2ScaledMM -from optics_char_gpt2_nokoef import OpticGPT2NOKoef -from optics_char_gpt2_nokoef_newf import OpticGPT2NOKoefNewF -import shutil -seed = 1337 -torch.manual_seed(seed) -models = { - 'gpt2': GPT2, - 'optic_gpt2': OpticGPT2, - 'optic_gpt2_ff': OpticGPT2FF, - 'optic_gpt2_traindiag': OpticGPT2TrainDiag, - 'optic_gpt2_newformula': OpticGPT2NewFormula, - 'optic_gpt2_nokoef': OpticGPT2NOKoef, - 'optic_gpt2_nokoef_newformula': OpticGPT2NOKoefNewF, - 'gpt2_scaledmm': GPT2ScaledMM -} - -batch_size = 50 -gradient_accumulation_steps = 5 # check this impl for correctness https://unsloth.ai/blog/gradient -max_iters = 40000 -eval_interval = 300 -learning_rate = 1e-3 -device = 'cuda' if torch.cuda.is_available() else 'cpu' -eval_iters = 200 -layers_num = 2 -h_dim = 64 -max_seq_len = 512 -num_heads = 1 -dropout_rate = 0.1 -pixel_size = 3.6e-6 -assert batch_size % gradient_accumulation_steps == 0 -# CUDA_VISIBLE_DEVICES=1 python src/main.py optic_gpt2_ff ./data/wiki.train.tokens ./data/wiki.valid.tokens ./data/wiki.test.tokens seq_128_hdim_64 - -MODEL_CLASS = models[sys.argv[1]] -train_data_path = Path(sys.argv[2]) -val_data_path = Path(sys.argv[3]) -test_data_path = Path(sys.argv[4]) -comment = f"{sys.argv[1]}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[5] if len(sys.argv) >= 6 else ''}" - -logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' -writer = SummaryWriter(logs_dir) -script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name) -print("Logs dir:", logs_dir) -# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script -shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository -script_snapshot_path.chmod(0o500) # with read-only permission - -#################################### Dataset ######################################### - -# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt -with open(train_data_path, encoding='utf-8') as f: - train_text = f.read() - -with open(val_data_path, encoding='utf-8') as f: - val_text = f.read() - -with open(test_data_path, encoding='utf-8') as f: - test_text = f.read() - -text = train_text + val_text + test_text - -chars = sorted(set(text)) -vocab_size = len(chars) - -wtoi = {w:i for i,w in enumerate(chars)} -itow = {i:w for i,w in enumerate(chars)} - -encode = lambda s: [wtoi[w] for w in s] -decode = lambda idx: ''.join([itow[i] for i in idx]) - -def get_batch(data, seq_len, batch_size): - ix = torch.randint(len(data)-seq_len, (batch_size, )) - x = torch.stack([data[i:i+seq_len] for i in ix]).to(device) - y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device) - return x, y - -train_data = torch.tensor(encode(train_text), dtype=torch.long) -val_data = torch.tensor(encode(val_text), dtype=torch.long) -test_data = torch.tensor(encode(test_text), dtype=torch.long) - -@torch.no_grad() -def perplexity(model, data): - stride = max(1, len(data) // 10000) - losses = [] - for i in range(0, len(data)-max_seq_len-1, stride): - x = data[i:(i+max_seq_len)].to(device) - y = data[(i+1):(i+max_seq_len+1)].to(device) - logits, loss = model(x[None,...], y[None,...]) - losses.append(loss.item()) - print(f"\rppl {i}/{len(data)-max_seq_len-1}", end="") - return np.exp(np.mean(losses)) - -#################################### Model #########################################mo -def complete(m, start_idxs=[0], max_new_tokens=100): - start_idx = torch.tensor([start_idxs]).to(device) - generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens) - return decode(generated_tokens[0].tolist()) - -m = MODEL_CLASS( - vocab_size=vocab_size, - h_dim=h_dim, - max_seq_len=max_seq_len, - num_heads=num_heads, - pixel_size=pixel_size, - layers_num=layers_num -) -m = m.to(device) -writer.add_text('model', str(m), 0) - -#################################### Train ######################################### - -optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) - -m.eval() -completion = complete(m, encode("\n"*max_seq_len), 2*max_seq_len) -print(completion) -writer.add_text('completions', completion, 0) - -for i in range(max_iters): - m.train() - optimizer.zero_grad(set_to_none=True) - accumulated_loss = 0.0 - for j in range(gradient_accumulation_steps): - xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps) - logits, loss = m(xb, yb) - loss = loss / gradient_accumulation_steps - loss.backward() - accumulated_loss += loss.item() - if i % 100 == 0: - writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i) - optimizer.step() - writer.add_scalar('loss', accumulated_loss, i) - print(f"\r{i}/{max_iters} {accumulated_loss}", end="") - if i % 5000 == 0: - m.eval() - ppl = perplexity(model=m, data=val_data) - writer.add_scalar('val_perplexity', ppl.item(), i) - print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") - writer.add_text('completions', complete(m, encode("\n"*max_seq_len), 2*max_seq_len), i) - -m.eval() -ppl = perplexity(model=m, data=val_data) -print(f"\r{i+1}/{max_iters} {accumulated_loss}") -print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") -writer.add_scalar('val_perplexity', ppl.item(), i+1) -writer.add_scalar('loss', accumulated_loss, i) - -ppl = perplexity(model=m, data=test_data) -writer.add_scalar('test_perplexity', ppl.item(), i+1) -print(f"\rTest Perplexity at {i}: {ppl}") - -completion = complete(m, encode("\n"*max_seq_len), 2*max_seq_len) -print(completion) -writer.add_text('completions', completion, i+1) \ No newline at end of file diff --git a/src/optics_char_gpt2.py b/src/optics_char_gpt2.py deleted file mode 100644 index 3989ed6..0000000 --- a/src/optics_char_gpt2.py +++ /dev/null @@ -1,211 +0,0 @@ -import torch -import torch.nn as nn -from torch.nn import functional as F -from einops import rearrange -import optical_matrix_multiplication as omm -from optical_matrix_multiplication import propagator -import numpy as np -from torch.utils.tensorboard import SummaryWriter -from datetime import datetime -from pathlib import Path -import sys -torch.manual_seed(1337) - -#################################### Model ######################################### - -def norm(matrix: torch.Tensor, max_val: float = 1) -> torch.Tensor: - return matrix / (max_val + 1e-10) - -def optics_matmul_shift(sim, tensor_1, tensor_2): - tensor_1 = tensor_1[None,:,:,:] - tensor_2 = tensor_2[None,:,:,:] - - if torch.min(tensor_1) >= 0 and torch.min(tensor_2) >= 0: - max_abs = abs(max(torch.max(tensor_1), torch.max(tensor_2))) - a, b = norm(tensor_1, max_abs), norm(tensor_2, max_abs) - return sim(a, b)[0,:,:,:] * max_abs **2 - - min_abs = abs(min(torch.min(tensor_1), torch.min(tensor_2))) - max_abs = abs(max(torch.max(tensor_1), torch.max(tensor_2))) + min_abs - - shift_a = min_abs * torch.ones(tensor_1.shape).to(tensor_1.device) - shift_b = min_abs * torch.ones(tensor_2.shape).to(tensor_1.device) - a_a_sh = tensor_1 + shift_a - b_b_sh = tensor_2 + shift_b - - a_a_sh_norm, b_b_sh_norm = norm(a_a_sh, max_abs), norm(b_b_sh, max_abs) - shift_a_norm, shift_b_norm = norm(shift_a, max_abs), norm(shift_b, max_abs) - a_a_sh_b_b_sh = sim(a_a_sh_norm, b_b_sh_norm) - a_a_sh_b_sh = sim(a_a_sh_norm, shift_b_norm) - a_sh_b_b_sh = sim(shift_a_norm, b_b_sh_norm) - a_sh_b_sh = sim(shift_a_norm, shift_b_norm) - - return (a_a_sh_b_b_sh - a_a_sh_b_sh - a_sh_b_b_sh + a_sh_b_sh)[0,:,:,:] * max_abs ** 2 - -# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 -class RoPE(nn.Module): - def __init__(self, dim, max_seq_len=512): - super().__init__() - inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) - t = torch.arange(max_seq_len).type_as(inv_freq) - freqs = torch.einsum('i,j->ij', t, inv_freq) - self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) - - def rotate_half(self, x): - x1, x2 = x.chunk(2, dim=-1) - return torch.cat((-x2, x1), dim=-1) - - def forward(self, x, offset=0): - seq_len = x.size(1) - emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) - cos = emb.cos() - sin = emb.sin() - return (x * cos) + (self.rotate_half(x) * sin) - -# Transformers without Normalization https://jiachenzhu.github.io/DyT/ -class DyT(nn.Module): - def __init__(self, num_features, alpha_init_value=0.5): - super().__init__() - self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) - self.weight = nn.Parameter(torch.ones(num_features)) - self.bias = nn.Parameter(torch.zeros(num_features)) - - def forward(self, x): - x = torch.tanh(self.alpha * x) - return x * self.weight + self.bias - -# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 -# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 -class TransformerLayer(nn.Module): - def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): - super().__init__() - self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) - self.q_proj = nn.Linear(h_dim, h_dim) - self.k_proj = nn.Linear(h_dim, h_dim) - self.v_proj = nn.Linear(h_dim, h_dim) - self.o_proj = nn.Linear(h_dim, h_dim) - self.ff1 = nn.Linear(h_dim, 4*h_dim) - self.ff2 = nn.Linear(4*h_dim, h_dim) - self.ln1 = DyT(h_dim) - self.ln2 = DyT(h_dim) - self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) - self.k1 = nn.Parameter(torch.randn(1)) - self.k2 = nn.Parameter(torch.randn(1)) - - def split_to_heads(self, x, B, T, H): - if self.num_heads <= 1: return x - return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) - - def gather_heads(self, x, B, T, H): - if self.num_heads <= 1: return x - return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) - - def attention(self, x): - q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) - k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) - v = self.split_to_heads(self.v_proj(x), *x.shape) - scores = self.k1 * optics_matmul_shift(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5) - tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) - scores = scores.masked_fill(tril == 0, float('-inf')) - attention = nn.functional.softmax(scores, dim=2) - output = self.k2 * optics_matmul_shift(self.sim_output, attention, v) - return self.o_proj(self.gather_heads(output, *x.shape)) - - def forward(self, x): - x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) - x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) - return x - -class OpticGPT2(nn.Module): - def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, - pixel_size = 3.6e-6): - super().__init__() - self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) - if max_seq_len < 512: - self.sim_scores = omm.OpticalMul( - omm.Config(right_matrix_count_columns = max_seq_len, - right_matrix_count_rows = h_dim // num_heads, - right_matrix_width = pixel_size * max_seq_len, - right_matrix_height = pixel_size * (h_dim // num_heads), - min_height_gap = pixel_size, - right_matrix_split_x = 2, - right_matrix_split_y = 2, - left_matrix_split_x = 2, - left_matrix_split_y = 2, - result_matrix_split = 2, - distance = 0.01, - trainable_cylind_lens=False) - ) - self.sim_output = omm.OpticalMul( - omm.Config(right_matrix_count_columns = h_dim // num_heads, - right_matrix_count_rows = max_seq_len, - right_matrix_width = pixel_size * (h_dim // num_heads), - right_matrix_height = pixel_size * max_seq_len, - min_height_gap = pixel_size, - right_matrix_split_x = 2, - right_matrix_split_y = 2, - left_matrix_split_x = 2, - left_matrix_split_y = 2, - result_matrix_split = 2, - distance = 0.01, - trainable_cylind_lens=False) - ) - if max_seq_len >= 512: - self.sim_scores = omm.OpticalMul( - omm.Config(right_matrix_count_columns = max_seq_len, - right_matrix_count_rows = h_dim // num_heads, - right_matrix_width = pixel_size * max_seq_len, - right_matrix_height = pixel_size * (h_dim // num_heads), - min_height_gap = pixel_size, - right_matrix_split_x = 2, - right_matrix_split_y = 2, - left_matrix_split_x = 2, - left_matrix_split_y = 2, - result_matrix_split = 2, - distance = 0.15, - lens_size = 8192 * 2, - trainable_cylind_lens=False) - ) - self.sim_output = omm.OpticalMul( - omm.Config(right_matrix_count_columns = h_dim // num_heads, - right_matrix_count_rows = max_seq_len, - right_matrix_width = pixel_size * (h_dim // num_heads), - right_matrix_height = pixel_size * max_seq_len, - min_height_gap = pixel_size, - right_matrix_split_x = 2, - right_matrix_split_y = 2, - left_matrix_split_x = 2, - left_matrix_split_y = 2, - result_matrix_split = 2, - distance = 0.15, - lens_size = 8192 * 2, - trainable_cylind_lens=False) - ) - - self.layers = nn.ModuleList([ - TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads, - dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) - for _ in range(layers_num)]) - self.tok_embeds = nn.Embedding(vocab_size, h_dim) - self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) - self.lm_head = nn.Linear(h_dim, vocab_size) - - def forward(self, x, targets=None): - x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] - for l in self.layers: - x = l(x) - logits = self.lm_head(x) # B,T,C - loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None - return logits, loss - - # what is the purpose? autoregressive inference! - def generate(self, start_idx, max_new_tokens): - idx = start_idx - for i in range(max_new_tokens): - idx_cond = idx[:,-self.max_seq_len:] - logits, loss = self(idx_cond) - logits = logits[:,-1,:] # B, C - probs = F.softmax(logits, dim=-1) - idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) - idx = torch.cat([idx, idx_next], dim=1) - return idx \ No newline at end of file diff --git a/src/optics_char_gpt2_ff.py b/src/optics_char_gpt2_ff.py deleted file mode 100644 index 59bbd93..0000000 --- a/src/optics_char_gpt2_ff.py +++ /dev/null @@ -1,217 +0,0 @@ -import torch -import torch.nn as nn -from torch.nn import functional as F, init -from einops import rearrange -import optical_matrix_multiplication as omm -from optical_matrix_multiplication import propagator -import numpy as np -from torch.utils.tensorboard import SummaryWriter -from datetime import datetime -from pathlib import Path -import sys -import math -torch.manual_seed(1337) - -#################################### Model ######################################### - -def norm(matrix: torch.Tensor, max_val: float = 1) -> torch.Tensor: - return matrix / (max_val + 1e-10) - -def optics_matmul_shift(sim, tensor_1, tensor_2): - # print(tensor_1.shape, tensor_2.shape) - - tensor_1 = tensor_1[None,:,:,:] - tensor_2 = tensor_2[None,None,:,:] - # print(tensor_1.shape, tensor_2.shape) - # raise RuntimeError - - if torch.min(tensor_1) >= 0 and torch.min(tensor_2) >= 0: - max_abs = abs(max(torch.max(tensor_1), torch.max(tensor_2))) - a, b = norm(tensor_1, max_abs), norm(tensor_2, max_abs) - return sim(a, b)[0,:,:,:] * max_abs **2 - - min_abs = abs(min(torch.min(tensor_1), torch.min(tensor_2))) - max_abs = abs(max(torch.max(tensor_1), torch.max(tensor_2))) + min_abs - - shift_a = min_abs * torch.ones(tensor_1.shape).to(tensor_1.device) - shift_b = min_abs * torch.ones(tensor_2.shape).to(tensor_1.device) - a_a_sh = tensor_1 + shift_a - b_b_sh = tensor_2 + shift_b - - a_a_sh_norm, b_b_sh_norm = norm(a_a_sh, max_abs), norm(b_b_sh, max_abs) - shift_a_norm, shift_b_norm = norm(shift_a, max_abs), norm(shift_b, max_abs) - a_a_sh_b_b_sh = sim(a_a_sh_norm, b_b_sh_norm) - a_a_sh_b_sh = sim(a_a_sh_norm, shift_b_norm) - a_sh_b_b_sh = sim(shift_a_norm, b_b_sh_norm) - a_sh_b_sh = sim(shift_a_norm, shift_b_norm) - - return (a_a_sh_b_b_sh - a_a_sh_b_sh - a_sh_b_b_sh + a_sh_b_sh)[0,:,:,:] * max_abs ** 2 - -# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 -class RoPE(nn.Module): - def __init__(self, dim, max_seq_len=512): - super().__init__() - inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) - t = torch.arange(max_seq_len).type_as(inv_freq) - freqs = torch.einsum('i,j->ij', t, inv_freq) - self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) - - def rotate_half(self, x): - x1, x2 = x.chunk(2, dim=-1) - return torch.cat((-x2, x1), dim=-1) - - def forward(self, x, offset=0): - seq_len = x.size(1) - emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) - cos = emb.cos() - sin = emb.sin() - return (x * cos) + (self.rotate_half(x) * sin) - -# Transformers without Normalization https://jiachenzhu.github.io/DyT/ -class DyT(nn.Module): - def __init__(self, num_features, alpha_init_value=0.5): - super().__init__() - self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) - self.weight = nn.Parameter(torch.ones(num_features)) - self.bias = nn.Parameter(torch.zeros(num_features)) - - def forward(self, x): - x = torch.tanh(self.alpha * x) - return x * self.weight + self.bias - -class OpticLinear(nn.Module): - def __init__( - self, - in_features, - out_features, - bias = True, - device = None, - dtype = None, - pixel_size = 3.6e-6 - ) -> None: - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.in_features = in_features - self.out_features = out_features - self.weight = nn.Parameter( - torch.empty((in_features, out_features), **factory_kwargs) - ) - if bias: - self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs)) - else: - self.register_parameter("bias", None) - self.k = nn.Parameter(torch.randn(1)) - self.sim = omm.OpticalMul( - omm.Config( - right_matrix_count_columns = out_features , - right_matrix_count_rows = in_features, - right_matrix_width = pixel_size * out_features , - right_matrix_height = pixel_size * in_features, - min_height_gap = pixel_size, - right_matrix_split_x = 2, - right_matrix_split_y = 2, - left_matrix_split_x = 2, - left_matrix_split_y = 2, - result_matrix_split = 2, - distance = 0.01) - ) - self.reset_parameters() - - def forward(self, input): - """ - Runs the forward pass. - """ - return self.k * optics_matmul_shift(self.sim, input, self.weight) + self.bias - - def reset_parameters(self) -> None: - """ - Resets parameters based on their initialization used in ``__init__``. - """ - # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with - # uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see - # https://github.com/pytorch/pytorch/issues/57109 - init.kaiming_uniform_(self.weight, a=math.sqrt(5)) - if self.bias is not None: - fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) - bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 - init.uniform_(self.bias, -bound, bound) - - def extra_repr(self) -> str: - """ - Return the extra representation of the module. - """ - return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}" - - - -# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 -# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 -class TransformerLayer(nn.Module): - def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): - super().__init__() - self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) - self.q_proj = nn.Linear(h_dim, h_dim) - self.k_proj = nn.Linear(h_dim, h_dim) - self.v_proj = nn.Linear(h_dim, h_dim) - self.o_proj = nn.Linear(h_dim, h_dim) - self.ff1 = OpticLinear(h_dim, 4*h_dim) - self.ff2 = OpticLinear(4*h_dim, h_dim) - self.ln1 = DyT(h_dim) - self.ln2 = DyT(h_dim) - self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) - - def split_to_heads(self, x, B, T, H): - if self.num_heads <= 1: return x - return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) - - def gather_heads(self, x, B, T, H): - if self.num_heads <= 1: return x - return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) - - def attention(self, x): - q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) - k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) - v = self.split_to_heads(self.v_proj(x), *x.shape) - scores = (q @ k.transpose(1, 2)) * (self.h_dim ** -0.5) - tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) - scores = scores.masked_fill(tril == 0, float('-inf')) # encoder does not need this line - attention = nn.functional.softmax(scores, dim=2) - return self.o_proj(self.gather_heads(attention @ v, *x.shape)) - - def forward(self, x): - x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) - x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) - return x - -class OpticGPT2FF(nn.Module): - def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, - pixel_size = 3.6e-6): - super().__init__() - self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) - self.layers = nn.ModuleList([ - TransformerLayer(h_dim=self.h_dim, num_heads=self.num_heads, - dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) - for _ in range(layers_num)]) - self.tok_embeds = nn.Embedding(vocab_size, h_dim) - self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) - self.lm_head = nn.Linear(h_dim, vocab_size) - - def forward(self, x, targets=None): - x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] - for l in self.layers: - x = l(x) - logits = self.lm_head(x) # B,T,C - loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None - return logits, loss - - # what is the purpose? autoregressive inference! - def generate(self, start_idx, max_new_tokens): - idx = start_idx - for i in range(max_new_tokens): - idx_cond = idx[:,-self.max_seq_len:] - logits, loss = self(idx_cond) - logits = logits[:,-1,:] # B, C - probs = F.softmax(logits, dim=-1) - idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) - idx = torch.cat([idx, idx_next], dim=1) - return idx \ No newline at end of file diff --git a/src/optics_char_gpt2_nokoef.py b/src/optics_char_gpt2_nokoef.py deleted file mode 100644 index 0cd65d3..0000000 --- a/src/optics_char_gpt2_nokoef.py +++ /dev/null @@ -1,211 +0,0 @@ -import torch -import torch.nn as nn -from torch.nn import functional as F -from einops import rearrange -import optical_matrix_multiplication as omm -from optical_matrix_multiplication import propagator -import numpy as np -from torch.utils.tensorboard import SummaryWriter -from datetime import datetime -from pathlib import Path -import sys -torch.manual_seed(1337) - -#################################### Model ######################################### - -def norm(matrix: torch.Tensor, max_val: float = 1) -> torch.Tensor: - return matrix / (max_val + 1e-10) - -def optics_matmul_shift(sim, tensor_1, tensor_2): - tensor_1 = tensor_1[None,:,:,:] - tensor_2 = tensor_2[None,:,:,:] - - if torch.min(tensor_1) >= 0 and torch.min(tensor_2) >= 0: - max_abs = abs(max(torch.max(tensor_1), torch.max(tensor_2))) - a, b = norm(tensor_1, max_abs), norm(tensor_2, max_abs) - return sim(a, b)[0,:,:,:] * max_abs **2 - - min_abs = abs(min(torch.min(tensor_1), torch.min(tensor_2))) - max_abs = abs(max(torch.max(tensor_1), torch.max(tensor_2))) + min_abs - - shift_a = min_abs * torch.ones(tensor_1.shape).to(tensor_1.device) - shift_b = min_abs * torch.ones(tensor_2.shape).to(tensor_1.device) - a_a_sh = tensor_1 + shift_a - b_b_sh = tensor_2 + shift_b - - a_a_sh_norm, b_b_sh_norm = norm(a_a_sh, max_abs), norm(b_b_sh, max_abs) - shift_a_norm, shift_b_norm = norm(shift_a, max_abs), norm(shift_b, max_abs) - a_a_sh_b_b_sh = sim(a_a_sh_norm, b_b_sh_norm) - a_a_sh_b_sh = sim(a_a_sh_norm, shift_b_norm) - a_sh_b_b_sh = sim(shift_a_norm, b_b_sh_norm) - a_sh_b_sh = sim(shift_a_norm, shift_b_norm) - - return (a_a_sh_b_b_sh - a_a_sh_b_sh - a_sh_b_b_sh + a_sh_b_sh)[0,:,:,:] * max_abs ** 2 - -# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 -class RoPE(nn.Module): - def __init__(self, dim, max_seq_len=512): - super().__init__() - inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) - t = torch.arange(max_seq_len).type_as(inv_freq) - freqs = torch.einsum('i,j->ij', t, inv_freq) - self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) - - def rotate_half(self, x): - x1, x2 = x.chunk(2, dim=-1) - return torch.cat((-x2, x1), dim=-1) - - def forward(self, x, offset=0): - seq_len = x.size(1) - emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) - cos = emb.cos() - sin = emb.sin() - return (x * cos) + (self.rotate_half(x) * sin) - -# Transformers without Normalization https://jiachenzhu.github.io/DyT/ -class DyT(nn.Module): - def __init__(self, num_features, alpha_init_value=0.5): - super().__init__() - self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) - self.weight = nn.Parameter(torch.ones(num_features)) - self.bias = nn.Parameter(torch.zeros(num_features)) - - def forward(self, x): - x = torch.tanh(self.alpha * x) - return x * self.weight + self.bias - -# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 -# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 -class TransformerLayer(nn.Module): - def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): - super().__init__() - self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) - self.q_proj = nn.Linear(h_dim, h_dim) - self.k_proj = nn.Linear(h_dim, h_dim) - self.v_proj = nn.Linear(h_dim, h_dim) - self.o_proj = nn.Linear(h_dim, h_dim) - self.ff1 = nn.Linear(h_dim, 4*h_dim) - self.ff2 = nn.Linear(4*h_dim, h_dim) - self.ln1 = DyT(h_dim) - self.ln2 = DyT(h_dim) - self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) - - def split_to_heads(self, x, B, T, H): - if self.num_heads <= 1: return x - return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) - - def gather_heads(self, x, B, T, H): - if self.num_heads <= 1: return x - return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) - - def attention(self, x): - q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) - k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) - v = self.split_to_heads(self.v_proj(x), *x.shape) - scores = optics_matmul_shift(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5) - tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) - scores = scores.masked_fill(tril == 0, float('-inf')) - attention = nn.functional.softmax(scores, dim=2) - output = optics_matmul_shift(self.sim_output, attention, v) - return self.o_proj(self.gather_heads(output, *x.shape)) - - def forward(self, x): - x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) - x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) - return x - -class OpticGPT2NOKoef(nn.Module): - def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, - pixel_size = 3.6e-6): - super().__init__() - self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) - if max_seq_len < 512: - self.sim_scores = omm.OpticalMul( - omm.Config(right_matrix_count_columns = max_seq_len, - right_matrix_count_rows = h_dim // num_heads, - right_matrix_width = pixel_size * max_seq_len, - right_matrix_height = pixel_size * (h_dim // num_heads), - min_height_gap = pixel_size, - right_matrix_split_x = 2, - right_matrix_split_y = 2, - left_matrix_split_x = 2, - left_matrix_split_y = 2, - result_matrix_split = 2, - distance = 0.01, - trainable_cylind_lens=False) - ) - self.sim_output = omm.OpticalMul( - omm.Config(right_matrix_count_columns = h_dim // num_heads, - right_matrix_count_rows = max_seq_len, - right_matrix_width = pixel_size * (h_dim // num_heads), - right_matrix_height = pixel_size * max_seq_len, - min_height_gap = pixel_size, - right_matrix_split_x = 2, - right_matrix_split_y = 2, - left_matrix_split_x = 2, - left_matrix_split_y = 2, - result_matrix_split = 2, - distance = 0.01, - trainable_cylind_lens=False) - ) - if max_seq_len >= 512: - self.sim_scores = omm.OpticalMul( - omm.Config(right_matrix_count_columns = max_seq_len, - right_matrix_count_rows = h_dim // num_heads, - right_matrix_width = pixel_size * max_seq_len, - right_matrix_height = pixel_size * (h_dim // num_heads), - min_height_gap = pixel_size, - right_matrix_split_x = 2, - right_matrix_split_y = 2, - left_matrix_split_x = 2, - left_matrix_split_y = 2, - result_matrix_split = 2, - distance = 0.15, - lens_size = 8192 * 2, - trainable_cylind_lens=False) - ) - self.sim_output = omm.OpticalMul( - omm.Config(right_matrix_count_columns = h_dim // num_heads, - right_matrix_count_rows = max_seq_len, - right_matrix_width = pixel_size * (h_dim // num_heads), - right_matrix_height = pixel_size * max_seq_len, - min_height_gap = pixel_size, - right_matrix_split_x = 2, - right_matrix_split_y = 2, - left_matrix_split_x = 2, - left_matrix_split_y = 2, - result_matrix_split = 2, - distance = 0.15, - lens_size = 8192 * 2, - trainable_cylind_lens=False) - ) - self.sim_scores = omm.ScatterDataParallel(self.sim_scores) - self.sim_output = omm.ScatterDataParallel(self.sim_output) - - self.layers = nn.ModuleList([ - TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads, - dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) - for _ in range(layers_num)]) - self.tok_embeds = nn.Embedding(vocab_size, h_dim) - self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) - self.lm_head = nn.Linear(h_dim, vocab_size) - - def forward(self, x, targets=None): - x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] - for l in self.layers: - x = l(x) - logits = self.lm_head(x) # B,T,C - loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None - return logits, loss - - # what is the purpose? autoregressive inference! - def generate(self, start_idx, max_new_tokens): - idx = start_idx - for i in range(max_new_tokens): - idx_cond = idx[:,-self.max_seq_len:] - logits, loss = self(idx_cond) - logits = logits[:,-1,:] # B, C - probs = F.softmax(logits, dim=-1) - idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) - idx = torch.cat([idx, idx_next], dim=1) - return idx \ No newline at end of file diff --git a/src/optics_char_gpt2_traindiag.py b/src/optics_char_gpt2_traindiag.py deleted file mode 100644 index 26de9d1..0000000 --- a/src/optics_char_gpt2_traindiag.py +++ /dev/null @@ -1,210 +0,0 @@ -import torch -import torch.nn as nn -from torch.nn import functional as F -from einops import rearrange -import optical_matrix_multiplication as omm -from optical_matrix_multiplication import propagator -import numpy as np -from torch.utils.tensorboard import SummaryWriter -from datetime import datetime -from pathlib import Path -import sys -torch.manual_seed(1337) - -#################################### Model ######################################### - -def norm(matrix: torch.Tensor, max_val: float = 1) -> torch.Tensor: - return matrix / (max_val + 1e-10) - -def optics_matmul_shift(sim, tensor_1, tensor_2): - tensor_1 = tensor_1[None,:,:,:] - tensor_2 = tensor_2[None,:,:,:] - - if torch.min(tensor_1) >= 0 and torch.min(tensor_2) >= 0: - max_abs = abs(max(torch.max(tensor_1), torch.max(tensor_2))) - a, b = norm(tensor_1, max_abs), norm(tensor_2, max_abs) - return sim(a, b)[0,:,:,:] * max_abs **2 - - min_abs = abs(min(torch.min(tensor_1), torch.min(tensor_2))) - max_abs = abs(max(torch.max(tensor_1), torch.max(tensor_2))) + min_abs - - shift_a = min_abs * torch.ones(tensor_1.shape).to(tensor_1.device) - shift_b = min_abs * torch.ones(tensor_2.shape).to(tensor_1.device) - a_a_sh = tensor_1 + shift_a - b_b_sh = tensor_2 + shift_b - - a_a_sh_norm, b_b_sh_norm = norm(a_a_sh, max_abs), norm(b_b_sh, max_abs) - shift_a_norm, shift_b_norm = norm(shift_a, max_abs), norm(shift_b, max_abs) - a_a_sh_b_b_sh = sim(a_a_sh_norm, b_b_sh_norm) - a_a_sh_b_sh = sim(a_a_sh_norm, shift_b_norm) - a_sh_b_b_sh = sim(shift_a_norm, b_b_sh_norm) - a_sh_b_sh = sim(shift_a_norm, shift_b_norm) - - return (a_a_sh_b_b_sh - a_a_sh_b_sh - a_sh_b_b_sh + a_sh_b_sh)[0,:,:,:] * max_abs ** 2 - -# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 -class RoPE(nn.Module): - def __init__(self, dim, max_seq_len=512): - super().__init__() - inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) - t = torch.arange(max_seq_len).type_as(inv_freq) - freqs = torch.einsum('i,j->ij', t, inv_freq) - self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) - - def rotate_half(self, x): - x1, x2 = x.chunk(2, dim=-1) - return torch.cat((-x2, x1), dim=-1) - - def forward(self, x, offset=0): - seq_len = x.size(1) - emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) - cos = emb.cos() - sin = emb.sin() - return (x * cos) + (self.rotate_half(x) * sin) - -# Transformers without Normalization https://jiachenzhu.github.io/DyT/ -class DyT(nn.Module): - def __init__(self, num_features, alpha_init_value=0.5): - super().__init__() - self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) - self.weight = nn.Parameter(torch.ones(num_features)) - self.bias = nn.Parameter(torch.zeros(num_features)) - - def forward(self, x): - x = torch.tanh(self.alpha * x) - return x * self.weight + self.bias - -# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 -# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 -class TransformerLayer(nn.Module): - def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128, pixel_size=3.6e-6): - super().__init__() - self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) - self.q_proj = nn.Linear(h_dim, h_dim) - self.k_proj = nn.Linear(h_dim, h_dim) - self.v_proj = nn.Linear(h_dim, h_dim) - self.o_proj = nn.Linear(h_dim, h_dim) - self.ff1 = nn.Linear(h_dim, 4*h_dim) - self.ff2 = nn.Linear(4*h_dim, h_dim) - self.ln1 = DyT(h_dim) - self.ln2 = DyT(h_dim) - self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) - self.k1 = nn.Parameter(torch.randn(1)) - self.k2 = nn.Parameter(torch.randn(1)) - if max_seq_len < 512: - self.sim_scores = omm.OpticalMul( - omm.Config(right_matrix_count_columns = max_seq_len, - right_matrix_count_rows = h_dim // num_heads, - right_matrix_width = pixel_size * max_seq_len, - right_matrix_height = pixel_size * (h_dim // num_heads), - min_height_gap = pixel_size, - right_matrix_split_x = 2, - right_matrix_split_y = 2, - left_matrix_split_x = 2, - left_matrix_split_y = 2, - result_matrix_split = 2, - distance = 0.01, - trainable_cylind_lens=True) - ) - self.sim_output = omm.OpticalMul( - omm.Config(right_matrix_count_columns = h_dim // num_heads, - right_matrix_count_rows = max_seq_len, - right_matrix_width = pixel_size * (h_dim // num_heads), - right_matrix_height = pixel_size * max_seq_len, - min_height_gap = pixel_size, - right_matrix_split_x = 2, - right_matrix_split_y = 2, - left_matrix_split_x = 2, - left_matrix_split_y = 2, - result_matrix_split = 2, - distance = 0.01, - trainable_cylind_lens=True) - ) - if max_seq_len >= 512: - self.sim_scores = omm.OpticalMul( - omm.Config(right_matrix_count_columns = max_seq_len, - right_matrix_count_rows = h_dim // num_heads, - right_matrix_width = pixel_size * max_seq_len, - right_matrix_height = pixel_size * (h_dim // num_heads), - min_height_gap = pixel_size, - right_matrix_split_x = 2, - right_matrix_split_y = 2, - left_matrix_split_x = 2, - left_matrix_split_y = 2, - result_matrix_split = 2, - distance = 0.15, - lens_size = 8192 * 2) - ) - self.sim_output = omm.OpticalMul( - omm.Config(right_matrix_count_columns = h_dim // num_heads, - right_matrix_count_rows = max_seq_len, - right_matrix_width = pixel_size * (h_dim // num_heads), - right_matrix_height = pixel_size * max_seq_len, - min_height_gap = pixel_size, - right_matrix_split_x = 2, - right_matrix_split_y = 2, - left_matrix_split_x = 2, - left_matrix_split_y = 2, - result_matrix_split = 2, - distance = 0.15, - lens_size = 8192 * 2) - ) - self.sim_scores = omm.DataParallel(self.sim_scores) - self.sim_output = omm.DataParallel(self.sim_output) - - def split_to_heads(self, x, B, T, H): - if self.num_heads <= 1: return x - return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) - - def gather_heads(self, x, B, T, H): - if self.num_heads <= 1: return x - return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) - - def attention(self, x): - q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) - k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) - v = self.split_to_heads(self.v_proj(x), *x.shape) - scores = self.k1 * optics_matmul_shift(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5) - tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) - scores = scores.masked_fill(tril == 0, float('-inf')) - attention = nn.functional.softmax(scores, dim=2) - output = self.k2 * optics_matmul_shift(self.sim_output, attention, v) - return self.o_proj(self.gather_heads(output, *x.shape)) - - def forward(self, x): - x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) - x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) - return x - -class OpticGPT2TrainDiag(nn.Module): - def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, - pixel_size = 3.6e-6): - super().__init__() - self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) - self.layers = nn.ModuleList([ - TransformerLayer(h_dim=self.h_dim, num_heads=self.num_heads, - dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len, pixel_size=pixel_size) - for _ in range(layers_num)]) - self.tok_embeds = nn.Embedding(vocab_size, h_dim) - self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) - self.lm_head = nn.Linear(h_dim, vocab_size) - - def forward(self, x, targets=None): - x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] - for l in self.layers: - x = l(x) - logits = self.lm_head(x) # B,T,C - loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None - return logits, loss - - # what is the purpose? autoregressive inference! - def generate(self, start_idx, max_new_tokens): - idx = start_idx - for i in range(max_new_tokens): - idx_cond = idx[:,-self.max_seq_len:] - logits, loss = self(idx_cond) - logits = logits[:,-1,:] # B, C - probs = F.softmax(logits, dim=-1) - idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) - idx = torch.cat([idx, idx_next], dim=1) - return idx \ No newline at end of file diff --git a/src/train_char_gpt2_128.py b/src/train_char_gpt2_128.py new file mode 100644 index 0000000..a407bee --- /dev/null +++ b/src/train_char_gpt2_128.py @@ -0,0 +1,367 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +from einops import rearrange +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime +import sys +from pathlib import Path +import shutil +seed = 1337 +torch.manual_seed(seed) + +batch_size = 50 +gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient +max_iters = int(4e4) #40000 +eval_interval = 300 +learning_rate = 1e-3 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +eval_iters = 200 +layers_num = 2 +h_dim = 64 +max_seq_len = 128 +num_heads = 1 +dropout_rate = 0.1 +pixel_size = 3.6e-6 +assert batch_size % gradient_accumulation_steps == 0 + +############################### MODEL ############################################################# + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = nn.Linear(h_dim, 4*h_dim) + self.ff2 = nn.Linear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = (q @ k.transpose(1, 2)) * (self.h_dim ** -0.5) + tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) + scores = scores.masked_fill(tril == 0, float('-inf')) # encoder does not need this line + attention = nn.functional.softmax(scores, dim=2) + return self.o_proj(self.gather_heads(attention @ v, *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +class GPT2(nn.Module): + def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, pixel_size=None): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.layers = nn.ModuleList([ + TransformerLayer(h_dim=self.h_dim, num_heads=self.num_heads, dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) + for _ in range(layers_num)]) + self.tok_embeds = nn.Embedding(vocab_size, h_dim) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + self.lm_head = nn.Linear(h_dim, vocab_size) + + def forward(self, x, targets=None): + x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + logits = self.lm_head(x) # B,T,C + loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None + return logits, loss + + # what is the purpose? autoregressive inference! + def generate(self, start_idx, max_new_tokens): + idx = start_idx + for i in range(max_new_tokens): + idx_cond = idx[:,-self.max_seq_len:] + logits, loss = self(idx_cond) + logits = logits[:,-1,:] # B, C + probs = F.softmax(logits, dim=-1) + idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) + idx = torch.cat([idx, idx_next], dim=1) + return idx + +################################################################################################### + +MODEL_CLASS = GPT2 +train_data_path = Path("./data/wiki.train.tokens") +val_data_path = Path("./data/wiki.valid.tokens") +test_data_path = Path("./data/wiki.test.tokens") +comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}" + +logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +writer = SummaryWriter(logs_dir) +script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name) +print("Logs dir:", logs_dir) +# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script +shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository +script_snapshot_path.chmod(0o500) # with read-only permission + +# Create standalone checkpoints directory with your specified format +checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +Path(checkpoints_dir).mkdir(parents=True, exist_ok=True) +print("Checkpoints dir:", checkpoints_dir) + +#################################### Dataset ######################################### + +# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt +with open(train_data_path, encoding='utf-8') as f: + train_text = f.read() + +with open(val_data_path, encoding='utf-8') as f: + val_text = f.read() + +with open(test_data_path, encoding='utf-8') as f: + test_text = f.read() + +text = train_text + val_text + test_text + +chars = sorted(set(text)) +vocab_size = len(chars) + +wtoi = {w:i for i,w in enumerate(chars)} +itow = {i:w for i,w in enumerate(chars)} + +encode = lambda s: [wtoi[w] for w in s] +decode = lambda idx: ''.join([itow[i] for i in idx]) + +def get_batch(data, seq_len, batch_size): + ix = torch.randint(len(data)-seq_len, (batch_size, )) + x = torch.stack([data[i:i+seq_len] for i in ix]).to(device) + y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device) + return x, y + +train_data = torch.tensor(encode(train_text), dtype=torch.long) +val_data = torch.tensor(encode(val_text), dtype=torch.long) +test_data = torch.tensor(encode(test_text), dtype=torch.long) + +@torch.no_grad() +def perplexity(model, data, batch_size=32): + model.eval() + stride = max(1, len(data) // 10000) + total_loss_sum = 0.0 + total_tokens_count = 0 + + # Precompute all valid start positions + start_positions = list(range(0, len(data) - max_seq_len - 1, stride)) + total_sequences = len(start_positions) + + # Process sequences in batches + for i in range(0, total_sequences, batch_size): + batch_starts = start_positions[i:min(i + batch_size, total_sequences)] + + # Efficiently stack sequences into batch tensors + x_batch = torch.stack([ + data[start:start + max_seq_len] + for start in batch_starts + ]).to(device) + + y_batch = torch.stack([ + data[start + 1:start + max_seq_len + 1] + for start in batch_starts + ]).to(device) + + # Forward pass (model should return mean loss averaged over all tokens in batch) + _, mean_loss = model(x_batch, y_batch) + + # Accumulate weighted loss (mean_loss is already averaged over tokens) + num_tokens = y_batch.numel() + total_loss_sum += mean_loss.item() * num_tokens + total_tokens_count += num_tokens + + # Progress update + processed = min(i + batch_size, total_sequences) + print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True) + + print() # Final newline + return np.exp(total_loss_sum / total_tokens_count) + +#################################### Model #########################################mo +def complete(m, start_idxs=[0], max_new_tokens=100): + start_idx = torch.tensor([start_idxs]).to(device) + generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens) + return decode(generated_tokens[0].tolist()) + +m = MODEL_CLASS( + vocab_size=vocab_size, + h_dim=h_dim, + max_seq_len=max_seq_len, + num_heads=num_heads, + pixel_size=pixel_size, + layers_num=layers_num +) +m = m.to(device) +model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}' +writer.add_text('model', model_description, 0) + +print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.") +print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.") + +optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) + +#################################### Checkpoint Function ######################################### + +def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir): + """Save model checkpoint with complete training state""" + checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt' + torch.save({ + 'step': step, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss, + 'config': config, + 'wtoi': wtoi, + 'itow': itow, + }, checkpoint_path) + +# Training config for checkpointing +training_config = { + 'vocab_size': vocab_size, + 'layers_num': layers_num, + 'h_dim': h_dim, + 'max_seq_len': max_seq_len, + 'num_heads': num_heads, + 'dropout_rate': dropout_rate, + 'batch_size': batch_size, + 'learning_rate': learning_rate, + 'gradient_accumulation_steps': gradient_accumulation_steps, + 'pixel_size': pixel_size, + 'max_iters': max_iters, +} +#################################### Train ######################################### + +start_time = datetime.now() +print("Started at:", start_time) + +m.eval() +task_prompts = [ + "1 2 3 4 5", + "The capital of France is", + "The chemical symbol of gold is", + "If yesterday was Friday, then tomorrow will be", + "The opposite of hot is", + "The planets of the solar system are:", + "My favorite color is", + "If 5*x + 3 = 13, then x is", +] +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, 0) + +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, 0) + +for i in range(max_iters): + m.train() + optimizer.zero_grad(set_to_none=True) + accumulated_loss = 0.0 + for j in range(gradient_accumulation_steps): + xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps) + logits, loss = m(xb, yb) + loss = loss / gradient_accumulation_steps + loss.backward() + accumulated_loss += loss.item() + if i % 100 == 0: + writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i) + optimizer.step() + writer.add_scalar('loss', accumulated_loss, i) + print(f"\r{i}/{max_iters} {accumulated_loss}", end="") + if i % 5000 == 0: + m.eval() + ppl = perplexity(model=m, data=val_data, batch_size=batch_size) + writer.add_scalar('val_perplexity', ppl.item(), i) + print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}") + writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) + task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) + writer.add_text('completions/task', task_results, i) + save_checkpoint( + model=m, + optimizer=optimizer, + step=i, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir + ) + +m.eval() +ppl = perplexity(model=m, data=val_data, batch_size=batch_size) +print(f"\r{i+1}/{max_iters} {accumulated_loss}") +print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") +writer.add_scalar('val_perplexity', ppl.item(), i+1) +writer.add_scalar('loss', accumulated_loss, i) + +ppl = perplexity(model=m, data=test_data, batch_size=batch_size) +writer.add_scalar('test_perplexity', ppl.item(), i+1) +print(f"\rTest Perplexity at {i}: {ppl}") + +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, i+1) +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, i+1) + +# Save final checkpoint +save_checkpoint( + model=m, + optimizer=optimizer, + step=max_iters, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir +) +print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}") \ No newline at end of file diff --git a/src/train_char_gpt2_256.py b/src/train_char_gpt2_256.py new file mode 100644 index 0000000..b862708 --- /dev/null +++ b/src/train_char_gpt2_256.py @@ -0,0 +1,367 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +from einops import rearrange +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime +import sys +from pathlib import Path +import shutil +seed = 1337 +torch.manual_seed(seed) + +batch_size = 50 +gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient +max_iters = int(4e4) #40000 +eval_interval = 300 +learning_rate = 1e-3 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +eval_iters = 200 +layers_num = 2 +h_dim = 64 +max_seq_len = 256 +num_heads = 1 +dropout_rate = 0.1 +pixel_size = 3.6e-6 +assert batch_size % gradient_accumulation_steps == 0 + +############################### MODEL ############################################################# + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = nn.Linear(h_dim, 4*h_dim) + self.ff2 = nn.Linear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = (q @ k.transpose(1, 2)) * (self.h_dim ** -0.5) + tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) + scores = scores.masked_fill(tril == 0, float('-inf')) # encoder does not need this line + attention = nn.functional.softmax(scores, dim=2) + return self.o_proj(self.gather_heads(attention @ v, *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +class GPT2(nn.Module): + def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, pixel_size=None): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.layers = nn.ModuleList([ + TransformerLayer(h_dim=self.h_dim, num_heads=self.num_heads, dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) + for _ in range(layers_num)]) + self.tok_embeds = nn.Embedding(vocab_size, h_dim) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + self.lm_head = nn.Linear(h_dim, vocab_size) + + def forward(self, x, targets=None): + x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + logits = self.lm_head(x) # B,T,C + loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None + return logits, loss + + # what is the purpose? autoregressive inference! + def generate(self, start_idx, max_new_tokens): + idx = start_idx + for i in range(max_new_tokens): + idx_cond = idx[:,-self.max_seq_len:] + logits, loss = self(idx_cond) + logits = logits[:,-1,:] # B, C + probs = F.softmax(logits, dim=-1) + idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) + idx = torch.cat([idx, idx_next], dim=1) + return idx + +################################################################################################### + +MODEL_CLASS = GPT2 +train_data_path = Path("./data/wiki.train.tokens") +val_data_path = Path("./data/wiki.valid.tokens") +test_data_path = Path("./data/wiki.test.tokens") +comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}" + +logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +writer = SummaryWriter(logs_dir) +script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name) +print("Logs dir:", logs_dir) +# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script +shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository +script_snapshot_path.chmod(0o500) # with read-only permission + +# Create standalone checkpoints directory with your specified format +checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +Path(checkpoints_dir).mkdir(parents=True, exist_ok=True) +print("Checkpoints dir:", checkpoints_dir) + +#################################### Dataset ######################################### + +# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt +with open(train_data_path, encoding='utf-8') as f: + train_text = f.read() + +with open(val_data_path, encoding='utf-8') as f: + val_text = f.read() + +with open(test_data_path, encoding='utf-8') as f: + test_text = f.read() + +text = train_text + val_text + test_text + +chars = sorted(set(text)) +vocab_size = len(chars) + +wtoi = {w:i for i,w in enumerate(chars)} +itow = {i:w for i,w in enumerate(chars)} + +encode = lambda s: [wtoi[w] for w in s] +decode = lambda idx: ''.join([itow[i] for i in idx]) + +def get_batch(data, seq_len, batch_size): + ix = torch.randint(len(data)-seq_len, (batch_size, )) + x = torch.stack([data[i:i+seq_len] for i in ix]).to(device) + y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device) + return x, y + +train_data = torch.tensor(encode(train_text), dtype=torch.long) +val_data = torch.tensor(encode(val_text), dtype=torch.long) +test_data = torch.tensor(encode(test_text), dtype=torch.long) + +@torch.no_grad() +def perplexity(model, data, batch_size=32): + model.eval() + stride = max(1, len(data) // 10000) + total_loss_sum = 0.0 + total_tokens_count = 0 + + # Precompute all valid start positions + start_positions = list(range(0, len(data) - max_seq_len - 1, stride)) + total_sequences = len(start_positions) + + # Process sequences in batches + for i in range(0, total_sequences, batch_size): + batch_starts = start_positions[i:min(i + batch_size, total_sequences)] + + # Efficiently stack sequences into batch tensors + x_batch = torch.stack([ + data[start:start + max_seq_len] + for start in batch_starts + ]).to(device) + + y_batch = torch.stack([ + data[start + 1:start + max_seq_len + 1] + for start in batch_starts + ]).to(device) + + # Forward pass (model should return mean loss averaged over all tokens in batch) + _, mean_loss = model(x_batch, y_batch) + + # Accumulate weighted loss (mean_loss is already averaged over tokens) + num_tokens = y_batch.numel() + total_loss_sum += mean_loss.item() * num_tokens + total_tokens_count += num_tokens + + # Progress update + processed = min(i + batch_size, total_sequences) + print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True) + + print() # Final newline + return np.exp(total_loss_sum / total_tokens_count) + +#################################### Model #########################################mo +def complete(m, start_idxs=[0], max_new_tokens=100): + start_idx = torch.tensor([start_idxs]).to(device) + generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens) + return decode(generated_tokens[0].tolist()) + +m = MODEL_CLASS( + vocab_size=vocab_size, + h_dim=h_dim, + max_seq_len=max_seq_len, + num_heads=num_heads, + pixel_size=pixel_size, + layers_num=layers_num +) +m = m.to(device) +model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}' +writer.add_text('model', model_description, 0) + +print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.") +print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.") + +optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) + +#################################### Checkpoint Function ######################################### + +def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir): + """Save model checkpoint with complete training state""" + checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt' + torch.save({ + 'step': step, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss, + 'config': config, + 'wtoi': wtoi, + 'itow': itow, + }, checkpoint_path) + +# Training config for checkpointing +training_config = { + 'vocab_size': vocab_size, + 'layers_num': layers_num, + 'h_dim': h_dim, + 'max_seq_len': max_seq_len, + 'num_heads': num_heads, + 'dropout_rate': dropout_rate, + 'batch_size': batch_size, + 'learning_rate': learning_rate, + 'gradient_accumulation_steps': gradient_accumulation_steps, + 'pixel_size': pixel_size, + 'max_iters': max_iters, +} +#################################### Train ######################################### + +start_time = datetime.now() +print("Started at:", start_time) + +m.eval() +task_prompts = [ + "1 2 3 4 5", + "The capital of France is", + "The chemical symbol of gold is", + "If yesterday was Friday, then tomorrow will be", + "The opposite of hot is", + "The planets of the solar system are:", + "My favorite color is", + "If 5*x + 3 = 13, then x is", +] +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, 0) + +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, 0) + +for i in range(max_iters): + m.train() + optimizer.zero_grad(set_to_none=True) + accumulated_loss = 0.0 + for j in range(gradient_accumulation_steps): + xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps) + logits, loss = m(xb, yb) + loss = loss / gradient_accumulation_steps + loss.backward() + accumulated_loss += loss.item() + if i % 100 == 0: + writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i) + optimizer.step() + writer.add_scalar('loss', accumulated_loss, i) + print(f"\r{i}/{max_iters} {accumulated_loss}", end="") + if i % 5000 == 0: + m.eval() + ppl = perplexity(model=m, data=val_data, batch_size=batch_size) + writer.add_scalar('val_perplexity', ppl.item(), i) + print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}") + writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) + task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) + writer.add_text('completions/task', task_results, i) + save_checkpoint( + model=m, + optimizer=optimizer, + step=i, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir + ) + +m.eval() +ppl = perplexity(model=m, data=val_data, batch_size=batch_size) +print(f"\r{i+1}/{max_iters} {accumulated_loss}") +print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") +writer.add_scalar('val_perplexity', ppl.item(), i+1) +writer.add_scalar('loss', accumulated_loss, i) + +ppl = perplexity(model=m, data=test_data, batch_size=batch_size) +writer.add_scalar('test_perplexity', ppl.item(), i+1) +print(f"\rTest Perplexity at {i}: {ppl}") + +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, i+1) +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, i+1) + +# Save final checkpoint +save_checkpoint( + model=m, + optimizer=optimizer, + step=max_iters, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir +) +print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}") \ No newline at end of file diff --git a/src/train_gpt2.py b/src/train_char_gpt2_512.py similarity index 99% rename from src/train_gpt2.py rename to src/train_char_gpt2_512.py index 75de26f..b6d7137 100644 --- a/src/train_gpt2.py +++ b/src/train_char_gpt2_512.py @@ -11,6 +11,21 @@ import shutil seed = 1337 torch.manual_seed(seed) +batch_size = 50 +gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient +max_iters = int(4e4) #40000 +eval_interval = 300 +learning_rate = 1e-3 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +eval_iters = 200 +layers_num = 2 +h_dim = 64 +max_seq_len = 512 +num_heads = 1 +dropout_rate = 0.1 +pixel_size = 3.6e-6 +assert batch_size % gradient_accumulation_steps == 0 + ############################### MODEL ############################################################# # RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 @@ -117,23 +132,6 @@ class GPT2(nn.Module): ################################################################################################### - -batch_size = 50 -gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient -max_iters = int(4e4) #40000 -eval_interval = 300 -learning_rate = 1e-3 -device = 'cuda' if torch.cuda.is_available() else 'cpu' -eval_iters = 200 -layers_num = 2 -h_dim = 64 -max_seq_len = 512 -num_heads = 1 -dropout_rate = 0.1 -pixel_size = 3.6e-6 -assert batch_size % gradient_accumulation_steps == 0 -# CUDA_VISIBLE_DEVICES=1 python src/main.py - MODEL_CLASS = GPT2 train_data_path = Path("./data/wiki.train.tokens") val_data_path = Path("./data/wiki.valid.tokens") diff --git a/src/train_char_gpt2_64.py b/src/train_char_gpt2_64.py new file mode 100644 index 0000000..6a12c5a --- /dev/null +++ b/src/train_char_gpt2_64.py @@ -0,0 +1,367 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +from einops import rearrange +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime +import sys +from pathlib import Path +import shutil +seed = 1337 +torch.manual_seed(seed) + +batch_size = 50 +gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient +max_iters = int(4e4) #40000 +eval_interval = 300 +learning_rate = 1e-3 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +eval_iters = 200 +layers_num = 2 +h_dim = 64 +max_seq_len = 64 +num_heads = 1 +dropout_rate = 0.1 +pixel_size = 3.6e-6 +assert batch_size % gradient_accumulation_steps == 0 + +############################### MODEL ############################################################# + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = nn.Linear(h_dim, 4*h_dim) + self.ff2 = nn.Linear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = (q @ k.transpose(1, 2)) * (self.h_dim ** -0.5) + tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) + scores = scores.masked_fill(tril == 0, float('-inf')) # encoder does not need this line + attention = nn.functional.softmax(scores, dim=2) + return self.o_proj(self.gather_heads(attention @ v, *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +class GPT2(nn.Module): + def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, pixel_size=None): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.layers = nn.ModuleList([ + TransformerLayer(h_dim=self.h_dim, num_heads=self.num_heads, dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) + for _ in range(layers_num)]) + self.tok_embeds = nn.Embedding(vocab_size, h_dim) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + self.lm_head = nn.Linear(h_dim, vocab_size) + + def forward(self, x, targets=None): + x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + logits = self.lm_head(x) # B,T,C + loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None + return logits, loss + + # what is the purpose? autoregressive inference! + def generate(self, start_idx, max_new_tokens): + idx = start_idx + for i in range(max_new_tokens): + idx_cond = idx[:,-self.max_seq_len:] + logits, loss = self(idx_cond) + logits = logits[:,-1,:] # B, C + probs = F.softmax(logits, dim=-1) + idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) + idx = torch.cat([idx, idx_next], dim=1) + return idx + +################################################################################################### + +MODEL_CLASS = GPT2 +train_data_path = Path("./data/wiki.train.tokens") +val_data_path = Path("./data/wiki.valid.tokens") +test_data_path = Path("./data/wiki.test.tokens") +comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}" + +logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +writer = SummaryWriter(logs_dir) +script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name) +print("Logs dir:", logs_dir) +# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script +shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository +script_snapshot_path.chmod(0o500) # with read-only permission + +# Create standalone checkpoints directory with your specified format +checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +Path(checkpoints_dir).mkdir(parents=True, exist_ok=True) +print("Checkpoints dir:", checkpoints_dir) + +#################################### Dataset ######################################### + +# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt +with open(train_data_path, encoding='utf-8') as f: + train_text = f.read() + +with open(val_data_path, encoding='utf-8') as f: + val_text = f.read() + +with open(test_data_path, encoding='utf-8') as f: + test_text = f.read() + +text = train_text + val_text + test_text + +chars = sorted(set(text)) +vocab_size = len(chars) + +wtoi = {w:i for i,w in enumerate(chars)} +itow = {i:w for i,w in enumerate(chars)} + +encode = lambda s: [wtoi[w] for w in s] +decode = lambda idx: ''.join([itow[i] for i in idx]) + +def get_batch(data, seq_len, batch_size): + ix = torch.randint(len(data)-seq_len, (batch_size, )) + x = torch.stack([data[i:i+seq_len] for i in ix]).to(device) + y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device) + return x, y + +train_data = torch.tensor(encode(train_text), dtype=torch.long) +val_data = torch.tensor(encode(val_text), dtype=torch.long) +test_data = torch.tensor(encode(test_text), dtype=torch.long) + +@torch.no_grad() +def perplexity(model, data, batch_size=32): + model.eval() + stride = max(1, len(data) // 10000) + total_loss_sum = 0.0 + total_tokens_count = 0 + + # Precompute all valid start positions + start_positions = list(range(0, len(data) - max_seq_len - 1, stride)) + total_sequences = len(start_positions) + + # Process sequences in batches + for i in range(0, total_sequences, batch_size): + batch_starts = start_positions[i:min(i + batch_size, total_sequences)] + + # Efficiently stack sequences into batch tensors + x_batch = torch.stack([ + data[start:start + max_seq_len] + for start in batch_starts + ]).to(device) + + y_batch = torch.stack([ + data[start + 1:start + max_seq_len + 1] + for start in batch_starts + ]).to(device) + + # Forward pass (model should return mean loss averaged over all tokens in batch) + _, mean_loss = model(x_batch, y_batch) + + # Accumulate weighted loss (mean_loss is already averaged over tokens) + num_tokens = y_batch.numel() + total_loss_sum += mean_loss.item() * num_tokens + total_tokens_count += num_tokens + + # Progress update + processed = min(i + batch_size, total_sequences) + print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True) + + print() # Final newline + return np.exp(total_loss_sum / total_tokens_count) + +#################################### Model #########################################mo +def complete(m, start_idxs=[0], max_new_tokens=100): + start_idx = torch.tensor([start_idxs]).to(device) + generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens) + return decode(generated_tokens[0].tolist()) + +m = MODEL_CLASS( + vocab_size=vocab_size, + h_dim=h_dim, + max_seq_len=max_seq_len, + num_heads=num_heads, + pixel_size=pixel_size, + layers_num=layers_num +) +m = m.to(device) +model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}' +writer.add_text('model', model_description, 0) + +print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.") +print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.") + +optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) + +#################################### Checkpoint Function ######################################### + +def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir): + """Save model checkpoint with complete training state""" + checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt' + torch.save({ + 'step': step, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss, + 'config': config, + 'wtoi': wtoi, + 'itow': itow, + }, checkpoint_path) + +# Training config for checkpointing +training_config = { + 'vocab_size': vocab_size, + 'layers_num': layers_num, + 'h_dim': h_dim, + 'max_seq_len': max_seq_len, + 'num_heads': num_heads, + 'dropout_rate': dropout_rate, + 'batch_size': batch_size, + 'learning_rate': learning_rate, + 'gradient_accumulation_steps': gradient_accumulation_steps, + 'pixel_size': pixel_size, + 'max_iters': max_iters, +} +#################################### Train ######################################### + +start_time = datetime.now() +print("Started at:", start_time) + +m.eval() +task_prompts = [ + "1 2 3 4 5", + "The capital of France is", + "The chemical symbol of gold is", + "If yesterday was Friday, then tomorrow will be", + "The opposite of hot is", + "The planets of the solar system are:", + "My favorite color is", + "If 5*x + 3 = 13, then x is", +] +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, 0) + +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, 0) + +for i in range(max_iters): + m.train() + optimizer.zero_grad(set_to_none=True) + accumulated_loss = 0.0 + for j in range(gradient_accumulation_steps): + xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps) + logits, loss = m(xb, yb) + loss = loss / gradient_accumulation_steps + loss.backward() + accumulated_loss += loss.item() + if i % 100 == 0: + writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i) + optimizer.step() + writer.add_scalar('loss', accumulated_loss, i) + print(f"\r{i}/{max_iters} {accumulated_loss}", end="") + if i % 5000 == 0: + m.eval() + ppl = perplexity(model=m, data=val_data, batch_size=batch_size) + writer.add_scalar('val_perplexity', ppl.item(), i) + print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}") + writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) + task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) + writer.add_text('completions/task', task_results, i) + save_checkpoint( + model=m, + optimizer=optimizer, + step=i, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir + ) + +m.eval() +ppl = perplexity(model=m, data=val_data, batch_size=batch_size) +print(f"\r{i+1}/{max_iters} {accumulated_loss}") +print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") +writer.add_scalar('val_perplexity', ppl.item(), i+1) +writer.add_scalar('loss', accumulated_loss, i) + +ppl = perplexity(model=m, data=test_data, batch_size=batch_size) +writer.add_scalar('test_perplexity', ppl.item(), i+1) +print(f"\rTest Perplexity at {i}: {ppl}") + +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, i+1) +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, i+1) + +# Save final checkpoint +save_checkpoint( + model=m, + optimizer=optimizer, + step=max_iters, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir +) +print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}") \ No newline at end of file diff --git a/src/train_char_gpt2_koef_128.py b/src/train_char_gpt2_koef_128.py new file mode 100644 index 0000000..ee5814c --- /dev/null +++ b/src/train_char_gpt2_koef_128.py @@ -0,0 +1,369 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +from einops import rearrange +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime +import sys +from pathlib import Path +import shutil +seed = 1337 +torch.manual_seed(seed) + +batch_size = 50 +gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient +max_iters = int(4e4) #40000 +eval_interval = 300 +learning_rate = 1e-3 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +eval_iters = 200 +layers_num = 2 +h_dim = 64 +max_seq_len = 128 +num_heads = 1 +dropout_rate = 0.1 +pixel_size = 3.6e-6 +assert batch_size % gradient_accumulation_steps == 0 + +############################### MODEL ############################################################# + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = nn.Linear(h_dim, 4*h_dim) + self.ff2 = nn.Linear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + self.k1 = nn.Parameter(torch.randn(1)) + self.k2 = nn.Parameter(torch.randn(1)) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = (self.k1 * (q @ k.transpose(1, 2))) * (self.h_dim ** -0.5) + tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) + scores = scores.masked_fill(tril == 0, float('-inf')) # encoder does not need this line + attention = nn.functional.softmax(scores, dim=2) + return self.o_proj(self.gather_heads(self.k2 * (attention @ v), *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +class GPT2(nn.Module): + def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, pixel_size=None): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.layers = nn.ModuleList([ + TransformerLayer(h_dim=self.h_dim, num_heads=self.num_heads, dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) + for _ in range(layers_num)]) + self.tok_embeds = nn.Embedding(vocab_size, h_dim) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + self.lm_head = nn.Linear(h_dim, vocab_size) + + def forward(self, x, targets=None): + x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + logits = self.lm_head(x) # B,T,C + loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None + return logits, loss + + # what is the purpose? autoregressive inference! + def generate(self, start_idx, max_new_tokens): + idx = start_idx + for i in range(max_new_tokens): + idx_cond = idx[:,-self.max_seq_len:] + logits, loss = self(idx_cond) + logits = logits[:,-1,:] # B, C + probs = F.softmax(logits, dim=-1) + idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) + idx = torch.cat([idx, idx_next], dim=1) + return idx + +################################################################################################### + +MODEL_CLASS = GPT2 +train_data_path = Path("./data/wiki.train.tokens") +val_data_path = Path("./data/wiki.valid.tokens") +test_data_path = Path("./data/wiki.test.tokens") +comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}" + +logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +writer = SummaryWriter(logs_dir) +script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name) +print("Logs dir:", logs_dir) +# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script +shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository +script_snapshot_path.chmod(0o500) # with read-only permission + +# Create standalone checkpoints directory with your specified format +checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +Path(checkpoints_dir).mkdir(parents=True, exist_ok=True) +print("Checkpoints dir:", checkpoints_dir) + +#################################### Dataset ######################################### + +# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt +with open(train_data_path, encoding='utf-8') as f: + train_text = f.read() + +with open(val_data_path, encoding='utf-8') as f: + val_text = f.read() + +with open(test_data_path, encoding='utf-8') as f: + test_text = f.read() + +text = train_text + val_text + test_text + +chars = sorted(set(text)) +vocab_size = len(chars) + +wtoi = {w:i for i,w in enumerate(chars)} +itow = {i:w for i,w in enumerate(chars)} + +encode = lambda s: [wtoi[w] for w in s] +decode = lambda idx: ''.join([itow[i] for i in idx]) + +def get_batch(data, seq_len, batch_size): + ix = torch.randint(len(data)-seq_len, (batch_size, )) + x = torch.stack([data[i:i+seq_len] for i in ix]).to(device) + y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device) + return x, y + +train_data = torch.tensor(encode(train_text), dtype=torch.long) +val_data = torch.tensor(encode(val_text), dtype=torch.long) +test_data = torch.tensor(encode(test_text), dtype=torch.long) + +@torch.no_grad() +def perplexity(model, data, batch_size=32): + model.eval() + stride = max(1, len(data) // 10000) + total_loss_sum = 0.0 + total_tokens_count = 0 + + # Precompute all valid start positions + start_positions = list(range(0, len(data) - max_seq_len - 1, stride)) + total_sequences = len(start_positions) + + # Process sequences in batches + for i in range(0, total_sequences, batch_size): + batch_starts = start_positions[i:min(i + batch_size, total_sequences)] + + # Efficiently stack sequences into batch tensors + x_batch = torch.stack([ + data[start:start + max_seq_len] + for start in batch_starts + ]).to(device) + + y_batch = torch.stack([ + data[start + 1:start + max_seq_len + 1] + for start in batch_starts + ]).to(device) + + # Forward pass (model should return mean loss averaged over all tokens in batch) + _, mean_loss = model(x_batch, y_batch) + + # Accumulate weighted loss (mean_loss is already averaged over tokens) + num_tokens = y_batch.numel() + total_loss_sum += mean_loss.item() * num_tokens + total_tokens_count += num_tokens + + # Progress update + processed = min(i + batch_size, total_sequences) + print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True) + + print() # Final newline + return np.exp(total_loss_sum / total_tokens_count) + +#################################### Model #########################################mo +def complete(m, start_idxs=[0], max_new_tokens=100): + start_idx = torch.tensor([start_idxs]).to(device) + generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens) + return decode(generated_tokens[0].tolist()) + +m = MODEL_CLASS( + vocab_size=vocab_size, + h_dim=h_dim, + max_seq_len=max_seq_len, + num_heads=num_heads, + pixel_size=pixel_size, + layers_num=layers_num +) +m = m.to(device) +model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}' +writer.add_text('model', model_description, 0) + +print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.") +print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.") + +optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) + +#################################### Checkpoint Function ######################################### + +def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir): + """Save model checkpoint with complete training state""" + checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt' + torch.save({ + 'step': step, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss, + 'config': config, + 'wtoi': wtoi, + 'itow': itow, + }, checkpoint_path) + +# Training config for checkpointing +training_config = { + 'vocab_size': vocab_size, + 'layers_num': layers_num, + 'h_dim': h_dim, + 'max_seq_len': max_seq_len, + 'num_heads': num_heads, + 'dropout_rate': dropout_rate, + 'batch_size': batch_size, + 'learning_rate': learning_rate, + 'gradient_accumulation_steps': gradient_accumulation_steps, + 'pixel_size': pixel_size, + 'max_iters': max_iters, +} +#################################### Train ######################################### + +start_time = datetime.now() +print("Started at:", start_time) + +m.eval() +task_prompts = [ + "1 2 3 4 5", + "The capital of France is", + "The chemical symbol of gold is", + "If yesterday was Friday, then tomorrow will be", + "The opposite of hot is", + "The planets of the solar system are:", + "My favorite color is", + "If 5*x + 3 = 13, then x is", +] +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, 0) + +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, 0) + +for i in range(max_iters): + m.train() + optimizer.zero_grad(set_to_none=True) + accumulated_loss = 0.0 + for j in range(gradient_accumulation_steps): + xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps) + logits, loss = m(xb, yb) + loss = loss / gradient_accumulation_steps + loss.backward() + accumulated_loss += loss.item() + if i % 100 == 0: + writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i) + optimizer.step() + writer.add_scalar('loss', accumulated_loss, i) + print(f"\r{i}/{max_iters} {accumulated_loss}", end="") + if i % 5000 == 0: + m.eval() + ppl = perplexity(model=m, data=val_data, batch_size=batch_size) + writer.add_scalar('val_perplexity', ppl.item(), i) + print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}") + writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) + task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) + writer.add_text('completions/task', task_results, i) + save_checkpoint( + model=m, + optimizer=optimizer, + step=i, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir + ) + +m.eval() +ppl = perplexity(model=m, data=val_data, batch_size=batch_size) +print(f"\r{i+1}/{max_iters} {accumulated_loss}") +print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") +writer.add_scalar('val_perplexity', ppl.item(), i+1) +writer.add_scalar('loss', accumulated_loss, i) + +ppl = perplexity(model=m, data=test_data, batch_size=batch_size) +writer.add_scalar('test_perplexity', ppl.item(), i+1) +print(f"\rTest Perplexity at {i}: {ppl}") + +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, i+1) +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, i+1) + +# Save final checkpoint +save_checkpoint( + model=m, + optimizer=optimizer, + step=max_iters, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir +) +print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}") \ No newline at end of file diff --git a/src/train_char_gpt2_koef_256.py b/src/train_char_gpt2_koef_256.py new file mode 100644 index 0000000..d4aa19e --- /dev/null +++ b/src/train_char_gpt2_koef_256.py @@ -0,0 +1,369 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +from einops import rearrange +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime +import sys +from pathlib import Path +import shutil +seed = 1337 +torch.manual_seed(seed) + +batch_size = 50 +gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient +max_iters = int(4e4) #40000 +eval_interval = 300 +learning_rate = 1e-3 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +eval_iters = 200 +layers_num = 2 +h_dim = 64 +max_seq_len = 256 +num_heads = 1 +dropout_rate = 0.1 +pixel_size = 3.6e-6 +assert batch_size % gradient_accumulation_steps == 0 + +############################### MODEL ############################################################# + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = nn.Linear(h_dim, 4*h_dim) + self.ff2 = nn.Linear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + self.k1 = nn.Parameter(torch.randn(1)) + self.k2 = nn.Parameter(torch.randn(1)) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = (self.k1 * (q @ k.transpose(1, 2))) * (self.h_dim ** -0.5) + tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) + scores = scores.masked_fill(tril == 0, float('-inf')) # encoder does not need this line + attention = nn.functional.softmax(scores, dim=2) + return self.o_proj(self.gather_heads(self.k2 * (attention @ v), *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +class GPT2(nn.Module): + def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, pixel_size=None): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.layers = nn.ModuleList([ + TransformerLayer(h_dim=self.h_dim, num_heads=self.num_heads, dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) + for _ in range(layers_num)]) + self.tok_embeds = nn.Embedding(vocab_size, h_dim) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + self.lm_head = nn.Linear(h_dim, vocab_size) + + def forward(self, x, targets=None): + x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + logits = self.lm_head(x) # B,T,C + loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None + return logits, loss + + # what is the purpose? autoregressive inference! + def generate(self, start_idx, max_new_tokens): + idx = start_idx + for i in range(max_new_tokens): + idx_cond = idx[:,-self.max_seq_len:] + logits, loss = self(idx_cond) + logits = logits[:,-1,:] # B, C + probs = F.softmax(logits, dim=-1) + idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) + idx = torch.cat([idx, idx_next], dim=1) + return idx + +################################################################################################### + +MODEL_CLASS = GPT2 +train_data_path = Path("./data/wiki.train.tokens") +val_data_path = Path("./data/wiki.valid.tokens") +test_data_path = Path("./data/wiki.test.tokens") +comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}" + +logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +writer = SummaryWriter(logs_dir) +script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name) +print("Logs dir:", logs_dir) +# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script +shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository +script_snapshot_path.chmod(0o500) # with read-only permission + +# Create standalone checkpoints directory with your specified format +checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +Path(checkpoints_dir).mkdir(parents=True, exist_ok=True) +print("Checkpoints dir:", checkpoints_dir) + +#################################### Dataset ######################################### + +# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt +with open(train_data_path, encoding='utf-8') as f: + train_text = f.read() + +with open(val_data_path, encoding='utf-8') as f: + val_text = f.read() + +with open(test_data_path, encoding='utf-8') as f: + test_text = f.read() + +text = train_text + val_text + test_text + +chars = sorted(set(text)) +vocab_size = len(chars) + +wtoi = {w:i for i,w in enumerate(chars)} +itow = {i:w for i,w in enumerate(chars)} + +encode = lambda s: [wtoi[w] for w in s] +decode = lambda idx: ''.join([itow[i] for i in idx]) + +def get_batch(data, seq_len, batch_size): + ix = torch.randint(len(data)-seq_len, (batch_size, )) + x = torch.stack([data[i:i+seq_len] for i in ix]).to(device) + y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device) + return x, y + +train_data = torch.tensor(encode(train_text), dtype=torch.long) +val_data = torch.tensor(encode(val_text), dtype=torch.long) +test_data = torch.tensor(encode(test_text), dtype=torch.long) + +@torch.no_grad() +def perplexity(model, data, batch_size=32): + model.eval() + stride = max(1, len(data) // 10000) + total_loss_sum = 0.0 + total_tokens_count = 0 + + # Precompute all valid start positions + start_positions = list(range(0, len(data) - max_seq_len - 1, stride)) + total_sequences = len(start_positions) + + # Process sequences in batches + for i in range(0, total_sequences, batch_size): + batch_starts = start_positions[i:min(i + batch_size, total_sequences)] + + # Efficiently stack sequences into batch tensors + x_batch = torch.stack([ + data[start:start + max_seq_len] + for start in batch_starts + ]).to(device) + + y_batch = torch.stack([ + data[start + 1:start + max_seq_len + 1] + for start in batch_starts + ]).to(device) + + # Forward pass (model should return mean loss averaged over all tokens in batch) + _, mean_loss = model(x_batch, y_batch) + + # Accumulate weighted loss (mean_loss is already averaged over tokens) + num_tokens = y_batch.numel() + total_loss_sum += mean_loss.item() * num_tokens + total_tokens_count += num_tokens + + # Progress update + processed = min(i + batch_size, total_sequences) + print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True) + + print() # Final newline + return np.exp(total_loss_sum / total_tokens_count) + +#################################### Model #########################################mo +def complete(m, start_idxs=[0], max_new_tokens=100): + start_idx = torch.tensor([start_idxs]).to(device) + generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens) + return decode(generated_tokens[0].tolist()) + +m = MODEL_CLASS( + vocab_size=vocab_size, + h_dim=h_dim, + max_seq_len=max_seq_len, + num_heads=num_heads, + pixel_size=pixel_size, + layers_num=layers_num +) +m = m.to(device) +model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}' +writer.add_text('model', model_description, 0) + +print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.") +print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.") + +optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) + +#################################### Checkpoint Function ######################################### + +def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir): + """Save model checkpoint with complete training state""" + checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt' + torch.save({ + 'step': step, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss, + 'config': config, + 'wtoi': wtoi, + 'itow': itow, + }, checkpoint_path) + +# Training config for checkpointing +training_config = { + 'vocab_size': vocab_size, + 'layers_num': layers_num, + 'h_dim': h_dim, + 'max_seq_len': max_seq_len, + 'num_heads': num_heads, + 'dropout_rate': dropout_rate, + 'batch_size': batch_size, + 'learning_rate': learning_rate, + 'gradient_accumulation_steps': gradient_accumulation_steps, + 'pixel_size': pixel_size, + 'max_iters': max_iters, +} +#################################### Train ######################################### + +start_time = datetime.now() +print("Started at:", start_time) + +m.eval() +task_prompts = [ + "1 2 3 4 5", + "The capital of France is", + "The chemical symbol of gold is", + "If yesterday was Friday, then tomorrow will be", + "The opposite of hot is", + "The planets of the solar system are:", + "My favorite color is", + "If 5*x + 3 = 13, then x is", +] +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, 0) + +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, 0) + +for i in range(max_iters): + m.train() + optimizer.zero_grad(set_to_none=True) + accumulated_loss = 0.0 + for j in range(gradient_accumulation_steps): + xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps) + logits, loss = m(xb, yb) + loss = loss / gradient_accumulation_steps + loss.backward() + accumulated_loss += loss.item() + if i % 100 == 0: + writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i) + optimizer.step() + writer.add_scalar('loss', accumulated_loss, i) + print(f"\r{i}/{max_iters} {accumulated_loss}", end="") + if i % 5000 == 0: + m.eval() + ppl = perplexity(model=m, data=val_data, batch_size=batch_size) + writer.add_scalar('val_perplexity', ppl.item(), i) + print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}") + writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) + task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) + writer.add_text('completions/task', task_results, i) + save_checkpoint( + model=m, + optimizer=optimizer, + step=i, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir + ) + +m.eval() +ppl = perplexity(model=m, data=val_data, batch_size=batch_size) +print(f"\r{i+1}/{max_iters} {accumulated_loss}") +print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") +writer.add_scalar('val_perplexity', ppl.item(), i+1) +writer.add_scalar('loss', accumulated_loss, i) + +ppl = perplexity(model=m, data=test_data, batch_size=batch_size) +writer.add_scalar('test_perplexity', ppl.item(), i+1) +print(f"\rTest Perplexity at {i}: {ppl}") + +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, i+1) +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, i+1) + +# Save final checkpoint +save_checkpoint( + model=m, + optimizer=optimizer, + step=max_iters, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir +) +print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}") \ No newline at end of file diff --git a/src/train_char_gpt2_koef_512.py b/src/train_char_gpt2_koef_512.py new file mode 100644 index 0000000..093084a --- /dev/null +++ b/src/train_char_gpt2_koef_512.py @@ -0,0 +1,369 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +from einops import rearrange +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime +import sys +from pathlib import Path +import shutil +seed = 1337 +torch.manual_seed(seed) + +batch_size = 50 +gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient +max_iters = int(4e4) #40000 +eval_interval = 300 +learning_rate = 1e-3 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +eval_iters = 200 +layers_num = 2 +h_dim = 64 +max_seq_len = 512 +num_heads = 1 +dropout_rate = 0.1 +pixel_size = 3.6e-6 +assert batch_size % gradient_accumulation_steps == 0 + +############################### MODEL ############################################################# + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = nn.Linear(h_dim, 4*h_dim) + self.ff2 = nn.Linear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + self.k1 = nn.Parameter(torch.randn(1)) + self.k2 = nn.Parameter(torch.randn(1)) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = (self.k1 * (q @ k.transpose(1, 2))) * (self.h_dim ** -0.5) + tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) + scores = scores.masked_fill(tril == 0, float('-inf')) # encoder does not need this line + attention = nn.functional.softmax(scores, dim=2) + return self.o_proj(self.gather_heads(self.k2 * (attention @ v), *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +class GPT2(nn.Module): + def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, pixel_size=None): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.layers = nn.ModuleList([ + TransformerLayer(h_dim=self.h_dim, num_heads=self.num_heads, dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) + for _ in range(layers_num)]) + self.tok_embeds = nn.Embedding(vocab_size, h_dim) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + self.lm_head = nn.Linear(h_dim, vocab_size) + + def forward(self, x, targets=None): + x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + logits = self.lm_head(x) # B,T,C + loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None + return logits, loss + + # what is the purpose? autoregressive inference! + def generate(self, start_idx, max_new_tokens): + idx = start_idx + for i in range(max_new_tokens): + idx_cond = idx[:,-self.max_seq_len:] + logits, loss = self(idx_cond) + logits = logits[:,-1,:] # B, C + probs = F.softmax(logits, dim=-1) + idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) + idx = torch.cat([idx, idx_next], dim=1) + return idx + +################################################################################################### + +MODEL_CLASS = GPT2 +train_data_path = Path("./data/wiki.train.tokens") +val_data_path = Path("./data/wiki.valid.tokens") +test_data_path = Path("./data/wiki.test.tokens") +comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}" + +logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +writer = SummaryWriter(logs_dir) +script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name) +print("Logs dir:", logs_dir) +# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script +shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository +script_snapshot_path.chmod(0o500) # with read-only permission + +# Create standalone checkpoints directory with your specified format +checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +Path(checkpoints_dir).mkdir(parents=True, exist_ok=True) +print("Checkpoints dir:", checkpoints_dir) + +#################################### Dataset ######################################### + +# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt +with open(train_data_path, encoding='utf-8') as f: + train_text = f.read() + +with open(val_data_path, encoding='utf-8') as f: + val_text = f.read() + +with open(test_data_path, encoding='utf-8') as f: + test_text = f.read() + +text = train_text + val_text + test_text + +chars = sorted(set(text)) +vocab_size = len(chars) + +wtoi = {w:i for i,w in enumerate(chars)} +itow = {i:w for i,w in enumerate(chars)} + +encode = lambda s: [wtoi[w] for w in s] +decode = lambda idx: ''.join([itow[i] for i in idx]) + +def get_batch(data, seq_len, batch_size): + ix = torch.randint(len(data)-seq_len, (batch_size, )) + x = torch.stack([data[i:i+seq_len] for i in ix]).to(device) + y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device) + return x, y + +train_data = torch.tensor(encode(train_text), dtype=torch.long) +val_data = torch.tensor(encode(val_text), dtype=torch.long) +test_data = torch.tensor(encode(test_text), dtype=torch.long) + +@torch.no_grad() +def perplexity(model, data, batch_size=32): + model.eval() + stride = max(1, len(data) // 10000) + total_loss_sum = 0.0 + total_tokens_count = 0 + + # Precompute all valid start positions + start_positions = list(range(0, len(data) - max_seq_len - 1, stride)) + total_sequences = len(start_positions) + + # Process sequences in batches + for i in range(0, total_sequences, batch_size): + batch_starts = start_positions[i:min(i + batch_size, total_sequences)] + + # Efficiently stack sequences into batch tensors + x_batch = torch.stack([ + data[start:start + max_seq_len] + for start in batch_starts + ]).to(device) + + y_batch = torch.stack([ + data[start + 1:start + max_seq_len + 1] + for start in batch_starts + ]).to(device) + + # Forward pass (model should return mean loss averaged over all tokens in batch) + _, mean_loss = model(x_batch, y_batch) + + # Accumulate weighted loss (mean_loss is already averaged over tokens) + num_tokens = y_batch.numel() + total_loss_sum += mean_loss.item() * num_tokens + total_tokens_count += num_tokens + + # Progress update + processed = min(i + batch_size, total_sequences) + print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True) + + print() # Final newline + return np.exp(total_loss_sum / total_tokens_count) + +#################################### Model #########################################mo +def complete(m, start_idxs=[0], max_new_tokens=100): + start_idx = torch.tensor([start_idxs]).to(device) + generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens) + return decode(generated_tokens[0].tolist()) + +m = MODEL_CLASS( + vocab_size=vocab_size, + h_dim=h_dim, + max_seq_len=max_seq_len, + num_heads=num_heads, + pixel_size=pixel_size, + layers_num=layers_num +) +m = m.to(device) +model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}' +writer.add_text('model', model_description, 0) + +print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.") +print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.") + +optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) + +#################################### Checkpoint Function ######################################### + +def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir): + """Save model checkpoint with complete training state""" + checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt' + torch.save({ + 'step': step, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss, + 'config': config, + 'wtoi': wtoi, + 'itow': itow, + }, checkpoint_path) + +# Training config for checkpointing +training_config = { + 'vocab_size': vocab_size, + 'layers_num': layers_num, + 'h_dim': h_dim, + 'max_seq_len': max_seq_len, + 'num_heads': num_heads, + 'dropout_rate': dropout_rate, + 'batch_size': batch_size, + 'learning_rate': learning_rate, + 'gradient_accumulation_steps': gradient_accumulation_steps, + 'pixel_size': pixel_size, + 'max_iters': max_iters, +} +#################################### Train ######################################### + +start_time = datetime.now() +print("Started at:", start_time) + +m.eval() +task_prompts = [ + "1 2 3 4 5", + "The capital of France is", + "The chemical symbol of gold is", + "If yesterday was Friday, then tomorrow will be", + "The opposite of hot is", + "The planets of the solar system are:", + "My favorite color is", + "If 5*x + 3 = 13, then x is", +] +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, 0) + +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, 0) + +for i in range(max_iters): + m.train() + optimizer.zero_grad(set_to_none=True) + accumulated_loss = 0.0 + for j in range(gradient_accumulation_steps): + xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps) + logits, loss = m(xb, yb) + loss = loss / gradient_accumulation_steps + loss.backward() + accumulated_loss += loss.item() + if i % 100 == 0: + writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i) + optimizer.step() + writer.add_scalar('loss', accumulated_loss, i) + print(f"\r{i}/{max_iters} {accumulated_loss}", end="") + if i % 5000 == 0: + m.eval() + ppl = perplexity(model=m, data=val_data, batch_size=batch_size) + writer.add_scalar('val_perplexity', ppl.item(), i) + print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}") + writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) + task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) + writer.add_text('completions/task', task_results, i) + save_checkpoint( + model=m, + optimizer=optimizer, + step=i, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir + ) + +m.eval() +ppl = perplexity(model=m, data=val_data, batch_size=batch_size) +print(f"\r{i+1}/{max_iters} {accumulated_loss}") +print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") +writer.add_scalar('val_perplexity', ppl.item(), i+1) +writer.add_scalar('loss', accumulated_loss, i) + +ppl = perplexity(model=m, data=test_data, batch_size=batch_size) +writer.add_scalar('test_perplexity', ppl.item(), i+1) +print(f"\rTest Perplexity at {i}: {ppl}") + +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, i+1) +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, i+1) + +# Save final checkpoint +save_checkpoint( + model=m, + optimizer=optimizer, + step=max_iters, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir +) +print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}") \ No newline at end of file diff --git a/src/train_char_gpt2_koef_64.py b/src/train_char_gpt2_koef_64.py new file mode 100644 index 0000000..6b68bba --- /dev/null +++ b/src/train_char_gpt2_koef_64.py @@ -0,0 +1,369 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +from einops import rearrange +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime +import sys +from pathlib import Path +import shutil +seed = 1337 +torch.manual_seed(seed) + +batch_size = 50 +gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient +max_iters = int(4e4) #40000 +eval_interval = 300 +learning_rate = 1e-3 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +eval_iters = 200 +layers_num = 2 +h_dim = 64 +max_seq_len = 64 +num_heads = 1 +dropout_rate = 0.1 +pixel_size = 3.6e-6 +assert batch_size % gradient_accumulation_steps == 0 + +############################### MODEL ############################################################# + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = nn.Linear(h_dim, 4*h_dim) + self.ff2 = nn.Linear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + self.k1 = nn.Parameter(torch.randn(1)) + self.k2 = nn.Parameter(torch.randn(1)) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = (self.k1 * (q @ k.transpose(1, 2))) * (self.h_dim ** -0.5) + tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) + scores = scores.masked_fill(tril == 0, float('-inf')) # encoder does not need this line + attention = nn.functional.softmax(scores, dim=2) + return self.o_proj(self.gather_heads(self.k2 * (attention @ v), *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +class GPT2(nn.Module): + def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, pixel_size=None): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.layers = nn.ModuleList([ + TransformerLayer(h_dim=self.h_dim, num_heads=self.num_heads, dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) + for _ in range(layers_num)]) + self.tok_embeds = nn.Embedding(vocab_size, h_dim) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + self.lm_head = nn.Linear(h_dim, vocab_size) + + def forward(self, x, targets=None): + x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + logits = self.lm_head(x) # B,T,C + loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None + return logits, loss + + # what is the purpose? autoregressive inference! + def generate(self, start_idx, max_new_tokens): + idx = start_idx + for i in range(max_new_tokens): + idx_cond = idx[:,-self.max_seq_len:] + logits, loss = self(idx_cond) + logits = logits[:,-1,:] # B, C + probs = F.softmax(logits, dim=-1) + idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) + idx = torch.cat([idx, idx_next], dim=1) + return idx + +################################################################################################### + +MODEL_CLASS = GPT2 +train_data_path = Path("./data/wiki.train.tokens") +val_data_path = Path("./data/wiki.valid.tokens") +test_data_path = Path("./data/wiki.test.tokens") +comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}" + +logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +writer = SummaryWriter(logs_dir) +script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name) +print("Logs dir:", logs_dir) +# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script +shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository +script_snapshot_path.chmod(0o500) # with read-only permission + +# Create standalone checkpoints directory with your specified format +checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +Path(checkpoints_dir).mkdir(parents=True, exist_ok=True) +print("Checkpoints dir:", checkpoints_dir) + +#################################### Dataset ######################################### + +# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt +with open(train_data_path, encoding='utf-8') as f: + train_text = f.read() + +with open(val_data_path, encoding='utf-8') as f: + val_text = f.read() + +with open(test_data_path, encoding='utf-8') as f: + test_text = f.read() + +text = train_text + val_text + test_text + +chars = sorted(set(text)) +vocab_size = len(chars) + +wtoi = {w:i for i,w in enumerate(chars)} +itow = {i:w for i,w in enumerate(chars)} + +encode = lambda s: [wtoi[w] for w in s] +decode = lambda idx: ''.join([itow[i] for i in idx]) + +def get_batch(data, seq_len, batch_size): + ix = torch.randint(len(data)-seq_len, (batch_size, )) + x = torch.stack([data[i:i+seq_len] for i in ix]).to(device) + y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device) + return x, y + +train_data = torch.tensor(encode(train_text), dtype=torch.long) +val_data = torch.tensor(encode(val_text), dtype=torch.long) +test_data = torch.tensor(encode(test_text), dtype=torch.long) + +@torch.no_grad() +def perplexity(model, data, batch_size=32): + model.eval() + stride = max(1, len(data) // 10000) + total_loss_sum = 0.0 + total_tokens_count = 0 + + # Precompute all valid start positions + start_positions = list(range(0, len(data) - max_seq_len - 1, stride)) + total_sequences = len(start_positions) + + # Process sequences in batches + for i in range(0, total_sequences, batch_size): + batch_starts = start_positions[i:min(i + batch_size, total_sequences)] + + # Efficiently stack sequences into batch tensors + x_batch = torch.stack([ + data[start:start + max_seq_len] + for start in batch_starts + ]).to(device) + + y_batch = torch.stack([ + data[start + 1:start + max_seq_len + 1] + for start in batch_starts + ]).to(device) + + # Forward pass (model should return mean loss averaged over all tokens in batch) + _, mean_loss = model(x_batch, y_batch) + + # Accumulate weighted loss (mean_loss is already averaged over tokens) + num_tokens = y_batch.numel() + total_loss_sum += mean_loss.item() * num_tokens + total_tokens_count += num_tokens + + # Progress update + processed = min(i + batch_size, total_sequences) + print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True) + + print() # Final newline + return np.exp(total_loss_sum / total_tokens_count) + +#################################### Model #########################################mo +def complete(m, start_idxs=[0], max_new_tokens=100): + start_idx = torch.tensor([start_idxs]).to(device) + generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens) + return decode(generated_tokens[0].tolist()) + +m = MODEL_CLASS( + vocab_size=vocab_size, + h_dim=h_dim, + max_seq_len=max_seq_len, + num_heads=num_heads, + pixel_size=pixel_size, + layers_num=layers_num +) +m = m.to(device) +model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}' +writer.add_text('model', model_description, 0) + +print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.") +print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.") + +optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) + +#################################### Checkpoint Function ######################################### + +def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir): + """Save model checkpoint with complete training state""" + checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt' + torch.save({ + 'step': step, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss, + 'config': config, + 'wtoi': wtoi, + 'itow': itow, + }, checkpoint_path) + +# Training config for checkpointing +training_config = { + 'vocab_size': vocab_size, + 'layers_num': layers_num, + 'h_dim': h_dim, + 'max_seq_len': max_seq_len, + 'num_heads': num_heads, + 'dropout_rate': dropout_rate, + 'batch_size': batch_size, + 'learning_rate': learning_rate, + 'gradient_accumulation_steps': gradient_accumulation_steps, + 'pixel_size': pixel_size, + 'max_iters': max_iters, +} +#################################### Train ######################################### + +start_time = datetime.now() +print("Started at:", start_time) + +m.eval() +task_prompts = [ + "1 2 3 4 5", + "The capital of France is", + "The chemical symbol of gold is", + "If yesterday was Friday, then tomorrow will be", + "The opposite of hot is", + "The planets of the solar system are:", + "My favorite color is", + "If 5*x + 3 = 13, then x is", +] +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, 0) + +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, 0) + +for i in range(max_iters): + m.train() + optimizer.zero_grad(set_to_none=True) + accumulated_loss = 0.0 + for j in range(gradient_accumulation_steps): + xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps) + logits, loss = m(xb, yb) + loss = loss / gradient_accumulation_steps + loss.backward() + accumulated_loss += loss.item() + if i % 100 == 0: + writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i) + optimizer.step() + writer.add_scalar('loss', accumulated_loss, i) + print(f"\r{i}/{max_iters} {accumulated_loss}", end="") + if i % 5000 == 0: + m.eval() + ppl = perplexity(model=m, data=val_data, batch_size=batch_size) + writer.add_scalar('val_perplexity', ppl.item(), i) + print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}") + writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) + task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) + writer.add_text('completions/task', task_results, i) + save_checkpoint( + model=m, + optimizer=optimizer, + step=i, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir + ) + +m.eval() +ppl = perplexity(model=m, data=val_data, batch_size=batch_size) +print(f"\r{i+1}/{max_iters} {accumulated_loss}") +print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") +writer.add_scalar('val_perplexity', ppl.item(), i+1) +writer.add_scalar('loss', accumulated_loss, i) + +ppl = perplexity(model=m, data=test_data, batch_size=batch_size) +writer.add_scalar('test_perplexity', ppl.item(), i+1) +print(f"\rTest Perplexity at {i}: {ppl}") + +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, i+1) +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, i+1) + +# Save final checkpoint +save_checkpoint( + model=m, + optimizer=optimizer, + step=max_iters, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir +) +print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}") \ No newline at end of file diff --git a/src/train_optics_char_gpt2_128.py b/src/train_optics_char_gpt2_128.py new file mode 100644 index 0000000..1933153 --- /dev/null +++ b/src/train_optics_char_gpt2_128.py @@ -0,0 +1,473 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +from einops import rearrange +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime +import sys +from pathlib import Path +import shutil +seed = 1337 +torch.manual_seed(seed) + +batch_size = 50 +gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient +max_iters = int(4e4) #40000 +eval_interval = 300 +learning_rate = 1e-3 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +eval_iters = 200 +layers_num = 2 +h_dim = 64 +max_seq_len = 128 +num_heads = 1 +dropout_rate = 0.1 +pixel_size = 3.6e-6 +assert batch_size % gradient_accumulation_steps == 0 + +############################### MODEL ############################################################# + +def new_formula(sim, tensor_1, tensor_2): + tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1 + tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2 + device = tensor_1.device + + A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0) + A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0) + B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0) + B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0) + + max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений + max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений + max_B_pos = torch.max(B_pos) + max_B_neg = torch.max(B_neg) + + zero_template = torch.zeros_like( + torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3])) + + if max_A_pos > 0 and max_B_pos > 0: + t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos + else: + t1 = zero_template.clone().to(device) + + if max_A_pos > 0 and max_B_neg > 0: + t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg + else: + t2 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_pos > 0: + t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos + else: + t3 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_neg > 0: + t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg + else: + t4 = zero_template.clone().to(device) + + return (t1 - t2 - t3 + t4)[0,:,:,:] + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = nn.Linear(h_dim, 4*h_dim) + self.ff2 = nn.Linear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + self.k1 = nn.Parameter(torch.randn(1)) + self.k2 = nn.Parameter(torch.randn(1)) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = self.k1 * new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5) + tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) + scores = scores.masked_fill(tril == 0, float('-inf')) + attention = nn.functional.softmax(scores, dim=2) + output = self.k2 * new_formula(self.sim_output, attention, v) + return self.o_proj(self.gather_heads(output, *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +class OpticGPT2(nn.Module): + def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, + pixel_size = 3.6e-6): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + if max_seq_len < 512: + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=False) + ) + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=False) + ) + if max_seq_len >= 512: + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2, + trainable_cylind_lens=False) + ) + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2, + trainable_cylind_lens=False) + ) + + self.layers = nn.ModuleList([ + TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads, + dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) + for _ in range(layers_num)]) + self.tok_embeds = nn.Embedding(vocab_size, h_dim) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + self.lm_head = nn.Linear(h_dim, vocab_size) + + def forward(self, x, targets=None): + x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + logits = self.lm_head(x) # B,T,C + loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None + return logits, loss + + # what is the purpose? autoregressive inference! + def generate(self, start_idx, max_new_tokens): + idx = start_idx + for i in range(max_new_tokens): + idx_cond = idx[:,-self.max_seq_len:] + logits, loss = self(idx_cond) + logits = logits[:,-1,:] # B, C + probs = F.softmax(logits, dim=-1) + idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) + idx = torch.cat([idx, idx_next], dim=1) + return idx + +################################################################################################### + +MODEL_CLASS = OpticGPT2 +train_data_path = Path("./data/wiki.train.tokens") +val_data_path = Path("./data/wiki.valid.tokens") +test_data_path = Path("./data/wiki.test.tokens") +comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}" + +logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +writer = SummaryWriter(logs_dir) +script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name) +print("Logs dir:", logs_dir) +# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script +shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository +script_snapshot_path.chmod(0o500) # with read-only permission + +# Create standalone checkpoints directory with your specified format +checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +Path(checkpoints_dir).mkdir(parents=True, exist_ok=True) +print("Checkpoints dir:", checkpoints_dir) + +#################################### Dataset ######################################### + +# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt +with open(train_data_path, encoding='utf-8') as f: + train_text = f.read() + +with open(val_data_path, encoding='utf-8') as f: + val_text = f.read() + +with open(test_data_path, encoding='utf-8') as f: + test_text = f.read() + +text = train_text + val_text + test_text + +chars = sorted(set(text)) +vocab_size = len(chars) + +wtoi = {w:i for i,w in enumerate(chars)} +itow = {i:w for i,w in enumerate(chars)} + +encode = lambda s: [wtoi[w] for w in s] +decode = lambda idx: ''.join([itow[i] for i in idx]) + +def get_batch(data, seq_len, batch_size): + ix = torch.randint(len(data)-seq_len, (batch_size, )) + x = torch.stack([data[i:i+seq_len] for i in ix]).to(device) + y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device) + return x, y + +train_data = torch.tensor(encode(train_text), dtype=torch.long) +val_data = torch.tensor(encode(val_text), dtype=torch.long) +test_data = torch.tensor(encode(test_text), dtype=torch.long) + +@torch.no_grad() +def perplexity(model, data, batch_size=32): + model.eval() + stride = max(1, len(data) // 10000) + total_loss_sum = 0.0 + total_tokens_count = 0 + + # Precompute all valid start positions + start_positions = list(range(0, len(data) - max_seq_len - 1, stride)) + total_sequences = len(start_positions) + + # Process sequences in batches + for i in range(0, total_sequences, batch_size): + batch_starts = start_positions[i:min(i + batch_size, total_sequences)] + + # Efficiently stack sequences into batch tensors + x_batch = torch.stack([ + data[start:start + max_seq_len] + for start in batch_starts + ]).to(device) + + y_batch = torch.stack([ + data[start + 1:start + max_seq_len + 1] + for start in batch_starts + ]).to(device) + + # Forward pass (model should return mean loss averaged over all tokens in batch) + _, mean_loss = model(x_batch, y_batch) + + # Accumulate weighted loss (mean_loss is already averaged over tokens) + num_tokens = y_batch.numel() + total_loss_sum += mean_loss.item() * num_tokens + total_tokens_count += num_tokens + + # Progress update + processed = min(i + batch_size, total_sequences) + print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True) + + print() # Final newline + return np.exp(total_loss_sum / total_tokens_count) + +#################################### Model #########################################mo +def complete(m, start_idxs=[0], max_new_tokens=100): + start_idx = torch.tensor([start_idxs]).to(device) + generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens) + return decode(generated_tokens[0].tolist()) + +m = MODEL_CLASS( + vocab_size=vocab_size, + h_dim=h_dim, + max_seq_len=max_seq_len, + num_heads=num_heads, + pixel_size=pixel_size, + layers_num=layers_num +) +m = m.to(device) +model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}' +writer.add_text('model', model_description, 0) + +print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.") +print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.") + +optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) + +#################################### Checkpoint Function ######################################### + +def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir): + """Save model checkpoint with complete training state""" + checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt' + torch.save({ + 'step': step, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss, + 'config': config, + 'wtoi': wtoi, + 'itow': itow, + }, checkpoint_path) + +# Training config for checkpointing +training_config = { + 'vocab_size': vocab_size, + 'layers_num': layers_num, + 'h_dim': h_dim, + 'max_seq_len': max_seq_len, + 'num_heads': num_heads, + 'dropout_rate': dropout_rate, + 'batch_size': batch_size, + 'learning_rate': learning_rate, + 'gradient_accumulation_steps': gradient_accumulation_steps, + 'pixel_size': pixel_size, + 'max_iters': max_iters, +} +#################################### Train ######################################### + +start_time = datetime.now() +print("Started at:", start_time) + +m.eval() +task_prompts = [ + "1 2 3 4 5", + "The capital of France is", + "The chemical symbol of gold is", + "If yesterday was Friday, then tomorrow will be", + "The opposite of hot is", + "The planets of the solar system are:", + "My favorite color is", + "If 5*x + 3 = 13, then x is", +] +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, 0) + +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, 0) + +for i in range(max_iters): + m.train() + optimizer.zero_grad(set_to_none=True) + accumulated_loss = 0.0 + for j in range(gradient_accumulation_steps): + xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps) + logits, loss = m(xb, yb) + loss = loss / gradient_accumulation_steps + loss.backward() + accumulated_loss += loss.item() + if i % 100 == 0: + writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i) + optimizer.step() + writer.add_scalar('loss', accumulated_loss, i) + print(f"\r{i}/{max_iters} {accumulated_loss}", end="") + if i % 5000 == 0: + m.eval() + ppl = perplexity(model=m, data=val_data, batch_size=batch_size) + writer.add_scalar('val_perplexity', ppl.item(), i) + print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}") + writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) + task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) + writer.add_text('completions/task', task_results, i) + save_checkpoint( + model=m, + optimizer=optimizer, + step=i, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir + ) + +m.eval() +ppl = perplexity(model=m, data=val_data, batch_size=batch_size) +print(f"\r{i+1}/{max_iters} {accumulated_loss}") +print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") +writer.add_scalar('val_perplexity', ppl.item(), i+1) +writer.add_scalar('loss', accumulated_loss, i) + +ppl = perplexity(model=m, data=test_data, batch_size=batch_size) +writer.add_scalar('test_perplexity', ppl.item(), i+1) +print(f"\rTest Perplexity at {i}: {ppl}") + +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, i+1) +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, i+1) + +# Save final checkpoint +save_checkpoint( + model=m, + optimizer=optimizer, + step=max_iters, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir +) +print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}") \ No newline at end of file diff --git a/src/train_optics_char_gpt2_256.py b/src/train_optics_char_gpt2_256.py new file mode 100644 index 0000000..2634f9b --- /dev/null +++ b/src/train_optics_char_gpt2_256.py @@ -0,0 +1,473 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +from einops import rearrange +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime +import sys +from pathlib import Path +import shutil +seed = 1337 +torch.manual_seed(seed) + +batch_size = 50 +gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient +max_iters = int(4e4) #40000 +eval_interval = 300 +learning_rate = 1e-3 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +eval_iters = 200 +layers_num = 2 +h_dim = 64 +max_seq_len = 256 +num_heads = 1 +dropout_rate = 0.1 +pixel_size = 3.6e-6 +assert batch_size % gradient_accumulation_steps == 0 + +############################### MODEL ############################################################# + +def new_formula(sim, tensor_1, tensor_2): + tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1 + tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2 + device = tensor_1.device + + A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0) + A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0) + B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0) + B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0) + + max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений + max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений + max_B_pos = torch.max(B_pos) + max_B_neg = torch.max(B_neg) + + zero_template = torch.zeros_like( + torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3])) + + if max_A_pos > 0 and max_B_pos > 0: + t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos + else: + t1 = zero_template.clone().to(device) + + if max_A_pos > 0 and max_B_neg > 0: + t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg + else: + t2 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_pos > 0: + t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos + else: + t3 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_neg > 0: + t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg + else: + t4 = zero_template.clone().to(device) + + return (t1 - t2 - t3 + t4)[0,:,:,:] + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = nn.Linear(h_dim, 4*h_dim) + self.ff2 = nn.Linear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + self.k1 = nn.Parameter(torch.randn(1)) + self.k2 = nn.Parameter(torch.randn(1)) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = self.k1 * new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5) + tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) + scores = scores.masked_fill(tril == 0, float('-inf')) + attention = nn.functional.softmax(scores, dim=2) + output = self.k2 * new_formula(self.sim_output, attention, v) + return self.o_proj(self.gather_heads(output, *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +class OpticGPT2(nn.Module): + def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, + pixel_size = 3.6e-6): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + if max_seq_len < 512: + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=False) + ) + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=False) + ) + if max_seq_len >= 512: + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2, + trainable_cylind_lens=False) + ) + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2, + trainable_cylind_lens=False) + ) + + self.layers = nn.ModuleList([ + TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads, + dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) + for _ in range(layers_num)]) + self.tok_embeds = nn.Embedding(vocab_size, h_dim) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + self.lm_head = nn.Linear(h_dim, vocab_size) + + def forward(self, x, targets=None): + x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + logits = self.lm_head(x) # B,T,C + loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None + return logits, loss + + # what is the purpose? autoregressive inference! + def generate(self, start_idx, max_new_tokens): + idx = start_idx + for i in range(max_new_tokens): + idx_cond = idx[:,-self.max_seq_len:] + logits, loss = self(idx_cond) + logits = logits[:,-1,:] # B, C + probs = F.softmax(logits, dim=-1) + idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) + idx = torch.cat([idx, idx_next], dim=1) + return idx + +################################################################################################### + +MODEL_CLASS = OpticGPT2 +train_data_path = Path("./data/wiki.train.tokens") +val_data_path = Path("./data/wiki.valid.tokens") +test_data_path = Path("./data/wiki.test.tokens") +comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}" + +logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +writer = SummaryWriter(logs_dir) +script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name) +print("Logs dir:", logs_dir) +# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script +shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository +script_snapshot_path.chmod(0o500) # with read-only permission + +# Create standalone checkpoints directory with your specified format +checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +Path(checkpoints_dir).mkdir(parents=True, exist_ok=True) +print("Checkpoints dir:", checkpoints_dir) + +#################################### Dataset ######################################### + +# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt +with open(train_data_path, encoding='utf-8') as f: + train_text = f.read() + +with open(val_data_path, encoding='utf-8') as f: + val_text = f.read() + +with open(test_data_path, encoding='utf-8') as f: + test_text = f.read() + +text = train_text + val_text + test_text + +chars = sorted(set(text)) +vocab_size = len(chars) + +wtoi = {w:i for i,w in enumerate(chars)} +itow = {i:w for i,w in enumerate(chars)} + +encode = lambda s: [wtoi[w] for w in s] +decode = lambda idx: ''.join([itow[i] for i in idx]) + +def get_batch(data, seq_len, batch_size): + ix = torch.randint(len(data)-seq_len, (batch_size, )) + x = torch.stack([data[i:i+seq_len] for i in ix]).to(device) + y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device) + return x, y + +train_data = torch.tensor(encode(train_text), dtype=torch.long) +val_data = torch.tensor(encode(val_text), dtype=torch.long) +test_data = torch.tensor(encode(test_text), dtype=torch.long) + +@torch.no_grad() +def perplexity(model, data, batch_size=32): + model.eval() + stride = max(1, len(data) // 10000) + total_loss_sum = 0.0 + total_tokens_count = 0 + + # Precompute all valid start positions + start_positions = list(range(0, len(data) - max_seq_len - 1, stride)) + total_sequences = len(start_positions) + + # Process sequences in batches + for i in range(0, total_sequences, batch_size): + batch_starts = start_positions[i:min(i + batch_size, total_sequences)] + + # Efficiently stack sequences into batch tensors + x_batch = torch.stack([ + data[start:start + max_seq_len] + for start in batch_starts + ]).to(device) + + y_batch = torch.stack([ + data[start + 1:start + max_seq_len + 1] + for start in batch_starts + ]).to(device) + + # Forward pass (model should return mean loss averaged over all tokens in batch) + _, mean_loss = model(x_batch, y_batch) + + # Accumulate weighted loss (mean_loss is already averaged over tokens) + num_tokens = y_batch.numel() + total_loss_sum += mean_loss.item() * num_tokens + total_tokens_count += num_tokens + + # Progress update + processed = min(i + batch_size, total_sequences) + print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True) + + print() # Final newline + return np.exp(total_loss_sum / total_tokens_count) + +#################################### Model #########################################mo +def complete(m, start_idxs=[0], max_new_tokens=100): + start_idx = torch.tensor([start_idxs]).to(device) + generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens) + return decode(generated_tokens[0].tolist()) + +m = MODEL_CLASS( + vocab_size=vocab_size, + h_dim=h_dim, + max_seq_len=max_seq_len, + num_heads=num_heads, + pixel_size=pixel_size, + layers_num=layers_num +) +m = m.to(device) +model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}' +writer.add_text('model', model_description, 0) + +print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.") +print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.") + +optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) + +#################################### Checkpoint Function ######################################### + +def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir): + """Save model checkpoint with complete training state""" + checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt' + torch.save({ + 'step': step, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss, + 'config': config, + 'wtoi': wtoi, + 'itow': itow, + }, checkpoint_path) + +# Training config for checkpointing +training_config = { + 'vocab_size': vocab_size, + 'layers_num': layers_num, + 'h_dim': h_dim, + 'max_seq_len': max_seq_len, + 'num_heads': num_heads, + 'dropout_rate': dropout_rate, + 'batch_size': batch_size, + 'learning_rate': learning_rate, + 'gradient_accumulation_steps': gradient_accumulation_steps, + 'pixel_size': pixel_size, + 'max_iters': max_iters, +} +#################################### Train ######################################### + +start_time = datetime.now() +print("Started at:", start_time) + +m.eval() +task_prompts = [ + "1 2 3 4 5", + "The capital of France is", + "The chemical symbol of gold is", + "If yesterday was Friday, then tomorrow will be", + "The opposite of hot is", + "The planets of the solar system are:", + "My favorite color is", + "If 5*x + 3 = 13, then x is", +] +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, 0) + +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, 0) + +for i in range(max_iters): + m.train() + optimizer.zero_grad(set_to_none=True) + accumulated_loss = 0.0 + for j in range(gradient_accumulation_steps): + xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps) + logits, loss = m(xb, yb) + loss = loss / gradient_accumulation_steps + loss.backward() + accumulated_loss += loss.item() + if i % 100 == 0: + writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i) + optimizer.step() + writer.add_scalar('loss', accumulated_loss, i) + print(f"\r{i}/{max_iters} {accumulated_loss}", end="") + if i % 5000 == 0: + m.eval() + ppl = perplexity(model=m, data=val_data, batch_size=batch_size) + writer.add_scalar('val_perplexity', ppl.item(), i) + print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}") + writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) + task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) + writer.add_text('completions/task', task_results, i) + save_checkpoint( + model=m, + optimizer=optimizer, + step=i, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir + ) + +m.eval() +ppl = perplexity(model=m, data=val_data, batch_size=batch_size) +print(f"\r{i+1}/{max_iters} {accumulated_loss}") +print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") +writer.add_scalar('val_perplexity', ppl.item(), i+1) +writer.add_scalar('loss', accumulated_loss, i) + +ppl = perplexity(model=m, data=test_data, batch_size=batch_size) +writer.add_scalar('test_perplexity', ppl.item(), i+1) +print(f"\rTest Perplexity at {i}: {ppl}") + +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, i+1) +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, i+1) + +# Save final checkpoint +save_checkpoint( + model=m, + optimizer=optimizer, + step=max_iters, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir +) +print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}") \ No newline at end of file diff --git a/src/train_optics_char_gpt2_512.py b/src/train_optics_char_gpt2_512.py new file mode 100644 index 0000000..e4eb00d --- /dev/null +++ b/src/train_optics_char_gpt2_512.py @@ -0,0 +1,473 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +from einops import rearrange +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime +import sys +from pathlib import Path +import shutil +seed = 1337 +torch.manual_seed(seed) + +batch_size = 50 +gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient +max_iters = int(4e4) #40000 +eval_interval = 300 +learning_rate = 1e-3 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +eval_iters = 200 +layers_num = 2 +h_dim = 64 +max_seq_len = 512 +num_heads = 1 +dropout_rate = 0.1 +pixel_size = 3.6e-6 +assert batch_size % gradient_accumulation_steps == 0 + +############################### MODEL ############################################################# + +def new_formula(sim, tensor_1, tensor_2): + tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1 + tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2 + device = tensor_1.device + + A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0) + A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0) + B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0) + B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0) + + max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений + max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений + max_B_pos = torch.max(B_pos) + max_B_neg = torch.max(B_neg) + + zero_template = torch.zeros_like( + torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3])) + + if max_A_pos > 0 and max_B_pos > 0: + t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos + else: + t1 = zero_template.clone().to(device) + + if max_A_pos > 0 and max_B_neg > 0: + t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg + else: + t2 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_pos > 0: + t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos + else: + t3 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_neg > 0: + t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg + else: + t4 = zero_template.clone().to(device) + + return (t1 - t2 - t3 + t4)[0,:,:,:] + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = nn.Linear(h_dim, 4*h_dim) + self.ff2 = nn.Linear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + self.k1 = nn.Parameter(torch.randn(1)) + self.k2 = nn.Parameter(torch.randn(1)) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = self.k1 * new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5) + tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) + scores = scores.masked_fill(tril == 0, float('-inf')) + attention = nn.functional.softmax(scores, dim=2) + output = self.k2 * new_formula(self.sim_output, attention, v) + return self.o_proj(self.gather_heads(output, *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +class OpticGPT2(nn.Module): + def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, + pixel_size = 3.6e-6): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + if max_seq_len < 512: + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=False) + ) + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=False) + ) + if max_seq_len >= 512: + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2, + trainable_cylind_lens=False) + ) + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2, + trainable_cylind_lens=False) + ) + + self.layers = nn.ModuleList([ + TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads, + dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) + for _ in range(layers_num)]) + self.tok_embeds = nn.Embedding(vocab_size, h_dim) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + self.lm_head = nn.Linear(h_dim, vocab_size) + + def forward(self, x, targets=None): + x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + logits = self.lm_head(x) # B,T,C + loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None + return logits, loss + + # what is the purpose? autoregressive inference! + def generate(self, start_idx, max_new_tokens): + idx = start_idx + for i in range(max_new_tokens): + idx_cond = idx[:,-self.max_seq_len:] + logits, loss = self(idx_cond) + logits = logits[:,-1,:] # B, C + probs = F.softmax(logits, dim=-1) + idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) + idx = torch.cat([idx, idx_next], dim=1) + return idx + +################################################################################################### + +MODEL_CLASS = OpticGPT2 +train_data_path = Path("./data/wiki.train.tokens") +val_data_path = Path("./data/wiki.valid.tokens") +test_data_path = Path("./data/wiki.test.tokens") +comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}" + +logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +writer = SummaryWriter(logs_dir) +script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name) +print("Logs dir:", logs_dir) +# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script +shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository +script_snapshot_path.chmod(0o500) # with read-only permission + +# Create standalone checkpoints directory with your specified format +checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +Path(checkpoints_dir).mkdir(parents=True, exist_ok=True) +print("Checkpoints dir:", checkpoints_dir) + +#################################### Dataset ######################################### + +# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt +with open(train_data_path, encoding='utf-8') as f: + train_text = f.read() + +with open(val_data_path, encoding='utf-8') as f: + val_text = f.read() + +with open(test_data_path, encoding='utf-8') as f: + test_text = f.read() + +text = train_text + val_text + test_text + +chars = sorted(set(text)) +vocab_size = len(chars) + +wtoi = {w:i for i,w in enumerate(chars)} +itow = {i:w for i,w in enumerate(chars)} + +encode = lambda s: [wtoi[w] for w in s] +decode = lambda idx: ''.join([itow[i] for i in idx]) + +def get_batch(data, seq_len, batch_size): + ix = torch.randint(len(data)-seq_len, (batch_size, )) + x = torch.stack([data[i:i+seq_len] for i in ix]).to(device) + y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device) + return x, y + +train_data = torch.tensor(encode(train_text), dtype=torch.long) +val_data = torch.tensor(encode(val_text), dtype=torch.long) +test_data = torch.tensor(encode(test_text), dtype=torch.long) + +@torch.no_grad() +def perplexity(model, data, batch_size=32): + model.eval() + stride = max(1, len(data) // 10000) + total_loss_sum = 0.0 + total_tokens_count = 0 + + # Precompute all valid start positions + start_positions = list(range(0, len(data) - max_seq_len - 1, stride)) + total_sequences = len(start_positions) + + # Process sequences in batches + for i in range(0, total_sequences, batch_size): + batch_starts = start_positions[i:min(i + batch_size, total_sequences)] + + # Efficiently stack sequences into batch tensors + x_batch = torch.stack([ + data[start:start + max_seq_len] + for start in batch_starts + ]).to(device) + + y_batch = torch.stack([ + data[start + 1:start + max_seq_len + 1] + for start in batch_starts + ]).to(device) + + # Forward pass (model should return mean loss averaged over all tokens in batch) + _, mean_loss = model(x_batch, y_batch) + + # Accumulate weighted loss (mean_loss is already averaged over tokens) + num_tokens = y_batch.numel() + total_loss_sum += mean_loss.item() * num_tokens + total_tokens_count += num_tokens + + # Progress update + processed = min(i + batch_size, total_sequences) + print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True) + + print() # Final newline + return np.exp(total_loss_sum / total_tokens_count) + +#################################### Model #########################################mo +def complete(m, start_idxs=[0], max_new_tokens=100): + start_idx = torch.tensor([start_idxs]).to(device) + generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens) + return decode(generated_tokens[0].tolist()) + +m = MODEL_CLASS( + vocab_size=vocab_size, + h_dim=h_dim, + max_seq_len=max_seq_len, + num_heads=num_heads, + pixel_size=pixel_size, + layers_num=layers_num +) +m = m.to(device) +model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}' +writer.add_text('model', model_description, 0) + +print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.") +print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.") + +optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) + +#################################### Checkpoint Function ######################################### + +def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir): + """Save model checkpoint with complete training state""" + checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt' + torch.save({ + 'step': step, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss, + 'config': config, + 'wtoi': wtoi, + 'itow': itow, + }, checkpoint_path) + +# Training config for checkpointing +training_config = { + 'vocab_size': vocab_size, + 'layers_num': layers_num, + 'h_dim': h_dim, + 'max_seq_len': max_seq_len, + 'num_heads': num_heads, + 'dropout_rate': dropout_rate, + 'batch_size': batch_size, + 'learning_rate': learning_rate, + 'gradient_accumulation_steps': gradient_accumulation_steps, + 'pixel_size': pixel_size, + 'max_iters': max_iters, +} +#################################### Train ######################################### + +start_time = datetime.now() +print("Started at:", start_time) + +m.eval() +task_prompts = [ + "1 2 3 4 5", + "The capital of France is", + "The chemical symbol of gold is", + "If yesterday was Friday, then tomorrow will be", + "The opposite of hot is", + "The planets of the solar system are:", + "My favorite color is", + "If 5*x + 3 = 13, then x is", +] +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, 0) + +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, 0) + +for i in range(max_iters): + m.train() + optimizer.zero_grad(set_to_none=True) + accumulated_loss = 0.0 + for j in range(gradient_accumulation_steps): + xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps) + logits, loss = m(xb, yb) + loss = loss / gradient_accumulation_steps + loss.backward() + accumulated_loss += loss.item() + if i % 100 == 0: + writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i) + optimizer.step() + writer.add_scalar('loss', accumulated_loss, i) + print(f"\r{i}/{max_iters} {accumulated_loss}", end="") + if i % 5000 == 0: + m.eval() + ppl = perplexity(model=m, data=val_data, batch_size=batch_size) + writer.add_scalar('val_perplexity', ppl.item(), i) + print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}") + writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) + task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) + writer.add_text('completions/task', task_results, i) + save_checkpoint( + model=m, + optimizer=optimizer, + step=i, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir + ) + +m.eval() +ppl = perplexity(model=m, data=val_data, batch_size=batch_size) +print(f"\r{i+1}/{max_iters} {accumulated_loss}") +print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") +writer.add_scalar('val_perplexity', ppl.item(), i+1) +writer.add_scalar('loss', accumulated_loss, i) + +ppl = perplexity(model=m, data=test_data, batch_size=batch_size) +writer.add_scalar('test_perplexity', ppl.item(), i+1) +print(f"\rTest Perplexity at {i}: {ppl}") + +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, i+1) +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, i+1) + +# Save final checkpoint +save_checkpoint( + model=m, + optimizer=optimizer, + step=max_iters, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir +) +print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}") \ No newline at end of file diff --git a/src/optics_char_gpt2_new_formula.py b/src/train_optics_char_gpt2_64.py similarity index 51% rename from src/optics_char_gpt2_new_formula.py rename to src/train_optics_char_gpt2_64.py index 6b59105..877c41c 100644 --- a/src/optics_char_gpt2_new_formula.py +++ b/src/train_optics_char_gpt2_64.py @@ -2,19 +2,31 @@ import torch import torch.nn as nn from torch.nn import functional as F from einops import rearrange -import optical_matrix_multiplication as omm -from optical_matrix_multiplication import propagator import numpy as np from torch.utils.tensorboard import SummaryWriter from datetime import datetime -from pathlib import Path import sys -torch.manual_seed(1337) +from pathlib import Path +import shutil +seed = 1337 +torch.manual_seed(seed) -#################################### Model ######################################### +batch_size = 50 +gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient +max_iters = int(4e4) #40000 +eval_interval = 300 +learning_rate = 1e-3 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +eval_iters = 200 +layers_num = 2 +h_dim = 64 +max_seq_len = 64 +num_heads = 1 +dropout_rate = 0.1 +pixel_size = 3.6e-6 +assert batch_size % gradient_accumulation_steps == 0 -# def norm(matrix: torch.Tensor, max_val: float = 1) -> torch.Tensor: -# return matrix / (max_val + 1e-10) +############################### MODEL ############################################################# def new_formula(sim, tensor_1, tensor_2): tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1 @@ -130,12 +142,12 @@ class TransformerLayer(nn.Module): x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) return x -class OpticGPT2NewFormula(nn.Module): +class OpticGPT2(nn.Module): def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, pixel_size = 3.6e-6): super().__init__() self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) - if max_seq_len != 512: + if max_seq_len < 512: self.sim_scores = omm.OpticalMul( omm.Config(right_matrix_count_columns = max_seq_len, right_matrix_count_rows = h_dim // num_heads, @@ -164,7 +176,7 @@ class OpticGPT2NewFormula(nn.Module): distance = 0.01, trainable_cylind_lens=False) ) - if max_seq_len == 512: + if max_seq_len >= 512: self.sim_scores = omm.OpticalMul( omm.Config(right_matrix_count_columns = max_seq_len, right_matrix_count_rows = h_dim // num_heads, @@ -195,8 +207,6 @@ class OpticGPT2NewFormula(nn.Module): lens_size = 8192 * 2, trainable_cylind_lens=False) ) - self.sim_scores = omm.DataParallel(self.sim_scores) - self.sim_output = omm.DataParallel(self.sim_output) self.layers = nn.ModuleList([ TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads, @@ -224,4 +234,240 @@ class OpticGPT2NewFormula(nn.Module): probs = F.softmax(logits, dim=-1) idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) idx = torch.cat([idx, idx_next], dim=1) - return idx \ No newline at end of file + return idx + +################################################################################################### + +MODEL_CLASS = OpticGPT2 +train_data_path = Path("./data/wiki.train.tokens") +val_data_path = Path("./data/wiki.valid.tokens") +test_data_path = Path("./data/wiki.test.tokens") +comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}" + +logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +writer = SummaryWriter(logs_dir) +script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name) +print("Logs dir:", logs_dir) +# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script +shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository +script_snapshot_path.chmod(0o500) # with read-only permission + +# Create standalone checkpoints directory with your specified format +checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +Path(checkpoints_dir).mkdir(parents=True, exist_ok=True) +print("Checkpoints dir:", checkpoints_dir) + +#################################### Dataset ######################################### + +# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt +with open(train_data_path, encoding='utf-8') as f: + train_text = f.read() + +with open(val_data_path, encoding='utf-8') as f: + val_text = f.read() + +with open(test_data_path, encoding='utf-8') as f: + test_text = f.read() + +text = train_text + val_text + test_text + +chars = sorted(set(text)) +vocab_size = len(chars) + +wtoi = {w:i for i,w in enumerate(chars)} +itow = {i:w for i,w in enumerate(chars)} + +encode = lambda s: [wtoi[w] for w in s] +decode = lambda idx: ''.join([itow[i] for i in idx]) + +def get_batch(data, seq_len, batch_size): + ix = torch.randint(len(data)-seq_len, (batch_size, )) + x = torch.stack([data[i:i+seq_len] for i in ix]).to(device) + y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device) + return x, y + +train_data = torch.tensor(encode(train_text), dtype=torch.long) +val_data = torch.tensor(encode(val_text), dtype=torch.long) +test_data = torch.tensor(encode(test_text), dtype=torch.long) + +@torch.no_grad() +def perplexity(model, data, batch_size=32): + model.eval() + stride = max(1, len(data) // 10000) + total_loss_sum = 0.0 + total_tokens_count = 0 + + # Precompute all valid start positions + start_positions = list(range(0, len(data) - max_seq_len - 1, stride)) + total_sequences = len(start_positions) + + # Process sequences in batches + for i in range(0, total_sequences, batch_size): + batch_starts = start_positions[i:min(i + batch_size, total_sequences)] + + # Efficiently stack sequences into batch tensors + x_batch = torch.stack([ + data[start:start + max_seq_len] + for start in batch_starts + ]).to(device) + + y_batch = torch.stack([ + data[start + 1:start + max_seq_len + 1] + for start in batch_starts + ]).to(device) + + # Forward pass (model should return mean loss averaged over all tokens in batch) + _, mean_loss = model(x_batch, y_batch) + + # Accumulate weighted loss (mean_loss is already averaged over tokens) + num_tokens = y_batch.numel() + total_loss_sum += mean_loss.item() * num_tokens + total_tokens_count += num_tokens + + # Progress update + processed = min(i + batch_size, total_sequences) + print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True) + + print() # Final newline + return np.exp(total_loss_sum / total_tokens_count) + +#################################### Model #########################################mo +def complete(m, start_idxs=[0], max_new_tokens=100): + start_idx = torch.tensor([start_idxs]).to(device) + generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens) + return decode(generated_tokens[0].tolist()) + +m = MODEL_CLASS( + vocab_size=vocab_size, + h_dim=h_dim, + max_seq_len=max_seq_len, + num_heads=num_heads, + pixel_size=pixel_size, + layers_num=layers_num +) +m = m.to(device) +model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}' +writer.add_text('model', model_description, 0) + +print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.") +print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.") + +optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) + +#################################### Checkpoint Function ######################################### + +def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir): + """Save model checkpoint with complete training state""" + checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt' + torch.save({ + 'step': step, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss, + 'config': config, + 'wtoi': wtoi, + 'itow': itow, + }, checkpoint_path) + +# Training config for checkpointing +training_config = { + 'vocab_size': vocab_size, + 'layers_num': layers_num, + 'h_dim': h_dim, + 'max_seq_len': max_seq_len, + 'num_heads': num_heads, + 'dropout_rate': dropout_rate, + 'batch_size': batch_size, + 'learning_rate': learning_rate, + 'gradient_accumulation_steps': gradient_accumulation_steps, + 'pixel_size': pixel_size, + 'max_iters': max_iters, +} +#################################### Train ######################################### + +start_time = datetime.now() +print("Started at:", start_time) + +m.eval() +task_prompts = [ + "1 2 3 4 5", + "The capital of France is", + "The chemical symbol of gold is", + "If yesterday was Friday, then tomorrow will be", + "The opposite of hot is", + "The planets of the solar system are:", + "My favorite color is", + "If 5*x + 3 = 13, then x is", +] +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, 0) + +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, 0) + +for i in range(max_iters): + m.train() + optimizer.zero_grad(set_to_none=True) + accumulated_loss = 0.0 + for j in range(gradient_accumulation_steps): + xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps) + logits, loss = m(xb, yb) + loss = loss / gradient_accumulation_steps + loss.backward() + accumulated_loss += loss.item() + if i % 100 == 0: + writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i) + optimizer.step() + writer.add_scalar('loss', accumulated_loss, i) + print(f"\r{i}/{max_iters} {accumulated_loss}", end="") + if i % 5000 == 0: + m.eval() + ppl = perplexity(model=m, data=val_data, batch_size=batch_size) + writer.add_scalar('val_perplexity', ppl.item(), i) + print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}") + writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) + task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) + writer.add_text('completions/task', task_results, i) + save_checkpoint( + model=m, + optimizer=optimizer, + step=i, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir + ) + +m.eval() +ppl = perplexity(model=m, data=val_data, batch_size=batch_size) +print(f"\r{i+1}/{max_iters} {accumulated_loss}") +print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") +writer.add_scalar('val_perplexity', ppl.item(), i+1) +writer.add_scalar('loss', accumulated_loss, i) + +ppl = perplexity(model=m, data=test_data, batch_size=batch_size) +writer.add_scalar('test_perplexity', ppl.item(), i+1) +print(f"\rTest Perplexity at {i}: {ppl}") + +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, i+1) +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, i+1) + +# Save final checkpoint +save_checkpoint( + model=m, + optimizer=optimizer, + step=max_iters, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir +) +print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}") \ No newline at end of file diff --git a/src/train_optics_char_gpt2_ff.py b/src/train_optics_char_gpt2_ff.py new file mode 100644 index 0000000..ac21ec3 --- /dev/null +++ b/src/train_optics_char_gpt2_ff.py @@ -0,0 +1,472 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +from einops import rearrange +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime +import sys +from pathlib import Path +import shutil +seed = 1337 +torch.manual_seed(seed) + +batch_size = 50 +gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient +max_iters = int(4e4) #40000 +eval_interval = 300 +learning_rate = 1e-3 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +eval_iters = 200 +layers_num = 2 +h_dim = 64 +max_seq_len = 512 +num_heads = 1 +dropout_rate = 0.1 +pixel_size = 3.6e-6 +assert batch_size % gradient_accumulation_steps == 0 + +############################### MODEL ############################################################# + +def new_formula(sim, tensor_1, tensor_2): + tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1 + tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2 + device = tensor_1.device + + A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0) + A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0) + B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0) + B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0) + + max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений + max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений + max_B_pos = torch.max(B_pos) + max_B_neg = torch.max(B_neg) + + zero_template = torch.zeros_like( + torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3])) + + if max_A_pos > 0 and max_B_pos > 0: + t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos + else: + t1 = zero_template.clone().to(device) + + if max_A_pos > 0 and max_B_neg > 0: + t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg + else: + t2 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_pos > 0: + t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos + else: + t3 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_neg > 0: + t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg + else: + t4 = zero_template.clone().to(device) + + return (t1 - t2 - t3 + t4)[0,:,:,:] + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +class OpticLinear(nn.Module): + def __init__( + self, + in_features, + out_features, + bias = True, + device = None, + dtype = None, + pixel_size = 3.6e-6 + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = nn.Parameter( + torch.empty((in_features, out_features), **factory_kwargs) + ) + if bias: + self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs)) + else: + self.register_parameter("bias", None) + self.k = nn.Parameter(torch.randn(1)) + self.sim = omm.OpticalMul( + omm.Config( + right_matrix_count_columns = out_features , + right_matrix_count_rows = in_features, + right_matrix_width = pixel_size * out_features , + right_matrix_height = pixel_size * in_features, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01) + ) + self.reset_parameters() + + def forward(self, input): + """ + Runs the forward pass. + """ + return self.k * new_formula(self.sim, input, self.weight) + self.bias + + def reset_parameters(self) -> None: + """ + Resets parameters based on their initialization used in ``__init__``. + """ + # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with + # uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see + # https://github.com/pytorch/pytorch/issues/57109 + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + init.uniform_(self.bias, -bound, bound) + + def extra_repr(self) -> str: + """ + Return the extra representation of the module. + """ + return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}" + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = OpticLinear(h_dim, 4*h_dim) + self.ff2 = OpticLinear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = (q @ k.transpose(1, 2)) * (self.h_dim ** -0.5) + tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) + scores = scores.masked_fill(tril == 0, float('-inf')) # encoder does not need this line + attention = nn.functional.softmax(scores, dim=2) + return self.o_proj(self.gather_heads(attention @ v, *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +class OpticGPT2FF(nn.Module): + def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, + pixel_size = 3.6e-6): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.layers = nn.ModuleList([ + TransformerLayer(h_dim=self.h_dim, num_heads=self.num_heads, + dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) + for _ in range(layers_num)]) + self.tok_embeds = nn.Embedding(vocab_size, h_dim) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + self.lm_head = nn.Linear(h_dim, vocab_size) + + def forward(self, x, targets=None): + x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + logits = self.lm_head(x) # B,T,C + loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None + return logits, loss + + # what is the purpose? autoregressive inference! + def generate(self, start_idx, max_new_tokens): + idx = start_idx + for i in range(max_new_tokens): + idx_cond = idx[:,-self.max_seq_len:] + logits, loss = self(idx_cond) + logits = logits[:,-1,:] # B, C + probs = F.softmax(logits, dim=-1) + idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) + idx = torch.cat([idx, idx_next], dim=1) + return idx + +################################################################################################### + +MODEL_CLASS = OpticGPT2FF +train_data_path = Path("./data/wiki.train.tokens") +val_data_path = Path("./data/wiki.valid.tokens") +test_data_path = Path("./data/wiki.test.tokens") +comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}" + +logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +writer = SummaryWriter(logs_dir) +script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name) +print("Logs dir:", logs_dir) +# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script +shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository +script_snapshot_path.chmod(0o500) # with read-only permission + +# Create standalone checkpoints directory with your specified format +checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +Path(checkpoints_dir).mkdir(parents=True, exist_ok=True) +print("Checkpoints dir:", checkpoints_dir) + +#################################### Dataset ######################################### + +# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt +with open(train_data_path, encoding='utf-8') as f: + train_text = f.read() + +with open(val_data_path, encoding='utf-8') as f: + val_text = f.read() + +with open(test_data_path, encoding='utf-8') as f: + test_text = f.read() + +text = train_text + val_text + test_text + +chars = sorted(set(text)) +vocab_size = len(chars) + +wtoi = {w:i for i,w in enumerate(chars)} +itow = {i:w for i,w in enumerate(chars)} + +encode = lambda s: [wtoi[w] for w in s] +decode = lambda idx: ''.join([itow[i] for i in idx]) + +def get_batch(data, seq_len, batch_size): + ix = torch.randint(len(data)-seq_len, (batch_size, )) + x = torch.stack([data[i:i+seq_len] for i in ix]).to(device) + y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device) + return x, y + +train_data = torch.tensor(encode(train_text), dtype=torch.long) +val_data = torch.tensor(encode(val_text), dtype=torch.long) +test_data = torch.tensor(encode(test_text), dtype=torch.long) + +@torch.no_grad() +def perplexity(model, data, batch_size=32): + model.eval() + stride = max(1, len(data) // 10000) + total_loss_sum = 0.0 + total_tokens_count = 0 + + # Precompute all valid start positions + start_positions = list(range(0, len(data) - max_seq_len - 1, stride)) + total_sequences = len(start_positions) + + # Process sequences in batches + for i in range(0, total_sequences, batch_size): + batch_starts = start_positions[i:min(i + batch_size, total_sequences)] + + # Efficiently stack sequences into batch tensors + x_batch = torch.stack([ + data[start:start + max_seq_len] + for start in batch_starts + ]).to(device) + + y_batch = torch.stack([ + data[start + 1:start + max_seq_len + 1] + for start in batch_starts + ]).to(device) + + # Forward pass (model should return mean loss averaged over all tokens in batch) + _, mean_loss = model(x_batch, y_batch) + + # Accumulate weighted loss (mean_loss is already averaged over tokens) + num_tokens = y_batch.numel() + total_loss_sum += mean_loss.item() * num_tokens + total_tokens_count += num_tokens + + # Progress update + processed = min(i + batch_size, total_sequences) + print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True) + + print() # Final newline + return np.exp(total_loss_sum / total_tokens_count) + +#################################### Model #########################################mo +def complete(m, start_idxs=[0], max_new_tokens=100): + start_idx = torch.tensor([start_idxs]).to(device) + generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens) + return decode(generated_tokens[0].tolist()) + +m = MODEL_CLASS( + vocab_size=vocab_size, + h_dim=h_dim, + max_seq_len=max_seq_len, + num_heads=num_heads, + pixel_size=pixel_size, + layers_num=layers_num +) +m = m.to(device) +model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}' +writer.add_text('model', model_description, 0) + +print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.") +print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.") + +optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) + +#################################### Checkpoint Function ######################################### + +def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir): + """Save model checkpoint with complete training state""" + checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt' + torch.save({ + 'step': step, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss, + 'config': config, + 'wtoi': wtoi, + 'itow': itow, + }, checkpoint_path) + +# Training config for checkpointing +training_config = { + 'vocab_size': vocab_size, + 'layers_num': layers_num, + 'h_dim': h_dim, + 'max_seq_len': max_seq_len, + 'num_heads': num_heads, + 'dropout_rate': dropout_rate, + 'batch_size': batch_size, + 'learning_rate': learning_rate, + 'gradient_accumulation_steps': gradient_accumulation_steps, + 'pixel_size': pixel_size, + 'max_iters': max_iters, +} +#################################### Train ######################################### + +start_time = datetime.now() +print("Started at:", start_time) + +m.eval() +task_prompts = [ + "1 2 3 4 5", + "The capital of France is", + "The chemical symbol of gold is", + "If yesterday was Friday, then tomorrow will be", + "The opposite of hot is", + "The planets of the solar system are:", + "My favorite color is", + "If 5*x + 3 = 13, then x is", +] +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, 0) + +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, 0) + +for i in range(max_iters): + m.train() + optimizer.zero_grad(set_to_none=True) + accumulated_loss = 0.0 + for j in range(gradient_accumulation_steps): + xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps) + logits, loss = m(xb, yb) + loss = loss / gradient_accumulation_steps + loss.backward() + accumulated_loss += loss.item() + if i % 100 == 0: + writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i) + optimizer.step() + writer.add_scalar('loss', accumulated_loss, i) + print(f"\r{i}/{max_iters} {accumulated_loss}", end="") + if i % 5000 == 0: + m.eval() + ppl = perplexity(model=m, data=val_data, batch_size=batch_size) + writer.add_scalar('val_perplexity', ppl.item(), i) + print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}") + writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) + task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) + writer.add_text('completions/task', task_results, i) + save_checkpoint( + model=m, + optimizer=optimizer, + step=i, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir + ) + +m.eval() +ppl = perplexity(model=m, data=val_data, batch_size=batch_size) +print(f"\r{i+1}/{max_iters} {accumulated_loss}") +print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") +writer.add_scalar('val_perplexity', ppl.item(), i+1) +writer.add_scalar('loss', accumulated_loss, i) + +ppl = perplexity(model=m, data=test_data, batch_size=batch_size) +writer.add_scalar('test_perplexity', ppl.item(), i+1) +print(f"\rTest Perplexity at {i}: {ppl}") + +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, i+1) +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, i+1) + +# Save final checkpoint +save_checkpoint( + model=m, + optimizer=optimizer, + step=max_iters, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir +) +print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}") \ No newline at end of file diff --git a/src/train_optics_char_gpt2_nokoef_128.py b/src/train_optics_char_gpt2_nokoef_128.py new file mode 100644 index 0000000..b1d0977 --- /dev/null +++ b/src/train_optics_char_gpt2_nokoef_128.py @@ -0,0 +1,471 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +from einops import rearrange +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime +import sys +from pathlib import Path +import shutil +seed = 1337 +torch.manual_seed(seed) + +batch_size = 50 +gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient +max_iters = int(4e4) #40000 +eval_interval = 300 +learning_rate = 1e-3 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +eval_iters = 200 +layers_num = 2 +h_dim = 64 +max_seq_len = 128 +num_heads = 1 +dropout_rate = 0.1 +pixel_size = 3.6e-6 +assert batch_size % gradient_accumulation_steps == 0 + +############################### MODEL ############################################################# + +def new_formula(sim, tensor_1, tensor_2): + tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1 + tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2 + device = tensor_1.device + + A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0) + A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0) + B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0) + B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0) + + max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений + max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений + max_B_pos = torch.max(B_pos) + max_B_neg = torch.max(B_neg) + + zero_template = torch.zeros_like( + torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3])) + + if max_A_pos > 0 and max_B_pos > 0: + t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos + else: + t1 = zero_template.clone().to(device) + + if max_A_pos > 0 and max_B_neg > 0: + t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg + else: + t2 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_pos > 0: + t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos + else: + t3 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_neg > 0: + t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg + else: + t4 = zero_template.clone().to(device) + + return (t1 - t2 - t3 + t4)[0,:,:,:] + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = nn.Linear(h_dim, 4*h_dim) + self.ff2 = nn.Linear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5) + tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) + scores = scores.masked_fill(tril == 0, float('-inf')) + attention = nn.functional.softmax(scores, dim=2) + output = new_formula(self.sim_output, attention, v) + return self.o_proj(self.gather_heads(output, *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +class OpticGPT2(nn.Module): + def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, + pixel_size = 3.6e-6): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + if max_seq_len < 512: + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=False) + ) + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=False) + ) + if max_seq_len >= 512: + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2, + trainable_cylind_lens=False) + ) + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2, + trainable_cylind_lens=False) + ) + + self.layers = nn.ModuleList([ + TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads, + dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) + for _ in range(layers_num)]) + self.tok_embeds = nn.Embedding(vocab_size, h_dim) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + self.lm_head = nn.Linear(h_dim, vocab_size) + + def forward(self, x, targets=None): + x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + logits = self.lm_head(x) # B,T,C + loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None + return logits, loss + + # what is the purpose? autoregressive inference! + def generate(self, start_idx, max_new_tokens): + idx = start_idx + for i in range(max_new_tokens): + idx_cond = idx[:,-self.max_seq_len:] + logits, loss = self(idx_cond) + logits = logits[:,-1,:] # B, C + probs = F.softmax(logits, dim=-1) + idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) + idx = torch.cat([idx, idx_next], dim=1) + return idx + +################################################################################################### + +MODEL_CLASS = OpticGPT2 +train_data_path = Path("./data/wiki.train.tokens") +val_data_path = Path("./data/wiki.valid.tokens") +test_data_path = Path("./data/wiki.test.tokens") +comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}" + +logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +writer = SummaryWriter(logs_dir) +script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name) +print("Logs dir:", logs_dir) +# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script +shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository +script_snapshot_path.chmod(0o500) # with read-only permission + +# Create standalone checkpoints directory with your specified format +checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +Path(checkpoints_dir).mkdir(parents=True, exist_ok=True) +print("Checkpoints dir:", checkpoints_dir) + +#################################### Dataset ######################################### + +# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt +with open(train_data_path, encoding='utf-8') as f: + train_text = f.read() + +with open(val_data_path, encoding='utf-8') as f: + val_text = f.read() + +with open(test_data_path, encoding='utf-8') as f: + test_text = f.read() + +text = train_text + val_text + test_text + +chars = sorted(set(text)) +vocab_size = len(chars) + +wtoi = {w:i for i,w in enumerate(chars)} +itow = {i:w for i,w in enumerate(chars)} + +encode = lambda s: [wtoi[w] for w in s] +decode = lambda idx: ''.join([itow[i] for i in idx]) + +def get_batch(data, seq_len, batch_size): + ix = torch.randint(len(data)-seq_len, (batch_size, )) + x = torch.stack([data[i:i+seq_len] for i in ix]).to(device) + y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device) + return x, y + +train_data = torch.tensor(encode(train_text), dtype=torch.long) +val_data = torch.tensor(encode(val_text), dtype=torch.long) +test_data = torch.tensor(encode(test_text), dtype=torch.long) + +@torch.no_grad() +def perplexity(model, data, batch_size=32): + model.eval() + stride = max(1, len(data) // 10000) + total_loss_sum = 0.0 + total_tokens_count = 0 + + # Precompute all valid start positions + start_positions = list(range(0, len(data) - max_seq_len - 1, stride)) + total_sequences = len(start_positions) + + # Process sequences in batches + for i in range(0, total_sequences, batch_size): + batch_starts = start_positions[i:min(i + batch_size, total_sequences)] + + # Efficiently stack sequences into batch tensors + x_batch = torch.stack([ + data[start:start + max_seq_len] + for start in batch_starts + ]).to(device) + + y_batch = torch.stack([ + data[start + 1:start + max_seq_len + 1] + for start in batch_starts + ]).to(device) + + # Forward pass (model should return mean loss averaged over all tokens in batch) + _, mean_loss = model(x_batch, y_batch) + + # Accumulate weighted loss (mean_loss is already averaged over tokens) + num_tokens = y_batch.numel() + total_loss_sum += mean_loss.item() * num_tokens + total_tokens_count += num_tokens + + # Progress update + processed = min(i + batch_size, total_sequences) + print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True) + + print() # Final newline + return np.exp(total_loss_sum / total_tokens_count) + +#################################### Model #########################################mo +def complete(m, start_idxs=[0], max_new_tokens=100): + start_idx = torch.tensor([start_idxs]).to(device) + generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens) + return decode(generated_tokens[0].tolist()) + +m = MODEL_CLASS( + vocab_size=vocab_size, + h_dim=h_dim, + max_seq_len=max_seq_len, + num_heads=num_heads, + pixel_size=pixel_size, + layers_num=layers_num +) +m = m.to(device) +model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}' +writer.add_text('model', model_description, 0) + +print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.") +print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.") + +optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) + +#################################### Checkpoint Function ######################################### + +def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir): + """Save model checkpoint with complete training state""" + checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt' + torch.save({ + 'step': step, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss, + 'config': config, + 'wtoi': wtoi, + 'itow': itow, + }, checkpoint_path) + +# Training config for checkpointing +training_config = { + 'vocab_size': vocab_size, + 'layers_num': layers_num, + 'h_dim': h_dim, + 'max_seq_len': max_seq_len, + 'num_heads': num_heads, + 'dropout_rate': dropout_rate, + 'batch_size': batch_size, + 'learning_rate': learning_rate, + 'gradient_accumulation_steps': gradient_accumulation_steps, + 'pixel_size': pixel_size, + 'max_iters': max_iters, +} +#################################### Train ######################################### + +start_time = datetime.now() +print("Started at:", start_time) + +m.eval() +task_prompts = [ + "1 2 3 4 5", + "The capital of France is", + "The chemical symbol of gold is", + "If yesterday was Friday, then tomorrow will be", + "The opposite of hot is", + "The planets of the solar system are:", + "My favorite color is", + "If 5*x + 3 = 13, then x is", +] +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, 0) + +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, 0) + +for i in range(max_iters): + m.train() + optimizer.zero_grad(set_to_none=True) + accumulated_loss = 0.0 + for j in range(gradient_accumulation_steps): + xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps) + logits, loss = m(xb, yb) + loss = loss / gradient_accumulation_steps + loss.backward() + accumulated_loss += loss.item() + if i % 100 == 0: + writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i) + optimizer.step() + writer.add_scalar('loss', accumulated_loss, i) + print(f"\r{i}/{max_iters} {accumulated_loss}", end="") + if i % 5000 == 0: + m.eval() + ppl = perplexity(model=m, data=val_data, batch_size=batch_size) + writer.add_scalar('val_perplexity', ppl.item(), i) + print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}") + writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) + task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) + writer.add_text('completions/task', task_results, i) + save_checkpoint( + model=m, + optimizer=optimizer, + step=i, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir + ) + +m.eval() +ppl = perplexity(model=m, data=val_data, batch_size=batch_size) +print(f"\r{i+1}/{max_iters} {accumulated_loss}") +print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") +writer.add_scalar('val_perplexity', ppl.item(), i+1) +writer.add_scalar('loss', accumulated_loss, i) + +ppl = perplexity(model=m, data=test_data, batch_size=batch_size) +writer.add_scalar('test_perplexity', ppl.item(), i+1) +print(f"\rTest Perplexity at {i}: {ppl}") + +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, i+1) +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, i+1) + +# Save final checkpoint +save_checkpoint( + model=m, + optimizer=optimizer, + step=max_iters, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir +) +print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}") \ No newline at end of file diff --git a/src/train_optics_char_gpt2_nokoef_256.py b/src/train_optics_char_gpt2_nokoef_256.py new file mode 100644 index 0000000..905408c --- /dev/null +++ b/src/train_optics_char_gpt2_nokoef_256.py @@ -0,0 +1,471 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +from einops import rearrange +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime +import sys +from pathlib import Path +import shutil +seed = 1337 +torch.manual_seed(seed) + +batch_size = 50 +gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient +max_iters = int(4e4) #40000 +eval_interval = 300 +learning_rate = 1e-3 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +eval_iters = 200 +layers_num = 2 +h_dim = 64 +max_seq_len = 256 +num_heads = 1 +dropout_rate = 0.1 +pixel_size = 3.6e-6 +assert batch_size % gradient_accumulation_steps == 0 + +############################### MODEL ############################################################# + +def new_formula(sim, tensor_1, tensor_2): + tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1 + tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2 + device = tensor_1.device + + A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0) + A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0) + B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0) + B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0) + + max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений + max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений + max_B_pos = torch.max(B_pos) + max_B_neg = torch.max(B_neg) + + zero_template = torch.zeros_like( + torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3])) + + if max_A_pos > 0 and max_B_pos > 0: + t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos + else: + t1 = zero_template.clone().to(device) + + if max_A_pos > 0 and max_B_neg > 0: + t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg + else: + t2 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_pos > 0: + t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos + else: + t3 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_neg > 0: + t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg + else: + t4 = zero_template.clone().to(device) + + return (t1 - t2 - t3 + t4)[0,:,:,:] + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = nn.Linear(h_dim, 4*h_dim) + self.ff2 = nn.Linear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5) + tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) + scores = scores.masked_fill(tril == 0, float('-inf')) + attention = nn.functional.softmax(scores, dim=2) + output = new_formula(self.sim_output, attention, v) + return self.o_proj(self.gather_heads(output, *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +class OpticGPT2(nn.Module): + def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, + pixel_size = 3.6e-6): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + if max_seq_len < 512: + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=False) + ) + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=False) + ) + if max_seq_len >= 512: + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2, + trainable_cylind_lens=False) + ) + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2, + trainable_cylind_lens=False) + ) + + self.layers = nn.ModuleList([ + TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads, + dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) + for _ in range(layers_num)]) + self.tok_embeds = nn.Embedding(vocab_size, h_dim) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + self.lm_head = nn.Linear(h_dim, vocab_size) + + def forward(self, x, targets=None): + x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + logits = self.lm_head(x) # B,T,C + loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None + return logits, loss + + # what is the purpose? autoregressive inference! + def generate(self, start_idx, max_new_tokens): + idx = start_idx + for i in range(max_new_tokens): + idx_cond = idx[:,-self.max_seq_len:] + logits, loss = self(idx_cond) + logits = logits[:,-1,:] # B, C + probs = F.softmax(logits, dim=-1) + idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) + idx = torch.cat([idx, idx_next], dim=1) + return idx + +################################################################################################### + +MODEL_CLASS = OpticGPT2 +train_data_path = Path("./data/wiki.train.tokens") +val_data_path = Path("./data/wiki.valid.tokens") +test_data_path = Path("./data/wiki.test.tokens") +comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}" + +logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +writer = SummaryWriter(logs_dir) +script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name) +print("Logs dir:", logs_dir) +# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script +shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository +script_snapshot_path.chmod(0o500) # with read-only permission + +# Create standalone checkpoints directory with your specified format +checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +Path(checkpoints_dir).mkdir(parents=True, exist_ok=True) +print("Checkpoints dir:", checkpoints_dir) + +#################################### Dataset ######################################### + +# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt +with open(train_data_path, encoding='utf-8') as f: + train_text = f.read() + +with open(val_data_path, encoding='utf-8') as f: + val_text = f.read() + +with open(test_data_path, encoding='utf-8') as f: + test_text = f.read() + +text = train_text + val_text + test_text + +chars = sorted(set(text)) +vocab_size = len(chars) + +wtoi = {w:i for i,w in enumerate(chars)} +itow = {i:w for i,w in enumerate(chars)} + +encode = lambda s: [wtoi[w] for w in s] +decode = lambda idx: ''.join([itow[i] for i in idx]) + +def get_batch(data, seq_len, batch_size): + ix = torch.randint(len(data)-seq_len, (batch_size, )) + x = torch.stack([data[i:i+seq_len] for i in ix]).to(device) + y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device) + return x, y + +train_data = torch.tensor(encode(train_text), dtype=torch.long) +val_data = torch.tensor(encode(val_text), dtype=torch.long) +test_data = torch.tensor(encode(test_text), dtype=torch.long) + +@torch.no_grad() +def perplexity(model, data, batch_size=32): + model.eval() + stride = max(1, len(data) // 10000) + total_loss_sum = 0.0 + total_tokens_count = 0 + + # Precompute all valid start positions + start_positions = list(range(0, len(data) - max_seq_len - 1, stride)) + total_sequences = len(start_positions) + + # Process sequences in batches + for i in range(0, total_sequences, batch_size): + batch_starts = start_positions[i:min(i + batch_size, total_sequences)] + + # Efficiently stack sequences into batch tensors + x_batch = torch.stack([ + data[start:start + max_seq_len] + for start in batch_starts + ]).to(device) + + y_batch = torch.stack([ + data[start + 1:start + max_seq_len + 1] + for start in batch_starts + ]).to(device) + + # Forward pass (model should return mean loss averaged over all tokens in batch) + _, mean_loss = model(x_batch, y_batch) + + # Accumulate weighted loss (mean_loss is already averaged over tokens) + num_tokens = y_batch.numel() + total_loss_sum += mean_loss.item() * num_tokens + total_tokens_count += num_tokens + + # Progress update + processed = min(i + batch_size, total_sequences) + print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True) + + print() # Final newline + return np.exp(total_loss_sum / total_tokens_count) + +#################################### Model #########################################mo +def complete(m, start_idxs=[0], max_new_tokens=100): + start_idx = torch.tensor([start_idxs]).to(device) + generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens) + return decode(generated_tokens[0].tolist()) + +m = MODEL_CLASS( + vocab_size=vocab_size, + h_dim=h_dim, + max_seq_len=max_seq_len, + num_heads=num_heads, + pixel_size=pixel_size, + layers_num=layers_num +) +m = m.to(device) +model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}' +writer.add_text('model', model_description, 0) + +print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.") +print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.") + +optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) + +#################################### Checkpoint Function ######################################### + +def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir): + """Save model checkpoint with complete training state""" + checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt' + torch.save({ + 'step': step, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss, + 'config': config, + 'wtoi': wtoi, + 'itow': itow, + }, checkpoint_path) + +# Training config for checkpointing +training_config = { + 'vocab_size': vocab_size, + 'layers_num': layers_num, + 'h_dim': h_dim, + 'max_seq_len': max_seq_len, + 'num_heads': num_heads, + 'dropout_rate': dropout_rate, + 'batch_size': batch_size, + 'learning_rate': learning_rate, + 'gradient_accumulation_steps': gradient_accumulation_steps, + 'pixel_size': pixel_size, + 'max_iters': max_iters, +} +#################################### Train ######################################### + +start_time = datetime.now() +print("Started at:", start_time) + +m.eval() +task_prompts = [ + "1 2 3 4 5", + "The capital of France is", + "The chemical symbol of gold is", + "If yesterday was Friday, then tomorrow will be", + "The opposite of hot is", + "The planets of the solar system are:", + "My favorite color is", + "If 5*x + 3 = 13, then x is", +] +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, 0) + +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, 0) + +for i in range(max_iters): + m.train() + optimizer.zero_grad(set_to_none=True) + accumulated_loss = 0.0 + for j in range(gradient_accumulation_steps): + xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps) + logits, loss = m(xb, yb) + loss = loss / gradient_accumulation_steps + loss.backward() + accumulated_loss += loss.item() + if i % 100 == 0: + writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i) + optimizer.step() + writer.add_scalar('loss', accumulated_loss, i) + print(f"\r{i}/{max_iters} {accumulated_loss}", end="") + if i % 5000 == 0: + m.eval() + ppl = perplexity(model=m, data=val_data, batch_size=batch_size) + writer.add_scalar('val_perplexity', ppl.item(), i) + print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}") + writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) + task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) + writer.add_text('completions/task', task_results, i) + save_checkpoint( + model=m, + optimizer=optimizer, + step=i, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir + ) + +m.eval() +ppl = perplexity(model=m, data=val_data, batch_size=batch_size) +print(f"\r{i+1}/{max_iters} {accumulated_loss}") +print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") +writer.add_scalar('val_perplexity', ppl.item(), i+1) +writer.add_scalar('loss', accumulated_loss, i) + +ppl = perplexity(model=m, data=test_data, batch_size=batch_size) +writer.add_scalar('test_perplexity', ppl.item(), i+1) +print(f"\rTest Perplexity at {i}: {ppl}") + +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, i+1) +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, i+1) + +# Save final checkpoint +save_checkpoint( + model=m, + optimizer=optimizer, + step=max_iters, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir +) +print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}") \ No newline at end of file diff --git a/src/train_optics_char_gpt2_nokoef_512.py b/src/train_optics_char_gpt2_nokoef_512.py new file mode 100644 index 0000000..bd9b26c --- /dev/null +++ b/src/train_optics_char_gpt2_nokoef_512.py @@ -0,0 +1,471 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +from einops import rearrange +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime +import sys +from pathlib import Path +import shutil +seed = 1337 +torch.manual_seed(seed) + +batch_size = 50 +gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient +max_iters = int(4e4) #40000 +eval_interval = 300 +learning_rate = 1e-3 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +eval_iters = 200 +layers_num = 2 +h_dim = 64 +max_seq_len = 512 +num_heads = 1 +dropout_rate = 0.1 +pixel_size = 3.6e-6 +assert batch_size % gradient_accumulation_steps == 0 + +############################### MODEL ############################################################# + +def new_formula(sim, tensor_1, tensor_2): + tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1 + tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2 + device = tensor_1.device + + A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0) + A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0) + B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0) + B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0) + + max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений + max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений + max_B_pos = torch.max(B_pos) + max_B_neg = torch.max(B_neg) + + zero_template = torch.zeros_like( + torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3])) + + if max_A_pos > 0 and max_B_pos > 0: + t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos + else: + t1 = zero_template.clone().to(device) + + if max_A_pos > 0 and max_B_neg > 0: + t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg + else: + t2 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_pos > 0: + t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos + else: + t3 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_neg > 0: + t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg + else: + t4 = zero_template.clone().to(device) + + return (t1 - t2 - t3 + t4)[0,:,:,:] + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = nn.Linear(h_dim, 4*h_dim) + self.ff2 = nn.Linear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5) + tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) + scores = scores.masked_fill(tril == 0, float('-inf')) + attention = nn.functional.softmax(scores, dim=2) + output = new_formula(self.sim_output, attention, v) + return self.o_proj(self.gather_heads(output, *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +class OpticGPT2(nn.Module): + def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, + pixel_size = 3.6e-6): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + if max_seq_len < 512: + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=False) + ) + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=False) + ) + if max_seq_len >= 512: + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2, + trainable_cylind_lens=False) + ) + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2, + trainable_cylind_lens=False) + ) + + self.layers = nn.ModuleList([ + TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads, + dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) + for _ in range(layers_num)]) + self.tok_embeds = nn.Embedding(vocab_size, h_dim) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + self.lm_head = nn.Linear(h_dim, vocab_size) + + def forward(self, x, targets=None): + x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + logits = self.lm_head(x) # B,T,C + loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None + return logits, loss + + # what is the purpose? autoregressive inference! + def generate(self, start_idx, max_new_tokens): + idx = start_idx + for i in range(max_new_tokens): + idx_cond = idx[:,-self.max_seq_len:] + logits, loss = self(idx_cond) + logits = logits[:,-1,:] # B, C + probs = F.softmax(logits, dim=-1) + idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) + idx = torch.cat([idx, idx_next], dim=1) + return idx + +################################################################################################### + +MODEL_CLASS = OpticGPT2 +train_data_path = Path("./data/wiki.train.tokens") +val_data_path = Path("./data/wiki.valid.tokens") +test_data_path = Path("./data/wiki.test.tokens") +comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}" + +logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +writer = SummaryWriter(logs_dir) +script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name) +print("Logs dir:", logs_dir) +# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script +shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository +script_snapshot_path.chmod(0o500) # with read-only permission + +# Create standalone checkpoints directory with your specified format +checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +Path(checkpoints_dir).mkdir(parents=True, exist_ok=True) +print("Checkpoints dir:", checkpoints_dir) + +#################################### Dataset ######################################### + +# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt +with open(train_data_path, encoding='utf-8') as f: + train_text = f.read() + +with open(val_data_path, encoding='utf-8') as f: + val_text = f.read() + +with open(test_data_path, encoding='utf-8') as f: + test_text = f.read() + +text = train_text + val_text + test_text + +chars = sorted(set(text)) +vocab_size = len(chars) + +wtoi = {w:i for i,w in enumerate(chars)} +itow = {i:w for i,w in enumerate(chars)} + +encode = lambda s: [wtoi[w] for w in s] +decode = lambda idx: ''.join([itow[i] for i in idx]) + +def get_batch(data, seq_len, batch_size): + ix = torch.randint(len(data)-seq_len, (batch_size, )) + x = torch.stack([data[i:i+seq_len] for i in ix]).to(device) + y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device) + return x, y + +train_data = torch.tensor(encode(train_text), dtype=torch.long) +val_data = torch.tensor(encode(val_text), dtype=torch.long) +test_data = torch.tensor(encode(test_text), dtype=torch.long) + +@torch.no_grad() +def perplexity(model, data, batch_size=32): + model.eval() + stride = max(1, len(data) // 10000) + total_loss_sum = 0.0 + total_tokens_count = 0 + + # Precompute all valid start positions + start_positions = list(range(0, len(data) - max_seq_len - 1, stride)) + total_sequences = len(start_positions) + + # Process sequences in batches + for i in range(0, total_sequences, batch_size): + batch_starts = start_positions[i:min(i + batch_size, total_sequences)] + + # Efficiently stack sequences into batch tensors + x_batch = torch.stack([ + data[start:start + max_seq_len] + for start in batch_starts + ]).to(device) + + y_batch = torch.stack([ + data[start + 1:start + max_seq_len + 1] + for start in batch_starts + ]).to(device) + + # Forward pass (model should return mean loss averaged over all tokens in batch) + _, mean_loss = model(x_batch, y_batch) + + # Accumulate weighted loss (mean_loss is already averaged over tokens) + num_tokens = y_batch.numel() + total_loss_sum += mean_loss.item() * num_tokens + total_tokens_count += num_tokens + + # Progress update + processed = min(i + batch_size, total_sequences) + print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True) + + print() # Final newline + return np.exp(total_loss_sum / total_tokens_count) + +#################################### Model #########################################mo +def complete(m, start_idxs=[0], max_new_tokens=100): + start_idx = torch.tensor([start_idxs]).to(device) + generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens) + return decode(generated_tokens[0].tolist()) + +m = MODEL_CLASS( + vocab_size=vocab_size, + h_dim=h_dim, + max_seq_len=max_seq_len, + num_heads=num_heads, + pixel_size=pixel_size, + layers_num=layers_num +) +m = m.to(device) +model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}' +writer.add_text('model', model_description, 0) + +print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.") +print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.") + +optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) + +#################################### Checkpoint Function ######################################### + +def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir): + """Save model checkpoint with complete training state""" + checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt' + torch.save({ + 'step': step, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss, + 'config': config, + 'wtoi': wtoi, + 'itow': itow, + }, checkpoint_path) + +# Training config for checkpointing +training_config = { + 'vocab_size': vocab_size, + 'layers_num': layers_num, + 'h_dim': h_dim, + 'max_seq_len': max_seq_len, + 'num_heads': num_heads, + 'dropout_rate': dropout_rate, + 'batch_size': batch_size, + 'learning_rate': learning_rate, + 'gradient_accumulation_steps': gradient_accumulation_steps, + 'pixel_size': pixel_size, + 'max_iters': max_iters, +} +#################################### Train ######################################### + +start_time = datetime.now() +print("Started at:", start_time) + +m.eval() +task_prompts = [ + "1 2 3 4 5", + "The capital of France is", + "The chemical symbol of gold is", + "If yesterday was Friday, then tomorrow will be", + "The opposite of hot is", + "The planets of the solar system are:", + "My favorite color is", + "If 5*x + 3 = 13, then x is", +] +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, 0) + +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, 0) + +for i in range(max_iters): + m.train() + optimizer.zero_grad(set_to_none=True) + accumulated_loss = 0.0 + for j in range(gradient_accumulation_steps): + xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps) + logits, loss = m(xb, yb) + loss = loss / gradient_accumulation_steps + loss.backward() + accumulated_loss += loss.item() + if i % 100 == 0: + writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i) + optimizer.step() + writer.add_scalar('loss', accumulated_loss, i) + print(f"\r{i}/{max_iters} {accumulated_loss}", end="") + if i % 5000 == 0: + m.eval() + ppl = perplexity(model=m, data=val_data, batch_size=batch_size) + writer.add_scalar('val_perplexity', ppl.item(), i) + print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}") + writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) + task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) + writer.add_text('completions/task', task_results, i) + save_checkpoint( + model=m, + optimizer=optimizer, + step=i, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir + ) + +m.eval() +ppl = perplexity(model=m, data=val_data, batch_size=batch_size) +print(f"\r{i+1}/{max_iters} {accumulated_loss}") +print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") +writer.add_scalar('val_perplexity', ppl.item(), i+1) +writer.add_scalar('loss', accumulated_loss, i) + +ppl = perplexity(model=m, data=test_data, batch_size=batch_size) +writer.add_scalar('test_perplexity', ppl.item(), i+1) +print(f"\rTest Perplexity at {i}: {ppl}") + +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, i+1) +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, i+1) + +# Save final checkpoint +save_checkpoint( + model=m, + optimizer=optimizer, + step=max_iters, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir +) +print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}") \ No newline at end of file diff --git a/src/optics_char_gpt2_nokoef_newf.py b/src/train_optics_char_gpt2_nokoef_64.py similarity index 51% rename from src/optics_char_gpt2_nokoef_newf.py rename to src/train_optics_char_gpt2_nokoef_64.py index d99f149..7dc825a 100644 --- a/src/optics_char_gpt2_nokoef_newf.py +++ b/src/train_optics_char_gpt2_nokoef_64.py @@ -2,16 +2,31 @@ import torch import torch.nn as nn from torch.nn import functional as F from einops import rearrange -import optical_matrix_multiplication as omm -from optical_matrix_multiplication import propagator import numpy as np from torch.utils.tensorboard import SummaryWriter from datetime import datetime -from pathlib import Path import sys -torch.manual_seed(1337) +from pathlib import Path +import shutil +seed = 1337 +torch.manual_seed(seed) + +batch_size = 50 +gradient_accumulation_steps = 1 # check this impl for correctness https://unsloth.ai/blog/gradient +max_iters = int(4e4) #40000 +eval_interval = 300 +learning_rate = 1e-3 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +eval_iters = 200 +layers_num = 2 +h_dim = 64 +max_seq_len = 64 +num_heads = 1 +dropout_rate = 0.1 +pixel_size = 3.6e-6 +assert batch_size % gradient_accumulation_steps == 0 -#################################### Model ######################################### +############################### MODEL ############################################################# def new_formula(sim, tensor_1, tensor_2): tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1 @@ -125,7 +140,7 @@ class TransformerLayer(nn.Module): x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) return x -class OpticGPT2NOKoefNewF(nn.Module): +class OpticGPT2(nn.Module): def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, pixel_size = 3.6e-6): super().__init__() @@ -190,8 +205,6 @@ class OpticGPT2NOKoefNewF(nn.Module): lens_size = 8192 * 2, trainable_cylind_lens=False) ) - self.sim_scores = omm.ScatterDataParallel(self.sim_scores) - self.sim_output = omm.ScatterDataParallel(self.sim_output) self.layers = nn.ModuleList([ TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads, @@ -219,4 +232,240 @@ class OpticGPT2NOKoefNewF(nn.Module): probs = F.softmax(logits, dim=-1) idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) idx = torch.cat([idx, idx_next], dim=1) - return idx \ No newline at end of file + return idx + +################################################################################################### + +MODEL_CLASS = OpticGPT2 +train_data_path = Path("./data/wiki.train.tokens") +val_data_path = Path("./data/wiki.valid.tokens") +test_data_path = Path("./data/wiki.test.tokens") +comment = f"{Path(__file__).stem}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[1] if len(sys.argv) > 1 else ''}" + +logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +writer = SummaryWriter(logs_dir) +script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name) +print("Logs dir:", logs_dir) +# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script +shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository +script_snapshot_path.chmod(0o500) # with read-only permission + +# Create standalone checkpoints directory with your specified format +checkpoints_dir = f'./checkpoints/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +Path(checkpoints_dir).mkdir(parents=True, exist_ok=True) +print("Checkpoints dir:", checkpoints_dir) + +#################################### Dataset ######################################### + +# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt +with open(train_data_path, encoding='utf-8') as f: + train_text = f.read() + +with open(val_data_path, encoding='utf-8') as f: + val_text = f.read() + +with open(test_data_path, encoding='utf-8') as f: + test_text = f.read() + +text = train_text + val_text + test_text + +chars = sorted(set(text)) +vocab_size = len(chars) + +wtoi = {w:i for i,w in enumerate(chars)} +itow = {i:w for i,w in enumerate(chars)} + +encode = lambda s: [wtoi[w] for w in s] +decode = lambda idx: ''.join([itow[i] for i in idx]) + +def get_batch(data, seq_len, batch_size): + ix = torch.randint(len(data)-seq_len, (batch_size, )) + x = torch.stack([data[i:i+seq_len] for i in ix]).to(device) + y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device) + return x, y + +train_data = torch.tensor(encode(train_text), dtype=torch.long) +val_data = torch.tensor(encode(val_text), dtype=torch.long) +test_data = torch.tensor(encode(test_text), dtype=torch.long) + +@torch.no_grad() +def perplexity(model, data, batch_size=32): + model.eval() + stride = max(1, len(data) // 10000) + total_loss_sum = 0.0 + total_tokens_count = 0 + + # Precompute all valid start positions + start_positions = list(range(0, len(data) - max_seq_len - 1, stride)) + total_sequences = len(start_positions) + + # Process sequences in batches + for i in range(0, total_sequences, batch_size): + batch_starts = start_positions[i:min(i + batch_size, total_sequences)] + + # Efficiently stack sequences into batch tensors + x_batch = torch.stack([ + data[start:start + max_seq_len] + for start in batch_starts + ]).to(device) + + y_batch = torch.stack([ + data[start + 1:start + max_seq_len + 1] + for start in batch_starts + ]).to(device) + + # Forward pass (model should return mean loss averaged over all tokens in batch) + _, mean_loss = model(x_batch, y_batch) + + # Accumulate weighted loss (mean_loss is already averaged over tokens) + num_tokens = y_batch.numel() + total_loss_sum += mean_loss.item() * num_tokens + total_tokens_count += num_tokens + + # Progress update + processed = min(i + batch_size, total_sequences) + print(f"\rppl {processed}/{total_sequences} ({processed/total_sequences*100:.1f}%)", end="", flush=True) + + print() # Final newline + return np.exp(total_loss_sum / total_tokens_count) + +#################################### Model #########################################mo +def complete(m, start_idxs=[0], max_new_tokens=100): + start_idx = torch.tensor([start_idxs]).to(device) + generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens) + return decode(generated_tokens[0].tolist()) + +m = MODEL_CLASS( + vocab_size=vocab_size, + h_dim=h_dim, + max_seq_len=max_seq_len, + num_heads=num_heads, + pixel_size=pixel_size, + layers_num=layers_num +) +m = m.to(device) +model_description = str(m) + f'\nParameters count - {sum(p.numel() for p in m.parameters())}' +writer.add_text('model', model_description, 0) + +print(f"{sum(p.numel() for p in m.parameters()) * 8} minimum number of tokens to train model.") +print(f"{(sum(p.numel() for p in m.parameters()) * 8)//(batch_size)} minimum number of iterations to train this model.") + +optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) + +#################################### Checkpoint Function ######################################### + +def save_checkpoint(model, optimizer, step, loss, config, wtoi, itow, checkpoint_dir): + """Save model checkpoint with complete training state""" + checkpoint_path = Path(checkpoint_dir) / f'checkpoint_{step:07d}.pt' + torch.save({ + 'step': step, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss, + 'config': config, + 'wtoi': wtoi, + 'itow': itow, + }, checkpoint_path) + +# Training config for checkpointing +training_config = { + 'vocab_size': vocab_size, + 'layers_num': layers_num, + 'h_dim': h_dim, + 'max_seq_len': max_seq_len, + 'num_heads': num_heads, + 'dropout_rate': dropout_rate, + 'batch_size': batch_size, + 'learning_rate': learning_rate, + 'gradient_accumulation_steps': gradient_accumulation_steps, + 'pixel_size': pixel_size, + 'max_iters': max_iters, +} +#################################### Train ######################################### + +start_time = datetime.now() +print("Started at:", start_time) + +m.eval() +task_prompts = [ + "1 2 3 4 5", + "The capital of France is", + "The chemical symbol of gold is", + "If yesterday was Friday, then tomorrow will be", + "The opposite of hot is", + "The planets of the solar system are:", + "My favorite color is", + "If 5*x + 3 = 13, then x is", +] +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, 0) + +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, 0) + +for i in range(max_iters): + m.train() + optimizer.zero_grad(set_to_none=True) + accumulated_loss = 0.0 + for j in range(gradient_accumulation_steps): + xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps) + logits, loss = m(xb, yb) + loss = loss / gradient_accumulation_steps + loss.backward() + accumulated_loss += loss.item() + if i % 100 == 0: + writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i) + optimizer.step() + writer.add_scalar('loss', accumulated_loss, i) + print(f"\r{i}/{max_iters} {accumulated_loss}", end="") + if i % 5000 == 0: + m.eval() + ppl = perplexity(model=m, data=val_data, batch_size=batch_size) + writer.add_scalar('val_perplexity', ppl.item(), i) + print(f"\r{datetime.now() - start_time} Perplexity at {i}: {ppl}") + writer.add_text('completions', complete(m, encode("\n\n"), max_seq_len), i) + task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) + writer.add_text('completions/task', task_results, i) + save_checkpoint( + model=m, + optimizer=optimizer, + step=i, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir + ) + +m.eval() +ppl = perplexity(model=m, data=val_data, batch_size=batch_size) +print(f"\r{i+1}/{max_iters} {accumulated_loss}") +print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") +writer.add_scalar('val_perplexity', ppl.item(), i+1) +writer.add_scalar('loss', accumulated_loss, i) + +ppl = perplexity(model=m, data=test_data, batch_size=batch_size) +writer.add_scalar('test_perplexity', ppl.item(), i+1) +print(f"\rTest Perplexity at {i}: {ppl}") + +completion = complete(m, encode("\n\n"), max_seq_len) +print(completion) +writer.add_text('completions', completion, i+1) +task_results = "\n".join([complete(m, encode(task_prompt), 32) for task_prompt in task_prompts]) +print(task_results) +writer.add_text('completions/task', task_results, i+1) + +# Save final checkpoint +save_checkpoint( + model=m, + optimizer=optimizer, + step=max_iters, + loss=accumulated_loss, + config=training_config, + wtoi=wtoi, + itow=itow, + checkpoint_dir=checkpoints_dir +) +print(f"\n✓ Training complete. Final checkpoint saved to {checkpoints_dir}") \ No newline at end of file From 6b79a541ec14fa8f8602d544dd9ff0c42d5877cd Mon Sep 17 00:00:00 2001 From: Vladimir Protsenko Date: Sat, 14 Feb 2026 18:08:27 +0000 Subject: [PATCH 9/9] scripts version --- src/char_gpt2_scaledmatmul.py | 117 ---------------------------------- 1 file changed, 117 deletions(-) delete mode 100644 src/char_gpt2_scaledmatmul.py diff --git a/src/char_gpt2_scaledmatmul.py b/src/char_gpt2_scaledmatmul.py deleted file mode 100644 index 25bb992..0000000 --- a/src/char_gpt2_scaledmatmul.py +++ /dev/null @@ -1,117 +0,0 @@ -import torch -import torch.nn as nn -from torch.nn import functional as F -from einops import rearrange -import numpy as np -from torch.utils.tensorboard import SummaryWriter -from datetime import datetime -import sys -from pathlib import Path -torch.manual_seed(1337) - -#################################### Model ######################################### - - -# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 -class RoPE(nn.Module): - def __init__(self, dim, max_seq_len=512): - super().__init__() - inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) - t = torch.arange(max_seq_len).type_as(inv_freq) - freqs = torch.einsum('i,j->ij', t, inv_freq) - self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) - - def rotate_half(self, x): - x1, x2 = x.chunk(2, dim=-1) - return torch.cat((-x2, x1), dim=-1) - - def forward(self, x, offset=0): - seq_len = x.size(1) - emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) - cos = emb.cos() - sin = emb.sin() - return (x * cos) + (self.rotate_half(x) * sin) - -# Transformers without Normalization https://jiachenzhu.github.io/DyT/ -class DyT(nn.Module): - def __init__(self, num_features, alpha_init_value=0.5): - super().__init__() - self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) - self.weight = nn.Parameter(torch.ones(num_features)) - self.bias = nn.Parameter(torch.zeros(num_features)) - - def forward(self, x): - x = torch.tanh(self.alpha * x) - return x * self.weight + self.bias - -# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 -# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 -class TransformerLayer(nn.Module): - def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): - super().__init__() - self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) - self.q_proj = nn.Linear(h_dim, h_dim) - self.k_proj = nn.Linear(h_dim, h_dim) - self.v_proj = nn.Linear(h_dim, h_dim) - self.o_proj = nn.Linear(h_dim, h_dim) - self.ff1 = nn.Linear(h_dim, 4*h_dim) - self.ff2 = nn.Linear(4*h_dim, h_dim) - self.ln1 = DyT(h_dim) - self.ln2 = DyT(h_dim) - self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) - self.k1 = nn.Parameter(torch.randn(1)) - self.k2 = nn.Parameter(torch.randn(1)) - - def split_to_heads(self, x, B, T, H): - if self.num_heads <= 1: return x - return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) - - def gather_heads(self, x, B, T, H): - if self.num_heads <= 1: return x - return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) - - def attention(self, x): - q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) - k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) - v = self.split_to_heads(self.v_proj(x), *x.shape) - scores = self.k1 * (q @ k.transpose(1, 2)) * (self.h_dim ** -0.5) - tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) - scores = scores.masked_fill(tril == 0, float('-inf')) # encoder does not need this line - attention = nn.functional.softmax(scores, dim=2) - return self.o_proj(self.gather_heads(self.k2 * (attention @ v), *x.shape)) - - def forward(self, x): - x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) - x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) - return x - -class GPT2ScaledMM(nn.Module): - def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, pixel_size=None): - super().__init__() - self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) - self.layers = nn.ModuleList([ - TransformerLayer(h_dim=self.h_dim, num_heads=self.num_heads, dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) - for _ in range(layers_num)]) - self.tok_embeds = nn.Embedding(vocab_size, h_dim) - self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) - self.lm_head = nn.Linear(h_dim, vocab_size) - - def forward(self, x, targets=None): - x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] - for l in self.layers: - x = l(x) - logits = self.lm_head(x) # B,T,C - loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None - return logits, loss - - # what is the purpose? autoregressive inference! - def generate(self, start_idx, max_new_tokens): - idx = start_idx - for i in range(max_new_tokens): - idx_cond = idx[:,-self.max_seq_len:] - logits, loss = self(idx_cond) - logits = logits[:,-1,:] # B, C - probs = F.softmax(logits, dim=-1) - idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) - idx = torch.cat([idx, idx_next], dim=1) - return idx \ No newline at end of file