optic mul error test
parent
85249dfbba
commit
1ace114d0c
@ -0,0 +1,93 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import optical_matrix_multiplication as omm
|
||||
import matplotlib.pyplot as plt
|
||||
device = 'cpu'
|
||||
|
||||
h_dim = 64
|
||||
pixel_size = 3.6e-6
|
||||
batch_size = 100
|
||||
test_lengths = [59, 64, 128, 256, 512]
|
||||
|
||||
for max_seq_len in test_lengths:
|
||||
if max_seq_len < 512:
|
||||
sim_scores = omm.OpticalMul(
|
||||
omm.Config(right_matrix_count_columns = max_seq_len,
|
||||
right_matrix_count_rows = h_dim,
|
||||
right_matrix_width = pixel_size * max_seq_len,
|
||||
right_matrix_height = pixel_size * h_dim,
|
||||
min_height_gap = pixel_size,
|
||||
right_matrix_split_x = 2,
|
||||
right_matrix_split_y = 2,
|
||||
left_matrix_split_x = 2,
|
||||
left_matrix_split_y = 2,
|
||||
result_matrix_split = 2,
|
||||
distance = 0.01,
|
||||
lens_size = 8192)
|
||||
)
|
||||
sim_output = omm.OpticalMul(
|
||||
omm.Config(right_matrix_count_columns = h_dim,
|
||||
right_matrix_count_rows = max_seq_len,
|
||||
right_matrix_width = pixel_size * h_dim,
|
||||
right_matrix_height = pixel_size * max_seq_len,
|
||||
min_height_gap = pixel_size,
|
||||
right_matrix_split_x = 2,
|
||||
right_matrix_split_y = 2,
|
||||
left_matrix_split_x = 2,
|
||||
left_matrix_split_y = 2,
|
||||
result_matrix_split = 2,
|
||||
distance = 0.01,
|
||||
lens_size = 8192)
|
||||
)
|
||||
else:
|
||||
sim_scores = omm.OpticalMul(
|
||||
omm.Config(right_matrix_count_columns = max_seq_len,
|
||||
right_matrix_count_rows = h_dim,
|
||||
right_matrix_width = pixel_size * max_seq_len,
|
||||
right_matrix_height = pixel_size * h_dim,
|
||||
min_height_gap = pixel_size,
|
||||
right_matrix_split_x = 2,
|
||||
right_matrix_split_y = 2,
|
||||
left_matrix_split_x = 2,
|
||||
left_matrix_split_y = 2,
|
||||
result_matrix_split = 2,
|
||||
distance = 0.15,
|
||||
lens_size = 8192 * 2)
|
||||
)
|
||||
sim_output = omm.OpticalMul(
|
||||
omm.Config(right_matrix_count_columns = h_dim,
|
||||
right_matrix_count_rows = max_seq_len,
|
||||
right_matrix_width = pixel_size * h_dim,
|
||||
right_matrix_height = pixel_size * max_seq_len,
|
||||
min_height_gap = pixel_size,
|
||||
right_matrix_split_x = 2,
|
||||
right_matrix_split_y = 2,
|
||||
left_matrix_split_x = 2,
|
||||
left_matrix_split_y = 2,
|
||||
result_matrix_split = 2,
|
||||
distance = 0.15,
|
||||
lens_size = 8192 * 2)
|
||||
)
|
||||
|
||||
def cko(x,y):
|
||||
x = x**2
|
||||
y = y**2
|
||||
return (((x / x.mean() - y / y.mean())**2).mean())**0.5 * 100
|
||||
|
||||
sim_scores = sim_scores.to(device=device)
|
||||
sim_output = sim_output.to(device=device)
|
||||
|
||||
q = torch.rand((batch_size, 1, max_seq_len, h_dim)).to(device=device)
|
||||
k = torch.rand((batch_size, 1, max_seq_len, h_dim)).to(device=device).transpose(-2, -1)
|
||||
true_scores = q @ k
|
||||
opt_scores = sim_scores(q, k)
|
||||
CKO_scores = cko(true_scores, opt_scores).detach().cpu().numpy()
|
||||
|
||||
scores = torch.rand((batch_size, 1, max_seq_len, max_seq_len)).to(device=device)
|
||||
v = torch.rand((batch_size, 1, max_seq_len, h_dim)).to(device=device)
|
||||
true_o = scores @ v
|
||||
opt_o = sim_output(scores, v)
|
||||
CKO_o = cko(true_o, opt_o).detach().cpu().numpy()
|
||||
|
||||
print(f"CKO sim_scores[{h_dim},{max_seq_len}] [{q.shape[-2]}, {q.shape[-1]}]x[{k.shape[-2]}, {k.shape[-1]}]: {CKO_scores}")
|
||||
print(f"CKO sim_output[{max_seq_len},{h_dim}] [{true_scores.shape[-2]}, {true_scores.shape[-1]}]x[{v.shape[-2]}, {v.shape[-1]}]: {CKO_o}")
|
||||
Loading…
Reference in New Issue