From b7e005427c6f2732172dd7f1d1b00de1f0b14c12 Mon Sep 17 00:00:00 2001 From: sainishwanth Date: Sun, 2 Apr 2023 10:00:45 +0530 Subject: [PATCH] twml - Changed If statements with match casing in python --- twml/twml/tensorio.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/twml/twml/tensorio.py b/twml/twml/tensorio.py index bc551ac56..020daf8e5 100644 --- a/twml/twml/tensorio.py +++ b/twml/twml/tensorio.py @@ -45,22 +45,23 @@ class _KeyRecorder(object): # convert tensorio tensor type to numpy data type. # also returns element size in bytes. def _get_data_type(data_type): - if data_type == 'Double': - return (np.float64, 8) + match data_type: + case 'Double': + return (np.float64, 8) - if data_type == 'Float': - return (np.float32, 4) + case 'Float': + return (np.float32, 4) - if data_type == 'Int': - return (np.int32, 4) + case 'Int': + return (np.int32, 4) - if data_type == 'Long': - return (np.int64, 8) + case 'Long': + return (np.int64, 8) - if data_type == 'Byte': - return (np.int8, 1) - - raise ValueError('Unexpected tensorio data type: ' + data_type) + case 'Byte': + return (np.int8, 1) + case _: + raise ValueError('Unexpected tensorio data type: ' + data_type) class TensorIO(object):