the-algorithm/twml/twml/tensorio.py

162 lines
4.8 KiB
Python

# pylint: disable=missing-docstring, bare-except, pointless-statement,
# pointless-string-statement, redundant-unittest-assert, no-else-return,
# no-member, old-style-class, dangerous-default-value, protected-access,
# too-few-public-methods
import os
import numpy as np
import yaml
"""
Utility to load tensors serialized by Deepbird V1.
Note that Deepbird V1 serialize tensor names as \"weight\".\'1\'.
For user-friendliness, the quotes are removed from the tensor names.
"""
# helper class used to assist hierarchical key access by remembering intermediate keys.
class _KeyRecorder(object):
def __init__(self, tensorio, keys=[]):
self.tensorio = tensorio
self.keys = keys
def __getitem__(self, k):
new_keys = self.keys + [str(k)]
prefix = ".".join(new_keys)
key_list = self.tensorio.list_tensors()
# if we have a complete key, load the tensor.
if prefix in key_list:
return self.tensorio._load(prefix)
# we don't have a complete key yet, but at least one tensor should start with this prefix.
for k_value in key_list:
if k_value.startswith(prefix):
return _KeyRecorder(self.tensorio, new_keys)
# if no key starts with the prefix, this _key_recorder is not valid.
raise ValueError("Key not found: " + prefix)
# convert tensorio tensor type to numpy data type.
# also returns element size in bytes.
def _get_data_type(data_type):
if data_type == 'Double':
return (np.float64, 8)
if data_type == 'Float':
return (np.float32, 4)
if data_type == 'Int':
return (np.int32, 4)
if data_type == 'Long':
return (np.int64, 8)
if data_type == 'Byte':
return (np.int8, 1)
raise ValueError('Unexpected tensorio data type: ' + data_type)
class TensorIO(object):
"""
Construct a TensorIO class.
tensorio_path: a directory containing tensors serialized using tensorio. tar file not supported.
mmap_tensor:
By default, loaded tensors use mmap storage.
Set this to false to not use mmap. Useful when loading multiple tensors.
"""
def __init__(self, tensorio_path, mmap_tensor=True):
self._tensorio_path = tensorio_path
self._mmap_tensor = mmap_tensor
# Make sure we can locate spec.yaml.
yaml_file = os.path.join(tensorio_path, 'spec.yaml')
if not os.path.exists(yaml_file):
raise ValueError('Invalid tensorio path: no spec.yaml found.')
# load spec.yaml.
with open(yaml_file, 'r') as file_open:
# Note that tensor names in the yaml are like this: \"weight\".\'1\'
# For user-friendliness, we remove the quotes.
_spec = yaml.safe_load(file_open)
self._spec = {k.replace("'", '').replace('"', ''): v for (k, v) in _spec.items()}
def list_tensors(self):
"""
Returns a list of tensors saved in the given path.
"""
return self._spec.keys()
def _load_tensor(self, name):
"""
Load Tensor with the given name.
Raise value error if the named tensor is not found.
Returns a numpy array if the named tensor is found.
"""
tensor_info = self._spec[name]
if tensor_info['type'] != 'tensor':
raise ValueError('Trying to load a tensor of unknown type: ' + tensor_info['type'])
filename = os.path.join(self._tensorio_path, tensor_info['filename'])
(data_type, element_size) = _get_data_type(tensor_info['tensorType'])
np_array = np.memmap(
filename,
dtype=data_type,
mode='r',
# -1 because lua offset is 1 based.
offset=(tensor_info['offset'] - 1) * element_size,
shape=tuple(tensor_info['size']),
order='C',
)
return np_array if self._mmap_tensor else np_array[:].copy()
def _load_nontensor_data(self, name):
"""
Load non-tensor data with the given name.
Returns a python string.
"""
tensor_info = self._spec[name]
return tensor_info['data']
def _load(self, name):
"""
Load data serialized under the given name, it could be a tensor or regular data.
"""
if name not in self._spec:
raise ValueError('The specified key {} is not found in {}'.format(name, self._tensorio_path))
data_type = self._spec[name]['type']
if data_type == 'tensor':
return self._load_tensor(name)
else:
return self._load_nontensor_data(name)
def load_all(self):
"""
Load all tensors stored in the tensorio directory.
Returns a dictionary from tensor name to numpy arrays.
"""
return {k: self._load(k) for k in self._spec}
###########################################
# The below are utilities for convenience #
###########################################
def __getitem__(self, k):
"""
Shorthand for _load_tensor, but also supports hierarchical access like: tensorio['a']['b']['1']
"""
if k in self._spec:
# We have a full tensor name, directly load it.
return self._load_tensor(k)
else:
return _KeyRecorder(self)[k]