diff --git a/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/train.py b/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/train.py index b10a7e240..db6744d8a 100644 --- a/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/train.py +++ b/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/train.py @@ -23,99 +23,104 @@ from .tf_model.weights_initializer_builder import TFModelWeightsInitializerBuild import twml def get_feature_values(features_values, params): - if params.lolly_model_tsv: - # The default DBv2 HashingDiscretizer bin membership interval is (a, b] - # - # The Earlybird Lolly prediction engine discretizer bin membership interval is [a, b) - # - # TFModelInitializerBuilder converts (a, b] to [a, b) by inverting the bin boundaries. - # - # Thus, invert the feature values, so that HashingDiscretizer can to find the correct bucket. - return tf.multiply(features_values, -1.0) - else: - return features_values + if params.lolly_model_tsv: + # The default DBv2 HashingDiscretizer bin membership interval is (a, b] + # + # The Earlybird Lolly prediction engine discretizer bin membership interval is [a, b) + # + # TFModelInitializerBuilder converts (a, b] to [a, b) by inverting the bin boundaries. + # + # Thus, invert the feature values, so that HashingDiscretizer can to find the correct bucket. + return tf.multiply(features_values, -1.0) + else: + return features_values def build_graph(features, label, mode, params, config=None): - ... - if mode != "infer": - ... - if opt.print_data_examples: - logits = print_data_example(logits, lolly_activations, features) - ... + # Function to build the Earlybird model graph + weights = None + if "weights" in features: + weights = make_weights_tensor(features["weights"], label, params) -# Added line breaks and indentation to improve readability -def print_data_example(logits, lolly_activations, features): - return tf.Print( - logits, - [ - logits, - lolly_activations, - tf.reshape(features['keys'], (1, -1)), - tf.reshape(tf.multiply(features['values'], -1.0), (1, -1)) - ], - message="DATA EXAMPLE = ", - summarize=10000 + num_bits = params.input_size_bits + + if mode == "infer": + indices = twml.limit_bits(features["input_sparse_tensor_indices"], num_bits) + dense_shape = tf.stack([features["input_sparse_tensor_shape"][0], 1 << num_bits]) + sparse_tf = tf.SparseTensor( + indices=indices, + values=get_feature_values(features["input_sparse_tensor_values"], params), + dense_shape=dense_shape ) + else: + features["values"] = get_feature_values(features["values"], params) + sparse_tf = twml.util.convert_to_sparse(features, num_bits) + if params.lolly_model_tsv: + tf_model_initializer = TFModelInitializerBuilder().build(LollyModelReader(params.lolly_model_tsv)) + bias_initializer, weight_initializer = TFModelWeightsInitializerBuilder(num_bits).build(tf_model_initializer) + discretizer = TFModelDiscretizerBuilder(num_bits).build(tf_model_initializer) + else: + discretizer = hub.Module(params.discretizer_save_dir) + bias_initializer, weight_initializer = None, None -# Import statements reformatted for better readability -import tensorflow.compat.v1 as tf -from tensorflow.python.estimator.export.export import build_raw_serving_input_receiver_fn -from tensorflow.python.framework import dtypes -from tensorflow.python.ops import array_ops -import tensorflow_hub as hub + input_sparse = discretizer(sparse_tf, signature="hashing_discretizer_calibrator") -from datetime import datetime -from tensorflow.compat.v1 import logging -from twitter.deepbird.projects.timelines.configs import all_configs -from twml.trainers import DataRecordTrainer -from twml.contrib.calibrators.common_calibrators import build_percentile_discretizer_graph -from twml.contrib.calibrators.common_calibrators import calibrate_discretizer_and_export -from .metrics import get_multi_binary_class_metric_fn -from .constants import TARGET_LABEL_IDX, PREDICTED_CLASSES -from .example_weights import add_weight_arguments, make_weights_tensor -from .lolly.data_helpers import get_lolly_logits -from .lolly.tf_model_initializer_builder import TFModelInitializerBuilder -from .lolly.reader import LollyModelReader -from .tf_model.discretizer_builder import TFModelDiscretizerBuilder -from .tf_model.weights_initializer_builder import TFModelWeightsInitializerBuilder + logits = twml.layers.full_sparse( + inputs=input_sparse, + output_size=1, + bias_initializer=bias_initializer, + weight_initializer=weight_initializer, + use_sparse_grads=(mode == "train"), + use_binary_values=True, + name="full_sparse_1" + ) -import twml + loss = None -# Added line breaks and indentation to improve readability -def get_feature_values(features_values, params): - if params.lolly_model_tsv: - return tf.multiply(features_values, -1.0) + if mode != "infer": + lolly_activations = get_lolly_logits(label) + + if opt.print_data_examples: + logits = print_data_example(logits, lolly_activations, features) + + if params.replicate_lolly: + loss = tf.reduce_mean(tf.math.squared_difference(logits, lolly_activations)) else: - return features_values + batch_size = tf.shape(label)[0] + target_label = tf.reshape(tensor=label[:, TARGET_LABEL_IDX], shape=(batch_size, 1)) + loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=target_label, logits=logits) + loss = twml.util.weighted_average(loss, weights) -# Added line breaks and indentation to improve readability -def build_graph(features, label, mode, params, config=None): - ... + num_labels = tf.shape(label)[1] + eb_scores = tf.tile(lolly_activations, [1, num_labels]) + logits = tf.tile(logits, [1, num_labels]) + logits = tf.concat([logits, eb_scores], axis=1) - if mode != "infer": - ... + output = tf.nn.sigmoid(logits) - if opt.print_data_examples: - logits = print_data_example(logits, lolly_activations, features) + return {"output": output, "loss": loss, "weights": weights} - ... - -# Added line breaks and indentation to improve readability def print_data_example(logits, lolly_activations, features): - return tf.Print( - logits, - [ - logits, - lolly_activations, - tf.reshape(features['keys'], (1, -1)), - tf.reshape(tf.multiply(features['values'], -1.0), (1, -1)) - ], - message="DATA EXAMPLE = ", - summarize=10000 - ) + # Function to print data example + return tf.Print( + logits, + [logits, lolly_activations, tf.reshape(features['keys'], (1, -1)), tf.reshape(tf.multiply(features['values'], -1.0), (1, -1))], + message="DATA EXAMPLE = ", + summarize=10000 + ) + +def earlybird_output_fn(graph_output): + # Function to process the Earlybird model output + export_outputs = { + tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: + tf.estimator.export.PredictOutput( + {"prediction": tf.identity(graph_output["output"], name="output_scores")} + ) + } + return export_outputs if __name__ == "__main__": + # Set up argument parser parser = DataRecordTrainer.add_parser_arguments() parser = twml.contrib.calibrators.add_discretizer_arguments(parser) @@ -136,8 +141,10 @@ if __name__ == "__main__": help="Prints 'DATA EXAMPLE = [[tf logit]][[logged lolly logit]][[feature ids][feature values]]'") add_weight_arguments(parser) + # Parse arguments opt = parser.parse_args() + # Set up feature configuration feature_config_module = all_configs.select_feature_config(opt.feature_config) feature_config = feature_config_module.get_feature_config(data_spec_path=opt.data_spec, label=opt.label) @@ -146,6 +153,7 @@ if __name__ == "__main__": feature_config, keep_fields=("ids", "keys", "values", "batch_size", "total_size", "codes")) + # Discretizer calibration (if necessary) if not opt.lolly_model_tsv: if opt.model_use_existing_discretizer: logging.info("Skipping discretizer calibration [model.use_existing_discretizer=True]") @@ -162,6 +170,7 @@ if __name__ == "__main__": build_graph_fn=build_percentile_discretizer_graph, feature_config=feature_config) + # Initialize trainer trainer = DataRecordTrainer( name="earlybird", params=opt, @@ -175,6 +184,7 @@ if __name__ == "__main__": warm_start_from=None ) + # Train and evaluate model train_input_fn = trainer.get_train_input_fn(parse_fn=parse_fn) eval_input_fn = trainer.get_eval_input_fn(parse_fn=parse_fn) @@ -184,6 +194,7 @@ if __name__ == "__main__": trainingEndTime = datetime.now() logging.info("Training and Evaluation time: " + str(trainingEndTime - trainingStartTime)) + # Export model (if current node is chief) if trainer._estimator.config.is_chief: serving_input_in_earlybird = { "input_sparse_tensor_indices": array_ops.placeholder( @@ -209,6 +220,3 @@ if __name__ == "__main__": feature_spec=feature_config.get_feature_spec() ) logging.info("The export model path is: " + opt.export_dir) - - -