the-algorithm/twml/twml/dataset.py

373 lines
15 KiB
Python
Raw Normal View History

"""
This module implements custom tf.data.datasets for twml.
"""
import numbers
from absl import logging
from kazoo.client import KazooClient
from libtwml import OPLIB
import tensorflow.compat.v1 as tf
from twml.constants import DEFAULT_ZOOKEEPER_BASE_ZNODE, DEFAULT_ZOOKEEPER_HOST
class BlockFormatDataset(tf.data.Dataset):
"""A ``tf.data.Dataset`` comprising records from one or more TFRecord files."""
def __init__(self, filenames, compression_type="auto", buffer_size=1 << 20):
"""
Creates a ``BlockFormatDataset``.
Args:
filenames:
A `tf.string` tensor containing one or more filenames.
compression_type:
A string specifying the compression type.
Can be one of 'gz' (or 'gzip'), 'none', 'auto' (default).
When compression_type == 'auto', it is inferred from file extension.
buffer_size:
Buffer size to be used during decompression. default: 1<<20.
"""
self._filenames = tf.convert_to_tensor(filenames, dtype=tf.string, name="filenames")
self._compression_type = tf.convert_to_tensor(compression_type.lower(), name="compression_type")
self._buffer_size = tf.convert_to_tensor(buffer_size, dtype=tf.int64, name="buffer_size")
# Parent class calss self._as_variant_tensor in init. So call this at the end.
super(BlockFormatDataset, self).__init__()
def _as_variant_tensor(self):
"""
Create the resource handle for the dataset.
"""
try:
block_format_dataset = __import__("libtwml_internal").OPLIB.block_format_dataset
return block_format_dataset(self._filenames)
except ImportError:
block_format_dataset = OPLIB.block_format_dataset_v2
return block_format_dataset(self._filenames, self._compression_type, self._buffer_size)
def _inputs(self):
return []
@property
def output_shapes(self):
"""Return output shapes"""
return tf.TensorShape([])
@property
def output_types(self):
"""Return output types"""
return tf.string
@property
def output_classes(self):
"""Return output classes"""
return tf.Tensor
def downsample_dataset(dataset, sample_rate, rate_name):
"""
Downsample a tf.data.Dataset at sample_rate
"""
if sample_rate is None or sample_rate == 1.0:
return dataset
elif not isinstance(sample_rate, numbers.Real):
raise TypeError("dataset %s must be a real number" % rate_name)
elif sample_rate <= 0 or sample_rate > 1:
raise ValueError("dataset %s must be in range (0, 1])" % rate_name)
return dataset.filter(lambda _: tf.squeeze(tf.random_uniform([1])) < sample_rate)
def _filenames_dataset(files, shards=None, shard_index=None):
"""
Get a tf.data.Dataset with file names from a list of files
Optionally shard the file list (see stream_block_format_dataset)
"""
files = tf.data.Dataset.from_tensor_slices(files)
if [shards, shard_index] != [None, None]:
logging.info("Sharding files dataset (index: %d, shards: %d)" % (shard_index, shards))
files = files.shard(num_shards=shards, index=shard_index)
return files
def stream_block_format_dataset(
files, parse_fn, batch_size, num_threads,
shuffle=True, repeat=False,
block_length=None, part_file_parallelism=None, file_shuffle_size=None,
record_shuffle_size=None, dataset_fn=None,
keep_rate=None, parts_downsampling_rate=None, prefetch_size=2,
shards=None, shard_index=None, shuffle_files=True, interleave=True):
"""
Helper function to stream a list of part files.
Args:
files:
List of input files which will create a dataset.
parse_fn:
A function that takes a byte tensor containing a datarecord and decodes it.
batch_size:
The batch size for each step.
num_threads:
Number of threads working on the data in parallel.
shuffle:
Shuffle records within each file using ``record_shuffle_size``. Defaults to True.
repeat:
Repeat the dataset indefinitely. Defaults to False.
Useful when you want to use an ``[train,eval]_steps`` greater than the size of the dataset
(otherwise ``Estimator.[train,evaluate]`` stop when the end of the dataset is reached).
block_length (optional):
Number of consecutive records to pull from a single part file.
Defaults to batch_size.
part_file_parallelism (optional):
Number of part files to read from in parallel. Once a part file is completely read, it will
be replaced by the next part file in the part file list.
``num_threads`` specifies a reader thread pool size, while ``part_file_parallelism`` specifies
the number of files to read from in parallel. If ``part_file_parallelism`` is greater than or
equal to ``num_threads``, the reads will be distributed over ``num_threads``. On the other hand,
if ``part_file_parallelism`` is smaller than``num_threads``, it is very likely that the reader
thread pool will be underutilized, since it can never be the case that every reader thread has
a part file to read from.
file_shuffle_size (optional):
the buffer_size used for shuffling of the list of files.
Defaults to 1000. For example, if you have 2000 files, the first
1000 files are shuffled together, iterated through, then the next 1000 files are shuffled
and iterated through.
record_shuffle_size (optional):
the ``buffer_size`` used for shuffling records in each thread.
Defaults to ``batch_size * 8`` records.
dataset_fn (optional):
A function of that modifies the dataset after it reads different interleaved parts files.
Defaults to:
.. code-block:: python
def dataset_fn(dataset, parse_fn, batch_size):
return dataset.batch(batch_size).map(parse_fn, 1)
keep_rate (optional):
A float value in (0.0, 1.0] that indicates to drop records according to the Bernoulli
distribution with p = 1 - keep_rate.
Defaults to None (no records dropped).
parts_downsampling_rate (optional):
A float value in ``(0.0, 1.0]`` that indicates the factor by which to downsample part files.
For example, a value of 0.2 means only 20 percent of part files become part of the dataset.
Note that this argument is only useful in conjunction with a [train,eval]_steps of -1
(that is, when the entire dataset is used). Furthermore, note that even in this case, each
epoch will see a different set of part files. This is because new part files are re-sampled
every epoch. In other words, this argument is only provided for backwards compatibility with
DeepBird v1. We recommend you use a smaller [train,eval]_steps (or specify a keep_rate)
instead.
shards (optional):
Number of partitions to shard the dataset into. This is useful for codistillation and other
techniques that require each worker to train on disjoint partitions of the dataset.
The dataset is not sharded by default.
shard_index (optional):
Which partition of the dataset to use if ``shards`` is set.
shuffle_files (optional):
Shuffle the list of files. Defaults to True.
When False, files are iterated in the order they are passed in.
interleave (optional):
Interleave records from multiple files in parallel. Defaults to True.
Returns:
tf.data.DataSet of batches of HashedDataRecord resource handles decoded and streamed online.
"""
# Creating a dataset from an input directory
files = _filenames_dataset(files, shards=shards, shard_index=shard_index)
file_shuffle_size = file_shuffle_size if file_shuffle_size is not None else 100000
record_shuffle_size = record_shuffle_size if record_shuffle_size is not None else (batch_size * 8)
block_length = block_length if block_length is not None else batch_size
logging.info("NUM_THREADS: %d", num_threads)
if repeat:
files = files.repeat()
if shuffle_files:
# Randomly shuffle the files list.
files = files.shuffle(buffer_size=file_shuffle_size)
# Downsample parts files
files = downsample_dataset(files, parts_downsampling_rate, "parts_downsampling_rate")
# Interleave the result from BlockFormatDataset
# block_length == batch_size results in batch_size records being read from a single file.
def map_fn(filenames):
'''function that maps each filename to a BlockFormatDataset'''
# reach each file using BlockFormatDataset
dataset = BlockFormatDataset(filenames)
# early prefetching can sometimes improve performance (like on GCS)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
# Shuffling before repeating ensures strong ordering.
if shuffle:
dataset = dataset.shuffle(buffer_size=record_shuffle_size)
return dataset
if interleave:
part_file_parallelism = num_threads if part_file_parallelism is None else part_file_parallelism
dataset = files.interleave(
map_fn, cycle_length=part_file_parallelism, block_length=block_length, num_parallel_calls=num_threads)
else:
dataset = files.flat_map(map_fn)
# Downsample DataRecords
dataset = downsample_dataset(dataset, keep_rate, "keep_rate")
if dataset_fn is None:
# Create a batch of datarecords and decode them
return dataset.batch(batch_size).map(parse_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE).prefetch(prefetch_size)
return dataset_fn(dataset, parse_fn, batch_size)
def cx_zk_path(path):
if path is None:
raise ValueError("Path for zookeeper dataset pointer is None. You must specify a path.")
return_path = "/".join([DEFAULT_ZOOKEEPER_BASE_ZNODE, path])
logging.info("Zookeeper path is: {}".format(return_path))
return return_path
def zookeeper_ordered_dataset(
files, parse_fn, batch_size, zk_counter_path, repeat=False,
num_threads=2, block_length=None, part_file_parallelism=None,
batch_shuffle_size=None, file_keep_rate=None, record_keep_rate=None,
prefetch_size=2, interleave=False, dataset_fn=None, verbose=False):
"""
Make a tf.Dataset given an ordered list of filenames, using Zookeeper to keep track of
which file to read, and to coordinate multiple workers.
Args:
files:
ordered list of (typically HDFS) filenames. This must remain consistent
between different workers, and between worker restarts (e.g. in the case
of instance failure or preemption).
To ensure this remains consistent, consider using the --train.files_list
option from DataRecordTrainer.
parse_fn:
A function that takes a byte tensor containing a datarecord and decodes it.
batch_size:
The batch size for each step.
zk_counter_path:
Path under the root node for the underlying zookeeper shared counter that
is used to coordinate distributed iteration over the list of files.
Full path will be `'/'.join([DEFAULT_ZOOKEEPER_BASE_ZNODE, zk_counter_path])`.
repeat:
Default False. Set True to repeat over the files forever.
num_threads:
Default 2. Number of threads working on the data in parallel.
Only used if interleave=True.
block_length:
Default None. Number of consecutive records to pull from a single part file.
If None, then block_length=batch_size will be used.
Only used if interleave=True.
part_file_parallelism:
Default None. Number of part files to read from in parallel. Once a part file is completely
read, it will be replaced by the next part file indicated by the zookeeper counter.
Only used if interleave=True.
``num_threads`` specifies a reader thread pool size, while ``part_file_parallelism`` specifies
the number of files to read from in parallel. If ``part_file_parallelism`` is greater than or
equal to ``num_threads``, the reads will be distributed over ``num_threads``. On the other hand,
if ``part_file_parallelism`` is smaller than``num_threads``, it is very likely that the reader
thread pool will be underutilized, since it can never be the case that every reader thread has
a part file to read from.
batch_shuffle_size:
Default None. Size of shuffle buffer, for shuffling that will be applied after batching.
if None, then batches will not be shuffled. Ignored if dataset_fn is provided.
file_keep_rate:
Default None. Fraction of files to keep, or None to keep all files.
record_keep_rate:
Default None. Fraction of records to keep, or None to keep all records.
prefetch_size:
Default 2. Number of parsed batches to prefetch. Ignored if dataset_fn is provided.
interleave:
Default False. Set True to use tf.data.Dataset.interleave rather than flat_map.
dataset_fn:
A function that is applied to the dataset of individual records, after
these have been read from the parts files.
If ``None`` (the default), the behavior will be as though dataset_fn were set to:
.. code-block:: python
def dataset_fn(dataset, parse_fn, batch_size):
dataset = dataset.batch(batch_size)
dataset = dataset.map(parse_fn, tf.data.experimental.AUTOTUNE)
if batch_shuffle_size:
dataset = dataset.shuffle(batch_shuffle_size)
return dataset.prefetch(prefetch_size)
verbose:
Default False. Set True to log the names of files loaded by TF.
"""
block_length = batch_size if block_length is None else block_length
part_file_parallelism = num_threads if part_file_parallelism is None else part_file_parallelism
def zk_index_generator(my_files=files):
zk = KazooClient(hosts=DEFAULT_ZOOKEEPER_HOST)
zk.start()
my_counter = zk.Counter(cx_zk_path(zk_counter_path), default=0)
while True:
my_counter += 1
counter_pre_value = my_counter.pre_value
if repeat:
counter_pre_value = counter_pre_value % len(my_files)
if counter_pre_value >= len(my_files):
break
else:
chosen_file = my_files[counter_pre_value]
if verbose:
logging.info("{}. yielding {}".format(counter_pre_value, chosen_file))
yield chosen_file
zk.stop()
files = tf.data.Dataset.from_generator(zk_index_generator, tf.string)
# Downsample parts files
files = downsample_dataset(files, file_keep_rate, "file_keep_rate")
def map_fn(filenames):
return BlockFormatDataset(filenames).prefetch(20)
# Dont interleave for sequential training
if interleave:
dataset = files.interleave(
map_fn,
cycle_length=part_file_parallelism,
block_length=block_length,
num_parallel_calls=num_threads)
else:
dataset = files.flat_map(map_fn)
# Downsample DataRecords
dataset = downsample_dataset(dataset, record_keep_rate, "record_keep_rate")
if dataset_fn is None:
# Create a batch of datarecords and decode them
dataset = dataset.batch(batch_size)
dataset = dataset.map(parse_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
# shuffle after batching and parsing for performance reasons
# faster b/c 1 random selection is made per batch rather than per record
if batch_shuffle_size:
dataset = dataset.shuffle(buffer_size=batch_shuffle_size)
dataset = dataset.prefetch(prefetch_size)
else:
dataset = dataset_fn(dataset, parse_fn, batch_size)
return dataset