fixed imports
This commit is contained in:
parent
855dba7fde
commit
84386fd8e4
@ -1,5 +1,5 @@
|
|||||||
from mp_pytorch import PhaseGenerator, NormalizedRBFBasisGenerator, ZeroStartNormalizedRBFBasisGenerator
|
from mp_pytorch.basis_gn import NormalizedRBFBasisGenerator, ZeroPaddingNormalizedRBFBasisGenerator
|
||||||
from mp_pytorch.basis_gn.rhytmic_basis import RhythmicBasisGenerator
|
from mp_pytorch.phase_gn import PhaseGenerator
|
||||||
|
|
||||||
ALL_TYPES = ["rbf", "zero_rbf", "rhythmic"]
|
ALL_TYPES = ["rbf", "zero_rbf", "rhythmic"]
|
||||||
|
|
||||||
@ -9,9 +9,10 @@ def get_basis_generator(basis_generator_type: str, phase_generator: PhaseGenerat
|
|||||||
if basis_generator_type == "rbf":
|
if basis_generator_type == "rbf":
|
||||||
return NormalizedRBFBasisGenerator(phase_generator, **kwargs)
|
return NormalizedRBFBasisGenerator(phase_generator, **kwargs)
|
||||||
elif basis_generator_type == "zero_rbf":
|
elif basis_generator_type == "zero_rbf":
|
||||||
return ZeroStartNormalizedRBFBasisGenerator(phase_generator, **kwargs)
|
return ZeroPaddingNormalizedRBFBasisGenerator(phase_generator, **kwargs)
|
||||||
elif basis_generator_type == "rhythmic":
|
elif basis_generator_type == "rhythmic":
|
||||||
return RhythmicBasisGenerator(phase_generator, **kwargs)
|
raise NotImplementedError()
|
||||||
|
# return RhythmicBasisGenerator(phase_generator, **kwargs)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Specified basis generator type {basis_generator_type} not supported, "
|
raise ValueError(f"Specified basis generator type {basis_generator_type} not supported, "
|
||||||
f"please choose one of {ALL_TYPES}.")
|
f"please choose one of {ALL_TYPES}.")
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from mp_pytorch import LinearPhaseGenerator, ExpDecayPhaseGenerator
|
from mp_pytorch.phase_gn import LinearPhaseGenerator, ExpDecayPhaseGenerator
|
||||||
from mp_pytorch.phase_gn.rhythmic_phase_generator import RhythmicPhaseGenerator
|
|
||||||
from mp_pytorch.phase_gn.smooth_phase_generator import SmoothPhaseGenerator
|
# from mp_pytorch.phase_gn.rhythmic_phase_generator import RhythmicPhaseGenerator
|
||||||
|
# from mp_pytorch.phase_gn.smooth_phase_generator import SmoothPhaseGenerator
|
||||||
|
|
||||||
ALL_TYPES = ["linear", "exp", "rhythmic", "smooth"]
|
ALL_TYPES = ["linear", "exp", "rhythmic", "smooth"]
|
||||||
|
|
||||||
@ -12,9 +13,11 @@ def get_phase_generator(phase_generator_type, **kwargs):
|
|||||||
elif phase_generator_type == "exp":
|
elif phase_generator_type == "exp":
|
||||||
return ExpDecayPhaseGenerator(**kwargs)
|
return ExpDecayPhaseGenerator(**kwargs)
|
||||||
elif phase_generator_type == "rhythmic":
|
elif phase_generator_type == "rhythmic":
|
||||||
return RhythmicPhaseGenerator(**kwargs)
|
raise NotImplementedError()
|
||||||
|
# return RhythmicPhaseGenerator(**kwargs)
|
||||||
elif phase_generator_type == "smooth":
|
elif phase_generator_type == "smooth":
|
||||||
return SmoothPhaseGenerator(**kwargs)
|
raise NotImplementedError()
|
||||||
|
# return SmoothPhaseGenerator(**kwargs)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Specified phase generator type {phase_generator_type} not supported, "
|
raise ValueError(f"Specified phase generator type {phase_generator_type} not supported, "
|
||||||
f"please choose one of {ALL_TYPES}.")
|
f"please choose one of {ALL_TYPES}.")
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
from mp_pytorch.basis_gn.basis_generator import BasisGenerator
|
from mp_pytorch.basis_gn import BasisGenerator
|
||||||
from mp_pytorch.mp.dmp import DMP
|
from mp_pytorch.mp import ProDMP, DMP, ProMP
|
||||||
from mp_pytorch.mp.idmp import IDMP
|
|
||||||
from mp_pytorch.mp.promp import ProMP
|
|
||||||
|
|
||||||
ALL_TYPES = ["promp", "dmp", "idmp"]
|
ALL_TYPES = ["promp", "dmp", "idmp"]
|
||||||
|
|
||||||
@ -15,7 +13,7 @@ def get_trajectory_generator(
|
|||||||
elif trajectory_generator_type == "dmp":
|
elif trajectory_generator_type == "dmp":
|
||||||
return DMP(basis_generator, action_dim, **kwargs)
|
return DMP(basis_generator, action_dim, **kwargs)
|
||||||
elif trajectory_generator_type == 'idmp':
|
elif trajectory_generator_type == 'idmp':
|
||||||
return IDMP(basis_generator, action_dim, **kwargs)
|
return ProDMP(basis_generator, action_dim, **kwargs)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Specified movement primitive type {trajectory_generator_type} not supported, "
|
raise ValueError(f"Specified movement primitive type {trajectory_generator_type} not supported, "
|
||||||
f"please choose one of {ALL_TYPES}.")
|
f"please choose one of {ALL_TYPES}.")
|
||||||
|
@ -10,15 +10,14 @@ import numpy as np
|
|||||||
from gym.envs.registration import register, registry
|
from gym.envs.registration import register, registry
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from dm_control import suite, manipulation, composer
|
from dm_control import suite, manipulation
|
||||||
from dm_control.rl import control
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import metaworld
|
import metaworld
|
||||||
except Exception:
|
except Exception:
|
||||||
# catch Exception due to Mujoco-py
|
# catch Exception as Import error does not catch missing mujoco-py
|
||||||
pass
|
pass
|
||||||
|
|
||||||
import fancy_gym
|
import fancy_gym
|
||||||
@ -227,7 +226,7 @@ def make_bb_env_helper(**kwargs):
|
|||||||
|
|
||||||
|
|
||||||
def make_dmc(
|
def make_dmc(
|
||||||
env_id: Union[str, composer.Environment, control.Environment],
|
env_id: str,
|
||||||
seed: int = None,
|
seed: int = None,
|
||||||
visualize_reward: bool = True,
|
visualize_reward: bool = True,
|
||||||
time_limit: Union[None, float] = None,
|
time_limit: Union[None, float] = None,
|
||||||
@ -274,7 +273,7 @@ def make_dmc(
|
|||||||
return env
|
return env
|
||||||
|
|
||||||
|
|
||||||
def make_metaworld(env_id, seed, **kwargs):
|
def make_metaworld(env_id: str, seed: int, **kwargs):
|
||||||
if env_id not in metaworld.ML1.ENV_NAMES:
|
if env_id not in metaworld.ML1.ENV_NAMES:
|
||||||
raise ValueError(f'Specified environment "{env_id}" not present in metaworld ML1.')
|
raise ValueError(f'Specified environment "{env_id}" not present in metaworld ML1.')
|
||||||
|
|
||||||
|
3
setup.py
3
setup.py
@ -27,8 +27,7 @@ setup(
|
|||||||
],
|
],
|
||||||
extras_require=extras,
|
extras_require=extras,
|
||||||
install_requires=[
|
install_requires=[
|
||||||
'gym>=0.24.0',
|
'gym[mujoco]>=0.24.0',
|
||||||
'mujoco==2.2.0',
|
|
||||||
'mp_pytorch @ git+https://github.com/ALRhub/MP_PyTorch.git@main'
|
'mp_pytorch @ git+https://github.com/ALRhub/MP_PyTorch.git@main'
|
||||||
],
|
],
|
||||||
packages=[package for package in find_packages() if package.startswith("fancy_gym")],
|
packages=[package for package in find_packages() if package.startswith("fancy_gym")],
|
||||||
|
Loading…
Reference in New Issue
Block a user