You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
652 lines
29 KiB
Python
652 lines
29 KiB
Python
import torch as _torch
|
|
import torch.nn as _nn
|
|
from .config import Config as _Config
|
|
from .propagator import PropagatorCrossLens as _PropCrossLens, PropagatorCylindLens as _PropCylindLens, PropagatorSinc as _PropSinc, Propagator as _Prop
|
|
from .propagator import (
|
|
PropagatorTrainableCylindLens as _PropagatorTrainableCylindLens,
|
|
PropagatorTrainableFocalDistCylindLens as _PropagatorTrainableFocalDistCylindLens
|
|
)
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
from typing import Optional
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
class OpticalMul(_nn.Module):
|
|
"""
|
|
Класс системы, выполняющей оптически операцию умножения матрицы на матрицу.
|
|
"""
|
|
def __init__(self, config: _Config):
|
|
"""
|
|
Конструктор класса.
|
|
|
|
Args:
|
|
config: конфигурация расчётной системы.
|
|
"""
|
|
super().__init__()
|
|
|
|
prop_one = _PropSinc(config.input_vector_plane, config.first_lens_plane, config)
|
|
prop_two = _PropCrossLens(config.first_lens_plane, config)
|
|
prop_three = _PropSinc(config.first_lens_plane, config.matrix_plane, config)
|
|
prop_four = _PropCylindLens(config.matrix_plane, config)
|
|
prop_five = _PropSinc(config.matrix_plane, config.second_lens_plane, config)
|
|
prop_six = _PropCrossLens(config.second_lens_plane, config).T
|
|
prop_seven = _PropSinc(config.second_lens_plane, config.output_vector_plane, config)
|
|
|
|
self._propagator_one: _Prop = prop_one + prop_two + prop_three + prop_four
|
|
self._propagator_two: _Prop = prop_five + prop_six + prop_seven
|
|
|
|
kron_vec_utils = _torch.ones((config.input_vector_split_y, config.input_vector_split_x))
|
|
kron_mat_utils = _torch.ones((config.matrix_split_x, config.matrix_split_y))
|
|
self.register_buffer('_kron_vec_utils', kron_vec_utils, persistent=True)
|
|
self.register_buffer('_kron_mat_utils', kron_mat_utils, persistent=True)
|
|
|
|
self._avg_pool = _nn.AvgPool2d((1, config.result_vector_split))
|
|
|
|
def prepare_vector(self, data: _torch.Tensor) -> _torch.Tensor:
|
|
"""
|
|
Метод подготовки матрицы левой матрицы, как набора векторов столбцов, к подаче на вход системы.
|
|
|
|
Args:
|
|
data: матрица комплексной амплитуды распределений световых полей.
|
|
|
|
Returns:
|
|
Матрицы содержащие вектора левой матрицы.
|
|
"""
|
|
data = data.cfloat().flip(-1)
|
|
data = data.unsqueeze(-2)
|
|
data = _torch.kron(data.contiguous(), self._kron_vec_utils)
|
|
return data
|
|
|
|
def prepare_matrix(self, data: _torch.Tensor) -> _torch.Tensor:
|
|
"""
|
|
Метод подготовки правой матрицы к подаче на вход системы.
|
|
|
|
Args:
|
|
data: матрица комплексной амплитуды распределения светового поля.
|
|
|
|
Returns:
|
|
Матрица - оптический элемент в центре модели.
|
|
"""
|
|
if (data.dim() > 4) and data.size(-1) == 2:
|
|
data = _torch.view_as_complex(data)
|
|
|
|
data = data.cfloat().transpose(-2, -1)
|
|
data = data.unsqueeze(-3)
|
|
data = _torch.kron(data.contiguous(), self._kron_mat_utils)
|
|
return data
|
|
|
|
def prepare_out(self, field: _torch.Tensor) -> _torch.Tensor:
|
|
"""
|
|
Метод получения результата матричного умножения.
|
|
|
|
Args:
|
|
data: матрицы выходого распределения светового поля системы.
|
|
|
|
Returns:
|
|
Вектор столбец (амплитудное распределение).
|
|
"""
|
|
### Закоментированная часть кода - более физически корректный вариант работы модели,
|
|
### однако, данный вариант кода будет требовать большое кол-во памяти во время обучения
|
|
field = field.abs().squeeze(-1) #**2
|
|
field = self._avg_pool(field)
|
|
return field.flip(-1) #**0.5
|
|
|
|
def forward(self,
|
|
input: _torch.Tensor,
|
|
other: _torch.Tensor) -> _torch.Tensor:
|
|
"""
|
|
Метод выполения матричного умножения.
|
|
|
|
Args:
|
|
input: матрица (B, C, H, W).
|
|
other: матрица (B, C, W, K).
|
|
|
|
Returns:
|
|
Рензультат матричного умножения (B, C, H, K).
|
|
|
|
Example:
|
|
>>> mul = OpticalMul(...)
|
|
>>> A = torch.rand((1, 1, 256, 256)) > 0.5
|
|
>>> B = torch.rand((1, 1, 256, 256)) > 0.5
|
|
>>> mul(A, B).shape
|
|
torch.Size([1, 1, 256, 256])
|
|
>>> A = torch.rand((1, 1, 64, 256)) > 0.5
|
|
>>> B = torch.rand((1, 1, 256, 128)) > 0.5
|
|
>>> mul(A, B).shape
|
|
torch.Size([1, 1, 64, 128])
|
|
"""
|
|
vec_field = self.prepare_vector(input)
|
|
mat_field = self.prepare_matrix(other)
|
|
|
|
vec_field = self._propagator_one(vec_field, mat_field.shape[-2:])
|
|
vec_field = self._propagator_two(vec_field * mat_field, (mat_field.size(-2), 1))
|
|
|
|
return self.prepare_out(vec_field)
|
|
|
|
class TrainableScalarOpticalMul(_nn.Module):
|
|
"""
|
|
Класс системы, выполняющей оптически операцию умножения матрицы на матрицу.
|
|
"""
|
|
def __init__(self, config: _Config):
|
|
"""
|
|
Конструктор класса.
|
|
|
|
Args:
|
|
config: конфигурация расчётной системы.
|
|
"""
|
|
super().__init__()
|
|
|
|
prop_one = _PropSinc(config.input_vector_plane, config.first_lens_plane, config)
|
|
prop_two = _PropCrossLens(config.first_lens_plane, config)
|
|
prop_three = _PropSinc(config.first_lens_plane, config.matrix_plane, config)
|
|
prop_four = _PropCylindLens(config.matrix_plane, config)
|
|
prop_five = _PropSinc(config.matrix_plane, config.second_lens_plane, config)
|
|
prop_six = _PropCrossLens(config.second_lens_plane, config).T
|
|
prop_seven = _PropSinc(config.second_lens_plane, config.output_vector_plane, config)
|
|
|
|
self._propagator_one: _Prop = prop_one + prop_two + prop_three + prop_four
|
|
self._propagator_two: _Prop = prop_five + prop_six + prop_seven
|
|
|
|
kron_vec_utils = _torch.ones((config.input_vector_split_y, config.input_vector_split_x))
|
|
kron_mat_utils = _torch.ones((config.matrix_split_x, config.matrix_split_y))
|
|
self.register_buffer('_kron_vec_utils', kron_vec_utils, persistent=True)
|
|
self.register_buffer('_kron_mat_utils', kron_mat_utils, persistent=True)
|
|
|
|
self._avg_pool = _nn.AvgPool2d((1, config.result_vector_split))
|
|
self.k = nn.Parameter(_torch.tensor(1))
|
|
|
|
def prepare_vector(self, data: _torch.Tensor) -> _torch.Tensor:
|
|
"""
|
|
Метод подготовки матрицы левой матрицы, как набора векторов столбцов, к подаче на вход системы.
|
|
|
|
Args:
|
|
data: матрица комплексной амплитуды распределений световых полей.
|
|
|
|
Returns:
|
|
Матрицы содержащие вектора левой матрицы.
|
|
"""
|
|
data = data.cfloat().flip(-1)
|
|
data = data.unsqueeze(-2)
|
|
data = _torch.kron(data.contiguous(), self._kron_vec_utils)
|
|
return data
|
|
|
|
def prepare_matrix(self, data: _torch.Tensor) -> _torch.Tensor:
|
|
"""
|
|
Метод подготовки правой матрицы к подаче на вход системы.
|
|
|
|
Args:
|
|
data: матрица комплексной амплитуды распределения светового поля.
|
|
|
|
Returns:
|
|
Матрица - оптический элемент в центре модели.
|
|
"""
|
|
if (data.dim() > 4) and data.size(-1) == 2:
|
|
data = _torch.view_as_complex(data)
|
|
|
|
data = data.cfloat().transpose(-2, -1)
|
|
data = data.unsqueeze(-3)
|
|
data = _torch.kron(data.contiguous(), self._kron_mat_utils)
|
|
return data
|
|
|
|
def prepare_out(self, field: _torch.Tensor) -> _torch.Tensor:
|
|
"""
|
|
Метод получения результата матричного умножения.
|
|
|
|
Args:
|
|
data: матрицы выходого распределения светового поля системы.
|
|
|
|
Returns:
|
|
Вектор столбец (амплитудное распределение).
|
|
"""
|
|
### Закоментированная часть кода - более физически корректный вариант работы модели,
|
|
### однако, данный вариант кода будет требовать большое кол-во памяти во время обучения
|
|
field = field.abs().squeeze(-1) #**2
|
|
field = self._avg_pool(field)
|
|
return field.flip(-1) #**0.5
|
|
|
|
def forward(self,
|
|
input: _torch.Tensor,
|
|
other: _torch.Tensor) -> _torch.Tensor:
|
|
"""
|
|
Метод выполения матричного умножения.
|
|
|
|
Args:
|
|
input: матрица (B, C, H, W).
|
|
other: матрица (B, C, W, K).
|
|
|
|
Returns:
|
|
Рензультат матричного умножения (B, C, H, K).
|
|
|
|
Example:
|
|
>>> mul = OpticalMul(...)
|
|
>>> A = torch.rand((1, 1, 256, 256)) > 0.5
|
|
>>> B = torch.rand((1, 1, 256, 256)) > 0.5
|
|
>>> mul(A, B).shape
|
|
torch.Size([1, 1, 256, 256])
|
|
>>> A = torch.rand((1, 1, 64, 256)) > 0.5
|
|
>>> B = torch.rand((1, 1, 256, 128)) > 0.5
|
|
>>> mul(A, B).shape
|
|
torch.Size([1, 1, 64, 128])
|
|
"""
|
|
vec_field = self.prepare_vector(input)
|
|
mat_field = self.prepare_matrix(other)
|
|
|
|
vec_field = self._propagator_one(vec_field, mat_field.shape[-2:])
|
|
vec_field = self._propagator_two(vec_field * mat_field, (mat_field.size(-2), 1))
|
|
|
|
return self.k * self.prepare_out(vec_field)
|
|
|
|
class TrainableLensOpticalMul(_nn.Module):
|
|
"""
|
|
Класс системы, выполняющей оптически операцию умножения матрицы на матрицу.
|
|
"""
|
|
def __init__(self, config: _Config):
|
|
"""
|
|
Конструктор класса.
|
|
|
|
Args:
|
|
config: конфигурация расчётной системы.
|
|
"""
|
|
super().__init__()
|
|
|
|
prop_one = _PropSinc(config.input_vector_plane, config.first_lens_plane, config)
|
|
prop_two = _PropCrossLens(config.first_lens_plane, config)
|
|
prop_three = _PropSinc(config.first_lens_plane, config.matrix_plane, config)
|
|
prop_four = _PropagatorTrainableCylindLens(config.matrix_plane, config)
|
|
prop_five = _PropSinc(config.matrix_plane, config.second_lens_plane, config)
|
|
prop_six = _PropCrossLens(config.second_lens_plane, config).T
|
|
prop_seven = _PropSinc(config.second_lens_plane, config.output_vector_plane, config)
|
|
|
|
self._propagator_one: _Prop = prop_one + prop_two + prop_three
|
|
self._propagator_cylind_lens: _Prop = prop_four
|
|
self._propagator_three: _Prop = prop_five + prop_six + prop_seven
|
|
|
|
kron_vec_utils = _torch.ones((config.input_vector_split_y, config.input_vector_split_x))
|
|
kron_mat_utils = _torch.ones((config.matrix_split_x, config.matrix_split_y))
|
|
self.register_buffer('_kron_vec_utils', kron_vec_utils, persistent=True)
|
|
self.register_buffer('_kron_mat_utils', kron_mat_utils, persistent=True)
|
|
|
|
self._avg_pool = _nn.AvgPool2d((1, config.result_vector_split))
|
|
|
|
def prepare_vector(self, data: _torch.Tensor) -> _torch.Tensor:
|
|
"""
|
|
Метод подготовки матрицы левой матрицы, как набора векторов столбцов, к подаче на вход системы.
|
|
|
|
Args:
|
|
data: матрица комплексной амплитуды распределений световых полей.
|
|
|
|
Returns:
|
|
Матрицы содержащие вектора левой матрицы.
|
|
"""
|
|
data = data.cfloat().flip(-1)
|
|
data = data.unsqueeze(-2)
|
|
data = _torch.kron(data.contiguous(), self._kron_vec_utils)
|
|
return data
|
|
|
|
def prepare_matrix(self, data: _torch.Tensor) -> _torch.Tensor:
|
|
"""
|
|
Метод подготовки правой матрицы к подаче на вход системы.
|
|
|
|
Args:
|
|
data: матрица комплексной амплитуды распределения светового поля.
|
|
|
|
Returns:
|
|
Матрица - оптический элемент в центре модели.
|
|
"""
|
|
if (data.dim() > 4) and data.size(-1) == 2:
|
|
data = _torch.view_as_complex(data)
|
|
|
|
data = data.cfloat().transpose(-2, -1)
|
|
data = data.unsqueeze(-3)
|
|
# TODO data should be at least two seq length. For one we get
|
|
# untimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
|
|
data = _torch.kron(data.contiguous(), self._kron_mat_utils)
|
|
return data
|
|
|
|
def prepare_out(self, field: _torch.Tensor) -> _torch.Tensor:
|
|
"""
|
|
Метод получения результата матричного умножения.
|
|
|
|
Args:
|
|
data: матрицы выходого распределения светового поля системы.
|
|
|
|
Returns:
|
|
Вектор столбец (амплитудное распределение).
|
|
"""
|
|
### Закоментированная часть кода - более физически корректный вариант работы модели,
|
|
### однако, данный вариант кода будет требовать большое кол-во памяти во время обучения
|
|
field = field.abs().squeeze(-1) #**2
|
|
field = self._avg_pool(field)
|
|
return field.flip(-1) #**0.5
|
|
|
|
def forward(self,
|
|
input: _torch.Tensor,
|
|
other: _torch.Tensor) -> _torch.Tensor:
|
|
"""
|
|
Метод выполения матричного умножения.
|
|
|
|
Args:
|
|
input: матрица (B, C, H, W).
|
|
other: матрица (B, C, W, K).
|
|
|
|
Returns:
|
|
Рензультат матричного умножения (B, C, H, K).
|
|
|
|
Example:
|
|
>>> mul = OpticalMul(...)
|
|
>>> A = torch.rand((1, 1, 256, 256)) > 0.5
|
|
>>> B = torch.rand((1, 1, 256, 256)) > 0.5
|
|
>>> mul(A, B).shape
|
|
torch.Size([1, 1, 256, 256])
|
|
>>> A = torch.rand((1, 1, 64, 256)) > 0.5
|
|
>>> B = torch.rand((1, 1, 256, 128)) > 0.5
|
|
>>> mul(A, B).shape
|
|
torch.Size([1, 1, 64, 128])
|
|
"""
|
|
vec_field = self.prepare_vector(input)
|
|
mat_field = self.prepare_matrix(other)
|
|
vec_field = self._propagator_one(vec_field, mat_field.shape[-2:])
|
|
vec_field = self._propagator_cylind_lens(vec_field, mat_field.shape[-2:])
|
|
vec_field = self._propagator_three(vec_field * mat_field, (mat_field.size(-2), 1))
|
|
|
|
return self.prepare_out(vec_field)
|
|
|
|
@_torch.no_grad()
|
|
def log_cylind_lens_operator_x(
|
|
self,
|
|
writer: SummaryWriter,
|
|
tag: str,
|
|
global_step: Optional[int] = None,
|
|
):
|
|
# 1. Apply exp to get the wrapped phase as it would be physically seen
|
|
# This ensures values outside [-pi, pi] wrap correctly
|
|
complex_op = _torch.exp(-1j * self._propagator_cylind_lens._operator_X_phi)
|
|
wrapped_phase = _torch.angle(complex_op).float() # Range: [-π, π]
|
|
|
|
# 2. Normalize for Image Visualization [0, 1]
|
|
phase_normalized = (wrapped_phase + _torch.pi) / (2 * _torch.pi)
|
|
|
|
# 3. Log as a 1-pixel high row
|
|
# Shape: [1, 1, Width]
|
|
phase_row = phase_normalized.unsqueeze(0).unsqueeze(0)
|
|
writer.add_image(f"{tag}/phase_row", phase_row, global_step, dataformats='CHW')
|
|
|
|
fig, ax = plt.subplots(figsize=(12, 4))
|
|
ax.plot(wrapped_phase.detach().cpu().numpy(), label=f'Step {global_step}')
|
|
ax.set_title(f"Cylindrical Lens Phase Profile (x) {tag}")
|
|
ax.set_xlabel("Pixel Index")
|
|
ax.set_ylabel("Phase (rad)")
|
|
ax.set_ylim([-_torch.pi - 0.5, _torch.pi + 0.5])
|
|
ax.grid(True, linestyle='--', alpha=0.6)
|
|
|
|
|
|
# Send the figure to the "Plots" or "Images" tab in TensorBoard
|
|
writer.add_figure(f"{tag}/phase_profile", fig, global_step)
|
|
plt.close(fig) # Important: prevent memory leaks
|
|
|
|
|
|
class TrainableFocalDistLensOpticalMul(_nn.Module):
|
|
"""
|
|
Класс системы, выполняющей оптически операцию умножения матрицы на матрицу.
|
|
"""
|
|
def __init__(self, config: _Config):
|
|
"""
|
|
Конструктор класса.
|
|
|
|
Args:
|
|
config: конфигурация расчётной системы.
|
|
"""
|
|
super().__init__()
|
|
|
|
prop_one = _PropSinc(config.input_vector_plane, config.first_lens_plane, config)
|
|
prop_two = _PropCrossLens(config.first_lens_plane, config)
|
|
prop_three = _PropSinc(config.first_lens_plane, config.matrix_plane, config)
|
|
prop_four = _PropagatorTrainableFocalDistCylindLens(config.matrix_plane, config)
|
|
prop_five = _PropSinc(config.matrix_plane, config.second_lens_plane, config)
|
|
prop_six = _PropCrossLens(config.second_lens_plane, config).T
|
|
prop_seven = _PropSinc(config.second_lens_plane, config.output_vector_plane, config)
|
|
|
|
self._propagator_one: _Prop = prop_one + prop_two + prop_three
|
|
self._propagator_cylind_lens: _Prop = prop_four
|
|
self._propagator_three: _Prop = prop_five + prop_six + prop_seven
|
|
|
|
kron_vec_utils = _torch.ones((config.input_vector_split_y, config.input_vector_split_x))
|
|
kron_mat_utils = _torch.ones((config.matrix_split_x, config.matrix_split_y))
|
|
self.register_buffer('_kron_vec_utils', kron_vec_utils, persistent=True)
|
|
self.register_buffer('_kron_mat_utils', kron_mat_utils, persistent=True)
|
|
|
|
self._avg_pool = _nn.AvgPool2d((1, config.result_vector_split))
|
|
|
|
def prepare_vector(self, data: _torch.Tensor) -> _torch.Tensor:
|
|
"""
|
|
Метод подготовки матрицы левой матрицы, как набора векторов столбцов, к подаче на вход системы.
|
|
|
|
Args:
|
|
data: матрица комплексной амплитуды распределений световых полей.
|
|
|
|
Returns:
|
|
Матрицы содержащие вектора левой матрицы.
|
|
"""
|
|
data = data.cfloat().flip(-1)
|
|
data = data.unsqueeze(-2)
|
|
data = _torch.kron(data.contiguous(), self._kron_vec_utils)
|
|
return data
|
|
|
|
def prepare_matrix(self, data: _torch.Tensor) -> _torch.Tensor:
|
|
"""
|
|
Метод подготовки правой матрицы к подаче на вход системы.
|
|
|
|
Args:
|
|
data: матрица комплексной амплитуды распределения светового поля.
|
|
|
|
Returns:
|
|
Матрица - оптический элемент в центре модели.
|
|
"""
|
|
if (data.dim() > 4) and data.size(-1) == 2:
|
|
data = _torch.view_as_complex(data)
|
|
|
|
data = data.cfloat().transpose(-2, -1)
|
|
data = data.unsqueeze(-3)
|
|
# TODO data should be at least two seq length. For one we get
|
|
# untimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
|
|
data = _torch.kron(data.contiguous(), self._kron_mat_utils)
|
|
return data
|
|
|
|
def prepare_out(self, field: _torch.Tensor) -> _torch.Tensor:
|
|
"""
|
|
Метод получения результата матричного умножения.
|
|
|
|
Args:
|
|
data: матрицы выходого распределения светового поля системы.
|
|
|
|
Returns:
|
|
Вектор столбец (амплитудное распределение).
|
|
"""
|
|
### Закоментированная часть кода - более физически корректный вариант работы модели,
|
|
### однако, данный вариант кода будет требовать большое кол-во памяти во время обучения
|
|
field = field.abs().squeeze(-1) #**2
|
|
field = self._avg_pool(field)
|
|
return field.flip(-1) #**0.5
|
|
|
|
def forward(self,
|
|
input: _torch.Tensor,
|
|
other: _torch.Tensor) -> _torch.Tensor:
|
|
"""
|
|
Метод выполения матричного умножения.
|
|
|
|
Args:
|
|
input: матрица (B, C, H, W).
|
|
other: матрица (B, C, W, K).
|
|
|
|
Returns:
|
|
Рензультат матричного умножения (B, C, H, K).
|
|
|
|
Example:
|
|
>>> mul = OpticalMul(...)
|
|
>>> A = torch.rand((1, 1, 256, 256)) > 0.5
|
|
>>> B = torch.rand((1, 1, 256, 256)) > 0.5
|
|
>>> mul(A, B).shape
|
|
torch.Size([1, 1, 256, 256])
|
|
>>> A = torch.rand((1, 1, 64, 256)) > 0.5
|
|
>>> B = torch.rand((1, 1, 256, 128)) > 0.5
|
|
>>> mul(A, B).shape
|
|
torch.Size([1, 1, 64, 128])
|
|
"""
|
|
vec_field = self.prepare_vector(input)
|
|
mat_field = self.prepare_matrix(other)
|
|
vec_field = self._propagator_one(vec_field, mat_field.shape[-2:])
|
|
vec_field = self._propagator_cylind_lens(vec_field, mat_field.shape[-2:])
|
|
vec_field = self._propagator_three(vec_field * mat_field, (mat_field.size(-2), 1))
|
|
|
|
return self.prepare_out(vec_field)
|
|
|
|
@_torch.no_grad()
|
|
def log_cylind_lens_operator_x(
|
|
self,
|
|
writer: SummaryWriter,
|
|
tag: str,
|
|
global_step: Optional[int] = None,
|
|
):
|
|
# 1. Apply exp to get the wrapped phase as it would be physically seen
|
|
# This ensures values outside [-pi, pi] wrap correctly
|
|
lens = self._propagator_cylind_lens
|
|
writer.add_scalar(f"{tag}/focal_distance", lens._distance.detach().cpu().numpy(), global_step)
|
|
|
|
complex_op = _torch.exp(-1j * lens._K / lens._distance * lens._linspace_by_x**2)
|
|
wrapped_phase = _torch.angle(complex_op).float() # Range: [-π, π]
|
|
|
|
# 2. Normalize for Image Visualization [0, 1]
|
|
phase_normalized = (wrapped_phase + _torch.pi) / (2 * _torch.pi)
|
|
|
|
# 3. Log as a 1-pixel high row
|
|
# Shape: [1, 1, Width]
|
|
phase_row = phase_normalized.unsqueeze(0).unsqueeze(0)
|
|
writer.add_image(f"{tag}/phase_row", phase_row, global_step, dataformats='CHW')
|
|
|
|
fig, ax = plt.subplots(figsize=(6, 4))
|
|
ax.plot(wrapped_phase.detach().cpu().numpy(), label=f'Step {global_step}')
|
|
ax.set_title(f"Cylindrical Lens Phase Profile (x) {tag}\nFocal distance: {lens._distance.detach().cpu().numpy()}")
|
|
ax.set_xlabel("Pixel Index")
|
|
ax.set_ylabel("Phase (rad)")
|
|
ax.set_ylim([-_torch.pi - 0.5, _torch.pi + 0.5])
|
|
ax.grid(True, linestyle='--', alpha=0.6)
|
|
|
|
# Send the figure to the "Plots" or "Images" tab in TensorBoard
|
|
writer.add_figure(f"{tag}/phase_profile", fig, global_step)
|
|
plt.close(fig) # Important: prevent memory leaks
|
|
|
|
|
|
class TrainableScalarAndLensOpticalMul(_nn.Module):
|
|
"""
|
|
Класс системы, выполняющей оптически операцию умножения матрицы на матрицу.
|
|
"""
|
|
def __init__(self, config: _Config):
|
|
"""
|
|
Конструктор класса.
|
|
|
|
Args:
|
|
config: конфигурация расчётной системы.
|
|
"""
|
|
super().__init__()
|
|
|
|
prop_one = _PropSinc(config.input_vector_plane, config.first_lens_plane, config)
|
|
prop_two = _PropCrossLens(config.first_lens_plane, config)
|
|
prop_three = _PropSinc(config.first_lens_plane, config.matrix_plane, config)
|
|
prop_four = _PropagatorTrainableCylindLens(config.matrix_plane, config)
|
|
prop_five = _PropSinc(config.matrix_plane, config.second_lens_plane, config)
|
|
prop_six = _PropCrossLens(config.second_lens_plane, config).T
|
|
prop_seven = _PropSinc(config.second_lens_plane, config.output_vector_plane, config)
|
|
|
|
self._propagator_one: _Prop = prop_one + prop_two + prop_three
|
|
self._propagator_two: _Prop = prop_four
|
|
self._propagator_three: _Prop = prop_five + prop_six + prop_seven
|
|
|
|
kron_vec_utils = _torch.ones((config.input_vector_split_y, config.input_vector_split_x))
|
|
kron_mat_utils = _torch.ones((config.matrix_split_x, config.matrix_split_y))
|
|
self.register_buffer('_kron_vec_utils', kron_vec_utils, persistent=True)
|
|
self.register_buffer('_kron_mat_utils', kron_mat_utils, persistent=True)
|
|
|
|
self._avg_pool = _nn.AvgPool2d((1, config.result_vector_split))
|
|
self.k = nn.Parameter(torch.tensor(1))
|
|
|
|
def prepare_vector(self, data: _torch.Tensor) -> _torch.Tensor:
|
|
"""
|
|
Метод подготовки матрицы левой матрицы, как набора векторов столбцов, к подаче на вход системы.
|
|
|
|
Args:
|
|
data: матрица комплексной амплитуды распределений световых полей.
|
|
|
|
Returns:
|
|
Матрицы содержащие вектора левой матрицы.
|
|
"""
|
|
data = data.cfloat().flip(-1)
|
|
data = data.unsqueeze(-2)
|
|
data = _torch.kron(data.contiguous(), self._kron_vec_utils)
|
|
return data
|
|
|
|
def prepare_matrix(self, data: _torch.Tensor) -> _torch.Tensor:
|
|
"""
|
|
Метод подготовки правой матрицы к подаче на вход системы.
|
|
|
|
Args:
|
|
data: матрица комплексной амплитуды распределения светового поля.
|
|
|
|
Returns:
|
|
Матрица - оптический элемент в центре модели.
|
|
"""
|
|
if (data.dim() > 4) and data.size(-1) == 2:
|
|
data = _torch.view_as_complex(data)
|
|
|
|
data = data.cfloat().transpose(-2, -1)
|
|
data = data.unsqueeze(-3)
|
|
data = _torch.kron(data.contiguous(), self._kron_mat_utils)
|
|
return data
|
|
|
|
def prepare_out(self, field: _torch.Tensor) -> _torch.Tensor:
|
|
"""
|
|
Метод получения результата матричного умножения.
|
|
|
|
Args:
|
|
data: матрицы выходого распределения светового поля системы.
|
|
|
|
Returns:
|
|
Вектор столбец (амплитудное распределение).
|
|
"""
|
|
### Закоментированная часть кода - более физически корректный вариант работы модели,
|
|
### однако, данный вариант кода будет требовать большое кол-во памяти во время обучения
|
|
field = field.abs().squeeze(-1) #**2
|
|
field = self._avg_pool(field)
|
|
return field.flip(-1) #**0.5
|
|
|
|
def forward(self,
|
|
input: _torch.Tensor,
|
|
other: _torch.Tensor) -> _torch.Tensor:
|
|
"""
|
|
Метод выполения матричного умножения.
|
|
|
|
Args:
|
|
input: матрица (B, C, H, W).
|
|
other: матрица (B, C, W, K).
|
|
|
|
Returns:
|
|
Рензультат матричного умножения (B, C, H, K).
|
|
|
|
Example:
|
|
>>> mul = OpticalMul(...)
|
|
>>> A = torch.rand((1, 1, 256, 256)) > 0.5
|
|
>>> B = torch.rand((1, 1, 256, 256)) > 0.5
|
|
>>> mul(A, B).shape
|
|
torch.Size([1, 1, 256, 256])
|
|
>>> A = torch.rand((1, 1, 64, 256)) > 0.5
|
|
>>> B = torch.rand((1, 1, 256, 128)) > 0.5
|
|
>>> mul(A, B).shape
|
|
torch.Size([1, 1, 64, 128])
|
|
"""
|
|
vec_field = self.prepare_vector(input)
|
|
mat_field = self.prepare_matrix(other)
|
|
|
|
vec_field = self._propagator_one(vec_field, mat_field.shape[-2:])
|
|
vec_field = self._propagator_two(vec_field, mat_field.shape[-2:])
|
|
vec_field = self._propagator_three(vec_field * mat_field, (mat_field.size(-2), 1))
|
|
|
|
return self.k * self.prepare_out(vec_field) |