227 lines
7.0 KiB
Python
227 lines
7.0 KiB
Python
import os
|
|
|
|
from toxicity_ml_pipeline.settings.default_settings_tox import LOCAL_DIR, MAX_SEQ_LENGTH
|
|
try:
|
|
from toxicity_ml_pipeline.optim.losses import MaskedBCE
|
|
except ImportError:
|
|
print('No MaskedBCE loss')
|
|
from toxicity_ml_pipeline.utils.helpers import execute_command
|
|
|
|
import tensorflow as tf
|
|
|
|
|
|
try:
|
|
from twitter.cuad.representation.models.text_encoder import TextEncoder
|
|
except ModuleNotFoundError:
|
|
print("No TextEncoder package")
|
|
|
|
try:
|
|
from transformers import TFAutoModelForSequenceClassification
|
|
except ModuleNotFoundError:
|
|
print("No HuggingFace package")
|
|
|
|
LOCAL_MODEL_DIR = os.path.join(LOCAL_DIR, "models")
|
|
|
|
|
|
def reload_model_weights(weights_dir, language, **kwargs):
|
|
optimizer = tf.keras.optimizers.Adam(0.01)
|
|
model_type = (
|
|
"twitter_bert_base_en_uncased_mlm"
|
|
if language == "en"
|
|
else "twitter_multilingual_bert_base_cased_mlm"
|
|
)
|
|
model = load(optimizer=optimizer, seed=42, model_type=model_type, **kwargs)
|
|
model.load_weights(weights_dir)
|
|
|
|
return model
|
|
|
|
|
|
def _locally_copy_models(model_type):
|
|
if model_type == "twitter_multilingual_bert_base_cased_mlm":
|
|
preprocessor = "bert_multi_cased_preprocess_3"
|
|
elif model_type == "twitter_bert_base_en_uncased_mlm":
|
|
preprocessor = "bert_en_uncased_preprocess_3"
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
copy_cmd = """mkdir {local_dir}
|
|
gsutil cp -r ...
|
|
gsutil cp -r ..."""
|
|
execute_command(
|
|
copy_cmd.format(model_type=model_type, preprocessor=preprocessor, local_dir=LOCAL_MODEL_DIR)
|
|
)
|
|
|
|
return preprocessor
|
|
|
|
|
|
def load_encoder(model_type, trainable):
|
|
try:
|
|
model = TextEncoder(
|
|
max_seq_lengths=MAX_SEQ_LENGTH,
|
|
model_type=model_type,
|
|
cluster="gcp",
|
|
trainable=trainable,
|
|
enable_dynamic_shapes=True,
|
|
)
|
|
except (OSError, tf.errors.AbortedError) as e:
|
|
print(e)
|
|
preprocessor = _locally_copy_models(model_type)
|
|
|
|
model = TextEncoder(
|
|
max_seq_lengths=MAX_SEQ_LENGTH,
|
|
local_model_path=f"models/{model_type}",
|
|
local_preprocessor_path=f"models/{preprocessor}",
|
|
cluster="gcp",
|
|
trainable=trainable,
|
|
enable_dynamic_shapes=True,
|
|
)
|
|
|
|
return model
|
|
|
|
|
|
def get_loss(loss_name, from_logits, **kwargs):
|
|
loss_name = loss_name.lower()
|
|
if loss_name == "bce":
|
|
print("Binary CE loss")
|
|
return tf.keras.losses.BinaryCrossentropy(from_logits=from_logits)
|
|
|
|
if loss_name == "cce":
|
|
print("Categorical cross-entropy loss")
|
|
return tf.keras.losses.CategoricalCrossentropy(from_logits=from_logits)
|
|
|
|
if loss_name == "scce":
|
|
print("Sparse categorical cross-entropy loss")
|
|
return tf.keras.losses.SparseCategoricalCrossentropy(from_logits=from_logits)
|
|
|
|
if loss_name == "focal_bce":
|
|
gamma = kwargs.get("gamma", 2)
|
|
print("Focal binary CE loss", gamma)
|
|
return tf.keras.losses.BinaryFocalCrossentropy(gamma=gamma, from_logits=from_logits)
|
|
|
|
if loss_name == 'masked_bce':
|
|
multitask = kwargs.get("multitask", False)
|
|
if from_logits or multitask:
|
|
raise NotImplementedError
|
|
print(f'Masked Binary Cross Entropy')
|
|
return MaskedBCE()
|
|
|
|
if loss_name == "inv_kl_loss":
|
|
raise NotImplementedError
|
|
|
|
raise ValueError(
|
|
f"This loss name is not valid: {loss_name}. Accepted loss names: BCE, masked BCE, CCE, sCCE, "
|
|
f"Focal_BCE, inv_KL_loss"
|
|
)
|
|
|
|
def _add_additional_embedding_layer(doc_embedding, glorot, seed):
|
|
doc_embedding = tf.keras.layers.Dense(768, activation="tanh", kernel_initializer=glorot)(doc_embedding)
|
|
doc_embedding = tf.keras.layers.Dropout(rate=0.1, seed=seed)(doc_embedding)
|
|
return doc_embedding
|
|
|
|
def _get_bias(**kwargs):
|
|
smart_bias_value = kwargs.get('smart_bias_value', 0)
|
|
print('Smart bias init to ', smart_bias_value)
|
|
output_bias = tf.keras.initializers.Constant(smart_bias_value)
|
|
return output_bias
|
|
|
|
|
|
def load_inhouse_bert(model_type, trainable, seed, **kwargs):
|
|
inputs = tf.keras.layers.Input(shape=(), dtype=tf.string)
|
|
encoder = load_encoder(model_type=model_type, trainable=trainable)
|
|
doc_embedding = encoder([inputs])["pooled_output"]
|
|
doc_embedding = tf.keras.layers.Dropout(rate=0.1, seed=seed)(doc_embedding)
|
|
|
|
glorot = tf.keras.initializers.glorot_uniform(seed=seed)
|
|
if kwargs.get("additional_layer", False):
|
|
doc_embedding = _add_additional_embedding_layer(doc_embedding, glorot, seed)
|
|
|
|
if kwargs.get('content_num_classes', None):
|
|
probs = get_last_layer(glorot=glorot, last_layer_name='target_output', **kwargs)(doc_embedding)
|
|
second_probs = get_last_layer(num_classes=kwargs['content_num_classes'],
|
|
last_layer_name='content_output',
|
|
glorot=glorot)(doc_embedding)
|
|
probs = [probs, second_probs]
|
|
else:
|
|
probs = get_last_layer(glorot=glorot, **kwargs)(doc_embedding)
|
|
model = tf.keras.models.Model(inputs=inputs, outputs=probs)
|
|
|
|
return model, False
|
|
|
|
def get_last_layer(**kwargs):
|
|
output_bias = _get_bias(**kwargs)
|
|
if 'glorot' in kwargs:
|
|
glorot = kwargs['glorot']
|
|
else:
|
|
glorot = tf.keras.initializers.glorot_uniform(seed=kwargs['seed'])
|
|
layer_name = kwargs.get('last_layer_name', 'dense_1')
|
|
|
|
if kwargs.get('num_classes', 1) > 1:
|
|
last_layer = tf.keras.layers.Dense(
|
|
kwargs["num_classes"], activation="softmax", kernel_initializer=glorot,
|
|
bias_initializer=output_bias, name=layer_name
|
|
)
|
|
|
|
elif kwargs.get('num_raters', 1) > 1:
|
|
if kwargs.get('multitask', False):
|
|
raise NotImplementedError
|
|
last_layer = tf.keras.layers.Dense(
|
|
kwargs['num_raters'], activation="sigmoid", kernel_initializer=glorot,
|
|
bias_initializer=output_bias, name='probs')
|
|
|
|
else:
|
|
last_layer = tf.keras.layers.Dense(
|
|
1, activation="sigmoid", kernel_initializer=glorot,
|
|
bias_initializer=output_bias, name=layer_name
|
|
)
|
|
|
|
return last_layer
|
|
|
|
def load_bertweet(**kwargs):
|
|
bert = TFAutoModelForSequenceClassification.from_pretrained(
|
|
os.path.join(LOCAL_MODEL_DIR, "bertweet-base"),
|
|
num_labels=1,
|
|
classifier_dropout=0.1,
|
|
hidden_size=768,
|
|
)
|
|
if "num_classes" in kwargs and kwargs["num_classes"] > 2:
|
|
raise NotImplementedError
|
|
|
|
return bert, True
|
|
|
|
|
|
def load(
|
|
optimizer,
|
|
seed,
|
|
model_type="twitter_multilingual_bert_base_cased_mlm",
|
|
loss_name="BCE",
|
|
trainable=True,
|
|
**kwargs,
|
|
):
|
|
if model_type == "bertweet-base":
|
|
model, from_logits = load_bertweet()
|
|
else:
|
|
model, from_logits = load_inhouse_bert(model_type, trainable, seed, **kwargs)
|
|
|
|
pr_auc = tf.keras.metrics.AUC(curve="PR", name="pr_auc", from_logits=from_logits)
|
|
roc_auc = tf.keras.metrics.AUC(curve="ROC", name="roc_auc", from_logits=from_logits)
|
|
|
|
loss = get_loss(loss_name, from_logits, **kwargs)
|
|
if kwargs.get('content_num_classes', None):
|
|
second_loss = get_loss(loss_name=kwargs['content_loss_name'], from_logits=from_logits)
|
|
loss_weights = {'content_output': kwargs['content_loss_weight'], 'target_output': 1}
|
|
model.compile(
|
|
optimizer=optimizer,
|
|
loss={'content_output': second_loss, 'target_output': loss},
|
|
loss_weights=loss_weights,
|
|
metrics=[pr_auc, roc_auc],
|
|
)
|
|
|
|
else:
|
|
model.compile(
|
|
optimizer=optimizer,
|
|
loss=loss,
|
|
metrics=[pr_auc, roc_auc],
|
|
)
|
|
print(model.summary(), "logits: ", from_logits)
|
|
|
|
return model |