diff --git a/trust_and_safety_models/toxicity/train.py b/trust_and_safety_models/toxicity/train.py index de450ee7b..4e3a33659 100644 --- a/trust_and_safety_models/toxicity/train.py +++ b/trust_and_safety_models/toxicity/train.py @@ -387,15 +387,16 @@ class Trainer(object): fold=i, ) else: - raise ValueError("Sure you want to do multiple fold training") - for mb_generator, steps_per_epoch, val_data, test_data in self.mb_loader(full_df=df): - self._train_single_fold( - mb_generator=mb_generator, - val_data=val_data, - test_data=test_data, - steps_per_epoch=steps_per_epoch, - fold=i, - ) - i += 1 - if i == 3: - break + a = input("Are you sure you want to do multiple fold training? (y/n)") + if a.lower() == "y": + for mb_generator, steps_per_epoch, val_data, test_data in self.mb_loader(full_df=df): + self._train_single_fold( + mb_generator=mb_generator, + val_data=val_data, + test_data=test_data, + steps_per_epoch=steps_per_epoch, + fold=i, + ) + i += 1 + if i == 3: + break