From a13f526f877f744251384f48281a69daafa6cadc Mon Sep 17 00:00:00 2001 From: Vladimir Protsenko Date: Wed, 29 Oct 2025 11:58:40 +0000 Subject: [PATCH] First commit --- .dockerignore | 2 + .gitattributes | 2 + .gitignore | 219 ++++++++++ Dockerfile | 51 +++ README.md | 3 + build.sh | 8 + src/char_gpt2.py | 115 ++++++ src/main.py | 129 ++++++ src/optical_matrix_multiplication/__init__.py | 9 + src/optical_matrix_multiplication/config.py | 379 ++++++++++++++++++ .../optical_mul.py | 116 ++++++ src/optical_matrix_multiplication/parallel.py | 97 +++++ .../propagator.py | 220 ++++++++++ src/optics_char_gpt2.py | 177 ++++++++ 14 files changed, 1527 insertions(+) create mode 100644 .dockerignore create mode 100644 .gitattributes create mode 100644 .gitignore create mode 100644 Dockerfile create mode 100644 README.md create mode 100755 build.sh create mode 100644 src/char_gpt2.py create mode 100644 src/main.py create mode 100644 src/optical_matrix_multiplication/__init__.py create mode 100644 src/optical_matrix_multiplication/config.py create mode 100644 src/optical_matrix_multiplication/optical_mul.py create mode 100644 src/optical_matrix_multiplication/parallel.py create mode 100644 src/optical_matrix_multiplication/propagator.py create mode 100644 src/optics_char_gpt2.py diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..d0e243a --- /dev/null +++ b/.dockerignore @@ -0,0 +1,2 @@ +data/ +logs/ \ No newline at end of file diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..a058641 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +*.pq filter=lfs diff=lfs merge=lfs -text +*.csv filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..d44313c --- /dev/null +++ b/.gitignore @@ -0,0 +1,219 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[codz] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py.cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +# Pipfile.lock + +# UV +# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# uv.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +# poetry.lock +# poetry.toml + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python. +# https://pdm-project.org/en/latest/usage/project/#working-with-version-control +# pdm.lock +# pdm.toml +.pdm-python +.pdm-build/ + +# pixi +# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control. +# pixi.lock +# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one +# in the .venv directory. It is recommended not to include this directory in version control. +.pixi + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# Redis +*.rdb +*.aof +*.pid + +# RabbitMQ +mnesia/ +rabbitmq/ +rabbitmq-data/ + +# ActiveMQ +activemq-data/ + +# SageMath parsed files +*.sage.py + +# Environments +.env +.envrc +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +# .idea/ + +# Abstra +# Abstra is an AI-powered process automation framework. +# Ignore directories containing user credentials, local state, and settings. +# Learn more at https://abstra.io/docs +.abstra/ + +# Visual Studio Code +# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore +# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore +# and can be added to the global gitignore or merged into this file. However, if you prefer, +# you could uncomment the following to ignore the entire vscode folder +# .vscode/ + +# Ruff stuff: +.ruff_cache/ + +# PyPI configuration file +.pypirc + +# Marimo +marimo/_static/ +marimo/_lsp/ +__marimo__/ + +# Streamlit +.streamlit/secrets.toml + +logs/ +data/ \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..043ea58 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,51 @@ +FROM pytorch/pytorch:2.6.0-cuda12.6-cudnn9-runtime +ENV TZ=Europe/Samara +RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone + +ARG USER +ARG GROUP +ARG UID +ARG GID + +RUN apt update +RUN apt install sudo -y +RUN sed -i 's/^%sudo.*/%sudo ALL=(ALL) NOPASSWD: ALL/' /etc/sudoers + +RUN groupadd --gid ${GID} ${GROUP} +RUN useradd --shell /bin/bash --uid ${UID} --gid ${GID} -G sudo --create-home ${USER} +RUN mkdir /wd +RUN chown ${USER}:${GROUP} /wd + +# SYSTEM CONFIGURATION +RUN apt install wget vim htop mc git tree -y +RUN apt-get install -y libssl-dev autoconf libtool make +RUN cd /usr/local/src && \ + wget https://curl.haxx.se/download/curl-7.88.1.zip && \ + unzip curl-7.88.1.zip && \ + cd curl-7.88.1 && \ + ./buildconf && \ + ./configure --with-ssl && \ + make && \ + make install && \ + cp /usr/local/bin/curl /usr/bin/curl && \ + ldconfig && \ + curl -V +RUN curl -fsSL https://code-server.dev/install.sh | sh +RUN /opt/conda/bin/conda install -n base ipykernel --update-deps --force-reinstall -y + +USER ${USER} + +# USER CONFIGURATION +RUN pip install schedulefree tensorboard opencv-python-headless scipy pandas matplotlib torchmetrics pyarrow einops nvitop + +RUN openssl req -x509 -newkey rsa:4096 -keyout /home/${USER}/key.pem -out /home/${USER}/cert.pem -sha256 -nodes -days 365 -subj "/C=RU/ST=SamaraRegion/L=Samara/O=SSAU/OU=LIAV/CN=vscode.ssau.ru/" +RUN mkdir -p /home/${USER}/.config/code-server +RUN echo 'bind-addr: 0.0.0.0:8443' >> /home/${USER}/.config/code-server/config.yaml +RUN echo "cert: /home/${USER}/cert.pem" >> /home/${USER}/.config/code-server/config.yaml +RUN echo "cert-key: /home/${USER}/key.pem" >> /home/${USER}/.config/code-server/config.yaml + +RUN code-server --install-extension ms-python.python + +ENV SHELL=/bin/bash +SHELL ["/bin/bash", "--login", "-i", "-c"] +WORKDIR /wd diff --git a/README.md b/README.md new file mode 100644 index 0000000..acfd961 --- /dev/null +++ b/README.md @@ -0,0 +1,3 @@ +# Скрипты для обучения gpt2 с симулятором оптического умножения матриц + +Репозиторий симулятора оптического умножения матриц https://github.com/amacomm/Optical_matrix_multiplication \ No newline at end of file diff --git a/build.sh b/build.sh new file mode 100755 index 0000000..7796d61 --- /dev/null +++ b/build.sh @@ -0,0 +1,8 @@ +#!/bin/bash +CURDIRNAME=${PWD##*/} + +docker build . -t ${USER}_${CURDIRNAME}_vscode \ + --build-arg USER=${USER} \ + --build-arg GROUP=${USER} \ + --build-arg UID=$(id -u ${USER}) \ + --build-arg GID=$(id -g ${USER}) diff --git a/src/char_gpt2.py b/src/char_gpt2.py new file mode 100644 index 0000000..25a0af0 --- /dev/null +++ b/src/char_gpt2.py @@ -0,0 +1,115 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +from einops import rearrange +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime +import sys +from pathlib import Path +torch.manual_seed(1337) + +#################################### Model ######################################### + + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = nn.Linear(h_dim, 4*h_dim) + self.ff2 = nn.Linear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = (q @ k.transpose(1, 2)) * (self.h_dim ** -0.5) + tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) + scores = scores.masked_fill(tril == 0, float('-inf')) # encoder does not need this line + attention = nn.functional.softmax(scores, dim=2) + return self.o_proj(self.gather_heads(attention @ v, *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +class GPT2(nn.Module): + def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=128, num_heads=4, dropout_rate = 0.1, pixel_size=None): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.layers = nn.ModuleList([ + TransformerLayer(h_dim=self.h_dim, num_heads=self.num_heads, dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) + for _ in range(layers_num)]) + self.tok_embeds = nn.Embedding(vocab_size, h_dim) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + self.lm_head = nn.Linear(h_dim, vocab_size) + + def forward(self, x, targets=None): + x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + logits = self.lm_head(x) # B,T,C + loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None + return logits, loss + + # what is the purpose? autoregressive inference! + def generate(self, start_idx, max_new_tokens): + idx = start_idx + for i in range(max_new_tokens): + idx_cond = idx[:,-self.max_seq_len:] + logits, loss = self(idx_cond) + logits = logits[:,-1,:] # B, C + probs = F.softmax(logits, dim=-1) + idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) + idx = torch.cat([idx, idx_next], dim=1) + return idx \ No newline at end of file diff --git a/src/main.py b/src/main.py new file mode 100644 index 0000000..8662127 --- /dev/null +++ b/src/main.py @@ -0,0 +1,129 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +from einops import rearrange +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime +import sys +from pathlib import Path +from char_gpt2 import GPT2 +from optics_char_gpt2 import OpticGPT2 +seed = 1337 +torch.manual_seed(seed) +models = {'gpt2': GPT2, 'optic_gpt2': OpticGPT2} + +batch_size = 50 +max_iters = 40000 +eval_interval = 300 +learning_rate = 1e-3 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +eval_iters = 200 +layers_num = 2 +h_dim = 64 +max_seq_len = 256 +num_heads = 1 +dropout_rate = 0.1 +pixel_size = 3.6e-6 + +# CUDA_VISIBLE_DEVICES=1 python .src/main.py gpt2|optic_gpt2 ./data/wiki.train.tokens ./data/wiki.valid.tokens ./data/wiki.test.tokens comment +MODEL_CLASS = models[sys.argv[1]] +train_data_path = Path(sys.argv[2]) +val_data_path = Path(sys.argv[3]) +test_data_path = Path(sys.argv[4]) +comment = f"{sys.argv[1]}_{train_data_path.name}_{sys.argv[5]}_{seed}" + +logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' +writer = SummaryWriter(logs_dir) +script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).name) +print("Logs dir:", logs_dir) +script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script +script_snapshot_path.chmod(0o400) # with read-only permission +#################################### Dataset ######################################### + +# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt +with open(train_data_path, encoding='utf-8') as f: + train_text = f.read() + +with open(val_data_path, encoding='utf-8') as f: + val_text = f.read() + +with open(test_data_path, encoding='utf-8') as f: + test_text = f.read() + +text = train_text + val_text + test_text + +chars = sorted(set(text)) +vocab_size = len(chars) + +wtoi = {w:i for i,w in enumerate(chars)} +itow = {i:w for i,w in enumerate(chars)} + +encode = lambda s: [wtoi[w] for w in s] +decode = lambda idx: ''.join([itow[i] for i in idx]) + +def get_batch(data, seq_len, batch_size): + ix = torch.randint(len(data)-seq_len, (batch_size, )) + x = torch.stack([data[i:i+seq_len] for i in ix]).to(device) + y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device) + return x, y + +train_data = torch.tensor(encode(train_text), dtype=torch.long) +val_data = torch.tensor(encode(val_text), dtype=torch.long) +test_data = torch.tensor(encode(test_text), dtype=torch.long) + +@torch.no_grad() +def perplexity(model, data): + stride = max(1, len(data) // 10000) + losses = [] + for i in range(0, len(data)-max_seq_len-1, stride): + x = data[i:(i+max_seq_len)].to(device) + y = data[(i+1):(i+max_seq_len+1)].to(device) + logits, loss = model(x[None,...], y[None,...]) + losses.append(loss.item()) + print(f"\rppl {i}/{len(data)-max_seq_len-1}", end="") + return np.exp(np.mean(losses)) + +#################################### Model #########################################mo +def complete(m, start_idxs=[0], max_new_tokens=100): + start_idx = torch.tensor([start_idxs]).to(device) + generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens) + return decode(generated_tokens[0].tolist()) + +m = MODEL_CLASS(vocab_size=vocab_size, h_dim=h_dim, max_seq_len=max_seq_len, num_heads=num_heads, pixel_size=pixel_size) +m = m.to(device) +#################################### Train ######################################### + +optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) + +completion = complete(m, encode("\n"*max_seq_len), 2*max_seq_len) +print(completion) +writer.add_text('completions', completion, 0) + +for i in range(max_iters): + xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size) + logits, loss = m(xb, yb) + optimizer.zero_grad(set_to_none=True) + loss.backward() + optimizer.step() + writer.add_scalar('loss', loss.item(), i) + print(f"\r{i}/{max_iters} {loss.item()}", end="") + if i % 5000 == 0: + ppl = perplexity(model=m, data=val_data) + writer.add_scalar('val_perplexity', ppl.item(), i) + print(f"\rPerplexity at {i}: {ppl}") + writer.add_text('completions', complete(m, encode("\n"*max_seq_len), 2*max_seq_len), i) + +ppl = perplexity(model=m, data=val_data) +print(f"\r{i+1}/{max_iters} {loss.item()}") +print(f"\rPerplexity at {i}: {ppl}") +writer.add_scalar('val_perplexity', ppl.item(), i+1) +writer.add_scalar('loss', loss.item(), i) + +ppl = perplexity(model=m, data=test_data) +writer.add_scalar('test_perplexity', ppl.item(), i+1) +print(f"\rTest Perplexity at {i}: {ppl}") + +completion = complete(m, encode("\n"*max_seq_len), 2*max_seq_len) +print(completion) +writer.add_text('completions', completion, i+1) \ No newline at end of file diff --git a/src/optical_matrix_multiplication/__init__.py b/src/optical_matrix_multiplication/__init__.py new file mode 100644 index 0000000..9974821 --- /dev/null +++ b/src/optical_matrix_multiplication/__init__.py @@ -0,0 +1,9 @@ +__all__ = ["config", + "propagator", + "optical_mul"] +__version__ = "3.0.0" + +from .config import Config +from . import propagator +from .optical_mul import OpticalMul +from .parallel import DataParallel \ No newline at end of file diff --git a/src/optical_matrix_multiplication/config.py b/src/optical_matrix_multiplication/config.py new file mode 100644 index 0000000..26c8589 --- /dev/null +++ b/src/optical_matrix_multiplication/config.py @@ -0,0 +1,379 @@ +import torch as _torch +import numpy as _np +from abc import ABC as _ABC +from typing import Tuple as _Tuple +import collections as _collections +import pickle as _pickle + +class ConfigOpticBase(_ABC): + """ + Абстрактный класс базовой информации о работе опттической установки. + + Поля: + wavelength: длина волны, используемая в оптической установке. + K: волновое число. + distance: дистанция распространения светового поля между плоскостями. + """ + def __init__(self, wavelength: float, distance: float): + """ + Конструктор класса базовой информации о работе опттической установки. + + Args: + wavelength: длина волны, используемая в оптической установке. + distance: дистанция распространения светового поля между плоскостями. + """ + self._wavelength: float = wavelength + self._K: float = 2 * _np.pi / wavelength + self._distance = distance + + @property + def wavelength(self) -> float: + """ + Returns: длинна волны. + """ + return self._wavelength + + @property + def K(self) -> float: + """ + Returns: волновое число. + """ + return self._K + + @property + def distance(self) -> float: + """ + Returns: дистанция распространения. + """ + return self._distance + +class ConfigDesignPlane: + """ + Класс данных о расчётной плоскости. + """ + def __init__(self, + pixel_count: None | int | _Tuple[int, int] = None, + pixel_size: None | float | _Tuple[float, float] = None, + aperture: None | float | _Tuple[float, float] = None + ): + """ + Конструктор класса данных о расчётной плоскости. + + Args: + pixel_count: данные о кол-ве пикселей в измерении. + pixel_size: данные о размере пикселей в измерении. + aperture: данные о апертуре измерения. + + Note: + 1. Достаточно указать два из трёх входных аргумента. + 2. Если было переданно одно число, будет считаться, что значение у каждого измерения совпадают, + если было переданно два числа, то первое число будет ассоциированно с Y измерением, второе с X. + """ + if aperture is None: + aperture = ConfigDesignPlane.__get_excluded_third(pixel_count, pixel_size, False) + elif pixel_count is None: + pixel_count = ConfigDesignPlane.__get_excluded_third(aperture, pixel_size) + elif pixel_size is None: + pixel_size = ConfigDesignPlane.__get_excluded_third(aperture, pixel_count) + + self._pixel_count: int | _Tuple[int, int] = pixel_count + self._pixel_size: float | _Tuple[float, float] = pixel_size + self._aperture: float | _Tuple[float, float] = aperture + + @staticmethod + def __get_excluded_third(first_element, second_element, div: bool = True): + ConfigDesignPlane.__check_for_none(first_element, second_element) + first_element_y, first_element_x = ConfigDesignPlane.__return_tuple(first_element) + second_element_y, second_element_x = ConfigDesignPlane.__return_tuple(second_element) + if div: + return first_element_y / second_element_y, first_element_x / second_element_x + else: + return first_element_y * second_element_y, first_element_x * second_element_x + + @staticmethod + def __check_for_none(first_element, second_element): + if (first_element is None) or (second_element is None): + raise TypeError("One of the provided elements is None, it is not possible to obtain the full dimensions of the calculated plane.") + + @staticmethod + def __return_element(element: int | _Tuple[int, int], dim: int = 0): + if isinstance(element, _collections.abc.Sequence): + return element[dim] + else: + return element + + @staticmethod + def __return_tuple(element: int | _Tuple[int, int]): + if isinstance(element, _collections.abc.Sequence): + return element + else: + return element, element + + @staticmethod + def __get_linspace_by_dim(aperture, pixel_count): + linspace = _torch.linspace(-aperture / 2, aperture / 2, pixel_count + 1)[:pixel_count] + linspace += aperture / (2 * pixel_count) + return linspace + + @property + def pixel_count_by_x(self) -> float: + """ + Returns: + Информация о кол-ве пикселей по оси X. + """ + return ConfigDesignPlane.__return_element(self._pixel_count, 1) + @property + def pixel_count_by_y(self) -> float: + """ + Returns: + Информация о кол-ве пикселей по оси Y. + """ + return ConfigDesignPlane.__return_element(self._pixel_count) + @property + def pixel_count(self) -> _Tuple[float, float]: + """ + Returns: + Информация о кол-ве пикселей по каждой оси [Y, x]. + """ + return ConfigDesignPlane.__return_tuple(self._pixel_count) + @property + def pixel_size_by_x(self) -> int: + """ + Returns: + Информация о размере пикселей по оси X. + """ + return ConfigDesignPlane.__return_element(self._pixel_size, 1) + @property + def pixel_size_by_y(self) -> int: + """ + Returns: + Информация о размере пикселей по оси Y. + """ + return ConfigDesignPlane.__return_element(self._pixel_size) + @property + def pixel_size(self) -> _Tuple[int, int]: + """ + Returns: + Информация о размере пикселей по каждой оси [Y, x]. + """ + return ConfigDesignPlane.__return_tuple(self._pixel_size) + @property + def aperture_width(self) -> float: + """ + Returns: + Информация о ширине расчётной плоскасти. + """ + return ConfigDesignPlane.__return_element(self._aperture, 1) + @property + def aperture_height(self) -> float: + """ + Returns: + Информация о высоте расчётной плоскасти. + """ + return ConfigDesignPlane.__return_element(self._aperture) + @property + def aperture(self) -> _Tuple[float, float]: + """ + Returns: + Информация о высоте и ширине расчётной плоскасти [H, W]. + """ + return ConfigDesignPlane.__return_tuple(self._aperture) + @property + def linspace_by_x(self) -> _torch.Tensor: + """ + Returns: + Расчётная сетка по оси X. + """ + return ConfigDesignPlane.__get_linspace_by_dim(self.aperture_width, self.pixel_count_by_x) + @property + def linspace_by_y(self) -> _torch.Tensor: + """ + Returns: + Расчётная сетка по оси Y. + """ + return ConfigDesignPlane.__get_linspace_by_dim(self.aperture_height, self.pixel_count_by_y) + @property + def meshgrid(self) -> _Tuple[_torch.Tensor, _torch.Tensor]: + """ + Returns: + Расчётная сетка по осям [Y, X]. + """ + linspace_by_x = self.linspace_by_x + linspace_by_y = self.linspace_by_y + return _torch.meshgrid((linspace_by_y, linspace_by_x)) + +class ConfigModelBase(_ABC): + """ + Абстрактный класс базовой информации об оптической установке. + """ + def __init__(self, + input_vector_plane: ConfigDesignPlane, + first_lens_plane: ConfigDesignPlane, + matrix_plane: ConfigDesignPlane, + second_lens_plane: ConfigDesignPlane, + output_vector_plane: ConfigDesignPlane + ): + self._input_vector_plane: ConfigDesignPlane = input_vector_plane + self._first_lens_plane: ConfigDesignPlane = first_lens_plane + self._matrix_plane: ConfigDesignPlane = matrix_plane + self._second_lens_plane: ConfigDesignPlane = second_lens_plane + self._output_vector_plane: ConfigDesignPlane = output_vector_plane + + @property + def input_vector_plane(self) -> ConfigDesignPlane: + """ + Returns: + Информация о расчётной плоскости входного вектора. + """ + return self._input_vector_plane + @property + def first_lens_plane(self) -> ConfigDesignPlane: + """ + Returns: + Информация о расчётной плоскости первой скрещенной линзы. + """ + return self._first_lens_plane + @property + def matrix_plane(self) -> ConfigDesignPlane: + """ + Returns: + Информация о расчётной плоскости элемента матрицы. + """ + return self._matrix_plane + @property + def second_lens_plane(self) -> ConfigDesignPlane: + """ + Returns: + Информация о расчётной плоскости второй скрещенной линзы. + """ + return self._second_lens_plane + @property + def output_vector_plane(self) -> ConfigDesignPlane: + """ + Returns: + Информация о расчётной плоскости выходного вектора оптической установки. + """ + return self._output_vector_plane + +class Config(ConfigOpticBase, ConfigModelBase): + """ + Класс конфигурации, хранит полную информацию о расчётной системе. + """ + def __init__(self, + right_matrix_count_columns: int, + right_matrix_count_rows: int, + right_matrix_width: float, + right_matrix_height: float, + min_height_gap: float, + right_matrix_split_x: int = 1, + right_matrix_split_y: int = 1, + left_matrix_split_x: int = 1, + left_matrix_split_y: int = 1, + result_matrix_split: int = 1, + camera_pixel_size: float = 3.6e-6, + wavelength: float = 532e-9, + distance: float = 0.03, + lens_pixel_size: float = 1.8e-6, + lens_size: int = 8192): + """ + Конструктор класса. + + Args: + right_matrix_count_columns: число столбцов в правой матрице, участвующей в операции матричного умножения. + right_matrix_count_rows: число строк в правой матрице, участвующей в операции матричного умножения. + right_matrix_width: ширина в метрах оптического элемента правой матрицы, участвующей в операции матричного умножения. + right_matrix_height: высота в метрах оптического элемента правой матрицы, участвующей в операции матричного умножения. + min_height_gap: минимально возможный зазор для отображения вектора левой матрицы, участвующей в операции матричного умножения. + right_matrix_split_x: число дробления элементов правой матрицы по X (используется для более точного моделирования). + right_matrix_split_y: число дробления элементов правой матрицы по Y (используется для более точного моделирования). + left_matrix_split_x: число дробления элементов левой матрицы по X (используется для более точного моделирования). + left_matrix_split_y: число дробления элементов левой матрицы по Y (используется для более точного моделирования). + result_matrix_split: число дробления элементов результирующей матрицы (используется для более точного моделирования). + camera_pixel_size: физический размер пикселя камеры, считывающей результирующее световое поле. + wavelength: длины волн в метрах используемых в системе. + distance: дистанция в метрах распространения светового поля между плоскостями. + lens_pixel_size: размер пикселя в метрах скрещенных линз в оптической системе (нужен исключительно для моделирования). + lens_size: размер скрещенных линз в метрах в оптической системе (нужен исключительно для моделирования). + """ + ConfigOpticBase.__init__(self, wavelength, distance) + + config_plane_one = ConfigDesignPlane((left_matrix_split_y, left_matrix_split_x * right_matrix_count_rows), + aperture=(min_height_gap, right_matrix_height) + ) + config_plane_lens = ConfigDesignPlane(lens_size, lens_pixel_size) + config_plane_three = ConfigDesignPlane((right_matrix_count_columns * left_matrix_split_x, right_matrix_count_rows * right_matrix_split_y), + aperture=(right_matrix_width, right_matrix_height) + ) + config_plane_five = ConfigDesignPlane((right_matrix_count_columns * result_matrix_split, 1), + aperture=(right_matrix_width, camera_pixel_size) + ) + ConfigModelBase.__init__(self, + config_plane_one, + config_plane_lens, + config_plane_three, + config_plane_lens, + config_plane_five + ) + + self._matrix_split_x: int = right_matrix_split_x + self._matrix_split_y: int = right_matrix_split_y + self._input_vector_split_x: int = left_matrix_split_x + self._input_vector_split_y: int = left_matrix_split_y + self._result_vector_split: int = result_matrix_split + + @property + def matrix_split_x(self) -> int: + """ + Returns: + Информация о разбиении пикселей элементов матрицы по оси X. + """ + return self._matrix_split_x + @property + def matrix_split_y(self) -> int: + """ + Returns: + Информация о разбиении пикселей элементов матрицы по оси Y. + """ + return self._matrix_split_y + @property + def input_vector_split_x(self) -> int: + """ + Returns: + Информация о разбиении пикселей элементов входного вектора по оси X. + """ + return self._input_vector_split_x + @property + def input_vector_split_y(self) -> int: + """ + Returns: + Информация о разбиении пикселей элементов входного вектора по оси Y. + """ + return self._input_vector_split_y + @property + def result_vector_split(self) -> int: + """ + Returns: + Информация о разбиении пикселей элементов выходного вектора по оси Y. + """ + return self._result_vector_split + + def save(self, filename: str = "config.pth"): + """ + Метод сохранения параметров конфигурации в файл. + + Args: + filename: название файла с параметрами конфигурации. + """ + with open(filename, 'wb') as f: + _pickle.dump(self, f, protocol=_pickle.HIGHEST_PROTOCOL) + + @staticmethod + def load(filename: str = "config.pth") -> 'Config': + """ + Метод загрузки параметров конфигурации из файла. + + Args: + filename: название файла с параметрами конфигурации. + """ + with open(filename, 'rb') as f: + return _pickle.load(f) diff --git a/src/optical_matrix_multiplication/optical_mul.py b/src/optical_matrix_multiplication/optical_mul.py new file mode 100644 index 0000000..4985adc --- /dev/null +++ b/src/optical_matrix_multiplication/optical_mul.py @@ -0,0 +1,116 @@ +import torch as _torch +import torch.nn as _nn +from .config import Config as _Config +from .propagator import PropagatorCrossLens as _PropCrossLens, PropagatorСylindLens as _PropСylindLens, PropagatorSinc as _PropSinc, Propagator as _Prop + +class OpticalMul(_nn.Module): + """ + Класс системы, выполняющей оптически операцию умножения матрицы на матрицу. + """ + def __init__(self, config: _Config): + """ + Конструктор класса. + + Args: + config: конфигурация расчётной системы. + """ + super(OpticalMul, self).__init__() + + prop_one = _PropSinc(config.input_vector_plane, config.first_lens_plane, config) + prop_two = _PropCrossLens(config.first_lens_plane, config) + prop_three = _PropSinc(config.first_lens_plane, config.matrix_plane, config) + prop_four = _PropСylindLens(config.matrix_plane, config) + prop_five = _PropSinc(config.matrix_plane, config.second_lens_plane, config) + prop_six = _PropCrossLens(config.second_lens_plane, config).T + prop_seven = _PropSinc(config.second_lens_plane, config.output_vector_plane, config) + + self._propagator_one: _Prop = prop_one + prop_two + prop_three + prop_four + self._propagator_two: _Prop = prop_five + prop_six + prop_seven + + kron_vec_utils = _torch.ones((config.input_vector_split_y, config.input_vector_split_x)) + kron_mat_utils = _torch.ones((config.matrix_split_x, config.matrix_split_y)) + self.register_buffer('_kron_vec_utils', kron_vec_utils, persistent=True) + self.register_buffer('_kron_mat_utils', kron_mat_utils, persistent=True) + + self._avg_pool = _nn.AvgPool2d((1, config.result_vector_split)) + + def prepare_vector(self, data: _torch.Tensor) -> _torch.Tensor: + """ + Метод подготовки матрицы левой матрицы, как набора векторов столбцов, к подаче на вход системы. + + Args: + data: матрица комплексной амплитуды распределений световых полей. + + Returns: + Матрицы содержащие вектора левой матрицы. + """ + data = data.cfloat().flip(-1) + data = data.unsqueeze(-2) + data = _torch.kron(data.contiguous(), self._kron_vec_utils) + return data + + def prepare_matrix(self, data: _torch.Tensor) -> _torch.Tensor: + """ + Метод подготовки правой матрицы к подаче на вход системы. + + Args: + data: матрица комплексной амплитуды распределения светового поля. + + Returns: + Матрица - оптический элемент в центре модели. + """ + if (data.dim() > 4) and data.size(-1) == 2: + data = _torch.view_as_complex(data) + + data = data.cfloat().transpose(-2, -1) + data = data.unsqueeze(-3) + data = _torch.kron(data.contiguous(), self._kron_mat_utils) + return data + + def prepare_out(self, field: _torch.Tensor) -> _torch.Tensor: + """ + Метод получения результата матричного умножения. + + Args: + data: матрицы выходого распределения светового поля системы. + + Returns: + Вектор столбец (амплитудное распределение). + """ + ### Закоментированная часть кода - более физически корректный вариант работы модели, + ### однако, данный вариант кода будет требовать большое кол-во памяти во время обучения + field = field.abs().squeeze(-1) #**2 + field = self._avg_pool(field) + return field.flip(-1) #**0.5 + + def forward(self, + input: _torch.Tensor, + other: _torch.Tensor) -> _torch.Tensor: + """ + Метод выполения матричного умножения. + + Args: + input: матрица (B, C, H, W). + other: матрица (B, C, W, K). + + Returns: + Рензультат матричного умножения (B, C, H, K). + + Example: + >>> mul = OpticalMul(...) + >>> A = torch.rand((1, 1, 256, 256)) > 0.5 + >>> B = torch.rand((1, 1, 256, 256)) > 0.5 + >>> mul(A, B).shape + torch.Size([1, 1, 256, 256]) + >>> A = torch.rand((1, 1, 64, 256)) > 0.5 + >>> B = torch.rand((1, 1, 256, 128)) > 0.5 + >>> mul(A, B).shape + torch.Size([1, 1, 64, 128]) + """ + vec_field = self.prepare_vector(input) + mat_field = self.prepare_matrix(other) + + vec_field = self._propagator_one(vec_field) + vec_field = self._propagator_two(vec_field * mat_field) + + return self.prepare_out(vec_field) \ No newline at end of file diff --git a/src/optical_matrix_multiplication/parallel.py b/src/optical_matrix_multiplication/parallel.py new file mode 100644 index 0000000..631a1cd --- /dev/null +++ b/src/optical_matrix_multiplication/parallel.py @@ -0,0 +1,97 @@ +import torch as _torch +import torch.nn as _nn +from typing import List, Union, Any +from collections.abc import Iterator + + +class DataParallel(_nn.Module): + """ + Класс параллельного вычисления модели на системе с множеством вычислителей. + + Поля: + module: расчётная модель. + devices: список задействующихся вычислителей. + output_device: устройство, в которое будут записаны результаты вычислений. + """ + def __init__(self, module: _nn.Module, devices: Union[None, List[Union[int, _torch.device]]] = None, + output_device: Union[int, _torch.device] = None) -> None: + """ + Конструктор класса. + + Args: + module: расчётная модель. + devices: список задействующихся вычислителей. + output_device: устройство, в которое будут записаны результаты вычислений. + """ + super(DataParallel, self).__init__() + + if not _torch.cuda.is_available(): + raise EnvironmentError("cuda is not available.") + return + + if not devices: + devices = [_torch.device(x) for x in range(_torch.cuda.device_count())] + + if not output_device: + output_device = devices[0] + + self.module = module + self.devices = devices + self.output_device = output_device + + def buffers(self, *inputs) -> Iterator[_torch.Tensor]: + ''' + Return an iterator over module buffers. + + Args: + recurse (bool): if True, then yields buffers of this module + and all submodules. Otherwise, yields only buffers that + are direct members of this module. + + Yields: + torch.Tensor: module buffer + ''' + return self.module.buffers(*inputs) + + def parameters(self, *inputs) -> Iterator[_nn.parameter.Parameter]: + ''' + Return an iterator over module parameters. + + This is typically passed to an optimizer. + + Args: + recurse (bool): if True, then yields parameters of this module + and all submodules. Otherwise, yields only parameters that + are direct members of this module. + + Yields: + Parameter: module parameter + ''' + return self.module.parameters(*inputs) + + def forward(self, input: _torch.Tensor, other: _torch.Tensor, **kwargs: Any) -> _torch.Tensor: + ''' + Return an iterator over module parameters. + + This is typically passed to an optimizer. + + Args: + recurse (bool): if True, then yields parameters of this module + and all submodules. Otherwise, yields only parameters that + are direct members of this module. + + Returns: + Parameter: module parameter + ''' + dim: int = 2 + if 'dim' in kwargs: + dim = kwargs['dim'] + + scattered_input = _nn.parallel.scatter(input, self.devices, dim) + broadcasted_other = _nn.parallel.comm.broadcast(other, self.devices) + replicas = _nn.parallel.replicate(self.module.to(self.devices[0]), self.devices) + stacked_input = [(scattered_input[i],) + (broadcasted_other[i],) for i in range(len(replicas))] + + outputs = _nn.parallel.parallel_apply(replicas, stacked_input) + + return _nn.parallel.gather(outputs, self.output_device, dim) \ No newline at end of file diff --git a/src/optical_matrix_multiplication/propagator.py b/src/optical_matrix_multiplication/propagator.py new file mode 100644 index 0000000..c667bd6 --- /dev/null +++ b/src/optical_matrix_multiplication/propagator.py @@ -0,0 +1,220 @@ +import torch as _torch +import torch.nn as _nn +import numpy as _np +from scipy.special import fresnel as _fresnel +from .config import ConfigOpticBase as _ConfigOpticBase, ConfigDesignPlane as _ConfigDesignPlane +from typing import Tuple as _Tuple, Sequence as _Sequence + +from abc import ABC as _ABC +import collections as _collections + +class Propagator(_ABC, _nn.Module): + """ + Абстрактный класс вычисления распространения светового поля в среде. + + Поля: + operator_X: оператор отображающий распроcтранение светового поля вдоль оси абсцисс + operator_Y: оператор отображающий распроcтранение светового поля вдоль оси ординат + """ + def __init__(self, operator_X: _torch.Tensor, operator_Y: _torch.Tensor): + super(Propagator, self).__init__() + operator_X: _torch.Tensor = _torch.view_as_real(operator_X) + operator_Y: _torch.Tensor = _torch.view_as_real(operator_Y) + self.register_buffer('_operator_X', operator_X, persistent=True) + self.register_buffer('_operator_Y', operator_Y, persistent=True) + + @property + def operator_X(self) -> _torch.Tensor: + """ + Returns: + оператор отображающий распроcтранение светового поля вдоль оси абсцисс + """ + return _torch.view_as_complex(self._operator_X) + @property + def operator_Y(self) -> _torch.Tensor: + """ + Returns: + оператор отображающий распроcтранение светового поля вдоль оси ординат + """ + return _torch.view_as_complex(self._operator_Y) + + def __operator_multiplication(self, first_X: _torch.Tensor, + second_X: _torch.Tensor, + first_Y: _torch.Tensor, + second_Y: _torch.Tensor)-> _Tuple[_torch.Tensor, _torch.Tensor]: + operator_Y = second_Y @ first_Y + operator_X = first_X @ second_X + return operator_X, operator_Y + + def cat(self, propagators: _Sequence['Propagator']) -> 'Propagator': + """ + Метод схлопывания операторов распространения. + + Args: + propagators: последовательность для схлопывания + + Returns: + новый пропогатор, заменяющих собой серию предыдущих + + Warning: + порядок расположения пропагаторов в последовательности важен, + идёт от первого к последниму + """ + operator_X: _torch.Tensor + operator_Y: _torch.Tensor + if not isinstance(propagators, _collections.abc.Sequence): + operator_X, operator_Y = self.__operator_multiplication(self.operator_X, + propagators.operator_X, + self.operator_Y, + propagators.operator_Y) + else: + size = len(propagators) + operator_X = self.operator_X + operator_Y = self.operator_Y + for i in range(size): + operator_X, operator_Y = self.__operator_multiplication(operator_X, + propagators[i].operator_X, + operator_Y, + propagators[i].operator_Y) + return Propagator(operator_X, operator_Y) + + def __add__(self, propagator: 'Propagator') -> 'Propagator': + """ + Метод схлопывания двух пропагаторов. + Args: + propagator: пропагатор с которым нужно произвести схлопывание + + Returns: + новый пропогатор, заменяющих собой оба предыдущих + + Warning: + операция не комутативная + """ + return self.cat(propagator) + + def forward(self, field: _torch.Tensor) -> _torch.Tensor: + """ + Метод распространения светового поля в среде. + + Args: + field: распределение комплексной амплитуды светового поля. + + Returns: + Распределение комплексной амплитуды светового поля, + после распространения. + """ + return self.operator_Y @ field @ self.operator_X + +class PropagatorLens(Propagator): + """ + Абстрактный класс распространения света в тонком оптическом элементе. + """ + def transpose(self) -> 'PropagatorLens': + """ + Метод транспонирования тонкого оптического элемента. + Returns: + Новый элемент, транспонированный относительно оригинального. + """ + obj = Propagator.__new__(PropagatorLens) + Propagator.__init__(obj, self.operator_Y, self.operator_X) + return obj + + @property + def T(self) -> 'PropagatorLens': + """ + Returns: + Новый элемент, транспонированный относительно текущего. + """ + return self.transpose() + +class PropagatorCrossLens(PropagatorLens): + """ + Класс распространения света в скрещенной линзе, + представленной тонким оптическим элементом. + """ + def __init__(self, plane: _ConfigDesignPlane, + config: _ConfigOpticBase): + """ + Конструктор класса скрещенной линзы. + + Args: + plane: данные о расчётной плоскости элемента. + config: данные о световом поле модели. + """ + operator_X = _torch.exp(-1j * config.K / config.distance * plane.linspace_by_x**2) + operator_Y = _torch.exp(-1j * config.K / 2 / config.distance * plane.linspace_by_y**2) + super(PropagatorCrossLens, self).__init__(_torch.diag_embed(operator_X), + _torch.diag_embed(operator_Y)) + +class PropagatorСylindLens(PropagatorLens): + """ + Класс распространения света в цилиндрической линзе, + представленной тонким оптическим элементом. + """ + def __init__(self, plane: _ConfigDesignPlane, + config: _ConfigOpticBase): + """ + Конструктор класса цилиндрической линзы. + + Args: + plane: данные о расчётной плоскости элемента. + config: данные о световом поле модели. + """ + operator_X = _torch.exp(-1j * config.K / config.distance * plane.linspace_by_x**2) + operator_Y = _torch.ones_like(plane.linspace_by_y, dtype=_torch.cfloat) + super(PropagatorСylindLens, self).__init__(_torch.diag_embed(operator_X), + _torch.diag_embed(operator_Y)) + +class PropagatorSinc(Propagator): + """ + Класс распространения света свободном пространстве + с использованием разложения по базисным sinc функциям. + """ + def __init__(self, first_plane: _ConfigDesignPlane, + second_plane: _ConfigDesignPlane, + config: _ConfigOpticBase): + """ + Конструктор класса распространения в свободном пространстве. + + Args: + first_plane: данные о начальной расчётной плоскости. + second_plane: данные о конечной расчётной плоскости. + config: данные о световом поле модели. + """ + operator_X, operator_Y = self.__get_operators(first_plane, + second_plane, + config) + super(PropagatorSinc, self).__init__(operator_X, operator_Y) + + def __get_operator_for_dim(self, + pixel_size_in: float, + pixel_size_out: float, + difference: float, + config: _ConfigOpticBase) -> _torch.Tensor: + bndW = 0.5 / pixel_size_in + eikz = (_np.exp(1j * config.K * config.distance)**0.5) + sq2p = (2 / _np.pi)**0.5 + sqzk = ((2 * config.distance / config.K)**0.5) + mu1 = -_np.pi * sqzk * bndW - difference / sqzk + mu2 = _np.pi * sqzk * bndW - difference / sqzk + S1, C1 = _fresnel(mu1 * sq2p) + S2, C2 = _fresnel(mu2 * sq2p) + return (((pixel_size_in * pixel_size_out)**0.5 / _np.pi) / sqzk * eikz + * _np.exp(0.5j * difference**2 * config.K / config.distance) + * (C2 - C1 - 1j * (S2 - S1)) / sq2p) + + def __get_operators(self, + first_plane: _ConfigDesignPlane, + second_plane: _ConfigDesignPlane, + config: _ConfigOpticBase) -> _Tuple[_torch.Tensor, _torch.Tensor]: + difference_x = first_plane.linspace_by_x[None, :] - second_plane.linspace_by_x[:, None] + difference_y = first_plane.linspace_by_y[None, :] - second_plane.linspace_by_y[:, None] + operator_X = self.__get_operator_for_dim(first_plane.pixel_size_by_x, + second_plane.pixel_size_by_x, + difference_x, + config).transpose(-2, -1) + operator_Y = self.__get_operator_for_dim(first_plane.pixel_size_by_y, + second_plane.pixel_size_by_y, + difference_y, + config) + return operator_X, operator_Y diff --git a/src/optics_char_gpt2.py b/src/optics_char_gpt2.py new file mode 100644 index 0000000..7e77118 --- /dev/null +++ b/src/optics_char_gpt2.py @@ -0,0 +1,177 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +from einops import rearrange +import optical_matrix_multiplication as omm +from optical_matrix_multiplication import propagator +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime +from pathlib import Path +import sys +torch.manual_seed(1337) + +#################################### Model ######################################### + +def norm(matrix: torch.Tensor, max_val: float = 1) -> torch.Tensor: + return matrix / (max_val + 1e-10) + +def optics_matmul_shift(sim, tensor_1, tensor_2): + tensor_1 = tensor_1[None,:,:,:] + tensor_2 = tensor_2[None,:,:,:] + + if torch.min(tensor_1) >= 0 and torch.min(tensor_2) >= 0: + max_abs = abs(max(torch.max(tensor_1), torch.max(tensor_2))) + a, b = norm(tensor_1, max_abs), norm(tensor_2, max_abs) + return sim(a, b)[0,:,:,:] * max_abs **2 + + min_abs = abs(min(torch.min(tensor_1), torch.min(tensor_2))) + max_abs = abs(max(torch.max(tensor_1), torch.max(tensor_2))) + min_abs + + shift_a = min_abs * torch.ones(tensor_1.shape).to(tensor_1.device) + shift_b = min_abs * torch.ones(tensor_2.shape).to(tensor_1.device) + a_a_sh = tensor_1 + shift_a + b_b_sh = tensor_2 + shift_b + + a_a_sh_norm, b_b_sh_norm = norm(a_a_sh, max_abs), norm(b_b_sh, max_abs) + shift_a_norm, shift_b_norm = norm(shift_a, max_abs), norm(shift_b, max_abs) + a_a_sh_b_b_sh = sim(a_a_sh_norm, b_b_sh_norm) + a_a_sh_b_sh = sim(a_a_sh_norm, shift_b_norm) + a_sh_b_b_sh = sim(shift_a_norm, b_b_sh_norm) + a_sh_b_sh = sim(shift_a_norm, shift_b_norm) + + return (a_a_sh_b_b_sh - a_a_sh_b_sh - a_sh_b_b_sh + a_sh_b_sh)[0,:,:,:] * max_abs ** 2 + +# RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864 +class RoPE(nn.Module): + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_seq_len).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) + self.register_buffer('emb', torch.cat([freqs, freqs], dim=-1)) + + def rotate_half(self, x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def forward(self, x, offset=0): + seq_len = x.size(1) + emb = self.emb[offset:offset+seq_len].view(1, seq_len, -1) + cos = emb.cos() + sin = emb.sin() + return (x * cos) + (self.rotate_half(x) * sin) + +# Transformers without Normalization https://jiachenzhu.github.io/DyT/ +class DyT(nn.Module): + def __init__(self, num_features, alpha_init_value=0.5): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + x = torch.tanh(self.alpha * x) + return x * self.weight + self.bias + +# Attention Is All You Need https://arxiv.org/pdf/1706.03762v7 +# NeoBERT: A Next-Generation BERT https://arxiv.org/html/2502.19587v1 +class TransformerLayer(nn.Module): + def __init__(self, h_dim, sim_scores, sim_output, num_heads=4, dropout_rate = 0.1, max_seq_len = 128): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.q_proj = nn.Linear(h_dim, h_dim) + self.k_proj = nn.Linear(h_dim, h_dim) + self.v_proj = nn.Linear(h_dim, h_dim) + self.o_proj = nn.Linear(h_dim, h_dim) + self.ff1 = nn.Linear(h_dim, 4*h_dim) + self.ff2 = nn.Linear(4*h_dim, h_dim) + self.ln1 = DyT(h_dim) + self.ln2 = DyT(h_dim) + self.rope = RoPE(dim=h_dim//self.num_heads, max_seq_len=max_seq_len) + self.k1 = nn.Parameter(torch.randn(1)) + self.k2 = nn.Parameter(torch.randn(1)) + + def split_to_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, 'b t (n h) -> (b n) t h', b=B, t=T, n=self.num_heads) + + def gather_heads(self, x, B, T, H): + if self.num_heads <= 1: return x + return rearrange(x, '(b n) t h -> b t (n h)', b=B, t=T, n=self.num_heads) + + def attention(self, x): + q = self.rope(self.split_to_heads(self.q_proj(x), *x.shape)) + k = self.rope(self.split_to_heads(self.k_proj(x), *x.shape)) + v = self.split_to_heads(self.v_proj(x), *x.shape) + scores = self.k1 * optics_matmul_shift(self.sim_scores, q, k.transpose(1, 2)) * (self.h_dim ** -0.5) + tril = torch.tril(torch.ones(x.shape[1],x.shape[1])).to(self.q_proj.bias.device) + scores = scores.masked_fill(tril == 0, float('-inf')) + attention = nn.functional.softmax(scores, dim=2) + output = self.k2 * optics_matmul_shift(self.sim_output, attention, v) + return self.o_proj(self.gather_heads(output, *x.shape)) + + def forward(self, x): + x = x + F.dropout1d(self.attention(self.ln1(x)), p=self.dropout_rate) + x = x + F.dropout1d(self.ff2(F.gelu(self.ff1(self.ln2(x)))), p=self.dropout_rate) + return x + +class OpticGPT2(nn.Module): + def __init__(self, vocab_size, layers_num=1, h_dim=64, max_seq_len=128, num_heads=4, dropout_rate = 0.1, + pixel_size = 3.6e-6): + super().__init__() + self.__dict__.update({k:v for k,v in locals().items() if k != 'self'}) + self.sim_scores = omm.OpticalMul( + omm.Config(right_matrix_count_columns = max_seq_len, + right_matrix_count_rows = h_dim // num_heads, + right_matrix_width = pixel_size * max_seq_len, + right_matrix_height = pixel_size * (h_dim // num_heads), + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01) + ) + + self.sim_output = omm.OpticalMul( + omm.Config(right_matrix_count_columns = h_dim // num_heads, + right_matrix_count_rows = max_seq_len, + right_matrix_width = pixel_size * (h_dim // num_heads), + right_matrix_height = pixel_size * max_seq_len, + min_height_gap = pixel_size, + right_matrix_split_x = 2, + right_matrix_split_y = 2, + left_matrix_split_x = 2, + left_matrix_split_y = 2, + result_matrix_split = 2, + distance = 0.01) + ) + self.layers = nn.ModuleList([ + TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads, + dropout_rate=self.dropout_rate, max_seq_len=self.max_seq_len) + for _ in range(layers_num)]) + self.tok_embeds = nn.Embedding(vocab_size, h_dim) + self.pos_embeds = nn.Parameter(torch.randn(1, self.max_seq_len, h_dim)) + self.lm_head = nn.Linear(h_dim, vocab_size) + + def forward(self, x, targets=None): + x = self.tok_embeds(x) + self.pos_embeds[:, :x.shape[1], :] + for l in self.layers: + x = l(x) + logits = self.lm_head(x) # B,T,C + loss = F.cross_entropy(rearrange(logits, "b t c -> b c t"), targets) if not targets is None else None + return logits, loss + + # what is the purpose? autoregressive inference! + def generate(self, start_idx, max_new_tokens): + idx = start_idx + for i in range(max_new_tokens): + idx_cond = idx[:,-self.max_seq_len:] + logits, loss = self(idx_cond) + logits = logits[:,-1,:] # B, C + probs = F.softmax(logits, dim=-1) + idx_next = torch.multinomial(probs, num_samples=1).to(self.lm_head.bias.device) + idx = torch.cat([idx, idx_next], dim=1) + return idx \ No newline at end of file