diff --git a/src/bert_training.py b/src/bert_training.py index ec78b55..62069ff 100644 --- a/src/bert_training.py +++ b/src/bert_training.py @@ -324,6 +324,7 @@ print(f"Model parameters count: ", sum(p.numel() for p in model.parameters())) # The Road Less Scheduled https://arxiv.org/html/2405.15682v4 optimizer = schedulefree.AdamWScheduleFree(model.parameters()) +# class weighting is important positive_counts = credit_train_dataset.targets_df.loc[credit_train_dataset.train_uniq_client_ids].values.sum() negative_counts = len(credit_train_dataset.targets_df.loc[credit_train_dataset.train_uniq_client_ids]) - positive_counts pos_weight = negative_counts / (positive_counts + 1e-15)