fixed imports
This commit is contained in:
parent
855dba7fde
commit
84386fd8e4
@ -1,5 +1,5 @@
|
||||
from mp_pytorch import PhaseGenerator, NormalizedRBFBasisGenerator, ZeroStartNormalizedRBFBasisGenerator
|
||||
from mp_pytorch.basis_gn.rhytmic_basis import RhythmicBasisGenerator
|
||||
from mp_pytorch.basis_gn import NormalizedRBFBasisGenerator, ZeroPaddingNormalizedRBFBasisGenerator
|
||||
from mp_pytorch.phase_gn import PhaseGenerator
|
||||
|
||||
ALL_TYPES = ["rbf", "zero_rbf", "rhythmic"]
|
||||
|
||||
@ -9,9 +9,10 @@ def get_basis_generator(basis_generator_type: str, phase_generator: PhaseGenerat
|
||||
if basis_generator_type == "rbf":
|
||||
return NormalizedRBFBasisGenerator(phase_generator, **kwargs)
|
||||
elif basis_generator_type == "zero_rbf":
|
||||
return ZeroStartNormalizedRBFBasisGenerator(phase_generator, **kwargs)
|
||||
return ZeroPaddingNormalizedRBFBasisGenerator(phase_generator, **kwargs)
|
||||
elif basis_generator_type == "rhythmic":
|
||||
return RhythmicBasisGenerator(phase_generator, **kwargs)
|
||||
raise NotImplementedError()
|
||||
# return RhythmicBasisGenerator(phase_generator, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"Specified basis generator type {basis_generator_type} not supported, "
|
||||
f"please choose one of {ALL_TYPES}.")
|
||||
|
@ -1,6 +1,7 @@
|
||||
from mp_pytorch import LinearPhaseGenerator, ExpDecayPhaseGenerator
|
||||
from mp_pytorch.phase_gn.rhythmic_phase_generator import RhythmicPhaseGenerator
|
||||
from mp_pytorch.phase_gn.smooth_phase_generator import SmoothPhaseGenerator
|
||||
from mp_pytorch.phase_gn import LinearPhaseGenerator, ExpDecayPhaseGenerator
|
||||
|
||||
# from mp_pytorch.phase_gn.rhythmic_phase_generator import RhythmicPhaseGenerator
|
||||
# from mp_pytorch.phase_gn.smooth_phase_generator import SmoothPhaseGenerator
|
||||
|
||||
ALL_TYPES = ["linear", "exp", "rhythmic", "smooth"]
|
||||
|
||||
@ -12,9 +13,11 @@ def get_phase_generator(phase_generator_type, **kwargs):
|
||||
elif phase_generator_type == "exp":
|
||||
return ExpDecayPhaseGenerator(**kwargs)
|
||||
elif phase_generator_type == "rhythmic":
|
||||
return RhythmicPhaseGenerator(**kwargs)
|
||||
raise NotImplementedError()
|
||||
# return RhythmicPhaseGenerator(**kwargs)
|
||||
elif phase_generator_type == "smooth":
|
||||
return SmoothPhaseGenerator(**kwargs)
|
||||
raise NotImplementedError()
|
||||
# return SmoothPhaseGenerator(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Specified phase generator type {phase_generator_type} not supported, "
|
||||
f"please choose one of {ALL_TYPES}.")
|
||||
|
@ -1,7 +1,5 @@
|
||||
from mp_pytorch.basis_gn.basis_generator import BasisGenerator
|
||||
from mp_pytorch.mp.dmp import DMP
|
||||
from mp_pytorch.mp.idmp import IDMP
|
||||
from mp_pytorch.mp.promp import ProMP
|
||||
from mp_pytorch.basis_gn import BasisGenerator
|
||||
from mp_pytorch.mp import ProDMP, DMP, ProMP
|
||||
|
||||
ALL_TYPES = ["promp", "dmp", "idmp"]
|
||||
|
||||
@ -15,7 +13,7 @@ def get_trajectory_generator(
|
||||
elif trajectory_generator_type == "dmp":
|
||||
return DMP(basis_generator, action_dim, **kwargs)
|
||||
elif trajectory_generator_type == 'idmp':
|
||||
return IDMP(basis_generator, action_dim, **kwargs)
|
||||
return ProDMP(basis_generator, action_dim, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"Specified movement primitive type {trajectory_generator_type} not supported, "
|
||||
f"please choose one of {ALL_TYPES}.")
|
||||
|
@ -10,15 +10,14 @@ import numpy as np
|
||||
from gym.envs.registration import register, registry
|
||||
|
||||
try:
|
||||
from dm_control import suite, manipulation, composer
|
||||
from dm_control.rl import control
|
||||
from dm_control import suite, manipulation
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
import metaworld
|
||||
except Exception:
|
||||
# catch Exception due to Mujoco-py
|
||||
# catch Exception as Import error does not catch missing mujoco-py
|
||||
pass
|
||||
|
||||
import fancy_gym
|
||||
@ -227,7 +226,7 @@ def make_bb_env_helper(**kwargs):
|
||||
|
||||
|
||||
def make_dmc(
|
||||
env_id: Union[str, composer.Environment, control.Environment],
|
||||
env_id: str,
|
||||
seed: int = None,
|
||||
visualize_reward: bool = True,
|
||||
time_limit: Union[None, float] = None,
|
||||
@ -274,7 +273,7 @@ def make_dmc(
|
||||
return env
|
||||
|
||||
|
||||
def make_metaworld(env_id, seed, **kwargs):
|
||||
def make_metaworld(env_id: str, seed: int, **kwargs):
|
||||
if env_id not in metaworld.ML1.ENV_NAMES:
|
||||
raise ValueError(f'Specified environment "{env_id}" not present in metaworld ML1.')
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user