parallel dim is not 0, but 1. diagonal nontrainable lens bug fixed.

pull/1/head
Vladimir Protsenko 2 months ago
parent a042b64d7e
commit 51147b36b3

@ -30,7 +30,7 @@ models = {
}
batch_size = 50
gradient_accumulation_steps = 2 # check this impl for correctness https://unsloth.ai/blog/gradient
gradient_accumulation_steps = 5 # check this impl for correctness https://unsloth.ai/blog/gradient
max_iters = 40000
eval_interval = 300
learning_rate = 1e-3
@ -38,7 +38,7 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
layers_num = 2
h_dim = 64
max_seq_len = 256
max_seq_len = 512
num_heads = 1
dropout_rate = 0.1
pixel_size = 3.6e-6
@ -149,7 +149,7 @@ for i in range(max_iters):
m.eval()
ppl = perplexity(model=m, data=val_data)
writer.add_scalar('val_perplexity', ppl.item(), i)
print(f"\rPerplexity at {i}: {ppl}")
print(f"\r{datetime.now()} Perplexity at {i}: {ppl}")
writer.add_text('completions', complete(m, encode("\n"*max_seq_len), 2*max_seq_len), i)
m.eval()

@ -114,12 +114,12 @@ class OpticalMul(_nn.Module):
"""
vec_field = self.prepare_vector(input)
mat_field = self.prepare_matrix(other)
if self.trainable_cylind_lens:
vec_field = self._propagator_one(vec_field)
vec_field = self._propagator_between(vec_field)
else:
vec_field = self._propagator_one(vec_field)
vec_field = self._propagator_two(vec_field * mat_field)
return self.prepare_out(vec_field)

@ -129,34 +129,31 @@ class ScatterDataParallel(_nn.Module):
Оптимизированный forward для attention матриц.
Особенности:
- Scatter по batch dimension (0) вместо произвольного dim
- Scatter по batch dimension (1) вместо произвольного dim
- Оба тензора scatter'ятся для согласованности размерностей
- Поддержка многомерных attention тензоров [batch, heads, seq, dim]
- Поддержка многомерных attention тензоров [batch, heads, seq, dim] ??
'''
# Определяем dimension для scatter на основе структуры тензоров
if input.dim() >= 3 and other.dim() >= 3:
# Для attention матриц scatter по batch dimension
scatter_dim = 0
scatter_dim = 1
else:
# Для обычных 2D матриц используем dim из kwargs или по умолчанию 2
scatter_dim = kwargs.get('dim', 2)
# Подготовка модуля и данных
self.module = self.module.to(self.devices[0])
replicas = _nn.parallel.replicate(self.module, self.devices)
# Scatter ОБОИХ тензоров для согласованности размерностей
scattered_input = _nn.parallel.scatter(input, self.devices, scatter_dim)
scattered_other = _nn.parallel.scatter(other, self.devices, scatter_dim)
# Создаем реплики модуля
replicas = _nn.parallel.replicate(self.module, self.devices)
# Формируем входные данные для каждого устройства
# Убедимся, что все списки одинаковой длины
min_len = min(len(scattered_input), len(scattered_other), len(replicas))
stacked_input = [(scattered_input[i], scattered_other[i]) for i in range(min_len)]
# Параллельное вычисление
outputs = _nn.parallel.parallel_apply(replicas[:min_len], stacked_input)

@ -16,7 +16,7 @@ class Propagator(_ABC, _nn.Module):
operator_X: оператор отображающий распроранение светового поля вдоль оси абсцисс
operator_Y: оператор отображающий распроранение светового поля вдоль оси ординат
"""
def __init__(self, operator_X: _torch.Tensor, operator_Y: _torch.Tensor, trainable = False, diagonal = False):
def __init__(self, operator_X: _torch.Tensor, operator_Y: _torch.Tensor, trainable = False):
super(Propagator, self).__init__()
operator_X: _torch.Tensor = _torch.view_as_real(operator_X)
operator_Y: _torch.Tensor = _torch.view_as_real(operator_Y)
@ -24,12 +24,10 @@ class Propagator(_ABC, _nn.Module):
self._operator_X = _nn.Parameter(operator_X)
self._operator_Y = _nn.Parameter(operator_Y)
self._trainable = trainable
self._diagonal = diagonal
else:
self.register_buffer('_operator_X', operator_X, persistent=True)
self.register_buffer('_operator_Y', operator_Y, persistent=True)
self._trainable = trainable
self._diagonal = diagonal
@property
def operator_X(self) -> _torch.Tensor:
@ -111,13 +109,13 @@ class Propagator(_ABC, _nn.Module):
Распределение комплексной амплитуды светового поля,
после распространения.
"""
if self._diagonal:
if self._trainable:
return _torch.diag_embed(self.operator_Y) @ field @ _torch.diag_embed(self.operator_X)
else:
return self.operator_Y @ field @ self.operator_X
def __repr__(self):
return f"Diag: {self._diagonal} Trainable: {self._trainable} Y shape: {self.operator_Y.shape}, X shape: {self.operator_X.shape}"
return f"Trainable: {self._trainable} Y shape: {self.operator_Y.shape}, X shape: {self.operator_X.shape}"
class PropagatorLens(Propagator):
"""
@ -167,7 +165,8 @@ class PropagatorСylindLens(PropagatorLens):
представленной тонким оптическим элементом.
"""
def __init__(self, plane: _ConfigDesignPlane,
config: _ConfigOpticBase, trainable = False):
config: _ConfigOpticBase,
trainable = False):
"""
Конструктор класса цилиндрической линзы.
@ -177,10 +176,14 @@ class PropagatorСylindLens(PropagatorLens):
"""
operator_X = _torch.exp(-1j * config.K / config.distance * plane.linspace_by_x**2)
operator_Y = _torch.ones_like(plane.linspace_by_y, dtype=_torch.cfloat)
super(PropagatorСylindLens, self).__init__(operator_X,
operator_Y,
trainable,
diagonal=True)
if trainable:
super(PropagatorСylindLens, self).__init__(operator_X,
operator_Y,
trainable)
else:
super(PropagatorСylindLens, self).__init__(_torch.diag_embed(operator_X),
_torch.diag_embed(operator_Y),
trainable)
class PropagatorSinc(Propagator):
"""

@ -190,6 +190,8 @@ class OpticGPT2NOKoefNewF(nn.Module):
lens_size = 8192 * 2,
trainable_cylind_lens=False)
)
self.sim_scores = omm.ScatterDataParallel(self.sim_scores)
self.sim_output = omm.ScatterDataParallel(self.sim_output)
self.layers = nn.ModuleList([
TransformerLayer(h_dim=self.h_dim, sim_scores=self.sim_scores, sim_output=self.sim_output, num_heads=self.num_heads,

Loading…
Cancel
Save