2022-06-30 14:08:54 +02:00
|
|
|
from collections import Mapping, MutableMapping
|
|
|
|
|
2020-12-07 11:13:27 +01:00
|
|
|
import numpy as np
|
2022-06-29 16:30:36 +02:00
|
|
|
import torch as ch
|
2020-12-07 11:13:27 +01:00
|
|
|
|
|
|
|
|
|
|
|
def angle_normalize(x, type="deg"):
|
|
|
|
"""
|
|
|
|
normalize angle x to [-pi,pi].
|
|
|
|
Args:
|
|
|
|
x: Angle in either degrees or radians
|
|
|
|
type: one of "deg" or "rad" for x being in degrees or radians
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
|
|
"""
|
2021-03-22 15:25:22 +01:00
|
|
|
|
|
|
|
if type not in ["deg", "rad"]: raise ValueError(f"Invalid type {type}. Choose one of 'deg' or 'rad'.")
|
|
|
|
|
2020-12-07 11:13:27 +01:00
|
|
|
if type == "deg":
|
2021-07-30 11:59:02 +02:00
|
|
|
x = np.deg2rad(x) # x * pi / 180
|
2021-03-22 15:25:22 +01:00
|
|
|
|
|
|
|
two_pi = 2 * np.pi
|
|
|
|
return x - two_pi * np.floor((x + np.pi) / two_pi)
|
2022-06-29 16:30:36 +02:00
|
|
|
|
|
|
|
|
|
|
|
def get_numpy(x: ch.Tensor):
|
2022-06-30 14:08:54 +02:00
|
|
|
"""
|
|
|
|
Returns numpy array from torch tensor
|
|
|
|
Args:
|
|
|
|
x:
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
|
|
"""
|
2022-06-29 16:30:36 +02:00
|
|
|
return x.detach().cpu().numpy()
|
2022-06-30 14:08:54 +02:00
|
|
|
|
|
|
|
|
|
|
|
def nested_update(base: MutableMapping, update):
|
|
|
|
"""
|
|
|
|
Updated method for nested Mappings
|
|
|
|
Args:
|
|
|
|
base: main Mapping to be updated
|
|
|
|
update: updated values for base Mapping
|
|
|
|
|
|
|
|
"""
|
|
|
|
for k, v in update.items():
|
|
|
|
base[k] = nested_update(base.get(k, {}), v) if isinstance(v, Mapping) else v
|