373 lines
15 KiB
Python
373 lines
15 KiB
Python
|
"""
|
||
|
This module implements custom tf.data.datasets for twml.
|
||
|
"""
|
||
|
import numbers
|
||
|
|
||
|
from absl import logging
|
||
|
from kazoo.client import KazooClient
|
||
|
from libtwml import OPLIB
|
||
|
import tensorflow.compat.v1 as tf
|
||
|
from twml.constants import DEFAULT_ZOOKEEPER_BASE_ZNODE, DEFAULT_ZOOKEEPER_HOST
|
||
|
|
||
|
|
||
|
class BlockFormatDataset(tf.data.Dataset):
|
||
|
"""A ``tf.data.Dataset`` comprising records from one or more TFRecord files."""
|
||
|
|
||
|
def __init__(self, filenames, compression_type="auto", buffer_size=1 << 20):
|
||
|
"""
|
||
|
Creates a ``BlockFormatDataset``.
|
||
|
|
||
|
Args:
|
||
|
filenames:
|
||
|
A `tf.string` tensor containing one or more filenames.
|
||
|
compression_type:
|
||
|
A string specifying the compression type.
|
||
|
Can be one of 'gz' (or 'gzip'), 'none', 'auto' (default).
|
||
|
When compression_type == 'auto', it is inferred from file extension.
|
||
|
buffer_size:
|
||
|
Buffer size to be used during decompression. default: 1<<20.
|
||
|
"""
|
||
|
self._filenames = tf.convert_to_tensor(filenames, dtype=tf.string, name="filenames")
|
||
|
self._compression_type = tf.convert_to_tensor(compression_type.lower(), name="compression_type")
|
||
|
self._buffer_size = tf.convert_to_tensor(buffer_size, dtype=tf.int64, name="buffer_size")
|
||
|
# Parent class calss self._as_variant_tensor in init. So call this at the end.
|
||
|
super(BlockFormatDataset, self).__init__()
|
||
|
|
||
|
def _as_variant_tensor(self):
|
||
|
"""
|
||
|
Create the resource handle for the dataset.
|
||
|
"""
|
||
|
try:
|
||
|
block_format_dataset = __import__("libtwml_internal").OPLIB.block_format_dataset
|
||
|
return block_format_dataset(self._filenames)
|
||
|
except ImportError:
|
||
|
block_format_dataset = OPLIB.block_format_dataset_v2
|
||
|
return block_format_dataset(self._filenames, self._compression_type, self._buffer_size)
|
||
|
|
||
|
def _inputs(self):
|
||
|
return []
|
||
|
|
||
|
@property
|
||
|
def output_shapes(self):
|
||
|
"""Return output shapes"""
|
||
|
return tf.TensorShape([])
|
||
|
|
||
|
@property
|
||
|
def output_types(self):
|
||
|
"""Return output types"""
|
||
|
return tf.string
|
||
|
|
||
|
@property
|
||
|
def output_classes(self):
|
||
|
"""Return output classes"""
|
||
|
return tf.Tensor
|
||
|
|
||
|
|
||
|
def downsample_dataset(dataset, sample_rate, rate_name):
|
||
|
"""
|
||
|
Downsample a tf.data.Dataset at sample_rate
|
||
|
"""
|
||
|
if sample_rate is None or sample_rate == 1.0:
|
||
|
return dataset
|
||
|
elif not isinstance(sample_rate, numbers.Real):
|
||
|
raise TypeError("dataset %s must be a real number" % rate_name)
|
||
|
elif sample_rate <= 0 or sample_rate > 1:
|
||
|
raise ValueError("dataset %s must be in range (0, 1])" % rate_name)
|
||
|
return dataset.filter(lambda _: tf.squeeze(tf.random_uniform([1])) < sample_rate)
|
||
|
|
||
|
|
||
|
def _filenames_dataset(files, shards=None, shard_index=None):
|
||
|
"""
|
||
|
Get a tf.data.Dataset with file names from a list of files
|
||
|
Optionally shard the file list (see stream_block_format_dataset)
|
||
|
"""
|
||
|
files = tf.data.Dataset.from_tensor_slices(files)
|
||
|
|
||
|
if [shards, shard_index] != [None, None]:
|
||
|
logging.info("Sharding files dataset (index: %d, shards: %d)" % (shard_index, shards))
|
||
|
files = files.shard(num_shards=shards, index=shard_index)
|
||
|
|
||
|
return files
|
||
|
|
||
|
|
||
|
def stream_block_format_dataset(
|
||
|
files, parse_fn, batch_size, num_threads,
|
||
|
shuffle=True, repeat=False,
|
||
|
block_length=None, part_file_parallelism=None, file_shuffle_size=None,
|
||
|
record_shuffle_size=None, dataset_fn=None,
|
||
|
keep_rate=None, parts_downsampling_rate=None, prefetch_size=2,
|
||
|
shards=None, shard_index=None, shuffle_files=True, interleave=True):
|
||
|
"""
|
||
|
Helper function to stream a list of part files.
|
||
|
|
||
|
Args:
|
||
|
files:
|
||
|
List of input files which will create a dataset.
|
||
|
parse_fn:
|
||
|
A function that takes a byte tensor containing a datarecord and decodes it.
|
||
|
batch_size:
|
||
|
The batch size for each step.
|
||
|
num_threads:
|
||
|
Number of threads working on the data in parallel.
|
||
|
shuffle:
|
||
|
Shuffle records within each file using ``record_shuffle_size``. Defaults to True.
|
||
|
repeat:
|
||
|
Repeat the dataset indefinitely. Defaults to False.
|
||
|
Useful when you want to use an ``[train,eval]_steps`` greater than the size of the dataset
|
||
|
(otherwise ``Estimator.[train,evaluate]`` stop when the end of the dataset is reached).
|
||
|
block_length (optional):
|
||
|
Number of consecutive records to pull from a single part file.
|
||
|
Defaults to batch_size.
|
||
|
part_file_parallelism (optional):
|
||
|
Number of part files to read from in parallel. Once a part file is completely read, it will
|
||
|
be replaced by the next part file in the part file list.
|
||
|
|
||
|
``num_threads`` specifies a reader thread pool size, while ``part_file_parallelism`` specifies
|
||
|
the number of files to read from in parallel. If ``part_file_parallelism`` is greater than or
|
||
|
equal to ``num_threads``, the reads will be distributed over ``num_threads``. On the other hand,
|
||
|
if ``part_file_parallelism`` is smaller than``num_threads``, it is very likely that the reader
|
||
|
thread pool will be underutilized, since it can never be the case that every reader thread has
|
||
|
a part file to read from.
|
||
|
|
||
|
file_shuffle_size (optional):
|
||
|
the buffer_size used for shuffling of the list of files.
|
||
|
Defaults to 1000. For example, if you have 2000 files, the first
|
||
|
1000 files are shuffled together, iterated through, then the next 1000 files are shuffled
|
||
|
and iterated through.
|
||
|
record_shuffle_size (optional):
|
||
|
the ``buffer_size`` used for shuffling records in each thread.
|
||
|
Defaults to ``batch_size * 8`` records.
|
||
|
dataset_fn (optional):
|
||
|
A function of that modifies the dataset after it reads different interleaved parts files.
|
||
|
Defaults to:
|
||
|
|
||
|
.. code-block:: python
|
||
|
|
||
|
def dataset_fn(dataset, parse_fn, batch_size):
|
||
|
return dataset.batch(batch_size).map(parse_fn, 1)
|
||
|
|
||
|
keep_rate (optional):
|
||
|
A float value in (0.0, 1.0] that indicates to drop records according to the Bernoulli
|
||
|
distribution with p = 1 - keep_rate.
|
||
|
Defaults to None (no records dropped).
|
||
|
|
||
|
parts_downsampling_rate (optional):
|
||
|
A float value in ``(0.0, 1.0]`` that indicates the factor by which to downsample part files.
|
||
|
For example, a value of 0.2 means only 20 percent of part files become part of the dataset.
|
||
|
|
||
|
Note that this argument is only useful in conjunction with a [train,eval]_steps of -1
|
||
|
(that is, when the entire dataset is used). Furthermore, note that even in this case, each
|
||
|
epoch will see a different set of part files. This is because new part files are re-sampled
|
||
|
every epoch. In other words, this argument is only provided for backwards compatibility with
|
||
|
DeepBird v1. We recommend you use a smaller [train,eval]_steps (or specify a keep_rate)
|
||
|
instead.
|
||
|
|
||
|
shards (optional):
|
||
|
Number of partitions to shard the dataset into. This is useful for codistillation and other
|
||
|
techniques that require each worker to train on disjoint partitions of the dataset.
|
||
|
The dataset is not sharded by default.
|
||
|
|
||
|
shard_index (optional):
|
||
|
Which partition of the dataset to use if ``shards`` is set.
|
||
|
|
||
|
shuffle_files (optional):
|
||
|
Shuffle the list of files. Defaults to True.
|
||
|
When False, files are iterated in the order they are passed in.
|
||
|
|
||
|
interleave (optional):
|
||
|
Interleave records from multiple files in parallel. Defaults to True.
|
||
|
|
||
|
Returns:
|
||
|
tf.data.DataSet of batches of HashedDataRecord resource handles decoded and streamed online.
|
||
|
"""
|
||
|
# Creating a dataset from an input directory
|
||
|
|
||
|
files = _filenames_dataset(files, shards=shards, shard_index=shard_index)
|
||
|
|
||
|
file_shuffle_size = file_shuffle_size if file_shuffle_size is not None else 100000
|
||
|
record_shuffle_size = record_shuffle_size if record_shuffle_size is not None else (batch_size * 8)
|
||
|
block_length = block_length if block_length is not None else batch_size
|
||
|
|
||
|
logging.info("NUM_THREADS: %d", num_threads)
|
||
|
|
||
|
if repeat:
|
||
|
files = files.repeat()
|
||
|
|
||
|
if shuffle_files:
|
||
|
# Randomly shuffle the files list.
|
||
|
files = files.shuffle(buffer_size=file_shuffle_size)
|
||
|
|
||
|
# Downsample parts files
|
||
|
files = downsample_dataset(files, parts_downsampling_rate, "parts_downsampling_rate")
|
||
|
|
||
|
# Interleave the result from BlockFormatDataset
|
||
|
# block_length == batch_size results in batch_size records being read from a single file.
|
||
|
def map_fn(filenames):
|
||
|
'''function that maps each filename to a BlockFormatDataset'''
|
||
|
# reach each file using BlockFormatDataset
|
||
|
dataset = BlockFormatDataset(filenames)
|
||
|
|
||
|
# early prefetching can sometimes improve performance (like on GCS)
|
||
|
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
|
||
|
|
||
|
# Shuffling before repeating ensures strong ordering.
|
||
|
if shuffle:
|
||
|
dataset = dataset.shuffle(buffer_size=record_shuffle_size)
|
||
|
|
||
|
return dataset
|
||
|
|
||
|
if interleave:
|
||
|
part_file_parallelism = num_threads if part_file_parallelism is None else part_file_parallelism
|
||
|
dataset = files.interleave(
|
||
|
map_fn, cycle_length=part_file_parallelism, block_length=block_length, num_parallel_calls=num_threads)
|
||
|
else:
|
||
|
dataset = files.flat_map(map_fn)
|
||
|
|
||
|
# Downsample DataRecords
|
||
|
dataset = downsample_dataset(dataset, keep_rate, "keep_rate")
|
||
|
|
||
|
if dataset_fn is None:
|
||
|
# Create a batch of datarecords and decode them
|
||
|
return dataset.batch(batch_size).map(parse_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE).prefetch(prefetch_size)
|
||
|
|
||
|
return dataset_fn(dataset, parse_fn, batch_size)
|
||
|
|
||
|
|
||
|
def cx_zk_path(path):
|
||
|
if path is None:
|
||
|
raise ValueError("Path for zookeeper dataset pointer is None. You must specify a path.")
|
||
|
return_path = "/".join([DEFAULT_ZOOKEEPER_BASE_ZNODE, path])
|
||
|
logging.info("Zookeeper path is: {}".format(return_path))
|
||
|
return return_path
|
||
|
|
||
|
|
||
|
def zookeeper_ordered_dataset(
|
||
|
files, parse_fn, batch_size, zk_counter_path, repeat=False,
|
||
|
num_threads=2, block_length=None, part_file_parallelism=None,
|
||
|
batch_shuffle_size=None, file_keep_rate=None, record_keep_rate=None,
|
||
|
prefetch_size=2, interleave=False, dataset_fn=None, verbose=False):
|
||
|
"""
|
||
|
Make a tf.Dataset given an ordered list of filenames, using Zookeeper to keep track of
|
||
|
which file to read, and to coordinate multiple workers.
|
||
|
|
||
|
Args:
|
||
|
files:
|
||
|
ordered list of (typically HDFS) filenames. This must remain consistent
|
||
|
between different workers, and between worker restarts (e.g. in the case
|
||
|
of instance failure or preemption).
|
||
|
To ensure this remains consistent, consider using the --train.files_list
|
||
|
option from DataRecordTrainer.
|
||
|
parse_fn:
|
||
|
A function that takes a byte tensor containing a datarecord and decodes it.
|
||
|
batch_size:
|
||
|
The batch size for each step.
|
||
|
zk_counter_path:
|
||
|
Path under the root node for the underlying zookeeper shared counter that
|
||
|
is used to coordinate distributed iteration over the list of files.
|
||
|
Full path will be `'/'.join([DEFAULT_ZOOKEEPER_BASE_ZNODE, zk_counter_path])`.
|
||
|
repeat:
|
||
|
Default False. Set True to repeat over the files forever.
|
||
|
num_threads:
|
||
|
Default 2. Number of threads working on the data in parallel.
|
||
|
Only used if interleave=True.
|
||
|
block_length:
|
||
|
Default None. Number of consecutive records to pull from a single part file.
|
||
|
If None, then block_length=batch_size will be used.
|
||
|
Only used if interleave=True.
|
||
|
part_file_parallelism:
|
||
|
Default None. Number of part files to read from in parallel. Once a part file is completely
|
||
|
read, it will be replaced by the next part file indicated by the zookeeper counter.
|
||
|
Only used if interleave=True.
|
||
|
|
||
|
``num_threads`` specifies a reader thread pool size, while ``part_file_parallelism`` specifies
|
||
|
the number of files to read from in parallel. If ``part_file_parallelism`` is greater than or
|
||
|
equal to ``num_threads``, the reads will be distributed over ``num_threads``. On the other hand,
|
||
|
if ``part_file_parallelism`` is smaller than``num_threads``, it is very likely that the reader
|
||
|
thread pool will be underutilized, since it can never be the case that every reader thread has
|
||
|
a part file to read from.
|
||
|
|
||
|
batch_shuffle_size:
|
||
|
Default None. Size of shuffle buffer, for shuffling that will be applied after batching.
|
||
|
if None, then batches will not be shuffled. Ignored if dataset_fn is provided.
|
||
|
file_keep_rate:
|
||
|
Default None. Fraction of files to keep, or None to keep all files.
|
||
|
record_keep_rate:
|
||
|
Default None. Fraction of records to keep, or None to keep all records.
|
||
|
prefetch_size:
|
||
|
Default 2. Number of parsed batches to prefetch. Ignored if dataset_fn is provided.
|
||
|
interleave:
|
||
|
Default False. Set True to use tf.data.Dataset.interleave rather than flat_map.
|
||
|
dataset_fn:
|
||
|
A function that is applied to the dataset of individual records, after
|
||
|
these have been read from the parts files.
|
||
|
If ``None`` (the default), the behavior will be as though dataset_fn were set to:
|
||
|
|
||
|
.. code-block:: python
|
||
|
|
||
|
def dataset_fn(dataset, parse_fn, batch_size):
|
||
|
dataset = dataset.batch(batch_size)
|
||
|
dataset = dataset.map(parse_fn, tf.data.experimental.AUTOTUNE)
|
||
|
if batch_shuffle_size:
|
||
|
dataset = dataset.shuffle(batch_shuffle_size)
|
||
|
return dataset.prefetch(prefetch_size)
|
||
|
|
||
|
verbose:
|
||
|
Default False. Set True to log the names of files loaded by TF.
|
||
|
"""
|
||
|
block_length = batch_size if block_length is None else block_length
|
||
|
part_file_parallelism = num_threads if part_file_parallelism is None else part_file_parallelism
|
||
|
|
||
|
def zk_index_generator(my_files=files):
|
||
|
zk = KazooClient(hosts=DEFAULT_ZOOKEEPER_HOST)
|
||
|
zk.start()
|
||
|
my_counter = zk.Counter(cx_zk_path(zk_counter_path), default=0)
|
||
|
while True:
|
||
|
my_counter += 1
|
||
|
counter_pre_value = my_counter.pre_value
|
||
|
if repeat:
|
||
|
counter_pre_value = counter_pre_value % len(my_files)
|
||
|
if counter_pre_value >= len(my_files):
|
||
|
break
|
||
|
else:
|
||
|
chosen_file = my_files[counter_pre_value]
|
||
|
if verbose:
|
||
|
logging.info("{}. yielding {}".format(counter_pre_value, chosen_file))
|
||
|
yield chosen_file
|
||
|
zk.stop()
|
||
|
|
||
|
files = tf.data.Dataset.from_generator(zk_index_generator, tf.string)
|
||
|
|
||
|
# Downsample parts files
|
||
|
files = downsample_dataset(files, file_keep_rate, "file_keep_rate")
|
||
|
|
||
|
def map_fn(filenames):
|
||
|
return BlockFormatDataset(filenames).prefetch(20)
|
||
|
|
||
|
# Dont interleave for sequential training
|
||
|
if interleave:
|
||
|
dataset = files.interleave(
|
||
|
map_fn,
|
||
|
cycle_length=part_file_parallelism,
|
||
|
block_length=block_length,
|
||
|
num_parallel_calls=num_threads)
|
||
|
else:
|
||
|
dataset = files.flat_map(map_fn)
|
||
|
|
||
|
# Downsample DataRecords
|
||
|
dataset = downsample_dataset(dataset, record_keep_rate, "record_keep_rate")
|
||
|
|
||
|
if dataset_fn is None:
|
||
|
# Create a batch of datarecords and decode them
|
||
|
dataset = dataset.batch(batch_size)
|
||
|
dataset = dataset.map(parse_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
||
|
# shuffle after batching and parsing for performance reasons
|
||
|
# faster b/c 1 random selection is made per batch rather than per record
|
||
|
if batch_shuffle_size:
|
||
|
dataset = dataset.shuffle(buffer_size=batch_shuffle_size)
|
||
|
dataset = dataset.prefetch(prefetch_size)
|
||
|
|
||
|
else:
|
||
|
dataset = dataset_fn(dataset, parse_fn, batch_size)
|
||
|
|
||
|
return dataset
|