First commit

main
Vladimir Protsenko 6 days ago
commit a13f526f87

@ -0,0 +1,2 @@
data/
logs/

2
.gitattributes vendored

@ -0,0 +1,2 @@
*.pq filter=lfs diff=lfs merge=lfs -text
*.csv filter=lfs diff=lfs merge=lfs -text

219
.gitignore vendored

@ -0,0 +1,219 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[codz]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py.cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
# Pipfile.lock
# UV
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# uv.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
# poetry.lock
# poetry.toml
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
# pdm.lock
# pdm.toml
.pdm-python
.pdm-build/
# pixi
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
# pixi.lock
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
# in the .venv directory. It is recommended not to include this directory in version control.
.pixi
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# Redis
*.rdb
*.aof
*.pid
# RabbitMQ
mnesia/
rabbitmq/
rabbitmq-data/
# ActiveMQ
activemq-data/
# SageMath parsed files
*.sage.py
# Environments
.env
.envrc
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
# .idea/
# Abstra
# Abstra is an AI-powered process automation framework.
# Ignore directories containing user credentials, local state, and settings.
# Learn more at https://abstra.io/docs
.abstra/
# Visual Studio Code
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
# and can be added to the global gitignore or merged into this file. However, if you prefer,
# you could uncomment the following to ignore the entire vscode folder
# .vscode/
# Ruff stuff:
.ruff_cache/
# PyPI configuration file
.pypirc
# Marimo
marimo/_static/
marimo/_lsp/
__marimo__/
# Streamlit
.streamlit/secrets.toml
logs/
data/

@ -0,0 +1,51 @@
FROM pytorch/pytorch:2.6.0-cuda12.6-cudnn9-runtime
ENV TZ=Europe/Samara
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
ARG USER
ARG GROUP
ARG UID
ARG GID
RUN apt update
RUN apt install sudo -y
RUN sed -i 's/^%sudo.*/%sudo ALL=(ALL) NOPASSWD: ALL/' /etc/sudoers
RUN groupadd --gid ${GID} ${GROUP}
RUN useradd --shell /bin/bash --uid ${UID} --gid ${GID} -G sudo --create-home ${USER}
RUN mkdir /wd
RUN chown ${USER}:${GROUP} /wd
# SYSTEM CONFIGURATION
RUN apt install wget vim htop mc git tree -y
RUN apt-get install -y libssl-dev autoconf libtool make
RUN cd /usr/local/src && \
wget https://curl.haxx.se/download/curl-7.88.1.zip && \
unzip curl-7.88.1.zip && \
cd curl-7.88.1 && \
./buildconf && \
./configure --with-ssl && \
make && \
make install && \
cp /usr/local/bin/curl /usr/bin/curl && \
ldconfig && \
curl -V
RUN curl -fsSL https://code-server.dev/install.sh | sh
RUN /opt/conda/bin/conda install -n base ipykernel --update-deps --force-reinstall -y
USER ${USER}
# USER CONFIGURATION
RUN pip install schedulefree tensorboard opencv-python-headless scipy pandas matplotlib torchmetrics pyarrow einops nvitop
RUN openssl req -x509 -newkey rsa:4096 -keyout /home/${USER}/key.pem -out /home/${USER}/cert.pem -sha256 -nodes -days 365 -subj "/C=RU/ST=SamaraRegion/L=Samara/O=SSAU/OU=LIAV/CN=vscode.ssau.ru/"
RUN mkdir -p /home/${USER}/.config/code-server
RUN echo 'bind-addr: 0.0.0.0:8443' >> /home/${USER}/.config/code-server/config.yaml
RUN echo "cert: /home/${USER}/cert.pem" >> /home/${USER}/.config/code-server/config.yaml
RUN echo "cert-key: /home/${USER}/key.pem" >> /home/${USER}/.config/code-server/config.yaml
RUN code-server --install-extension ms-python.python
ENV SHELL=/bin/bash
SHELL ["/bin/bash", "--login", "-i", "-c"]
WORKDIR /wd

@ -0,0 +1,3 @@
# Скрипты для обучения gpt2 с симулятором оптического умножения матриц
Репозиторий симулятора оптического умножения матриц https://github.com/amacomm/Optical_matrix_multiplication

@ -0,0 +1,8 @@
#!/bin/bash
CURDIRNAME=${PWD##*/}
docker build . -t ${USER}_${CURDIRNAME}_vscode \
--build-arg USER=${USER} \
--build-arg GROUP=${USER} \
--build-arg UID=$(id -u ${USER}) \
--build-arg GID=$(id -g ${USER})

@ -0,0 +1,115 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import sys
from pathlib import Path
torch.manual_seed(1337)
#################################### Model #########################################
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = (q @ k.transpose(1, 2)) * (self.h_dim ** -0.5)
tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device)
scores = scores.masked_fill(tril == 0, float('-inf')) # encoder does not need this line
attention = nn.functional.softmax(scores, dim=2)
return self.o_proj(self.gather_heads(attention @ v, *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
class GPT2(nn.Module):
def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=128, num_heads=4, dropout_rate = 0.1, pixel_size=None):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.layers = nn.ModuleList([
TransformerLayer(h_dim=self.h_dim, num_heads=self.num_heads, dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len)
for _ in range(layers_num)])
self.tok_embeds = nn.Embedding(vocab_size, h_dim)
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
self.lm_head = nn.Linear(h_dim, vocab_size)
def forward(self, x, targets=None):
x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
logits = self.lm_head(x) # B,T,C
loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None
return logits, loss
# what is the purpose? autoregressive inference!
def generate(self, start_idx, max_new_tokens):
idx = start_idx
for i in range(max_new_tokens):
idx_cond = idx[:,-self.max_seq_len:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:] # B, C
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
idx = torch.cat([idx, idx_next], dim=1)
return idx

@ -0,0 +1,129 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import sys
from pathlib import Path
from char_gpt2 import GPT2
from optics_char_gpt2 import OpticGPT2
seed = 1337
torch.manual_seed(seed)
models = {'gpt2': GPT2, 'optic_gpt2': OpticGPT2}
batch_size = 50
max_iters = 40000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 256
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
# CUDA_VISIBLE_DEVICES=1 python .src/main.py gpt2|optic_gpt2 ./data/wiki.train.tokens ./data/wiki.valid.tokens ./data/wiki.test.tokens comment
MODEL_CLASS = models[sys.argv[1]]
train_data_path = Path(sys.argv[2])
val_data_path = Path(sys.argv[3])
test_data_path = Path(sys.argv[4])
comment = f"{sys.argv[1]}_{train_data_path.name}_{sys.argv[5]}_{seed}"
logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/'
writer = SummaryWriter(logs_dir)
script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).name)
print("Logs dir:", logs_dir)
script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script
script_snapshot_path.chmod(0o400) # with read-only permission
#################################### Dataset #########################################
# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open(train_data_path, encoding='utf-8') as f:
train_text = f.read()
with open(val_data_path, encoding='utf-8') as f:
val_text = f.read()
with open(test_data_path, encoding='utf-8') as f:
test_text = f.read()
text = train_text + val_text + test_text
chars = sorted(set(text))
vocab_size = len(chars)
wtoi = {w:i for i,w in enumerate(chars)}
itow = {i:w for i,w in enumerate(chars)}
encode = lambda s: [wtoi[w] for w in s]
decode = lambda idx: ''.join([itow[i] for i in idx])
def get_batch(data, seq_len, batch_size):
ix = torch.randint(len(data)-seq_len, (batch_size, ))
x = torch.stack([data[i:i+seq_len] for i in ix]).to(device)
y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device)
return x, y
train_data = torch.tensor(encode(train_text), dtype=torch.long)
val_data = torch.tensor(encode(val_text), dtype=torch.long)
test_data = torch.tensor(encode(test_text), dtype=torch.long)
@torch.no_grad()
def perplexity(model, data):
stride = max(1, len(data) // 10000)
losses = []
for i in range(0, len(data)-max_seq_len-1, stride):
x = data[i:(i+max_seq_len)].to(device)
y = data[(i+1):(i+max_seq_len+1)].to(device)
logits, loss = model(x[None,...], y[None,...])
losses.append(loss.item())
print(f"\rppl {i}/{len(data)-max_seq_len-1}", end="")
return np.exp(np.mean(losses))
#################################### Model #########################################mo
def complete(m, start_idxs=[0], max_new_tokens=100):
start_idx = torch.tensor([start_idxs]).to(device)
generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens)
return decode(generated_tokens[0].tolist())
m = MODEL_CLASS(vocab_size=vocab_size, h_dim=h_dim, max_seq_len=max_seq_len, num_heads=num_heads, pixel_size=pixel_size)
m = m.to(device)
#################################### Train #########################################
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01)
completion = complete(m, encode("\n"*max_seq_len), 2*max_seq_len)
print(completion)
writer.add_text('completions', completion, 0)
for i in range(max_iters):
xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size)
logits, loss = m(xb, yb)
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
writer.add_scalar('loss', loss.item(), i)
print(f"\r{i}/{max_iters} {loss.item()}", end="")
if i % 5000 == 0:
ppl = perplexity(model=m, data=val_data)
writer.add_scalar('val_perplexity', ppl.item(), i)
print(f"\rPerplexity at {i}: {ppl}")
writer.add_text('completions', complete(m, encode("\n"*max_seq_len), 2*max_seq_len), i)
ppl = perplexity(model=m, data=val_data)
print(f"\r{i+1}/{max_iters} {loss.item()}")
print(f"\rPerplexity at {i}: {ppl}")
writer.add_scalar('val_perplexity', ppl.item(), i+1)
writer.add_scalar('loss', loss.item(), i)
ppl = perplexity(model=m, data=test_data)
writer.add_scalar('test_perplexity', ppl.item(), i+1)
print(f"\rTest Perplexity at {i}: {ppl}")
completion = complete(m, encode("\n"*max_seq_len), 2*max_seq_len)
print(completion)
writer.add_text('completions', completion, i+1)

@ -0,0 +1,9 @@
__all__ = ["config",
"propagator",
"optical_mul"]
__version__ = "3.0.0"
from .config import Config
from . import propagator
from .optical_mul import OpticalMul
from .parallel import DataParallel

@ -0,0 +1,379 @@
import torch as _torch
import numpy as _np
from abc import ABC as _ABC
from typing import Tuple as _Tuple
import collections as _collections
import pickle as _pickle
class ConfigOpticBase(_ABC):
"""
Абстрактный класс базовой информации о работе опттической установки.
Поля:
wavelength: длина волны, используемая в оптической установке.
K: волновое число.
distance: дистанция распространения светового поля между плоскостями.
"""
def __init__(self, wavelength: float, distance: float):
"""
Конструктор класса базовой информации о работе опттической установки.
Args:
wavelength: длина волны, используемая в оптической установке.
distance: дистанция распространения светового поля между плоскостями.
"""
self._wavelength: float = wavelength
self._K: float = 2 * _np.pi / wavelength
self._distance = distance
@property
def wavelength(self) -> float:
"""
Returns: длинна волны.
"""
return self._wavelength
@property
def K(self) -> float:
"""
Returns: волновое число.
"""
return self._K
@property
def distance(self) -> float:
"""
Returns: дистанция распространения.
"""
return self._distance
class ConfigDesignPlane:
"""
Класс данных о расчётной плоскости.
"""
def __init__(self,
pixel_count: None | int | _Tuple[int, int] = None,
pixel_size: None | float | _Tuple[float, float] = None,
aperture: None | float | _Tuple[float, float] = None
):
"""
Конструктор класса данных о расчётной плоскости.
Args:
pixel_count: данные о кол-ве пикселей в измерении.
pixel_size: данные о размере пикселей в измерении.
aperture: данные о апертуре измерения.
Note:
1. Достаточно указать два из трёх входных аргумента.
2. Если было переданно одно число, будет считаться, что значение у каждого измерения совпадают,
если было переданно два числа, то первое число будет ассоциированно с Y измерением, второе с X.
"""
if aperture is None:
aperture = ConfigDesignPlane.__get_excluded_third(pixel_count, pixel_size, False)
elif pixel_count is None:
pixel_count = ConfigDesignPlane.__get_excluded_third(aperture, pixel_size)
elif pixel_size is None:
pixel_size = ConfigDesignPlane.__get_excluded_third(aperture, pixel_count)
self._pixel_count: int | _Tuple[int, int] = pixel_count
self._pixel_size: float | _Tuple[float, float] = pixel_size
self._aperture: float | _Tuple[float, float] = aperture
@staticmethod
def __get_excluded_third(first_element, second_element, div: bool = True):
ConfigDesignPlane.__check_for_none(first_element, second_element)
first_element_y, first_element_x = ConfigDesignPlane.__return_tuple(first_element)
second_element_y, second_element_x = ConfigDesignPlane.__return_tuple(second_element)
if div:
return first_element_y / second_element_y, first_element_x / second_element_x
else:
return first_element_y * second_element_y, first_element_x * second_element_x
@staticmethod
def __check_for_none(first_element, second_element):
if (first_element is None) or (second_element is None):
raise TypeError("One of the provided elements is None, it is not possible to obtain the full dimensions of the calculated plane.")
@staticmethod
def __return_element(element: int | _Tuple[int, int], dim: int = 0):
if isinstance(element, _collections.abc.Sequence):
return element[dim]
else:
return element
@staticmethod
def __return_tuple(element: int | _Tuple[int, int]):
if isinstance(element, _collections.abc.Sequence):
return element
else:
return element, element
@staticmethod
def __get_linspace_by_dim(aperture, pixel_count):
linspace = _torch.linspace(-aperture / 2, aperture / 2, pixel_count + 1)[:pixel_count]
linspace += aperture / (2 * pixel_count)
return linspace
@property
def pixel_count_by_x(self) -> float:
"""
Returns:
Информация о кол-ве пикселей по оси X.
"""
return ConfigDesignPlane.__return_element(self._pixel_count, 1)
@property
def pixel_count_by_y(self) -> float:
"""
Returns:
Информация о кол-ве пикселей по оси Y.
"""
return ConfigDesignPlane.__return_element(self._pixel_count)
@property
def pixel_count(self) -> _Tuple[float, float]:
"""
Returns:
Информация о кол-ве пикселей по каждой оси [Y, x].
"""
return ConfigDesignPlane.__return_tuple(self._pixel_count)
@property
def pixel_size_by_x(self) -> int:
"""
Returns:
Информация о размере пикселей по оси X.
"""
return ConfigDesignPlane.__return_element(self._pixel_size, 1)
@property
def pixel_size_by_y(self) -> int:
"""
Returns:
Информация о размере пикселей по оси Y.
"""
return ConfigDesignPlane.__return_element(self._pixel_size)
@property
def pixel_size(self) -> _Tuple[int, int]:
"""
Returns:
Информация о размере пикселей по каждой оси [Y, x].
"""
return ConfigDesignPlane.__return_tuple(self._pixel_size)
@property
def aperture_width(self) -> float:
"""
Returns:
Информация о ширине расчётной плоскасти.
"""
return ConfigDesignPlane.__return_element(self._aperture, 1)
@property
def aperture_height(self) -> float:
"""
Returns:
Информация о высоте расчётной плоскасти.
"""
return ConfigDesignPlane.__return_element(self._aperture)
@property
def aperture(self) -> _Tuple[float, float]:
"""
Returns:
Информация о высоте и ширине расчётной плоскасти [H, W].
"""
return ConfigDesignPlane.__return_tuple(self._aperture)
@property
def linspace_by_x(self) -> _torch.Tensor:
"""
Returns:
Расчётная сетка по оси X.
"""
return ConfigDesignPlane.__get_linspace_by_dim(self.aperture_width, self.pixel_count_by_x)
@property
def linspace_by_y(self) -> _torch.Tensor:
"""
Returns:
Расчётная сетка по оси Y.
"""
return ConfigDesignPlane.__get_linspace_by_dim(self.aperture_height, self.pixel_count_by_y)
@property
def meshgrid(self) -> _Tuple[_torch.Tensor, _torch.Tensor]:
"""
Returns:
Расчётная сетка по осям [Y, X].
"""
linspace_by_x = self.linspace_by_x
linspace_by_y = self.linspace_by_y
return _torch.meshgrid((linspace_by_y, linspace_by_x))
class ConfigModelBase(_ABC):
"""
Абстрактный класс базовой информации об оптической установке.
"""
def __init__(self,
input_vector_plane: ConfigDesignPlane,
first_lens_plane: ConfigDesignPlane,
matrix_plane: ConfigDesignPlane,
second_lens_plane: ConfigDesignPlane,
output_vector_plane: ConfigDesignPlane
):
self._input_vector_plane: ConfigDesignPlane = input_vector_plane
self._first_lens_plane: ConfigDesignPlane = first_lens_plane
self._matrix_plane: ConfigDesignPlane = matrix_plane
self._second_lens_plane: ConfigDesignPlane = second_lens_plane
self._output_vector_plane: ConfigDesignPlane = output_vector_plane
@property
def input_vector_plane(self) -> ConfigDesignPlane:
"""
Returns:
Информация о расчётной плоскости входного вектора.
"""
return self._input_vector_plane
@property
def first_lens_plane(self) -> ConfigDesignPlane:
"""
Returns:
Информация о расчётной плоскости первой скрещенной линзы.
"""
return self._first_lens_plane
@property
def matrix_plane(self) -> ConfigDesignPlane:
"""
Returns:
Информация о расчётной плоскости элемента матрицы.
"""
return self._matrix_plane
@property
def second_lens_plane(self) -> ConfigDesignPlane:
"""
Returns:
Информация о расчётной плоскости второй скрещенной линзы.
"""
return self._second_lens_plane
@property
def output_vector_plane(self) -> ConfigDesignPlane:
"""
Returns:
Информация о расчётной плоскости выходного вектора оптической установки.
"""
return self._output_vector_plane
class Config(ConfigOpticBase, ConfigModelBase):
"""
Класс конфигурации, хранит полную информацию о расчётной системе.
"""
def __init__(self,
right_matrix_count_columns: int,
right_matrix_count_rows: int,
right_matrix_width: float,
right_matrix_height: float,
min_height_gap: float,
right_matrix_split_x: int = 1,
right_matrix_split_y: int = 1,
left_matrix_split_x: int = 1,
left_matrix_split_y: int = 1,
result_matrix_split: int = 1,
camera_pixel_size: float = 3.6e-6,
wavelength: float = 532e-9,
distance: float = 0.03,
lens_pixel_size: float = 1.8e-6,
lens_size: int = 8192):
"""
Конструктор класса.
Args:
right_matrix_count_columns: число столбцов в правой матрице, участвующей в операции матричного умножения.
right_matrix_count_rows: число строк в правой матрице, участвующей в операции матричного умножения.
right_matrix_width: ширина в метрах оптического элемента правой матрицы, участвующей в операции матричного умножения.
right_matrix_height: высота в метрах оптического элемента правой матрицы, участвующей в операции матричного умножения.
min_height_gap: минимально возможный зазор для отображения вектора левой матрицы, участвующей в операции матричного умножения.
right_matrix_split_x: число дробления элементов правой матрицы по X (используется для более точного моделирования).
right_matrix_split_y: число дробления элементов правой матрицы по Y (используется для более точного моделирования).
left_matrix_split_x: число дробления элементов левой матрицы по X (используется для более точного моделирования).
left_matrix_split_y: число дробления элементов левой матрицы по Y (используется для более точного моделирования).
result_matrix_split: число дробления элементов результирующей матрицы (используется для более точного моделирования).
camera_pixel_size: физический размер пикселя камеры, считывающей результирующее световое поле.
wavelength: длины волн в метрах используемых в системе.
distance: дистанция в метрах распространения светового поля между плоскостями.
lens_pixel_size: размер пикселя в метрах скрещенных линз в оптической системе (нужен исключительно для моделирования).
lens_size: размер скрещенных линз в метрах в оптической системе (нужен исключительно для моделирования).
"""
ConfigOpticBase.__init__(self, wavelength, distance)
config_plane_one = ConfigDesignPlane((left_matrix_split_y, left_matrix_split_x * right_matrix_count_rows),
aperture=(min_height_gap, right_matrix_height)
)
config_plane_lens = ConfigDesignPlane(lens_size, lens_pixel_size)
config_plane_three = ConfigDesignPlane((right_matrix_count_columns * left_matrix_split_x, right_matrix_count_rows * right_matrix_split_y),
aperture=(right_matrix_width, right_matrix_height)
)
config_plane_five = ConfigDesignPlane((right_matrix_count_columns * result_matrix_split, 1),
aperture=(right_matrix_width, camera_pixel_size)
)
ConfigModelBase.__init__(self,
config_plane_one,
config_plane_lens,
config_plane_three,
config_plane_lens,
config_plane_five
)
self._matrix_split_x: int = right_matrix_split_x
self._matrix_split_y: int = right_matrix_split_y
self._input_vector_split_x: int = left_matrix_split_x
self._input_vector_split_y: int = left_matrix_split_y
self._result_vector_split: int = result_matrix_split
@property
def matrix_split_x(self) -> int:
"""
Returns:
Информация о разбиении пикселей элементов матрицы по оси X.
"""
return self._matrix_split_x
@property
def matrix_split_y(self) -> int:
"""
Returns:
Информация о разбиении пикселей элементов матрицы по оси Y.
"""
return self._matrix_split_y
@property
def input_vector_split_x(self) -> int:
"""
Returns:
Информация о разбиении пикселей элементов входного вектора по оси X.
"""
return self._input_vector_split_x
@property
def input_vector_split_y(self) -> int:
"""
Returns:
Информация о разбиении пикселей элементов входного вектора по оси Y.
"""
return self._input_vector_split_y
@property
def result_vector_split(self) -> int:
"""
Returns:
Информация о разбиении пикселей элементов выходного вектора по оси Y.
"""
return self._result_vector_split
def save(self, filename: str = "config.pth"):
"""
Метод сохранения параметров конфигурации в файл.
Args:
filename: название файла с параметрами конфигурации.
"""
with open(filename, 'wb') as f:
_pickle.dump(self, f, protocol=_pickle.HIGHEST_PROTOCOL)
@staticmethod
def load(filename: str = "config.pth") -> 'Config':
"""
Метод загрузки параметров конфигурации из файла.
Args:
filename: название файла с параметрами конфигурации.
"""
with open(filename, 'rb') as f:
return _pickle.load(f)

@ -0,0 +1,116 @@
import torch as _torch
import torch.nn as _nn
from .config import Config as _Config
from .propagator import PropagatorCrossLens as _PropCrossLens, PropagatorСylindLens as _PropСylindLens, PropagatorSinc as _PropSinc, Propagator as _Prop
class OpticalMul(_nn.Module):
"""
Класс системы, выполняющей оптически операцию умножения матрицы на матрицу.
"""
def __init__(self, config: _Config):
"""
Конструктор класса.
Args:
config: конфигурация расчётной системы.
"""
super(OpticalMul, self).__init__()
prop_one = _PropSinc(config.input_vector_plane, config.first_lens_plane, config)
prop_two = _PropCrossLens(config.first_lens_plane, config)
prop_three = _PropSinc(config.first_lens_plane, config.matrix_plane, config)
prop_four = _PropСylindLens(config.matrix_plane, config)
prop_five = _PropSinc(config.matrix_plane, config.second_lens_plane, config)
prop_six = _PropCrossLens(config.second_lens_plane, config).T
prop_seven = _PropSinc(config.second_lens_plane, config.output_vector_plane, config)
self._propagator_one: _Prop = prop_one + prop_two + prop_three + prop_four
self._propagator_two: _Prop = prop_five + prop_six + prop_seven
kron_vec_utils = _torch.ones((config.input_vector_split_y, config.input_vector_split_x))
kron_mat_utils = _torch.ones((config.matrix_split_x, config.matrix_split_y))
self.register_buffer('_kron_vec_utils', kron_vec_utils, persistent=True)
self.register_buffer('_kron_mat_utils', kron_mat_utils, persistent=True)
self._avg_pool = _nn.AvgPool2d((1, config.result_vector_split))
def prepare_vector(self, data: _torch.Tensor) -> _torch.Tensor:
"""
Метод подготовки матрицы левой матрицы, как набора векторов столбцов, к подаче на вход системы.
Args:
data: матрица комплексной амплитуды распределений световых полей.
Returns:
Матрицы содержащие вектора левой матрицы.
"""
data = data.cfloat().flip(-1)
data = data.unsqueeze(-2)
data = _torch.kron(data.contiguous(), self._kron_vec_utils)
return data
def prepare_matrix(self, data: _torch.Tensor) -> _torch.Tensor:
"""
Метод подготовки правой матрицы к подаче на вход системы.
Args:
data: матрица комплексной амплитуды распределения светового поля.
Returns:
Матрица - оптический элемент в центре модели.
"""
if (data.dim() > 4) and data.size(-1) == 2:
data = _torch.view_as_complex(data)
data = data.cfloat().transpose(-2, -1)
data = data.unsqueeze(-3)
data = _torch.kron(data.contiguous(), self._kron_mat_utils)
return data
def prepare_out(self, field: _torch.Tensor) -> _torch.Tensor:
"""
Метод получения результата матричного умножения.
Args:
data: матрицы выходого распределения светового поля системы.
Returns:
Вектор столбец (амплитудное распределение).
"""
### Закоментированная часть кода - более физически корректный вариант работы модели,
### однако, данный вариант кода будет требовать большое кол-во памяти во время обучения
field = field.abs().squeeze(-1) #**2
field = self._avg_pool(field)
return field.flip(-1) #**0.5
def forward(self,
input: _torch.Tensor,
other: _torch.Tensor) -> _torch.Tensor:
"""
Метод выполения матричного умножения.
Args:
input: матрица (B, C, H, W).
other: матрица (B, C, W, K).
Returns:
Рензультат матричного умножения (B, C, H, K).
Example:
>>> mul = OpticalMul(...)
>>> A = torch.rand((1, 1, 256, 256)) > 0.5
>>> B = torch.rand((1, 1, 256, 256)) > 0.5
>>> mul(A, B).shape
torch.Size([1, 1, 256, 256])
>>> A = torch.rand((1, 1, 64, 256)) > 0.5
>>> B = torch.rand((1, 1, 256, 128)) > 0.5
>>> mul(A, B).shape
torch.Size([1, 1, 64, 128])
"""
vec_field = self.prepare_vector(input)
mat_field = self.prepare_matrix(other)
vec_field = self._propagator_one(vec_field)
vec_field = self._propagator_two(vec_field * mat_field)
return self.prepare_out(vec_field)

@ -0,0 +1,97 @@
import torch as _torch
import torch.nn as _nn
from typing import List, Union, Any
from collections.abc import Iterator
class DataParallel(_nn.Module):
"""
Класс параллельного вычисления модели на системе с множеством вычислителей.
Поля:
module: расчётная модель.
devices: список задействующихся вычислителей.
output_device: устройство, в которое будут записаны результаты вычислений.
"""
def __init__(self, module: _nn.Module, devices: Union[None, List[Union[int, _torch.device]]] = None,
output_device: Union[int, _torch.device] = None) -> None:
"""
Конструктор класса.
Args:
module: расчётная модель.
devices: список задействующихся вычислителей.
output_device: устройство, в которое будут записаны результаты вычислений.
"""
super(DataParallel, self).__init__()
if not _torch.cuda.is_available():
raise EnvironmentError("cuda is not available.")
return
if not devices:
devices = [_torch.device(x) for x in range(_torch.cuda.device_count())]
if not output_device:
output_device = devices[0]
self.module = module
self.devices = devices
self.output_device = output_device
def buffers(self, *inputs) -> Iterator[_torch.Tensor]:
'''
Return an iterator over module buffers.
Args:
recurse (bool): if True, then yields buffers of this module
and all submodules. Otherwise, yields only buffers that
are direct members of this module.
Yields:
torch.Tensor: module buffer
'''
return self.module.buffers(*inputs)
def parameters(self, *inputs) -> Iterator[_nn.parameter.Parameter]:
'''
Return an iterator over module parameters.
This is typically passed to an optimizer.
Args:
recurse (bool): if True, then yields parameters of this module
and all submodules. Otherwise, yields only parameters that
are direct members of this module.
Yields:
Parameter: module parameter
'''
return self.module.parameters(*inputs)
def forward(self, input: _torch.Tensor, other: _torch.Tensor, **kwargs: Any) -> _torch.Tensor:
'''
Return an iterator over module parameters.
This is typically passed to an optimizer.
Args:
recurse (bool): if True, then yields parameters of this module
and all submodules. Otherwise, yields only parameters that
are direct members of this module.
Returns:
Parameter: module parameter
'''
dim: int = 2
if 'dim' in kwargs:
dim = kwargs['dim']
scattered_input = _nn.parallel.scatter(input, self.devices, dim)
broadcasted_other = _nn.parallel.comm.broadcast(other, self.devices)
replicas = _nn.parallel.replicate(self.module.to(self.devices[0]), self.devices)
stacked_input = [(scattered_input[i],) + (broadcasted_other[i],) for i in range(len(replicas))]
outputs = _nn.parallel.parallel_apply(replicas, stacked_input)
return _nn.parallel.gather(outputs, self.output_device, dim)

@ -0,0 +1,220 @@
import torch as _torch
import torch.nn as _nn
import numpy as _np
from scipy.special import fresnel as _fresnel
from .config import ConfigOpticBase as _ConfigOpticBase, ConfigDesignPlane as _ConfigDesignPlane
from typing import Tuple as _Tuple, Sequence as _Sequence
from abc import ABC as _ABC
import collections as _collections
class Propagator(_ABC, _nn.Module):
"""
Абстрактный класс вычисления распространения светового поля в среде.
Поля:
operator_X: оператор отображающий распроранение светового поля вдоль оси абсцисс
operator_Y: оператор отображающий распроранение светового поля вдоль оси ординат
"""
def __init__(self, operator_X: _torch.Tensor, operator_Y: _torch.Tensor):
super(Propagator, self).__init__()
operator_X: _torch.Tensor = _torch.view_as_real(operator_X)
operator_Y: _torch.Tensor = _torch.view_as_real(operator_Y)
self.register_buffer('_operator_X', operator_X, persistent=True)
self.register_buffer('_operator_Y', operator_Y, persistent=True)
@property
def operator_X(self) -> _torch.Tensor:
"""
Returns:
оператор отображающий распроранение светового поля вдоль оси абсцисс
"""
return _torch.view_as_complex(self._operator_X)
@property
def operator_Y(self) -> _torch.Tensor:
"""
Returns:
оператор отображающий распроранение светового поля вдоль оси ординат
"""
return _torch.view_as_complex(self._operator_Y)
def __operator_multiplication(self, first_X: _torch.Tensor,
second_X: _torch.Tensor,
first_Y: _torch.Tensor,
second_Y: _torch.Tensor)-> _Tuple[_torch.Tensor, _torch.Tensor]:
operator_Y = second_Y @ first_Y
operator_X = first_X @ second_X
return operator_X, operator_Y
def cat(self, propagators: _Sequence['Propagator']) -> 'Propagator':
"""
Метод схлопывания операторов распространения.
Args:
propagators: последовательность для схлопывания
Returns:
новый пропогатор, заменяющих собой серию предыдущих
Warning:
порядок расположения пропагаторов в последовательности важен,
идёт от первого к последниму
"""
operator_X: _torch.Tensor
operator_Y: _torch.Tensor
if not isinstance(propagators, _collections.abc.Sequence):
operator_X, operator_Y = self.__operator_multiplication(self.operator_X,
propagators.operator_X,
self.operator_Y,
propagators.operator_Y)
else:
size = len(propagators)
operator_X = self.operator_X
operator_Y = self.operator_Y
for i in range(size):
operator_X, operator_Y = self.__operator_multiplication(operator_X,
propagators[i].operator_X,
operator_Y,
propagators[i].operator_Y)
return Propagator(operator_X, operator_Y)
def __add__(self, propagator: 'Propagator') -> 'Propagator':
"""
Метод схлопывания двух пропагаторов.
Args:
propagator: пропагатор с которым нужно произвести схлопывание
Returns:
новый пропогатор, заменяющих собой оба предыдущих
Warning:
операция не комутативная
"""
return self.cat(propagator)
def forward(self, field: _torch.Tensor) -> _torch.Tensor:
"""
Метод распространения светового поля в среде.
Args:
field: распределение комплексной амплитуды светового поля.
Returns:
Распределение комплексной амплитуды светового поля,
после распространения.
"""
return self.operator_Y @ field @ self.operator_X
class PropagatorLens(Propagator):
"""
Абстрактный класс распространения света в тонком оптическом элементе.
"""
def transpose(self) -> 'PropagatorLens':
"""
Метод транспонирования тонкого оптического элемента.
Returns:
Новый элемент, транспонированный относительно оригинального.
"""
obj = Propagator.__new__(PropagatorLens)
Propagator.__init__(obj, self.operator_Y, self.operator_X)
return obj
@property
def T(self) -> 'PropagatorLens':
"""
Returns:
Новый элемент, транспонированный относительно текущего.
"""
return self.transpose()
class PropagatorCrossLens(PropagatorLens):
"""
Класс распространения света в скрещенной линзе,
представленной тонким оптическим элементом.
"""
def __init__(self, plane: _ConfigDesignPlane,
config: _ConfigOpticBase):
"""
Конструктор класса скрещенной линзы.
Args:
plane: данные о расчётной плоскости элемента.
config: данные о световом поле модели.
"""
operator_X = _torch.exp(-1j * config.K / config.distance * plane.linspace_by_x**2)
operator_Y = _torch.exp(-1j * config.K / 2 / config.distance * plane.linspace_by_y**2)
super(PropagatorCrossLens, self).__init__(_torch.diag_embed(operator_X),
_torch.diag_embed(operator_Y))
class PropagatorСylindLens(PropagatorLens):
"""
Класс распространения света в цилиндрической линзе,
представленной тонким оптическим элементом.
"""
def __init__(self, plane: _ConfigDesignPlane,
config: _ConfigOpticBase):
"""
Конструктор класса цилиндрической линзы.
Args:
plane: данные о расчётной плоскости элемента.
config: данные о световом поле модели.
"""
operator_X = _torch.exp(-1j * config.K / config.distance * plane.linspace_by_x**2)
operator_Y = _torch.ones_like(plane.linspace_by_y, dtype=_torch.cfloat)
super(PropagatorСylindLens, self).__init__(_torch.diag_embed(operator_X),
_torch.diag_embed(operator_Y))
class PropagatorSinc(Propagator):
"""
Класс распространения света свободном пространстве
с использованием разложения по базисным sinc функциям.
"""
def __init__(self, first_plane: _ConfigDesignPlane,
second_plane: _ConfigDesignPlane,
config: _ConfigOpticBase):
"""
Конструктор класса распространения в свободном пространстве.
Args:
first_plane: данные о начальной расчётной плоскости.
second_plane: данные о конечной расчётной плоскости.
config: данные о световом поле модели.
"""
operator_X, operator_Y = self.__get_operators(first_plane,
second_plane,
config)
super(PropagatorSinc, self).__init__(operator_X, operator_Y)
def __get_operator_for_dim(self,
pixel_size_in: float,
pixel_size_out: float,
difference: float,
config: _ConfigOpticBase) -> _torch.Tensor:
bndW = 0.5 / pixel_size_in
eikz = (_np.exp(1j * config.K * config.distance)**0.5)
sq2p = (2 / _np.pi)**0.5
sqzk = ((2 * config.distance / config.K)**0.5)
mu1 = -_np.pi * sqzk * bndW - difference / sqzk
mu2 = _np.pi * sqzk * bndW - difference / sqzk
S1, C1 = _fresnel(mu1 * sq2p)
S2, C2 = _fresnel(mu2 * sq2p)
return (((pixel_size_in * pixel_size_out)**0.5 / _np.pi) / sqzk * eikz
* _np.exp(0.5j * difference**2 * config.K / config.distance)
* (C2 - C1 - 1j * (S2 - S1)) / sq2p)
def __get_operators(self,
first_plane: _ConfigDesignPlane,
second_plane: _ConfigDesignPlane,
config: _ConfigOpticBase) -> _Tuple[_torch.Tensor, _torch.Tensor]:
difference_x = first_plane.linspace_by_x[None, :] - second_plane.linspace_by_x[:, None]
difference_y = first_plane.linspace_by_y[None, :] - second_plane.linspace_by_y[:, None]
operator_X = self.__get_operator_for_dim(first_plane.pixel_size_by_x,
second_plane.pixel_size_by_x,
difference_x,
config).transpose(-2, -1)
operator_Y = self.__get_operator_for_dim(first_plane.pixel_size_by_y,
second_plane.pixel_size_by_y,
difference_y,
config)
return operator_X, operator_Y

@ -0,0 +1,177 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import optical_matrix_multiplication as omm
from optical_matrix_multiplication import propagator
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
from pathlib import Path
import sys
torch.manual_seed(1337)
#################################### Model #########################################
def norm(matrix: torch.Tensor, max_val: float = 1) -> torch.Tensor:
return matrix / (max_val + 1e-10)
def optics_matmul_shift(sim, tensor_1, tensor_2):
tensor_1 = tensor_1[None,:,:,:]
tensor_2 = tensor_2[None,:,:,:]
if torch.min(tensor_1) >= 0 and torch.min(tensor_2) >= 0:
max_abs = abs(max(torch.max(tensor_1), torch.max(tensor_2)))
a, b = norm(tensor_1, max_abs), norm(tensor_2, max_abs)
return sim(a, b)[0,:,:,:] * max_abs **2
min_abs = abs(min(torch.min(tensor_1), torch.min(tensor_2)))
max_abs = abs(max(torch.max(tensor_1), torch.max(tensor_2))) + min_abs
shift_a = min_abs * torch.ones(tensor_1.shape).to(tensor_1.device)
shift_b = min_abs * torch.ones(tensor_2.shape).to(tensor_1.device)
a_a_sh = tensor_1 + shift_a
b_b_sh = tensor_2 + shift_b
a_a_sh_norm, b_b_sh_norm = norm(a_a_sh, max_abs), norm(b_b_sh, max_abs)
shift_a_norm, shift_b_norm = norm(shift_a, max_abs), norm(shift_b, max_abs)
a_a_sh_b_b_sh = sim(a_a_sh_norm, b_b_sh_norm)
a_a_sh_b_sh = sim(a_a_sh_norm, shift_b_norm)
a_sh_b_b_sh = sim(shift_a_norm, b_b_sh_norm)
a_sh_b_sh = sim(shift_a_norm, shift_b_norm)
return (a_a_sh_b_b_sh - a_a_sh_b_sh - a_sh_b_b_sh + a_sh_b_sh)[0,:,:,:] * max_abs ** 2
# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1))
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, offset=0):
seq_len = x.size(1)
emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1)
cos = emb.cos()
sin = emb.sin()
return (x * cos) + (self.rotate_half(x) * sin)
# Transformers without Normalization https://jiachenzhu.github.io/DyT/
class DyT(nn.Module):
def __init__(self, num_features, alpha_init_value=0.5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7
# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1
class TransformerLayer(nn.Module):
def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.q_proj = nn.Linear(h_dim, h_dim)
self.k_proj = nn.Linear(h_dim, h_dim)
self.v_proj = nn.Linear(h_dim, h_dim)
self.o_proj = nn.Linear(h_dim, h_dim)
self.ff1 = nn.Linear(h_dim, 4*h_dim)
self.ff2 = nn.Linear(4*h_dim, h_dim)
self.ln1 = DyT(h_dim)
self.ln2 = DyT(h_dim)
self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len)
self.k1 = nn.Parameter(torch.randn(1))
self.k2 = nn.Parameter(torch.randn(1))
def split_to_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads)
def gather_heads(self, x, B, T, H):
if self.num_heads <= 1: return x
return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads)
def attention(self, x):
q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape))
k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape))
v = self.split_to_heads(self.v_proj(x), *x.shape)
scores = self.k1 * optics_matmul_shift(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5)
tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device)
scores = scores.masked_fill(tril == 0, float('-inf'))
attention = nn.functional.softmax(scores, dim=2)
output = self.k2 * optics_matmul_shift(self.sim_output, attention, v)
return self.o_proj(self.gather_heads(output, *x.shape))
def forward(self, x):
x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate)
x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate)
return x
class OpticGPT2(nn.Module):
def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=128, num_heads=4, dropout_rate = 0.1,
pixel_size = 3.6e-6):
super().__init__()
self.__dict__.update({k:v for k,v in locals().items() if k != 'self'})
self.sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim // num_heads,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * (h_dim // num_heads),
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01)
)
self.sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim // num_heads,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * (h_dim // num_heads),
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01)
)
self.layers = nn.ModuleList([
TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads,
dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len)
for _ in range(layers_num)])
self.tok_embeds = nn.Embedding(vocab_size, h_dim)
self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim))
self.lm_head = nn.Linear(h_dim, vocab_size)
def forward(self, x, targets=None):
x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :]
for l in self.layers:
x = l(x)
logits = self.lm_head(x) # B,T,C
loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None
return logits, loss
# what is the purpose? autoregressive inference!
def generate(self, start_idx, max_new_tokens):
idx = start_idx
for i in range(max_new_tokens):
idx_cond = idx[:,-self.max_seq_len:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:] # B, C
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device)
idx = torch.cat([idx, idx_next], dim=1)
return idx
Loading…
Cancel
Save