optic mul error test

pull/1/head
Vladimir Protsenko 1 month ago
parent 85249dfbba
commit 1ace114d0c

@ -0,0 +1,93 @@
import torch
import torch.nn as nn
import optical_matrix_multiplication as omm
import matplotlib.pyplot as plt
device = 'cpu'
h_dim = 64
pixel_size = 3.6e-6
batch_size = 100
test_lengths = [59, 64, 128, 256, 512]
for max_seq_len in test_lengths:
if max_seq_len < 512:
sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * h_dim,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01,
lens_size = 8192)
)
sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * h_dim,
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.01,
lens_size = 8192)
)
else:
sim_scores = omm.OpticalMul(
omm.Config(right_matrix_count_columns = max_seq_len,
right_matrix_count_rows = h_dim,
right_matrix_width = pixel_size * max_seq_len,
right_matrix_height = pixel_size * h_dim,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2)
)
sim_output = omm.OpticalMul(
omm.Config(right_matrix_count_columns = h_dim,
right_matrix_count_rows = max_seq_len,
right_matrix_width = pixel_size * h_dim,
right_matrix_height = pixel_size * max_seq_len,
min_height_gap = pixel_size,
right_matrix_split_x = 2,
right_matrix_split_y = 2,
left_matrix_split_x = 2,
left_matrix_split_y = 2,
result_matrix_split = 2,
distance = 0.15,
lens_size = 8192 * 2)
)
def cko(x,y):
x = x**2
y = y**2
return (((x / x.mean() - y / y.mean())**2).mean())**0.5 * 100
sim_scores = sim_scores.to(device=device)
sim_output = sim_output.to(device=device)
q = torch.rand((batch_size, 1, max_seq_len, h_dim)).to(device=device)
k = torch.rand((batch_size, 1, max_seq_len, h_dim)).to(device=device).transpose(-2, -1)
true_scores = q @ k
opt_scores = sim_scores(q, k)
CKO_scores = cko(true_scores, opt_scores).detach().cpu().numpy()
scores = torch.rand((batch_size, 1, max_seq_len, max_seq_len)).to(device=device)
v = torch.rand((batch_size, 1, max_seq_len, h_dim)).to(device=device)
true_o = scores @ v
opt_o = sim_output(scores, v)
CKO_o = cko(true_o, opt_o).detach().cpu().numpy()
print(f"CKO sim_scores[{h_dim},{max_seq_len}] [{q.shape[-2]}, {q.shape[-1]}]x[{k.shape[-2]}, {k.shape[-1]}]: {CKO_scores}")
print(f"CKO sim_output[{max_seq_len},{h_dim}] [{true_scores.shape[-2]}, {true_scores.shape[-1]}]x[{v.shape[-2]}, {v.shape[-1]}]: {CKO_o}")
Loading…
Cancel
Save