optic mul error test
parent
85249dfbba
commit
1ace114d0c
@ -0,0 +1,93 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import optical_matrix_multiplication as omm
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
device = 'cpu'
|
||||||
|
|
||||||
|
h_dim = 64
|
||||||
|
pixel_size = 3.6e-6
|
||||||
|
batch_size = 100
|
||||||
|
test_lengths = [59, 64, 128, 256, 512]
|
||||||
|
|
||||||
|
for max_seq_len in test_lengths:
|
||||||
|
if max_seq_len < 512:
|
||||||
|
sim_scores = omm.OpticalMul(
|
||||||
|
omm.Config(right_matrix_count_columns = max_seq_len,
|
||||||
|
right_matrix_count_rows = h_dim,
|
||||||
|
right_matrix_width = pixel_size * max_seq_len,
|
||||||
|
right_matrix_height = pixel_size * h_dim,
|
||||||
|
min_height_gap = pixel_size,
|
||||||
|
right_matrix_split_x = 2,
|
||||||
|
right_matrix_split_y = 2,
|
||||||
|
left_matrix_split_x = 2,
|
||||||
|
left_matrix_split_y = 2,
|
||||||
|
result_matrix_split = 2,
|
||||||
|
distance = 0.01,
|
||||||
|
lens_size = 8192)
|
||||||
|
)
|
||||||
|
sim_output = omm.OpticalMul(
|
||||||
|
omm.Config(right_matrix_count_columns = h_dim,
|
||||||
|
right_matrix_count_rows = max_seq_len,
|
||||||
|
right_matrix_width = pixel_size * h_dim,
|
||||||
|
right_matrix_height = pixel_size * max_seq_len,
|
||||||
|
min_height_gap = pixel_size,
|
||||||
|
right_matrix_split_x = 2,
|
||||||
|
right_matrix_split_y = 2,
|
||||||
|
left_matrix_split_x = 2,
|
||||||
|
left_matrix_split_y = 2,
|
||||||
|
result_matrix_split = 2,
|
||||||
|
distance = 0.01,
|
||||||
|
lens_size = 8192)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sim_scores = omm.OpticalMul(
|
||||||
|
omm.Config(right_matrix_count_columns = max_seq_len,
|
||||||
|
right_matrix_count_rows = h_dim,
|
||||||
|
right_matrix_width = pixel_size * max_seq_len,
|
||||||
|
right_matrix_height = pixel_size * h_dim,
|
||||||
|
min_height_gap = pixel_size,
|
||||||
|
right_matrix_split_x = 2,
|
||||||
|
right_matrix_split_y = 2,
|
||||||
|
left_matrix_split_x = 2,
|
||||||
|
left_matrix_split_y = 2,
|
||||||
|
result_matrix_split = 2,
|
||||||
|
distance = 0.15,
|
||||||
|
lens_size = 8192 * 2)
|
||||||
|
)
|
||||||
|
sim_output = omm.OpticalMul(
|
||||||
|
omm.Config(right_matrix_count_columns = h_dim,
|
||||||
|
right_matrix_count_rows = max_seq_len,
|
||||||
|
right_matrix_width = pixel_size * h_dim,
|
||||||
|
right_matrix_height = pixel_size * max_seq_len,
|
||||||
|
min_height_gap = pixel_size,
|
||||||
|
right_matrix_split_x = 2,
|
||||||
|
right_matrix_split_y = 2,
|
||||||
|
left_matrix_split_x = 2,
|
||||||
|
left_matrix_split_y = 2,
|
||||||
|
result_matrix_split = 2,
|
||||||
|
distance = 0.15,
|
||||||
|
lens_size = 8192 * 2)
|
||||||
|
)
|
||||||
|
|
||||||
|
def cko(x,y):
|
||||||
|
x = x**2
|
||||||
|
y = y**2
|
||||||
|
return (((x / x.mean() - y / y.mean())**2).mean())**0.5 * 100
|
||||||
|
|
||||||
|
sim_scores = sim_scores.to(device=device)
|
||||||
|
sim_output = sim_output.to(device=device)
|
||||||
|
|
||||||
|
q = torch.rand((batch_size, 1, max_seq_len, h_dim)).to(device=device)
|
||||||
|
k = torch.rand((batch_size, 1, max_seq_len, h_dim)).to(device=device).transpose(-2, -1)
|
||||||
|
true_scores = q @ k
|
||||||
|
opt_scores = sim_scores(q, k)
|
||||||
|
CKO_scores = cko(true_scores, opt_scores).detach().cpu().numpy()
|
||||||
|
|
||||||
|
scores = torch.rand((batch_size, 1, max_seq_len, max_seq_len)).to(device=device)
|
||||||
|
v = torch.rand((batch_size, 1, max_seq_len, h_dim)).to(device=device)
|
||||||
|
true_o = scores @ v
|
||||||
|
opt_o = sim_output(scores, v)
|
||||||
|
CKO_o = cko(true_o, opt_o).detach().cpu().numpy()
|
||||||
|
|
||||||
|
print(f"CKO sim_scores[{h_dim},{max_seq_len}] [{q.shape[-2]}, {q.shape[-1]}]x[{k.shape[-2]}, {k.shape[-1]}]: {CKO_scores}")
|
||||||
|
print(f"CKO sim_output[{max_seq_len},{h_dim}] [{true_scores.shape[-2]}, {true_scores.shape[-1]}]x[{v.shape[-2]}, {v.shape[-1]}]: {CKO_o}")
|
||||||
Loading…
Reference in New Issue