diff --git a/fancy_gym/black_box/factory/basis_generator_factory.py b/fancy_gym/black_box/factory/basis_generator_factory.py index 610ad20..4601953 100644 --- a/fancy_gym/black_box/factory/basis_generator_factory.py +++ b/fancy_gym/black_box/factory/basis_generator_factory.py @@ -1,5 +1,5 @@ -from mp_pytorch import PhaseGenerator, NormalizedRBFBasisGenerator, ZeroStartNormalizedRBFBasisGenerator -from mp_pytorch.basis_gn.rhytmic_basis import RhythmicBasisGenerator +from mp_pytorch.basis_gn import NormalizedRBFBasisGenerator, ZeroPaddingNormalizedRBFBasisGenerator +from mp_pytorch.phase_gn import PhaseGenerator ALL_TYPES = ["rbf", "zero_rbf", "rhythmic"] @@ -9,9 +9,10 @@ def get_basis_generator(basis_generator_type: str, phase_generator: PhaseGenerat if basis_generator_type == "rbf": return NormalizedRBFBasisGenerator(phase_generator, **kwargs) elif basis_generator_type == "zero_rbf": - return ZeroStartNormalizedRBFBasisGenerator(phase_generator, **kwargs) + return ZeroPaddingNormalizedRBFBasisGenerator(phase_generator, **kwargs) elif basis_generator_type == "rhythmic": - return RhythmicBasisGenerator(phase_generator, **kwargs) + raise NotImplementedError() + # return RhythmicBasisGenerator(phase_generator, **kwargs) else: raise ValueError(f"Specified basis generator type {basis_generator_type} not supported, " f"please choose one of {ALL_TYPES}.") diff --git a/fancy_gym/black_box/factory/phase_generator_factory.py b/fancy_gym/black_box/factory/phase_generator_factory.py index ca0dd84..67e17f8 100644 --- a/fancy_gym/black_box/factory/phase_generator_factory.py +++ b/fancy_gym/black_box/factory/phase_generator_factory.py @@ -1,6 +1,7 @@ -from mp_pytorch import LinearPhaseGenerator, ExpDecayPhaseGenerator -from mp_pytorch.phase_gn.rhythmic_phase_generator import RhythmicPhaseGenerator -from mp_pytorch.phase_gn.smooth_phase_generator import SmoothPhaseGenerator +from mp_pytorch.phase_gn import LinearPhaseGenerator, ExpDecayPhaseGenerator + +# from mp_pytorch.phase_gn.rhythmic_phase_generator import RhythmicPhaseGenerator +# from mp_pytorch.phase_gn.smooth_phase_generator import SmoothPhaseGenerator ALL_TYPES = ["linear", "exp", "rhythmic", "smooth"] @@ -12,9 +13,11 @@ def get_phase_generator(phase_generator_type, **kwargs): elif phase_generator_type == "exp": return ExpDecayPhaseGenerator(**kwargs) elif phase_generator_type == "rhythmic": - return RhythmicPhaseGenerator(**kwargs) + raise NotImplementedError() + # return RhythmicPhaseGenerator(**kwargs) elif phase_generator_type == "smooth": - return SmoothPhaseGenerator(**kwargs) + raise NotImplementedError() + # return SmoothPhaseGenerator(**kwargs) else: raise ValueError(f"Specified phase generator type {phase_generator_type} not supported, " f"please choose one of {ALL_TYPES}.") diff --git a/fancy_gym/black_box/factory/trajectory_generator_factory.py b/fancy_gym/black_box/factory/trajectory_generator_factory.py index f7ca6e2..2b93a6c 100644 --- a/fancy_gym/black_box/factory/trajectory_generator_factory.py +++ b/fancy_gym/black_box/factory/trajectory_generator_factory.py @@ -1,7 +1,5 @@ -from mp_pytorch.basis_gn.basis_generator import BasisGenerator -from mp_pytorch.mp.dmp import DMP -from mp_pytorch.mp.idmp import IDMP -from mp_pytorch.mp.promp import ProMP +from mp_pytorch.basis_gn import BasisGenerator +from mp_pytorch.mp import ProDMP, DMP, ProMP ALL_TYPES = ["promp", "dmp", "idmp"] @@ -15,7 +13,7 @@ def get_trajectory_generator( elif trajectory_generator_type == "dmp": return DMP(basis_generator, action_dim, **kwargs) elif trajectory_generator_type == 'idmp': - return IDMP(basis_generator, action_dim, **kwargs) + return ProDMP(basis_generator, action_dim, **kwargs) else: raise ValueError(f"Specified movement primitive type {trajectory_generator_type} not supported, " f"please choose one of {ALL_TYPES}.") diff --git a/fancy_gym/utils/make_env_helpers.py b/fancy_gym/utils/make_env_helpers.py index 0f832ae..b0bfe8b 100644 --- a/fancy_gym/utils/make_env_helpers.py +++ b/fancy_gym/utils/make_env_helpers.py @@ -10,15 +10,14 @@ import numpy as np from gym.envs.registration import register, registry try: - from dm_control import suite, manipulation, composer - from dm_control.rl import control + from dm_control import suite, manipulation except ImportError: pass try: import metaworld except Exception: - # catch Exception due to Mujoco-py + # catch Exception as Import error does not catch missing mujoco-py pass import fancy_gym @@ -227,7 +226,7 @@ def make_bb_env_helper(**kwargs): def make_dmc( - env_id: Union[str, composer.Environment, control.Environment], + env_id: str, seed: int = None, visualize_reward: bool = True, time_limit: Union[None, float] = None, @@ -274,7 +273,7 @@ def make_dmc( return env -def make_metaworld(env_id, seed, **kwargs): +def make_metaworld(env_id: str, seed: int, **kwargs): if env_id not in metaworld.ML1.ENV_NAMES: raise ValueError(f'Specified environment "{env_id}" not present in metaworld ML1.') diff --git a/setup.py b/setup.py index 3f67e1a..3d428e2 100644 --- a/setup.py +++ b/setup.py @@ -27,8 +27,7 @@ setup( ], extras_require=extras, install_requires=[ - 'gym>=0.24.0', - 'mujoco==2.2.0', + 'gym[mujoco]>=0.24.0', 'mp_pytorch @ git+https://github.com/ALRhub/MP_PyTorch.git@main' ], packages=[package for package in find_packages() if package.startswith("fancy_gym")],