diff --git a/src/bert_training.py b/src/bert_training.py index 62069ff..d8836da 100644 --- a/src/bert_training.py +++ b/src/bert_training.py @@ -262,8 +262,9 @@ def test(start_time, epoch, batches_per_epoch, batch_size, model, optimizer, cre outputs = model(test_cat_inputs, test_num_inputs) test_auroc.update(outputs, test_targets.long()) print(f"\r {test_batch_id}/{len(credit_dataset.test_uniq_client_ids)//batch_size} {test_auroc.compute().item():.5f}", end = " "*20) - writer.add_scalar('test_roc_auc', test_auroc.compute().item(), epoch * batches_per_epoch) - print(f"\r {datetime.now() - start_time} {epoch}/{epochs} Test rocauc: {test_auroc.compute().item():.5f}", end = " "*20) + if not writer is None: + writer.add_scalar('test_roc_auc', test_auroc.compute().item(), epoch * batches_per_epoch) + print(f"\r {datetime.now() - start_time} {epoch}/{epochs} Test rocauc: {test_auroc.compute().item():.5f}", end = " "*20) print() ######################################### Training ################################################################ @@ -385,7 +386,7 @@ try: optimizer=optimizer, credit_dataset=credit_train_dataset, test_auroc=test_auroc, - writer=writer + writer=None ) rocauc = test_auroc.compute().item() save_checkpoint(