Fix bugs to create mp environments. Still conflicts with mp_pytorch_lib
This commit is contained in:
parent
d4e3b957a9
commit
2161cfd3a6
@ -428,7 +428,7 @@ for _v in _versions:
|
|||||||
kwargs_dict_bp_promp['name'] = f"alr_envs:{_v}"
|
kwargs_dict_bp_promp['name'] = f"alr_envs:{_v}"
|
||||||
register(
|
register(
|
||||||
id=_env_id,
|
id=_env_id,
|
||||||
entry_point='alr_envs.utils.make_env_helpers:make_mp_env_helper',
|
entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
|
||||||
kwargs=kwargs_dict_bp_promp
|
kwargs=kwargs_dict_bp_promp
|
||||||
)
|
)
|
||||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||||
|
@ -1 +1 @@
|
|||||||
from new_mp_wrapper import MPWrapper
|
from .mp_wrapper import MPWrapper
|
||||||
|
@ -25,8 +25,9 @@ class MPWrapper(RawInterfaceWrapper):
|
|||||||
[False] # env steps
|
[False] # env steps
|
||||||
])
|
])
|
||||||
|
|
||||||
def _episode_callback(self, action: np.ndarray) -> Tuple[np.ndarray, Union[np.ndarray, None]]:
|
# TODO: Fix this
|
||||||
if self.mp.learn_tau:
|
def _episode_callback(self, action: np.ndarray, mp) -> Tuple[np.ndarray, Union[np.ndarray, None]]:
|
||||||
|
if mp.learn_tau:
|
||||||
self.env.env.release_step = action[0] / self.env.dt # Tau value
|
self.env.env.release_step = action[0] / self.env.dt # Tau value
|
||||||
return action, None
|
return action, None
|
||||||
else:
|
else:
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
from abc import ABC
|
|
||||||
from typing import Tuple, Union
|
from typing import Tuple, Union
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
@ -80,7 +79,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
bc_time=np.zeros((1,)) if not self.replanning_schedule else self.current_traj_steps * self.dt,
|
bc_time=np.zeros((1,)) if not self.replanning_schedule else self.current_traj_steps * self.dt,
|
||||||
bc_pos=self.current_pos, bc_vel=self.current_vel)
|
bc_pos=self.current_pos, bc_vel=self.current_vel)
|
||||||
# TODO: is this correct for replanning? Do we need to adjust anything here?
|
# TODO: is this correct for replanning? Do we need to adjust anything here?
|
||||||
self.traj_gen.set_duration(None if self.learn_sub_trajectories else self.duration, np.array([self.dt]))
|
self.traj_gen.set_duration(None if self.learn_sub_trajectories else np.array([self.duration]), np.array([self.dt]))
|
||||||
traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True)
|
traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True)
|
||||||
trajectory_tensor, velocity_tensor = traj_dict['pos'], traj_dict['vel']
|
trajectory_tensor, velocity_tensor = traj_dict['pos'], traj_dict['vel']
|
||||||
|
|
||||||
@ -109,7 +108,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
""" This function generates a trajectory based on a MP and then does the usual loop over reset and step"""
|
""" This function generates a trajectory based on a MP and then does the usual loop over reset and step"""
|
||||||
|
|
||||||
# agent to learn when to release the ball
|
# agent to learn when to release the ball
|
||||||
mp_params, env_spec_params = self._episode_callback(action)
|
mp_params, env_spec_params = self.env._episode_callback(action, self.traj_gen)
|
||||||
trajectory, velocity = self.get_trajectory(mp_params)
|
trajectory, velocity = self.get_trajectory(mp_params)
|
||||||
|
|
||||||
trajectory_length = len(trajectory)
|
trajectory_length = len(trajectory)
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
from typing import Union, Tuple
|
from typing import Union, Tuple
|
||||||
|
from mp_pytorch.mp.mp_interfaces import MPInterface
|
||||||
|
from abc import abstractmethod
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from abc import abstractmethod
|
|
||||||
|
|
||||||
|
|
||||||
class RawInterfaceWrapper(gym.Wrapper):
|
class RawInterfaceWrapper(gym.Wrapper):
|
||||||
@ -56,7 +57,7 @@ class RawInterfaceWrapper(gym.Wrapper):
|
|||||||
# return bool(self.replanning_model(s))
|
# return bool(self.replanning_model(s))
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _episode_callback(self, action: np.ndarray) -> Tuple[np.ndarray, Union[np.ndarray, None]]:
|
def _episode_callback(self, action: np.ndarray, traj_gen: MPInterface) -> Tuple[np.ndarray, Union[np.ndarray, None]]:
|
||||||
"""
|
"""
|
||||||
Used to extract the parameters for the motion primitive and other parameters from an action array which might
|
Used to extract the parameters for the motion primitive and other parameters from an action array which might
|
||||||
include other actions like ball releasing time for the beer pong environment.
|
include other actions like ball releasing time for the beer pong environment.
|
||||||
|
@ -43,10 +43,11 @@ def make_rank(env_id: str, seed: int, rank: int = 0, return_callable=True, **kwa
|
|||||||
|
|
||||||
|
|
||||||
def make(env_id, seed, **kwargs):
|
def make(env_id, seed, **kwargs):
|
||||||
spec = registry.get(env_id)
|
# spec = registry.get(env_id) # TODO: This doesn't work with gym ==0.21.0
|
||||||
|
spec = registry.spec(env_id)
|
||||||
# This access is required to allow for nested dict updates
|
# This access is required to allow for nested dict updates
|
||||||
all_kwargs = deepcopy(spec._kwargs)
|
all_kwargs = deepcopy(spec._kwargs)
|
||||||
nested_update(all_kwargs, **kwargs)
|
nested_update(all_kwargs, kwargs)
|
||||||
return _make(env_id, seed, **all_kwargs)
|
return _make(env_id, seed, **all_kwargs)
|
||||||
|
|
||||||
|
|
||||||
@ -224,7 +225,7 @@ def make_bb_env_helper(**kwargs):
|
|||||||
seed = kwargs.pop("seed", None)
|
seed = kwargs.pop("seed", None)
|
||||||
wrappers = kwargs.pop("wrappers")
|
wrappers = kwargs.pop("wrappers")
|
||||||
|
|
||||||
traj_gen_kwargs = kwargs.pop("traj_gen_kwargs", {})
|
traj_gen_kwargs = kwargs.pop("trajectory_generator_kwargs", {})
|
||||||
black_box_kwargs = kwargs.pop('black_box_kwargs', {})
|
black_box_kwargs = kwargs.pop('black_box_kwargs', {})
|
||||||
contr_kwargs = kwargs.pop("controller_kwargs", {})
|
contr_kwargs = kwargs.pop("controller_kwargs", {})
|
||||||
phase_kwargs = kwargs.pop("phase_generator_kwargs", {})
|
phase_kwargs = kwargs.pop("phase_generator_kwargs", {})
|
||||||
|
Loading…
Reference in New Issue
Block a user