Merge remote-tracking branch 'origin/clean_api' into clean_api

This commit is contained in:
Onur 2022-07-19 12:17:19 +02:00
commit 5590318329
7 changed files with 29 additions and 27 deletions

View File

@ -50,7 +50,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
self.tracking_controller = tracking_controller self.tracking_controller = tracking_controller
# self.time_steps = np.linspace(0, self.duration, self.traj_steps) # self.time_steps = np.linspace(0, self.duration, self.traj_steps)
# self.traj_gen.set_mp_times(self.time_steps) # self.traj_gen.set_mp_times(self.time_steps)
self.traj_gen.set_duration(np.array([self.duration]), np.array([self.dt])) self.traj_gen.set_duration(self.duration - self.dt, self.dt)
# reward computation # reward computation
self.reward_aggregation = reward_aggregation self.reward_aggregation = reward_aggregation
@ -78,8 +78,8 @@ class BlackBoxWrapper(gym.ObservationWrapper):
self.traj_gen.set_boundary_conditions( self.traj_gen.set_boundary_conditions(
bc_time=np.array(0) if not self.do_replanning else np.array([self.current_traj_steps * self.dt]), bc_time=np.array(0) if not self.do_replanning else np.array([self.current_traj_steps * self.dt]),
bc_pos=self.current_pos, bc_vel=self.current_vel) bc_pos=self.current_pos, bc_vel=self.current_vel)
self.traj_gen.set_duration(None if self.learn_sub_trajectories else np.array([self.duration]), # TODO remove the - self.dt after Bruces fix.
np.array([self.dt])) self.traj_gen.set_duration(None if self.learn_sub_trajectories else self.duration - self.dt, self.dt)
traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True) traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True)
trajectory_tensor, velocity_tensor = traj_dict['pos'], traj_dict['vel'] trajectory_tensor, velocity_tensor = traj_dict['pos'], traj_dict['vel']
@ -87,7 +87,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
def _get_traj_gen_action_space(self): def _get_traj_gen_action_space(self):
"""This function can be used to set up an individual space for the parameters of the traj_gen.""" """This function can be used to set up an individual space for the parameters of the traj_gen."""
min_action_bounds, max_action_bounds = self.traj_gen.get_param_bounds() min_action_bounds, max_action_bounds = self.traj_gen.get_params_bounds().t()
action_space = gym.spaces.Box(low=min_action_bounds.numpy(), high=max_action_bounds.numpy(), action_space = gym.spaces.Box(low=min_action_bounds.numpy(), high=max_action_bounds.numpy(),
dtype=self.env.action_space.dtype) dtype=self.env.action_space.dtype)
return action_space return action_space

View File

@ -1,5 +1,5 @@
from mp_pytorch import PhaseGenerator, NormalizedRBFBasisGenerator, ZeroStartNormalizedRBFBasisGenerator from mp_pytorch.basis_gn import NormalizedRBFBasisGenerator, ZeroPaddingNormalizedRBFBasisGenerator
from mp_pytorch.basis_gn.rhytmic_basis import RhythmicBasisGenerator from mp_pytorch.phase_gn import PhaseGenerator
ALL_TYPES = ["rbf", "zero_rbf", "rhythmic"] ALL_TYPES = ["rbf", "zero_rbf", "rhythmic"]
@ -9,9 +9,10 @@ def get_basis_generator(basis_generator_type: str, phase_generator: PhaseGenerat
if basis_generator_type == "rbf": if basis_generator_type == "rbf":
return NormalizedRBFBasisGenerator(phase_generator, **kwargs) return NormalizedRBFBasisGenerator(phase_generator, **kwargs)
elif basis_generator_type == "zero_rbf": elif basis_generator_type == "zero_rbf":
return ZeroStartNormalizedRBFBasisGenerator(phase_generator, **kwargs) return ZeroPaddingNormalizedRBFBasisGenerator(phase_generator, **kwargs)
elif basis_generator_type == "rhythmic": elif basis_generator_type == "rhythmic":
return RhythmicBasisGenerator(phase_generator, **kwargs) raise NotImplementedError()
# return RhythmicBasisGenerator(phase_generator, **kwargs)
else: else:
raise ValueError(f"Specified basis generator type {basis_generator_type} not supported, " raise ValueError(f"Specified basis generator type {basis_generator_type} not supported, "
f"please choose one of {ALL_TYPES}.") f"please choose one of {ALL_TYPES}.")

View File

@ -1,6 +1,7 @@
from mp_pytorch import LinearPhaseGenerator, ExpDecayPhaseGenerator from mp_pytorch.phase_gn import LinearPhaseGenerator, ExpDecayPhaseGenerator
from mp_pytorch.phase_gn.rhythmic_phase_generator import RhythmicPhaseGenerator
from mp_pytorch.phase_gn.smooth_phase_generator import SmoothPhaseGenerator # from mp_pytorch.phase_gn.rhythmic_phase_generator import RhythmicPhaseGenerator
# from mp_pytorch.phase_gn.smooth_phase_generator import SmoothPhaseGenerator
ALL_TYPES = ["linear", "exp", "rhythmic", "smooth"] ALL_TYPES = ["linear", "exp", "rhythmic", "smooth"]
@ -12,9 +13,11 @@ def get_phase_generator(phase_generator_type, **kwargs):
elif phase_generator_type == "exp": elif phase_generator_type == "exp":
return ExpDecayPhaseGenerator(**kwargs) return ExpDecayPhaseGenerator(**kwargs)
elif phase_generator_type == "rhythmic": elif phase_generator_type == "rhythmic":
return RhythmicPhaseGenerator(**kwargs) raise NotImplementedError()
# return RhythmicPhaseGenerator(**kwargs)
elif phase_generator_type == "smooth": elif phase_generator_type == "smooth":
return SmoothPhaseGenerator(**kwargs) raise NotImplementedError()
# return SmoothPhaseGenerator(**kwargs)
else: else:
raise ValueError(f"Specified phase generator type {phase_generator_type} not supported, " raise ValueError(f"Specified phase generator type {phase_generator_type} not supported, "
f"please choose one of {ALL_TYPES}.") f"please choose one of {ALL_TYPES}.")

View File

@ -1,7 +1,5 @@
from mp_pytorch.basis_gn.basis_generator import BasisGenerator from mp_pytorch.basis_gn import BasisGenerator
from mp_pytorch.mp.dmp import DMP from mp_pytorch.mp import ProDMP, DMP, ProMP
from mp_pytorch.mp.idmp import IDMP
from mp_pytorch.mp.promp import ProMP
ALL_TYPES = ["promp", "dmp", "idmp"] ALL_TYPES = ["promp", "dmp", "idmp"]
@ -14,8 +12,10 @@ def get_trajectory_generator(
return ProMP(basis_generator, action_dim, **kwargs) return ProMP(basis_generator, action_dim, **kwargs)
elif trajectory_generator_type == "dmp": elif trajectory_generator_type == "dmp":
return DMP(basis_generator, action_dim, **kwargs) return DMP(basis_generator, action_dim, **kwargs)
elif trajectory_generator_type == 'idmp': elif trajectory_generator_type == 'prodmp':
return IDMP(basis_generator, action_dim, **kwargs) from mp_pytorch.basis_gn import ProDMPBasisGenerator
assert isinstance(basis_generator, ProDMPBasisGenerator)
return ProDMP(basis_generator, action_dim, **kwargs)
else: else:
raise ValueError(f"Specified movement primitive type {trajectory_generator_type} not supported, " raise ValueError(f"Specified movement primitive type {trajectory_generator_type} not supported, "
f"please choose one of {ALL_TYPES}.") f"please choose one of {ALL_TYPES}.")

View File

@ -126,7 +126,7 @@ for _dims in [5, 7]:
register( register(
id=f'Reacher{_dims}dSparse-v0', id=f'Reacher{_dims}dSparse-v0',
entry_point='fancy_gym.envs.mujoco:ReacherEnv', entry_point='fancy_gym.envs.mujoco:ReacherEnv',
max_episode_steps=MAX_EPISODE_STEPS_REACHER, max_episode_steps=200,
kwargs={ kwargs={
"sparse": True, "sparse": True,
'reward_weight': 200, 'reward_weight': 200,

View File

@ -10,15 +10,14 @@ import numpy as np
from gym.envs.registration import register, registry from gym.envs.registration import register, registry
try: try:
from dm_control import suite, manipulation, composer from dm_control import suite, manipulation
from dm_control.rl import control
except ImportError: except ImportError:
pass pass
try: try:
import metaworld import metaworld
except Exception: except Exception:
# catch Exception due to Mujoco-py # catch Exception as Import error does not catch missing mujoco-py
pass pass
import fancy_gym import fancy_gym
@ -227,7 +226,7 @@ def make_bb_env_helper(**kwargs):
def make_dmc( def make_dmc(
env_id: Union[str, composer.Environment, control.Environment], env_id: str,
seed: int = None, seed: int = None,
visualize_reward: bool = True, visualize_reward: bool = True,
time_limit: Union[None, float] = None, time_limit: Union[None, float] = None,
@ -274,7 +273,7 @@ def make_dmc(
return env return env
def make_metaworld(env_id, seed, **kwargs): def make_metaworld(env_id: str, seed: int, **kwargs):
if env_id not in metaworld.ML1.ENV_NAMES: if env_id not in metaworld.ML1.ENV_NAMES:
raise ValueError(f'Specified environment "{env_id}" not present in metaworld ML1.') raise ValueError(f'Specified environment "{env_id}" not present in metaworld ML1.')

View File

@ -27,8 +27,7 @@ setup(
], ],
extras_require=extras, extras_require=extras,
install_requires=[ install_requires=[
'gym>=0.24.0', 'gym[mujoco]>=0.24.0',
'mujoco==2.2.0',
'mp_pytorch @ git+https://github.com/ALRhub/MP_PyTorch.git@main' 'mp_pytorch @ git+https://github.com/ALRhub/MP_PyTorch.git@main'
], ],
packages=[package for package in find_packages() if package.startswith("fancy_gym")], packages=[package for package in find_packages() if package.startswith("fancy_gym")],