import pandas as pd import numpy as np from IPython.display import display, clear_output import torch import torch.nn as nn def print_results(epoch, logs): results = pd.DataFrame(data=np.array([v for k,v in logs.items()]).reshape(2,2), columns=['loss', 'accuracy'], index=['train','val']) results = results.style.set_caption(f"{epoch}") clear_output(wait=True) display(results) def train(X_train, y_train, X_val, y_val, model, epochs=50, lr=1e-2,verbose=False): X_train = torch.tensor(X_train, dtype=torch.float32).cuda() X_val = torch.tensor(X_val, dtype=torch.float32).cuda() history = {'loss': [], 'accuracy': [], 'val_loss': [], 'val_accuracy': []} model = torch.compile(model).cuda() criterion = nn.BCELoss() optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) train_dataset = torch.utils.data.TensorDataset(X_train, torch.tensor(y_train)) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4096, shuffle=True) val_dataset = torch.utils.data.TensorDataset(X_val, torch.tensor(y_val)) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4096, shuffle=False) for i in range(epochs): running_loss = 0.0 running_accuracy = 0.0 # Train model.train() for X_batch, y_batch in train_loader: optimizer.zero_grad() y_pred = model(X_batch.cuda()).cpu() # print(y_pred.shape, y_batch.shape) loss = criterion(y_pred, y_batch.float()) loss.backward() optimizer.step() running_loss += loss.item() * X_batch.size(0) running_accuracy += ((y_pred > 0.5) == y_batch).sum().item() running_loss /= X_train.shape[0] running_accuracy /= X_train.shape[0] val_running_loss = 0.0 val_running_accuracy = 0.0 # Validate model.eval() with torch.no_grad(): for X_batch, y_batch in val_loader: y_pred = model(X_batch.cuda()).cpu() loss = criterion(y_pred, y_batch.float()) val_running_loss += loss.item() * X_batch.size(0) val_running_accuracy += ((y_pred > 0.5) == y_batch).sum().item() val_running_loss /= X_val.shape[0] val_running_accuracy /= X_val.shape[0] if verbose: print_results(i, {'loss': running_loss, 'accuracy': running_accuracy, 'val_loss': val_running_loss, 'val_accuracy': val_running_accuracy}) history['loss'].append(running_loss) history['accuracy'].append(running_accuracy) history['val_loss'].append(val_running_loss) history['val_accuracy'].append(val_running_accuracy) if verbose: print_results(i, {'loss': running_loss, 'accuracy': running_accuracy, 'val_loss': val_running_loss, 'val_accuracy': val_running_accuracy}) return {'loss': running_loss, 'accuracy': running_accuracy, 'val_loss': val_running_loss, 'val_accuracy': val_running_accuracy}, history