diff --git a/src/char_gpt2_ff.py b/src/char_gpt2_scaledmatmul.py similarity index 94% rename from src/char_gpt2_ff.py rename to src/char_gpt2_scaledmatmul.py index c1cd2cb..25bb992 100644 --- a/src/char_gpt2_ff.py +++ b/src/char_gpt2_scaledmatmul.py @@ -59,6 +59,8 @@ class TransformerLayer(nn.Module): self.ln1 = DyT(h_dim) self.ln2 = DyT(h_dim) self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + self.k1 = nn.Parameter(torch.randn(1)) + self.k2 = nn.Parameter(torch.randn(1)) def split_to_heads(self, x, B, T, H): if self.num_heads <= 1: return x @@ -72,18 +74,18 @@ class TransformerLayer(nn.Module): q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) v = self.split_to_heads(self.v_proj(x), *x.shape) - scores = (q @ k.transpose(1, 2)) * (self.h_dim ** -0.5) + scores = self.k1 * (q @ k.transpose(1, 2)) * (self.h_dim ** -0.5) tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) scores = scores.masked_fill(tril == 0, float('-inf')) # encoder does not need this line attention = nn.functional.softmax(scores, dim=2) - return self.o_proj(self.gather_heads(attention @ v, *x.shape)) + return self.o_proj(self.gather_heads(self.k2 * (attention @ v), *x.shape)) def forward(self, x): x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) return x -class GPT2(nn.Module): +class GPT2ScaledMM(nn.Module): def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, pixel_size=None): super().__init__() self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) diff --git a/src/main.py b/src/main.py index 24406f8..52fd104 100644 --- a/src/main.py +++ b/src/main.py @@ -11,41 +11,54 @@ from char_gpt2 import GPT2 from optics_char_gpt2 import OpticGPT2 from optics_char_gpt2_traindiag import OpticGPT2TrainDiag from optics_char_gpt2_ff import OpticGPT2FF +from optics_char_gpt2_new_formula import OpticGPT2NewFormula +from char_gpt2_scaledmatmul import GPT2ScaledMM +from optics_char_gpt2_nokoef import OpticGPT2NOKoef +from optics_char_gpt2_nokoef_newf import OpticGPT2NOKoefNewF +import shutil seed = 1337 torch.manual_seed(seed) -models = {'gpt2': GPT2, 'optic_gpt2': OpticGPT2, 'optic_gpt2_ff': OpticGPT2FF, 'optic_gpt2_traindiag':OpticGPT2TrainDiag} - -batch_size = 25 -max_iters = 40000*2 +models = { + 'gpt2': GPT2, + 'optic_gpt2': OpticGPT2, + 'optic_gpt2_ff': OpticGPT2FF, + 'optic_gpt2_traindiag': OpticGPT2TrainDiag, + 'optic_gpt2_newformula': OpticGPT2NewFormula, + 'optic_gpt2_nokoef': OpticGPT2NOKoef, + 'optic_gpt2_nokoef_newformula': OpticGPT2NOKoefNewF, + 'gpt2_scaledmm': GPT2ScaledMM +} + +batch_size = 50 +gradient_accumulation_steps = 2 # check this impl for correctness https://unsloth.ai/blog/gradient +max_iters = 40000 eval_interval = 300 learning_rate = 1e-3 device = 'cuda' if torch.cuda.is_available() else 'cpu' eval_iters = 200 layers_num = 2 h_dim = 64 -max_seq_len = 64 +max_seq_len = 256 num_heads = 1 dropout_rate = 0.1 pixel_size = 3.6e-6 +assert batch_size % gradient_accumulation_steps == 0 # CUDA_VISIBLE_DEVICES=1 python src/main.py optic_gpt2_ff ./data/wiki.train.tokens ./data/wiki.valid.tokens ./data/wiki.test.tokens seq_128_hdim_64 -# CUDA_VISIBLE_DEVICES=2 python src/main.py optic_gpt2_ff ./data/wiki.train.tokens ./data/wiki.valid.tokens ./data/wiki.test.tokens seq_128_hdim_128 -# CUDA_VISIBLE_DEVICES=3 python src/main.py optic_gpt2_ff ./data/wiki.train.tokens ./data/wiki.valid.tokens ./data/wiki.test.tokens seq_128_hdim_256 -# CUDA_VISIBLE_DEVICES=4 python src/main.py gpt2 ./data/wiki.train.tokens ./data/wiki.valid.tokens ./data/wiki.test.tokens seq_64_hdim_64 -# CUDA_VISIBLE_DEVICES=5 python src/main.py gpt2 ./data/wiki.train.tokens ./data/wiki.valid.tokens ./data/wiki.test.tokens seq_64_hdim_128 -# CUDA_VISIBLE_DEVICES=6 python src/main.py gpt2 ./data/wiki.train.tokens ./data/wiki.valid.tokens ./data/wiki.test.tokens seq_64_hdim_256 -# CUDA_VISIBLE_DEVICES=1 python .src/main.py gpt2|optic_gpt2 ./data/wiki.train.tokens ./data/wiki.valid.tokens ./data/wiki.test.tokens comment -MODEL_CLASS = models[sys.argv[1]] + +MODEL_CLASS = models[sys.argv[1]] train_data_path = Path(sys.argv[2]) val_data_path = Path(sys.argv[3]) test_data_path = Path(sys.argv[4]) -comment = f"{sys.argv[1]}_{train_data_path.name}_{sys.argv[5]}_{seed}" +comment = f"{sys.argv[1]}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[5] if len(sys.argv) >= 6 else ''}" logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' writer = SummaryWriter(logs_dir) -script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).name) +script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name) print("Logs dir:", logs_dir) -script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script -script_snapshot_path.chmod(0o400) # with read-only permission +# script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script +shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository +script_snapshot_path.chmod(0o500) # with read-only permission + #################################### Dataset ######################################### # wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt @@ -106,33 +119,45 @@ m = MODEL_CLASS( layers_num=layers_num ) m = m.to(device) +writer.add_text('model', str(m), 0) + #################################### Train ######################################### optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) +m.eval() completion = complete(m, encode("\n"*max_seq_len), 2*max_seq_len) print(completion) writer.add_text('completions', completion, 0) for i in range(max_iters): - xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size) - logits, loss = m(xb, yb) + m.train() optimizer.zero_grad(set_to_none=True) - loss.backward() + accumulated_loss = 0.0 + for j in range(gradient_accumulation_steps): + xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps) + logits, loss = m(xb, yb) + loss = loss / gradient_accumulation_steps + loss.backward() + accumulated_loss += loss.item() + if i % 100 == 0: + writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i) optimizer.step() - writer.add_scalar('loss', loss.item(), i) - print(f"\r{i}/{max_iters} {loss.item()}", end="") + writer.add_scalar('loss', accumulated_loss, i) + print(f"\r{i}/{max_iters} {accumulated_loss}", end="") if i % 5000 == 0: + m.eval() ppl = perplexity(model=m, data=val_data) writer.add_scalar('val_perplexity', ppl.item(), i) print(f"\rPerplexity at {i}: {ppl}") writer.add_text('completions', complete(m, encode("\n"*max_seq_len), 2*max_seq_len), i) +m.eval() ppl = perplexity(model=m, data=val_data) -print(f"\r{i+1}/{max_iters} {loss.item()}") -print(f"\rPerplexity at {i}: {ppl}") +print(f"\r{i+1}/{max_iters} {accumulated_loss}") +print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") writer.add_scalar('val_perplexity', ppl.item(), i+1) -writer.add_scalar('loss', loss.item(), i) +writer.add_scalar('loss', accumulated_loss, i) ppl = perplexity(model=m, data=test_data) writer.add_scalar('test_perplexity', ppl.item(), i+1) diff --git a/src/optical_matrix_multiplication/__init__.py b/src/optical_matrix_multiplication/__init__.py index 9974821..9a1844c 100644 --- a/src/optical_matrix_multiplication/__init__.py +++ b/src/optical_matrix_multiplication/__init__.py @@ -6,4 +6,5 @@ __version__ = "3.0.0" from .config import Config from . import propagator from .optical_mul import OpticalMul -from .parallel import DataParallel \ No newline at end of file +from .parallel import DataParallel +from .parallel import ScatterDataParallel \ No newline at end of file diff --git a/src/optical_matrix_multiplication/optical_mul.py b/src/optical_matrix_multiplication/optical_mul.py index cb3e398..473cd74 100644 --- a/src/optical_matrix_multiplication/optical_mul.py +++ b/src/optical_matrix_multiplication/optical_mul.py @@ -15,33 +15,23 @@ class OpticalMul(_nn.Module): config: конфигурация расчётной системы. """ super(OpticalMul, self).__init__() + self.trainable_cylind_lens = config._trainable_cylind_lens prop_one = _PropSinc(config.input_vector_plane, config.first_lens_plane, config) prop_two = _PropCrossLens(config.first_lens_plane, config) prop_three = _PropSinc(config.first_lens_plane, config.matrix_plane, config) - prop_four = _PropСylindLens(config.matrix_plane, config, trainable=config._trainable_cylind_lens) + prop_four = _PropСylindLens(config.matrix_plane, config, trainable=self.trainable_cylind_lens) prop_five = _PropSinc(config.matrix_plane, config.second_lens_plane, config) prop_six = _PropCrossLens(config.second_lens_plane, config).T prop_seven = _PropSinc(config.second_lens_plane, config.output_vector_plane, config) - # print(prop_one) - # print(prop_two) - # print(prop_three) - # print(prop_four) - # print(prop_five) - # print((prop_one + prop_two + prop_three)) - # print((prop_one + prop_two + prop_three + prop_four)) - - self._propagator_one: _Prop = prop_one + prop_two + prop_three - self._propagator_between = prop_four + if self.trainable_cylind_lens: + self._propagator_one: _Prop = prop_one + prop_two + prop_three + self._propagator_between = prop_four + else: + self._propagator_one: _Prop = prop_one + prop_two + prop_three + prop_four self._propagator_two: _Prop = prop_five + prop_six + prop_seven - # print(self._propagator_one) - # print(self._propagator_between) - # print(self._propagator_between.operator_X) - # print(self._propagator_between.operator_Y) - # print(self._propagator_two) - kron_vec_utils = _torch.ones((config.input_vector_split_y, config.input_vector_split_x)) kron_mat_utils = _torch.ones((config.matrix_split_x, config.matrix_split_y)) self.register_buffer('_kron_vec_utils', kron_vec_utils, persistent=True) @@ -125,8 +115,11 @@ class OpticalMul(_nn.Module): vec_field = self.prepare_vector(input) mat_field = self.prepare_matrix(other) - vec_field = self._propagator_one(vec_field) - vec_field = self._propagator_between(vec_field) + if self.trainable_cylind_lens: + vec_field = self._propagator_one(vec_field) + vec_field = self._propagator_between(vec_field) + else: + vec_field = self._propagator_one(vec_field) vec_field = self._propagator_two(vec_field * mat_field) return self.prepare_out(vec_field) \ No newline at end of file diff --git a/src/optical_matrix_multiplication/parallel.py b/src/optical_matrix_multiplication/parallel.py index 631a1cd..b707415 100644 --- a/src/optical_matrix_multiplication/parallel.py +++ b/src/optical_matrix_multiplication/parallel.py @@ -94,4 +94,71 @@ class DataParallel(_nn.Module): outputs = _nn.parallel.parallel_apply(replicas, stacked_input) - return _nn.parallel.gather(outputs, self.output_device, dim) \ No newline at end of file + return _nn.parallel.gather(outputs, self.output_device, dim) + +class ScatterDataParallel(_nn.Module): + """ + Оптимизированный DataParallel для работы с attention матрицами разных размеров. + Эквивалентно DataParallel от PyTorch? + """ + def __init__(self, module: _nn.Module, devices: Union[None, List[Union[int, _torch.device]]] = None, + output_device: Union[int, _torch.device] = None) -> None: + super(ScatterDataParallel, self).__init__() + + if not _torch.cuda.is_available(): + raise EnvironmentError("cuda is not available.") + + if not devices: + devices = [_torch.device(f'cuda:{x}') for x in range(_torch.cuda.device_count())] + + if not output_device: + output_device = devices[0] + + self.module = module + self.devices = devices + self.output_device = output_device + + def buffers(self, *inputs) -> Iterator[_torch.Tensor]: + return self.module.buffers(*inputs) + + def parameters(self, *inputs) -> Iterator[_nn.parameter.Parameter]: + return self.module.parameters(*inputs) + + def forward(self, input: _torch.Tensor, other: _torch.Tensor, **kwargs: Any) -> _torch.Tensor: + ''' + Оптимизированный forward для attention матриц. + + Особенности: + - Scatter по batch dimension (0) вместо произвольного dim + - Оба тензора scatter'ятся для согласованности размерностей + - Поддержка многомерных attention тензоров [batch, heads, seq, dim] + ''' + + # Определяем dimension для scatter на основе структуры тензоров + if input.dim() >= 3 and other.dim() >= 3: + # Для attention матриц scatter по batch dimension + scatter_dim = 0 + else: + # Для обычных 2D матриц используем dim из kwargs или по умолчанию 2 + scatter_dim = kwargs.get('dim', 2) + + # Подготовка модуля и данных + self.module = self.module.to(self.devices[0]) + + # Scatter ОБОИХ тензоров для согласованности размерностей + scattered_input = _nn.parallel.scatter(input, self.devices, scatter_dim) + scattered_other = _nn.parallel.scatter(other, self.devices, scatter_dim) + + # Создаем реплики модуля + replicas = _nn.parallel.replicate(self.module, self.devices) + + # Формируем входные данные для каждого устройства + # Убедимся, что все списки одинаковой длины + min_len = min(len(scattered_input), len(scattered_other), len(replicas)) + stacked_input = [(scattered_input[i], scattered_other[i]) for i in range(min_len)] + + # Параллельное вычисление + outputs = _nn.parallel.parallel_apply(replicas[:min_len], stacked_input) + + # Сбор результатов + return _nn.parallel.gather(outputs, self.output_device, scatter_dim) \ No newline at end of file diff --git a/src/optics_char_gpt2.py b/src/optics_char_gpt2.py index 1c68c7a..3989ed6 100644 --- a/src/optics_char_gpt2.py +++ b/src/optics_char_gpt2.py @@ -121,7 +121,37 @@ class OpticGPT2(nn.Module): pixel_size = 3.6e-6): super().__init__() self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) - self.sim_scores = omm.OpticalMul( + if max_seq_len < 512: + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=False) + ) + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=False) + ) + if max_seq_len >= 512: + self.sim_scores = omm.OpticalMul( omm.Config(right_matrix_count_columns = max_seq_len, right_matrix_count_rows = h_dim // num_heads, right_matrix_width = pixel_size * max_seq_len, @@ -132,10 +162,11 @@ class OpticGPT2(nn.Module): left_matrix_split_x = 2, left_matrix_split_y = 2, result_matrix_split = 2, - distance = 0.01) + distance = 0.15, + lens_size = 8192 * 2, + trainable_cylind_lens=False) ) - - self.sim_output = omm.OpticalMul( + self.sim_output = omm.OpticalMul( omm.Config(right_matrix_count_columns = h_dim // num_heads, right_matrix_count_rows = max_seq_len, right_matrix_width = pixel_size * (h_dim // num_heads), @@ -146,8 +177,11 @@ class OpticGPT2(nn.Module): left_matrix_split_x = 2, left_matrix_split_y = 2, result_matrix_split = 2, - distance = 0.01) + distance = 0.15, + lens_size = 8192 * 2, + trainable_cylind_lens=False) ) + self.layers = nn.ModuleList([ TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads, dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) diff --git a/src/optics_char_gpt2_ff.py b/src/optics_char_gpt2_ff.py index 2347bf9..59bbd93 100644 --- a/src/optics_char_gpt2_ff.py +++ b/src/optics_char_gpt2_ff.py @@ -96,7 +96,6 @@ class OpticLinear(nn.Module): self.weight = nn.Parameter( torch.empty((in_features, out_features), **factory_kwargs) ) - # print(self.weight.shape) if bias: self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs)) else: diff --git a/src/optics_char_gpt2_new_formula.py b/src/optics_char_gpt2_new_formula.py new file mode 100644 index 0000000..6b59105 --- /dev/null +++ b/src/optics_char_gpt2_new_formula.py @@ -0,0 +1,227 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +from einops import rearrange +import optical_matrix_multiplication as omm +from optical_matrix_multiplication import propagator +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime +from pathlib import Path +import sys +torch.manual_seed(1337) + +#################################### Model ######################################### + +# def norm(matrix: torch.Tensor, max_val: float = 1) -> torch.Tensor: +# return matrix / (max_val + 1e-10) + +def new_formula(sim, tensor_1, tensor_2): + tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1 + tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2 + device = tensor_1.device + + A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0) + A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0) + B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0) + B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0) + + max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений + max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений + max_B_pos = torch.max(B_pos) + max_B_neg = torch.max(B_neg) + + zero_template = torch.zeros_like( + torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3])) + + if max_A_pos > 0 and max_B_pos > 0: + t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos + else: + t1 = zero_template.clone().to(device) + + if max_A_pos > 0 and max_B_neg > 0: + t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg + else: + t2 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_pos > 0: + t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos + else: + t3 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_neg > 0: + t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg + else: + t4 = zero_template.clone().to(device) + + return (t1 - t2 - t3 + t4)[0,:,:,:] + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = nn.Linear(h_dim, 4*h_dim) + self.ff2 = nn.Linear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + self.k1 = nn.Parameter(torch.randn(1)) + self.k2 = nn.Parameter(torch.randn(1)) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = self.k1 * new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5) + tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) + scores = scores.masked_fill(tril == 0, float('-inf')) + attention = nn.functional.softmax(scores, dim=2) + output = self.k2 * new_formula(self.sim_output, attention, v) + return self.o_proj(self.gather_heads(output, *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +class OpticGPT2NewFormula(nn.Module): + def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, + pixel_size = 3.6e-6): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + if max_seq_len != 512: + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=False) + ) + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=False) + ) + if max_seq_len == 512: + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2, + trainable_cylind_lens=False) + ) + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2, + trainable_cylind_lens=False) + ) + self.sim_scores = omm.DataParallel(self.sim_scores) + self.sim_output = omm.DataParallel(self.sim_output) + + self.layers = nn.ModuleList([ + TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads, + dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) + for _ in range(layers_num)]) + self.tok_embeds = nn.Embedding(vocab_size, h_dim) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + self.lm_head = nn.Linear(h_dim, vocab_size) + + def forward(self, x, targets=None): + x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + logits = self.lm_head(x) # B,T,C + loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None + return logits, loss + + # what is the purpose? autoregressive inference! + def generate(self, start_idx, max_new_tokens): + idx = start_idx + for i in range(max_new_tokens): + idx_cond = idx[:,-self.max_seq_len:] + logits, loss = self(idx_cond) + logits = logits[:,-1,:] # B, C + probs = F.softmax(logits, dim=-1) + idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) + idx = torch.cat([idx, idx_next], dim=1) + return idx \ No newline at end of file diff --git a/src/optics_char_gpt2_nokoef.py b/src/optics_char_gpt2_nokoef.py new file mode 100644 index 0000000..af1e518 --- /dev/null +++ b/src/optics_char_gpt2_nokoef.py @@ -0,0 +1,209 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +from einops import rearrange +import optical_matrix_multiplication as omm +from optical_matrix_multiplication import propagator +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime +from pathlib import Path +import sys +torch.manual_seed(1337) + +#################################### Model ######################################### + +def norm(matrix: torch.Tensor, max_val: float = 1) -> torch.Tensor: + return matrix / (max_val + 1e-10) + +def optics_matmul_shift(sim, tensor_1, tensor_2): + tensor_1 = tensor_1[None,:,:,:] + tensor_2 = tensor_2[None,:,:,:] + + if torch.min(tensor_1) >= 0 and torch.min(tensor_2) >= 0: + max_abs = abs(max(torch.max(tensor_1), torch.max(tensor_2))) + a, b = norm(tensor_1, max_abs), norm(tensor_2, max_abs) + return sim(a, b)[0,:,:,:] * max_abs **2 + + min_abs = abs(min(torch.min(tensor_1), torch.min(tensor_2))) + max_abs = abs(max(torch.max(tensor_1), torch.max(tensor_2))) + min_abs + + shift_a = min_abs * torch.ones(tensor_1.shape).to(tensor_1.device) + shift_b = min_abs * torch.ones(tensor_2.shape).to(tensor_1.device) + a_a_sh = tensor_1 + shift_a + b_b_sh = tensor_2 + shift_b + + a_a_sh_norm, b_b_sh_norm = norm(a_a_sh, max_abs), norm(b_b_sh, max_abs) + shift_a_norm, shift_b_norm = norm(shift_a, max_abs), norm(shift_b, max_abs) + a_a_sh_b_b_sh = sim(a_a_sh_norm, b_b_sh_norm) + a_a_sh_b_sh = sim(a_a_sh_norm, shift_b_norm) + a_sh_b_b_sh = sim(shift_a_norm, b_b_sh_norm) + a_sh_b_sh = sim(shift_a_norm, shift_b_norm) + + return (a_a_sh_b_b_sh - a_a_sh_b_sh - a_sh_b_b_sh + a_sh_b_sh)[0,:,:,:] * max_abs ** 2 + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = nn.Linear(h_dim, 4*h_dim) + self.ff2 = nn.Linear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = optics_matmul_shift(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5) + tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) + scores = scores.masked_fill(tril == 0, float('-inf')) + attention = nn.functional.softmax(scores, dim=2) + output = optics_matmul_shift(self.sim_output, attention, v) + return self.o_proj(self.gather_heads(output, *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +class OpticGPT2NOKoef(nn.Module): + def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, + pixel_size = 3.6e-6): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + if max_seq_len < 512: + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=False) + ) + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=False) + ) + if max_seq_len >= 512: + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2, + trainable_cylind_lens=False) + ) + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2, + trainable_cylind_lens=False) + ) + + self.layers = nn.ModuleList([ + TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads, + dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) + for _ in range(layers_num)]) + self.tok_embeds = nn.Embedding(vocab_size, h_dim) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + self.lm_head = nn.Linear(h_dim, vocab_size) + + def forward(self, x, targets=None): + x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + logits = self.lm_head(x) # B,T,C + loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None + return logits, loss + + # what is the purpose? autoregressive inference! + def generate(self, start_idx, max_new_tokens): + idx = start_idx + for i in range(max_new_tokens): + idx_cond = idx[:,-self.max_seq_len:] + logits, loss = self(idx_cond) + logits = logits[:,-1,:] # B, C + probs = F.softmax(logits, dim=-1) + idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) + idx = torch.cat([idx, idx_next], dim=1) + return idx \ No newline at end of file diff --git a/src/optics_char_gpt2_nokoef_newf.py b/src/optics_char_gpt2_nokoef_newf.py new file mode 100644 index 0000000..f865c18 --- /dev/null +++ b/src/optics_char_gpt2_nokoef_newf.py @@ -0,0 +1,220 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +from einops import rearrange +import optical_matrix_multiplication as omm +from optical_matrix_multiplication import propagator +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime +from pathlib import Path +import sys +torch.manual_seed(1337) + +#################################### Model ######################################### + +def new_formula(sim, tensor_1, tensor_2): + tensor_1 = tensor_1[None,:,:,:] if len(tensor_1.shape) < 4 else tensor_1 + tensor_2 = tensor_2[None,:,:,:] if len(tensor_2.shape) < 4 else tensor_2 + device = tensor_1.device + + A_pos = torch.clamp(tensor_1, min=0) # A⁺ = max(A, 0) + A_neg = torch.clamp(-tensor_1, min=0) # A⁻ = max(-A, 0) + B_pos = torch.clamp(tensor_2, min=0) # B⁺ = max(B, 0) + B_neg = torch.clamp(-tensor_2, min=0) # B⁻ = max(-B, 0) + + max_A_pos = torch.max(A_pos) # Может быть 0, если нет положительных значений + max_A_neg = torch.max(A_neg) # Может быть 0, если нет отрицательных значений + max_B_pos = torch.max(B_pos) + max_B_neg = torch.max(B_neg) + + zero_template = torch.zeros_like( + torch.empty(tensor_1.shape[0],tensor_1.shape[1], tensor_1.shape[2], tensor_2.shape[3])) + + if max_A_pos > 0 and max_B_pos > 0: + t1 = sim(A_pos / max_A_pos, B_pos / max_B_pos) * max_A_pos * max_B_pos + else: + t1 = zero_template.clone().to(device) + + if max_A_pos > 0 and max_B_neg > 0: + t2 = sim(A_pos / max_A_pos, B_neg / max_B_neg) * max_A_pos * max_B_neg + else: + t2 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_pos > 0: + t3 = sim(A_neg / max_A_neg, B_pos / max_B_pos) * max_A_neg * max_B_pos + else: + t3 = zero_template.clone().to(device) + + if max_A_neg > 0 and max_B_neg > 0: + t4 = sim(A_neg / max_A_neg, B_neg / max_B_neg) * max_A_neg * max_B_neg + else: + t4 = zero_template.clone().to(device) + + return (t1 - t2 - t3 + t4)[0,:,:,:] + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = nn.Linear(h_dim, 4*h_dim) + self.ff2 = nn.Linear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = new_formula(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5) + tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) + scores = scores.masked_fill(tril == 0, float('-inf')) + attention = nn.functional.softmax(scores, dim=2) + output = new_formula(self.sim_output, attention, v) + return self.o_proj(self.gather_heads(output, *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +class OpticGPT2NOKoefNewF(nn.Module): + def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, + pixel_size = 3.6e-6): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + if max_seq_len < 512: + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=False) + ) + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=False) + ) + if max_seq_len >= 512: + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2, + trainable_cylind_lens=False) + ) + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2, + trainable_cylind_lens=False) + ) + + self.layers = nn.ModuleList([ + TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads, + dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) + for _ in range(layers_num)]) + self.tok_embeds = nn.Embedding(vocab_size, h_dim) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + self.lm_head = nn.Linear(h_dim, vocab_size) + + def forward(self, x, targets=None): + x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + logits = self.lm_head(x) # B,T,C + loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None + return logits, loss + + # what is the purpose? autoregressive inference! + def generate(self, start_idx, max_new_tokens): + idx = start_idx + for i in range(max_new_tokens): + idx_cond = idx[:,-self.max_seq_len:] + logits, loss = self(idx_cond) + logits = logits[:,-1,:] # B, C + probs = F.softmax(logits, dim=-1) + idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) + idx = torch.cat([idx, idx_next], dim=1) + return idx \ No newline at end of file diff --git a/src/optics_char_gpt2_traindiag.py b/src/optics_char_gpt2_traindiag.py index 2cdfc81..26de9d1 100644 --- a/src/optics_char_gpt2_traindiag.py +++ b/src/optics_char_gpt2_traindiag.py @@ -77,7 +77,7 @@ class DyT(nn.Module): # Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 # NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 class TransformerLayer(nn.Module): - def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128, pixel_size=3.6e-6): super().__init__() self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) self.q_proj = nn.Linear(h_dim, h_dim) @@ -91,6 +91,66 @@ class TransformerLayer(nn.Module): self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) self.k1 = nn.Parameter(torch.randn(1)) self.k2 = nn.Parameter(torch.randn(1)) + if max_seq_len < 512: + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=True) + ) + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01, + trainable_cylind_lens=True) + ) + if max_seq_len >= 512: + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2) + ) + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.15, + lens_size = 8192 * 2) + ) + self.sim_scores = omm.DataParallel(self.sim_scores) + self.sim_output = omm.DataParallel(self.sim_output) def split_to_heads(self, x, B, T, H): if self.num_heads <= 1: return x @@ -121,38 +181,9 @@ class OpticGPT2TrainDiag(nn.Module): pixel_size = 3.6e-6): super().__init__() self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) - self.sim_scores = omm.OpticalMul( - omm.Config(right_matrix_count_columns = max_seq_len, - right_matrix_count_rows = h_dim // num_heads, - right_matrix_width = pixel_size * max_seq_len, - right_matrix_height = pixel_size * (h_dim // num_heads), - min_height_gap = pixel_size, - right_matrix_split_x = 2, - right_matrix_split_y = 2, - left_matrix_split_x = 2, - left_matrix_split_y = 2, - result_matrix_split = 2, - distance = 0.01, - trainable_cylind_lens=True) - ) - - self.sim_output = omm.OpticalMul( - omm.Config(right_matrix_count_columns = h_dim // num_heads, - right_matrix_count_rows = max_seq_len, - right_matrix_width = pixel_size * (h_dim // num_heads), - right_matrix_height = pixel_size * max_seq_len, - min_height_gap = pixel_size, - right_matrix_split_x = 2, - right_matrix_split_y = 2, - left_matrix_split_x = 2, - left_matrix_split_y = 2, - result_matrix_split = 2, - distance = 0.01, - trainable_cylind_lens=True) - ) self.layers = nn.ModuleList([ - TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads, - dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) + TransformerLayer(h_dim=self.h_dim, num_heads=self.num_heads, + dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len, pixel_size=pixel_size) for _ in range(layers_num)]) self.tok_embeds = nn.Embedding(vocab_size, h_dim) self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))