|
|
|
@ -262,8 +262,9 @@ def test(start_time, epoch, batches_per_epoch, batch_size, model, optimizer, cre
|
|
|
|
|
outputs = model(test_cat_inputs, test_num_inputs)
|
|
|
|
|
test_auroc.update(outputs, test_targets.long())
|
|
|
|
|
print(f"\r {test_batch_id}/{len(credit_dataset.test_uniq_client_ids)//batch_size} {test_auroc.compute().item():.5f}", end = " "*20)
|
|
|
|
|
writer.add_scalar('test_roc_auc', test_auroc.compute().item(), epoch * batches_per_epoch)
|
|
|
|
|
print(f"\r {datetime.now() - start_time} {epoch}/{epochs} Test rocauc: {test_auroc.compute().item():.5f}", end = " "*20)
|
|
|
|
|
if not writer is None:
|
|
|
|
|
writer.add_scalar('test_roc_auc', test_auroc.compute().item(), epoch * batches_per_epoch)
|
|
|
|
|
print(f"\r {datetime.now() - start_time} {epoch}/{epochs} Test rocauc: {test_auroc.compute().item():.5f}", end = " "*20)
|
|
|
|
|
print()
|
|
|
|
|
|
|
|
|
|
######################################### Training ################################################################
|
|
|
|
@ -385,7 +386,7 @@ try:
|
|
|
|
|
optimizer=optimizer,
|
|
|
|
|
credit_dataset=credit_train_dataset,
|
|
|
|
|
test_auroc=test_auroc,
|
|
|
|
|
writer=writer
|
|
|
|
|
writer=None
|
|
|
|
|
)
|
|
|
|
|
rocauc = test_auroc.compute().item()
|
|
|
|
|
save_checkpoint(
|
|
|
|
|