twml - Changed If statements with match casing in python
This commit is contained in:
parent
ec83d01dca
commit
b7e005427c
|
@ -45,21 +45,22 @@ class _KeyRecorder(object):
|
|||
# convert tensorio tensor type to numpy data type.
|
||||
# also returns element size in bytes.
|
||||
def _get_data_type(data_type):
|
||||
if data_type == 'Double':
|
||||
match data_type:
|
||||
case 'Double':
|
||||
return (np.float64, 8)
|
||||
|
||||
if data_type == 'Float':
|
||||
case 'Float':
|
||||
return (np.float32, 4)
|
||||
|
||||
if data_type == 'Int':
|
||||
case 'Int':
|
||||
return (np.int32, 4)
|
||||
|
||||
if data_type == 'Long':
|
||||
case 'Long':
|
||||
return (np.int64, 8)
|
||||
|
||||
if data_type == 'Byte':
|
||||
case 'Byte':
|
||||
return (np.int8, 1)
|
||||
|
||||
case _:
|
||||
raise ValueError('Unexpected tensorio data type: ' + data_type)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue