45 lines
1.3 KiB
Python
45 lines
1.3 KiB
Python
|
from typing import Callable
|
||
|
|
||
|
import tensorflow as tf
|
||
|
|
||
|
|
||
|
class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
|
||
|
def __init__(
|
||
|
self,
|
||
|
initial_learning_rate: float,
|
||
|
decay_schedule_fn: Callable,
|
||
|
warmup_steps: int,
|
||
|
power: float = 1.0,
|
||
|
name: str = "",
|
||
|
):
|
||
|
super().__init__()
|
||
|
self.initial_learning_rate = initial_learning_rate
|
||
|
self.warmup_steps = warmup_steps
|
||
|
self.power = power
|
||
|
self.decay_schedule_fn = decay_schedule_fn
|
||
|
self.name = name
|
||
|
|
||
|
def __call__(self, step):
|
||
|
with tf.name_scope(self.name or "WarmUp") as name:
|
||
|
global_step_float = tf.cast(step, tf.float32)
|
||
|
warmup_steps_float = tf.cast(self.warmup_steps, tf.float32)
|
||
|
warmup_percent_done = global_step_float / warmup_steps_float
|
||
|
warmup_learning_rate = self.initial_learning_rate * tf.math.pow(
|
||
|
warmup_percent_done, self.power
|
||
|
)
|
||
|
return tf.cond(
|
||
|
global_step_float < warmup_steps_float,
|
||
|
lambda: warmup_learning_rate,
|
||
|
lambda: self.decay_schedule_fn(step - self.warmup_steps),
|
||
|
name=name,
|
||
|
)
|
||
|
|
||
|
def get_config(self):
|
||
|
return {
|
||
|
"initial_learning_rate": self.initial_learning_rate,
|
||
|
"decay_schedule_fn": self.decay_schedule_fn,
|
||
|
"warmup_steps": self.warmup_steps,
|
||
|
"power": self.power,
|
||
|
"name": self.name,
|
||
|
}
|