From c94f4a42b9f6bf4808c87cc9dc37287585732b42 Mon Sep 17 00:00:00 2001 From: Harshil Mehta <37377066+harshil21@users.noreply.github.com> Date: Sat, 1 Apr 2023 01:23:09 +0530 Subject: [PATCH] make it easier to do multiple fold training --- trust_and_safety_models/toxicity/train.py | 25 ++++++++++++----------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/trust_and_safety_models/toxicity/train.py b/trust_and_safety_models/toxicity/train.py index de450ee7b..4e3a33659 100644 --- a/trust_and_safety_models/toxicity/train.py +++ b/trust_and_safety_models/toxicity/train.py @@ -387,15 +387,16 @@ class Trainer(object): fold=i, ) else: - raise ValueError("Sure you want to do multiple fold training") - for mb_generator, steps_per_epoch, val_data, test_data in self.mb_loader(full_df=df): - self._train_single_fold( - mb_generator=mb_generator, - val_data=val_data, - test_data=test_data, - steps_per_epoch=steps_per_epoch, - fold=i, - ) - i += 1 - if i == 3: - break + a = input("Are you sure you want to do multiple fold training? (y/n)") + if a.lower() == "y": + for mb_generator, steps_per_epoch, val_data, test_data in self.mb_loader(full_df=df): + self._train_single_fold( + mb_generator=mb_generator, + val_data=val_data, + test_data=test_data, + steps_per_epoch=steps_per_epoch, + fold=i, + ) + i += 1 + if i == 3: + break