diff --git a/twml/twml/tensorio.py b/twml/twml/tensorio.py index bc551ac56..020daf8e5 100644 --- a/twml/twml/tensorio.py +++ b/twml/twml/tensorio.py @@ -45,22 +45,23 @@ class _KeyRecorder(object): # convert tensorio tensor type to numpy data type. # also returns element size in bytes. def _get_data_type(data_type): - if data_type == 'Double': - return (np.float64, 8) + match data_type: + case 'Double': + return (np.float64, 8) - if data_type == 'Float': - return (np.float32, 4) + case 'Float': + return (np.float32, 4) - if data_type == 'Int': - return (np.int32, 4) + case 'Int': + return (np.int32, 4) - if data_type == 'Long': - return (np.int64, 8) + case 'Long': + return (np.int64, 8) - if data_type == 'Byte': - return (np.int8, 1) - - raise ValueError('Unexpected tensorio data type: ' + data_type) + case 'Byte': + return (np.int8, 1) + case _: + raise ValueError('Unexpected tensorio data type: ' + data_type) class TensorIO(object):