import torch import torch.nn as nn from torch.nn import functional as F from einops import rearrange import numpy as np from torch.utils.tensorboard import SummaryWriter from datetime import datetime import sys from pathlib import Path from char_gpt2 import GPT2 from optics_char_gpt2 import OpticGPT2 from optics_char_gpt2_traindiag import OpticGPT2TrainDiag from optics_char_gpt2_ff import OpticGPT2FF from optics_char_gpt2_new_formula import OpticGPT2NewFormula from char_gpt2_scaledmatmul import GPT2ScaledMM from optics_char_gpt2_nokoef import OpticGPT2NOKoef from optics_char_gpt2_nokoef_newf import OpticGPT2NOKoefNewF import shutil seed = 1337 torch.manual_seed(seed) models = { 'gpt2': GPT2, 'optic_gpt2': OpticGPT2, 'optic_gpt2_ff': OpticGPT2FF, 'optic_gpt2_traindiag': OpticGPT2TrainDiag, 'optic_gpt2_newformula': OpticGPT2NewFormula, 'optic_gpt2_nokoef': OpticGPT2NOKoef, 'optic_gpt2_nokoef_newformula': OpticGPT2NOKoefNewF, 'gpt2_scaledmm': GPT2ScaledMM } batch_size = 50 gradient_accumulation_steps = 2 # check this impl for correctness https://unsloth.ai/blog/gradient max_iters = 40000 eval_interval = 300 learning_rate = 1e-3 device = 'cuda' if torch.cuda.is_available() else 'cpu' eval_iters = 200 layers_num = 2 h_dim = 64 max_seq_len = 256 num_heads = 1 dropout_rate = 0.1 pixel_size = 3.6e-6 assert batch_size % gradient_accumulation_steps == 0 # CUDA_VISIBLE_DEVICES=1 python src/main.py optic_gpt2_ff ./data/wiki.train.tokens ./data/wiki.valid.tokens ./data/wiki.test.tokens seq_128_hdim_64 MODEL_CLASS = models[sys.argv[1]] train_data_path = Path(sys.argv[2]) val_data_path = Path(sys.argv[3]) test_data_path = Path(sys.argv[4]) comment = f"{sys.argv[1]}_{train_data_path.name}_seq_{max_seq_len}{'_' + sys.argv[5] if len(sys.argv) >= 6 else ''}" logs_dir = f'logs/{datetime.now().date()}_{datetime.now().hour:02d}_{datetime.now().minute:02d}_{datetime.now().second:02d}_{comment}/' writer = SummaryWriter(logs_dir) script_snapshot_path = Path(logs_dir + Path(sys.argv[0]).parent.name) print("Logs dir:", logs_dir) # script_snapshot_path.write_bytes(Path(sys.argv[0]).read_bytes()) # copy this version of script shutil.copytree(Path(sys.argv[0]).parent, script_snapshot_path) # snapshot this version of repository script_snapshot_path.chmod(0o500) # with read-only permission #################################### Dataset ######################################### # wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt with open(train_data_path, encoding='utf-8') as f: train_text = f.read() with open(val_data_path, encoding='utf-8') as f: val_text = f.read() with open(test_data_path, encoding='utf-8') as f: test_text = f.read() text = train_text + val_text + test_text chars = sorted(set(text)) vocab_size = len(chars) wtoi = {w:i for i,w in enumerate(chars)} itow = {i:w for i,w in enumerate(chars)} encode = lambda s: [wtoi[w] for w in s] decode = lambda idx: ''.join([itow[i] for i in idx]) def get_batch(data, seq_len, batch_size): ix = torch.randint(len(data)-seq_len, (batch_size, )) x = torch.stack([data[i:i+seq_len] for i in ix]).to(device) y = torch.stack([data[i+1:i+1+seq_len] for i in ix]).to(device) return x, y train_data = torch.tensor(encode(train_text), dtype=torch.long) val_data = torch.tensor(encode(val_text), dtype=torch.long) test_data = torch.tensor(encode(test_text), dtype=torch.long) @torch.no_grad() def perplexity(model, data): stride = max(1, len(data) // 10000) losses = [] for i in range(0, len(data)-max_seq_len-1, stride): x = data[i:(i+max_seq_len)].to(device) y = data[(i+1):(i+max_seq_len+1)].to(device) logits, loss = model(x[None,...], y[None,...]) losses.append(loss.item()) print(f"\rppl {i}/{len(data)-max_seq_len-1}", end="") return np.exp(np.mean(losses)) #################################### Model #########################################mo def complete(m, start_idxs=[0], max_new_tokens=100): start_idx = torch.tensor([start_idxs]).to(device) generated_tokens = m.generate(start_idx=start_idx, max_new_tokens=max_new_tokens) return decode(generated_tokens[0].tolist()) m = MODEL_CLASS( vocab_size=vocab_size, h_dim=h_dim, max_seq_len=max_seq_len, num_heads=num_heads, pixel_size=pixel_size, layers_num=layers_num ) m = m.to(device) writer.add_text('model', str(m), 0) #################################### Train ######################################### optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, betas=(0.90, 0.95), weight_decay=0.01) m.eval() completion = complete(m, encode("\n"*max_seq_len), 2*max_seq_len) print(completion) writer.add_text('completions', completion, 0) for i in range(max_iters): m.train() optimizer.zero_grad(set_to_none=True) accumulated_loss = 0.0 for j in range(gradient_accumulation_steps): xb, yb = get_batch(train_data, seq_len=max_seq_len, batch_size=batch_size//gradient_accumulation_steps) logits, loss = m(xb, yb) loss = loss / gradient_accumulation_steps loss.backward() accumulated_loss += loss.item() if i % 100 == 0: writer.add_scalar('Gradient_Norm/Total', torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=1e6).item(), i) optimizer.step() writer.add_scalar('loss', accumulated_loss, i) print(f"\r{i}/{max_iters} {accumulated_loss}", end="") if i % 5000 == 0: m.eval() ppl = perplexity(model=m, data=val_data) writer.add_scalar('val_perplexity', ppl.item(), i) print(f"\rPerplexity at {i}: {ppl}") writer.add_text('completions', complete(m, encode("\n"*max_seq_len), 2*max_seq_len), i) m.eval() ppl = perplexity(model=m, data=val_data) print(f"\r{i+1}/{max_iters} {accumulated_loss}") print(f"\r{datetime.now()} Perplexity at {i}: {ppl}") writer.add_scalar('val_perplexity', ppl.item(), i+1) writer.add_scalar('loss', accumulated_loss, i) ppl = perplexity(model=m, data=test_data) writer.add_scalar('test_perplexity', ppl.item(), i+1) print(f"\rTest Perplexity at {i}: {ppl}") completion = complete(m, encode("\n"*max_seq_len), 2*max_seq_len) print(completion) writer.add_text('completions', completion, i+1)