trainable cylind lens

pull/1/head
Vladimir Protsenko 3 months ago
parent 58e34b4017
commit d9aa740746

@ -0,0 +1,115 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import sys
from pathlib import Path
torch.manual_seed(1337)
#################################### Model #########################################
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = (q @ k.transpose(1, 2)) * (self.h_dim ** -0.5)
tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device)
scores = scores.masked_fill(tril == 0, float('-inf')) # encoder does not need this line
attention = nn.functional.softmax(scores, dim=2)
return self.o_proj(self.gather_heads(attention @ v, *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
class GPT2(nn.Module):
def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1, pixel_size=None):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.layers = nn.ModuleList([
TransformerLayer(h_dim=self.h_dim, num_heads=self.num_heads, dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len)
for _ in range(layers_num)])
self.tok_embeds = nn.Embedding(vocab_size, h_dim)
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
self.lm_head = nn.Linear(h_dim, vocab_size)
def forward(self, x, targets=None):
x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
logits = self.lm_head(x) # B,T,C
loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None
return logits, loss
# what is the purpose? autoregressive inference!
def generate(self, start_idx, max_new_tokens):
idx = start_idx
for i in range(max_new_tokens):
idx_cond = idx[:,-self.max_seq_len:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:] # B, C
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
idx = torch.cat([idx, idx_next], dim=1)
return idx

@ -9,23 +9,30 @@ import sys
from pathlib import Path from pathlib import Path
from char_gpt2 import GPT2 from char_gpt2 import GPT2
from optics_char_gpt2 import OpticGPT2 from optics_char_gpt2 import OpticGPT2
from optics_char_gpt2_traindiag import OpticGPT2TrainDiag
from optics_char_gpt2_ff import OpticGPT2FF
seed = 1337 seed = 1337
torch.manual_seed(seed) torch.manual_seed(seed)
models = {'gpt2': GPT2, 'optic_gpt2': OpticGPT2} models = {'gpt2': GPT2, 'optic_gpt2': OpticGPT2, 'optic_gpt2_ff': OpticGPT2FF, 'optic_gpt2_traindiag':OpticGPT2TrainDiag}
batch_size = 50 batch_size = 25
max_iters = 40000 max_iters = 40000*2
eval_interval = 300 eval_interval = 300
learning_rate = 1e-3 learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu' device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200 eval_iters = 200
layers_num = 22 layers_num = 2
h_dim = 64 h_dim = 64
max_seq_len = 256 max_seq_len = 64
num_heads = 4 num_heads = 1
dropout_rate = 0.1 dropout_rate = 0.1
pixel_size = 3.6e-6 pixel_size = 3.6e-6
# CUDA_VISIBLE_DEVICES=1 python src/main.py optic_gpt2_ff ./data/wiki.train.tokens ./data/wiki.valid.tokens ./data/wiki.test.tokens seq_128_hdim_64
# CUDA_VISIBLE_DEVICES=2 python src/main.py optic_gpt2_ff ./data/wiki.train.tokens ./data/wiki.valid.tokens ./data/wiki.test.tokens seq_128_hdim_128
# CUDA_VISIBLE_DEVICES=3 python src/main.py optic_gpt2_ff ./data/wiki.train.tokens ./data/wiki.valid.tokens ./data/wiki.test.tokens seq_128_hdim_256
# CUDA_VISIBLE_DEVICES=4 python src/main.py gpt2 ./data/wiki.train.tokens ./data/wiki.valid.tokens ./data/wiki.test.tokens seq_64_hdim_64
# CUDA_VISIBLE_DEVICES=5 python src/main.py gpt2 ./data/wiki.train.tokens ./data/wiki.valid.tokens ./data/wiki.test.tokens seq_64_hdim_128
# CUDA_VISIBLE_DEVICES=6 python src/main.py gpt2 ./data/wiki.train.tokens ./data/wiki.valid.tokens ./data/wiki.test.tokens seq_64_hdim_256
# CUDA_VISIBLE_DEVICES=1 python .src/main.py gpt2|optic_gpt2 ./data/wiki.train.tokens ./data/wiki.valid.tokens ./data/wiki.test.tokens comment # CUDA_VISIBLE_DEVICES=1 python .src/main.py gpt2|optic_gpt2 ./data/wiki.train.tokens ./data/wiki.valid.tokens ./data/wiki.test.tokens comment
MODEL_CLASS = models[sys.argv[1]] MODEL_CLASS = models[sys.argv[1]]
train_data_path = Path(sys.argv[2]) train_data_path = Path(sys.argv[2])

@ -274,7 +274,8 @@ class Config(ConfigOpticBase, ConfigModelBase):
wavelength: float = 532e-9, wavelength: float = 532e-9,
distance: float = 0.03, distance: float = 0.03,
lens_pixel_size: float = 1.8e-6, lens_pixel_size: float = 1.8e-6,
lens_size: int = 8192): lens_size: int = 8192,
trainable_cylind_lens = False):
""" """
Конструктор класса. Конструктор класса.
@ -294,6 +295,7 @@ class Config(ConfigOpticBase, ConfigModelBase):
distance: дистанция в метрах распространения светового поля между плоскостями. distance: дистанция в метрах распространения светового поля между плоскостями.
lens_pixel_size: размер пикселя в метрах скрещенных линз в оптической системе (нужен исключительно для моделирования). lens_pixel_size: размер пикселя в метрах скрещенных линз в оптической системе (нужен исключительно для моделирования).
lens_size: размер скрещенных линз в метрах в оптической системе (нужен исключительно для моделирования). lens_size: размер скрещенных линз в метрах в оптической системе (нужен исключительно для моделирования).
trainable_cylind_lens: обучаемые диагональные матрицы, линза перед фурье плоскостью
""" """
ConfigOpticBase.__init__(self, wavelength, distance) ConfigOpticBase.__init__(self, wavelength, distance)
@ -320,6 +322,7 @@ class Config(ConfigOpticBase, ConfigModelBase):
self._input_vector_split_x: int = left_matrix_split_x self._input_vector_split_x: int = left_matrix_split_x
self._input_vector_split_y: int = left_matrix_split_y self._input_vector_split_y: int = left_matrix_split_y
self._result_vector_split: int = result_matrix_split self._result_vector_split: int = result_matrix_split
self._trainable_cylind_lens = trainable_cylind_lens
@property @property
def matrix_split_x(self) -> int: def matrix_split_x(self) -> int:

@ -19,14 +19,29 @@ class OpticalMul(_nn.Module):
prop_one = _PropSinc(config.input_vector_plane, config.first_lens_plane, config) prop_one = _PropSinc(config.input_vector_plane, config.first_lens_plane, config)
prop_two = _PropCrossLens(config.first_lens_plane, config) prop_two = _PropCrossLens(config.first_lens_plane, config)
prop_three = _PropSinc(config.first_lens_plane, config.matrix_plane, config) prop_three = _PropSinc(config.first_lens_plane, config.matrix_plane, config)
prop_four = _PropСylindLens(config.matrix_plane, config) prop_four = _PropСylindLens(config.matrix_plane, config, trainable=config._trainable_cylind_lens)
prop_five = _PropSinc(config.matrix_plane, config.second_lens_plane, config) prop_five = _PropSinc(config.matrix_plane, config.second_lens_plane, config)
prop_six = _PropCrossLens(config.second_lens_plane, config).T prop_six = _PropCrossLens(config.second_lens_plane, config).T
prop_seven = _PropSinc(config.second_lens_plane, config.output_vector_plane, config) prop_seven = _PropSinc(config.second_lens_plane, config.output_vector_plane, config)
self._propagator_one: _Prop = prop_one + prop_two + prop_three + prop_four # print(prop_one)
# print(prop_two)
# print(prop_three)
# print(prop_four)
# print(prop_five)
# print((prop_one + prop_two + prop_three))
# print((prop_one + prop_two + prop_three + prop_four))
self._propagator_one: _Prop = prop_one + prop_two + prop_three
self._propagator_between = prop_four
self._propagator_two: _Prop = prop_five + prop_six + prop_seven self._propagator_two: _Prop = prop_five + prop_six + prop_seven
# print(self._propagator_one)
# print(self._propagator_between)
# print(self._propagator_between.operator_X)
# print(self._propagator_between.operator_Y)
# print(self._propagator_two)
kron_vec_utils = _torch.ones((config.input_vector_split_y, config.input_vector_split_x)) kron_vec_utils = _torch.ones((config.input_vector_split_y, config.input_vector_split_x))
kron_mat_utils = _torch.ones((config.matrix_split_x, config.matrix_split_y)) kron_mat_utils = _torch.ones((config.matrix_split_x, config.matrix_split_y))
self.register_buffer('_kron_vec_utils', kron_vec_utils, persistent=True) self.register_buffer('_kron_vec_utils', kron_vec_utils, persistent=True)
@ -111,6 +126,7 @@ class OpticalMul(_nn.Module):
mat_field = self.prepare_matrix(other) mat_field = self.prepare_matrix(other)
vec_field = self._propagator_one(vec_field) vec_field = self._propagator_one(vec_field)
vec_field = self._propagator_between(vec_field)
vec_field = self._propagator_two(vec_field * mat_field) vec_field = self._propagator_two(vec_field * mat_field)
return self.prepare_out(vec_field) return self.prepare_out(vec_field)

@ -16,12 +16,20 @@ class Propagator(_ABC, _nn.Module):
operator_X: оператор отображающий распроранение светового поля вдоль оси абсцисс operator_X: оператор отображающий распроранение светового поля вдоль оси абсцисс
operator_Y: оператор отображающий распроранение светового поля вдоль оси ординат operator_Y: оператор отображающий распроранение светового поля вдоль оси ординат
""" """
def __init__(self, operator_X: _torch.Tensor, operator_Y: _torch.Tensor): def __init__(self, operator_X: _torch.Tensor, operator_Y: _torch.Tensor, trainable = False, diagonal = False):
super(Propagator, self).__init__() super(Propagator, self).__init__()
operator_X: _torch.Tensor = _torch.view_as_real(operator_X) operator_X: _torch.Tensor = _torch.view_as_real(operator_X)
operator_Y: _torch.Tensor = _torch.view_as_real(operator_Y) operator_Y: _torch.Tensor = _torch.view_as_real(operator_Y)
if trainable:
self._operator_X = _nn.Parameter(operator_X)
self._operator_Y = _nn.Parameter(operator_Y)
self._trainable = trainable
self._diagonal = diagonal
else:
self.register_buffer('_operator_X', operator_X, persistent=True) self.register_buffer('_operator_X', operator_X, persistent=True)
self.register_buffer('_operator_Y', operator_Y, persistent=True) self.register_buffer('_operator_Y', operator_Y, persistent=True)
self._trainable = trainable
self._diagonal = diagonal
@property @property
def operator_X(self) -> _torch.Tensor: def operator_X(self) -> _torch.Tensor:
@ -103,8 +111,14 @@ class Propagator(_ABC, _nn.Module):
Распределение комплексной амплитуды светового поля, Распределение комплексной амплитуды светового поля,
после распространения. после распространения.
""" """
if self._diagonal:
return _torch.diag_embed(self.operator_Y) @ field @ _torch.diag_embed(self.operator_X)
else:
return self.operator_Y @ field @ self.operator_X return self.operator_Y @ field @ self.operator_X
def __repr__(self):
return f"Diag: {self._diagonal} Trainable: {self._trainable} Y shape: {self.operator_Y.shape}, X shape: {self.operator_X.shape}"
class PropagatorLens(Propagator): class PropagatorLens(Propagator):
""" """
Абстрактный класс распространения света в тонком оптическом элементе. Абстрактный класс распространения света в тонком оптическом элементе.
@ -133,7 +147,7 @@ class PropagatorCrossLens(PropagatorLens):
представленной тонким оптическим элементом. представленной тонким оптическим элементом.
""" """
def __init__(self, plane: _ConfigDesignPlane, def __init__(self, plane: _ConfigDesignPlane,
config: _ConfigOpticBase): config: _ConfigOpticBase, trainable = False):
""" """
Конструктор класса скрещенной линзы. Конструктор класса скрещенной линзы.
@ -144,7 +158,8 @@ class PropagatorCrossLens(PropagatorLens):
operator_X = _torch.exp(-1j * config.K / config.distance * plane.linspace_by_x**2) operator_X = _torch.exp(-1j * config.K / config.distance * plane.linspace_by_x**2)
operator_Y = _torch.exp(-1j * config.K / 2 / config.distance * plane.linspace_by_y**2) operator_Y = _torch.exp(-1j * config.K / 2 / config.distance * plane.linspace_by_y**2)
super(PropagatorCrossLens, self).__init__(_torch.diag_embed(operator_X), super(PropagatorCrossLens, self).__init__(_torch.diag_embed(operator_X),
_torch.diag_embed(operator_Y)) _torch.diag_embed(operator_Y),
trainable)
class PropagatorСylindLens(PropagatorLens): class PropagatorСylindLens(PropagatorLens):
""" """
@ -152,7 +167,7 @@ class PropagatorСylindLens(PropagatorLens):
представленной тонким оптическим элементом. представленной тонким оптическим элементом.
""" """
def __init__(self, plane: _ConfigDesignPlane, def __init__(self, plane: _ConfigDesignPlane,
config: _ConfigOpticBase): config: _ConfigOpticBase, trainable = False):
""" """
Конструктор класса цилиндрической линзы. Конструктор класса цилиндрической линзы.
@ -162,8 +177,10 @@ class PropagatorСylindLens(PropagatorLens):
""" """
operator_X = _torch.exp(-1j * config.K / config.distance * plane.linspace_by_x**2) operator_X = _torch.exp(-1j * config.K / config.distance * plane.linspace_by_x**2)
operator_Y = _torch.ones_like(plane.linspace_by_y, dtype=_torch.cfloat) operator_Y = _torch.ones_like(plane.linspace_by_y, dtype=_torch.cfloat)
super(PropagatorСylindLens, self).__init__(_torch.diag_embed(operator_X), super(PropagatorСylindLens, self).__init__(operator_X,
_torch.diag_embed(operator_Y)) operator_Y,
trainable,
diagonal=True)
class PropagatorSinc(Propagator): class PropagatorSinc(Propagator):
""" """
@ -172,7 +189,7 @@ class PropagatorSinc(Propagator):
""" """
def __init__(self, first_plane: _ConfigDesignPlane, def __init__(self, first_plane: _ConfigDesignPlane,
second_plane: _ConfigDesignPlane, second_plane: _ConfigDesignPlane,
config: _ConfigOpticBase): config: _ConfigOpticBase, trainable = False):
""" """
Конструктор класса распространения в свободном пространстве. Конструктор класса распространения в свободном пространстве.
@ -184,7 +201,7 @@ class PropagatorSinc(Propagator):
operator_X, operator_Y = self.__get_operators(first_plane, operator_X, operator_Y = self.__get_operators(first_plane,
second_plane, second_plane,
config) config)
super(PropagatorSinc, self).__init__(operator_X, operator_Y) super(PropagatorSinc, self).__init__(operator_X, operator_Y, trainable)
def __get_operator_for_dim(self, def __get_operator_for_dim(self,
pixel_size_in: float, pixel_size_in: float,

@ -0,0 +1,218 @@
import torch
import torch.nn as nn
from torch.nn import functional as F, init
from einops import rearrange
import optical_matrix_multiplication as omm
from optical_matrix_multiplication import propagator
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
from pathlib import Path
import sys
import math
torch.manual_seed(1337)
#################################### Model #########################################
def norm(matrix: torch.Tensor, max_val: float = 1) -> torch.Tensor:
return matrix / (max_val + 1e-10)
def optics_matmul_shift(sim, tensor_1, tensor_2):
# print(tensor_1.shape, tensor_2.shape)
tensor_1 = tensor_1[None,:,:,:]
tensor_2 = tensor_2[None,None,:,:]
# print(tensor_1.shape, tensor_2.shape)
# raise RuntimeError
if torch.min(tensor_1) >= 0 and torch.min(tensor_2) >= 0:
max_abs = abs(max(torch.max(tensor_1), torch.max(tensor_2)))
a, b = norm(tensor_1, max_abs), norm(tensor_2, max_abs)
return sim(a, b)[0,:,:,:] * max_abs **2
min_abs = abs(min(torch.min(tensor_1), torch.min(tensor_2)))
max_abs = abs(max(torch.max(tensor_1), torch.max(tensor_2))) + min_abs
shift_a = min_abs * torch.ones(tensor_1.shape).to(tensor_1.device)
shift_b = min_abs * torch.ones(tensor_2.shape).to(tensor_1.device)
a_a_sh = tensor_1 + shift_a
b_b_sh = tensor_2 + shift_b
a_a_sh_norm, b_b_sh_norm = norm(a_a_sh, max_abs), norm(b_b_sh, max_abs)
shift_a_norm, shift_b_norm = norm(shift_a, max_abs), norm(shift_b, max_abs)
a_a_sh_b_b_sh = sim(a_a_sh_norm, b_b_sh_norm)
a_a_sh_b_sh = sim(a_a_sh_norm, shift_b_norm)
a_sh_b_b_sh = sim(shift_a_norm, b_b_sh_norm)
a_sh_b_sh = sim(shift_a_norm, shift_b_norm)
return (a_a_sh_b_b_sh - a_a_sh_b_sh - a_sh_b_b_sh + a_sh_b_sh)[0,:,:,:] * max_abs ** 2
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
class OpticLinear(nn.Module):
def __init__(
self,
in_features,
out_features,
bias = True,
device = None,
dtype = None,
pixel_size = 3.6e-6
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(
torch.empty((in_features, out_features), **factory_kwargs)
)
# print(self.weight.shape)
if bias:
self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs))
else:
self.register_parameter("bias", None)
self.k = nn.Parameter(torch.randn(1))
self.sim = omm.OpticalMul(
omm.Config(
right_matrix_count_columns = out_features ,
right_matrix_count_rows = in_features,
right_matrix_width = pixel_size * out_features ,
right_matrix_height = pixel_size * in_features,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01)
)
self.reset_parameters()
def forward(self, input):
"""
Runs the forward pass.
"""
return self.k * optics_matmul_shift(self.sim, input, self.weight) + self.bias
def reset_parameters(self) -> None:
"""
Resets parameters based on their initialization used in ``__init__``.
"""
# Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
# uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see
# https://github.com/pytorch/pytorch/issues/57109
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
init.uniform_(self.bias, -bound, bound)
def extra_repr(self) -> str:
"""
Return the extra representation of the module.
"""
return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}"
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = OpticLinear(h_dim, 4*h_dim)
self.ff2 = OpticLinear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = (q @ k.transpose(1, 2)) * (self.h_dim ** -0.5)
tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device)
scores = scores.masked_fill(tril == 0, float('-inf')) # encoder does not need this line
attention = nn.functional.softmax(scores, dim=2)
return self.o_proj(self.gather_heads(attention @ v, *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
class OpticGPT2FF(nn.Module):
def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1,
pixel_size = 3.6e-6):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.layers = nn.ModuleList([
TransformerLayer(h_dim=self.h_dim, num_heads=self.num_heads,
dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len)
for _ in range(layers_num)])
self.tok_embeds = nn.Embedding(vocab_size, h_dim)
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
self.lm_head = nn.Linear(h_dim, vocab_size)
def forward(self, x, targets=None):
x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
logits = self.lm_head(x) # B,T,C
loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None
return logits, loss
# what is the purpose? autoregressive inference!
def generate(self, start_idx, max_new_tokens):
idx = start_idx
for i in range(max_new_tokens):
idx_cond = idx[:,-self.max_seq_len:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:] # B, C
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
idx = torch.cat([idx, idx_next], dim=1)
return idx

@ -0,0 +1,179 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import optical_matrix_multiplication as omm
from optical_matrix_multiplication import propagator
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
from pathlib import Path
import sys
torch.manual_seed(1337)
#################################### Model #########################################
def norm(matrix: torch.Tensor, max_val: float = 1) -> torch.Tensor:
return matrix / (max_val + 1e-10)
def optics_matmul_shift(sim, tensor_1, tensor_2):
tensor_1 = tensor_1[None,:,:,:]
tensor_2 = tensor_2[None,:,:,:]
if torch.min(tensor_1) >= 0 and torch.min(tensor_2) >= 0:
max_abs = abs(max(torch.max(tensor_1), torch.max(tensor_2)))
a, b = norm(tensor_1, max_abs), norm(tensor_2, max_abs)
return sim(a, b)[0,:,:,:] * max_abs **2
min_abs = abs(min(torch.min(tensor_1), torch.min(tensor_2)))
max_abs = abs(max(torch.max(tensor_1), torch.max(tensor_2))) + min_abs
shift_a = min_abs * torch.ones(tensor_1.shape).to(tensor_1.device)
shift_b = min_abs * torch.ones(tensor_2.shape).to(tensor_1.device)
a_a_sh = tensor_1 + shift_a
b_b_sh = tensor_2 + shift_b
a_a_sh_norm, b_b_sh_norm = norm(a_a_sh, max_abs), norm(b_b_sh, max_abs)
shift_a_norm, shift_b_norm = norm(shift_a, max_abs), norm(shift_b, max_abs)
a_a_sh_b_b_sh = sim(a_a_sh_norm, b_b_sh_norm)
a_a_sh_b_sh = sim(a_a_sh_norm, shift_b_norm)
a_sh_b_b_sh = sim(shift_a_norm, b_b_sh_norm)
a_sh_b_sh = sim(shift_a_norm, shift_b_norm)
return (a_a_sh_b_b_sh - a_a_sh_b_sh - a_sh_b_b_sh + a_sh_b_sh)[0,:,:,:] * max_abs ** 2
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
self.k1 = nn.Parameter(torch.randn(1))
self.k2 = nn.Parameter(torch.randn(1))
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = self.k1 * optics_matmul_shift(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5)
tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device)
scores = scores.masked_fill(tril == 0, float('-inf'))
attention = nn.functional.softmax(scores, dim=2)
output = self.k2 * optics_matmul_shift(self.sim_output, attention, v)
return self.o_proj(self.gather_heads(output, *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
class OpticGPT2TrainDiag(nn.Module):
def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=64, num_heads=1, dropout_rate = 0.1,
pixel_size = 3.6e-6):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01,
trainable_cylind_lens=True)
)
self.sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01,
trainable_cylind_lens=True)
)
self.layers = nn.ModuleList([
TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads,
dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len)
for _ in range(layers_num)])
self.tok_embeds = nn.Embedding(vocab_size, h_dim)
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
self.lm_head = nn.Linear(h_dim, vocab_size)
def forward(self, x, targets=None):
x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
logits = self.lm_head(x) # B,T,C
loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None
return logits, loss
# what is the purpose? autoregressive inference!
def generate(self, start_idx, max_new_tokens):
idx = start_idx
for i in range(max_new_tokens):
idx_cond = idx[:,-self.max_seq_len:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:] # B, C
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
idx = torch.cat([idx, idx_next], dim=1)
return idx
Loading…
Cancel
Save