Fix bugs to create mp environments. Still conflicts with mp_pytorch_lib

This commit is contained in:
Onur 2022-07-01 11:42:42 +02:00
parent d4e3b957a9
commit 2161cfd3a6
6 changed files with 14 additions and 12 deletions

View File

@ -428,7 +428,7 @@ for _v in _versions:
kwargs_dict_bp_promp['name'] = f"alr_envs:{_v}" kwargs_dict_bp_promp['name'] = f"alr_envs:{_v}"
register( register(
id=_env_id, id=_env_id,
entry_point='alr_envs.utils.make_env_helpers:make_mp_env_helper', entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
kwargs=kwargs_dict_bp_promp kwargs=kwargs_dict_bp_promp
) )
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)

View File

@ -1 +1 @@
from new_mp_wrapper import MPWrapper from .mp_wrapper import MPWrapper

View File

@ -25,8 +25,9 @@ class MPWrapper(RawInterfaceWrapper):
[False] # env steps [False] # env steps
]) ])
def _episode_callback(self, action: np.ndarray) -> Tuple[np.ndarray, Union[np.ndarray, None]]: # TODO: Fix this
if self.mp.learn_tau: def _episode_callback(self, action: np.ndarray, mp) -> Tuple[np.ndarray, Union[np.ndarray, None]]:
if mp.learn_tau:
self.env.env.release_step = action[0] / self.env.dt # Tau value self.env.env.release_step = action[0] / self.env.dt # Tau value
return action, None return action, None
else: else:

View File

@ -1,4 +1,3 @@
from abc import ABC
from typing import Tuple, Union from typing import Tuple, Union
import gym import gym
@ -80,7 +79,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
bc_time=np.zeros((1,)) if not self.replanning_schedule else self.current_traj_steps * self.dt, bc_time=np.zeros((1,)) if not self.replanning_schedule else self.current_traj_steps * self.dt,
bc_pos=self.current_pos, bc_vel=self.current_vel) bc_pos=self.current_pos, bc_vel=self.current_vel)
# TODO: is this correct for replanning? Do we need to adjust anything here? # TODO: is this correct for replanning? Do we need to adjust anything here?
self.traj_gen.set_duration(None if self.learn_sub_trajectories else self.duration, np.array([self.dt])) self.traj_gen.set_duration(None if self.learn_sub_trajectories else np.array([self.duration]), np.array([self.dt]))
traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True) traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True)
trajectory_tensor, velocity_tensor = traj_dict['pos'], traj_dict['vel'] trajectory_tensor, velocity_tensor = traj_dict['pos'], traj_dict['vel']
@ -109,7 +108,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
""" This function generates a trajectory based on a MP and then does the usual loop over reset and step""" """ This function generates a trajectory based on a MP and then does the usual loop over reset and step"""
# agent to learn when to release the ball # agent to learn when to release the ball
mp_params, env_spec_params = self._episode_callback(action) mp_params, env_spec_params = self.env._episode_callback(action, self.traj_gen)
trajectory, velocity = self.get_trajectory(mp_params) trajectory, velocity = self.get_trajectory(mp_params)
trajectory_length = len(trajectory) trajectory_length = len(trajectory)

View File

@ -1,8 +1,9 @@
from typing import Union, Tuple from typing import Union, Tuple
from mp_pytorch.mp.mp_interfaces import MPInterface
from abc import abstractmethod
import gym import gym
import numpy as np import numpy as np
from abc import abstractmethod
class RawInterfaceWrapper(gym.Wrapper): class RawInterfaceWrapper(gym.Wrapper):
@ -56,7 +57,7 @@ class RawInterfaceWrapper(gym.Wrapper):
# return bool(self.replanning_model(s)) # return bool(self.replanning_model(s))
return False return False
def _episode_callback(self, action: np.ndarray) -> Tuple[np.ndarray, Union[np.ndarray, None]]: def _episode_callback(self, action: np.ndarray, traj_gen: MPInterface) -> Tuple[np.ndarray, Union[np.ndarray, None]]:
""" """
Used to extract the parameters for the motion primitive and other parameters from an action array which might Used to extract the parameters for the motion primitive and other parameters from an action array which might
include other actions like ball releasing time for the beer pong environment. include other actions like ball releasing time for the beer pong environment.

View File

@ -43,10 +43,11 @@ def make_rank(env_id: str, seed: int, rank: int = 0, return_callable=True, **kwa
def make(env_id, seed, **kwargs): def make(env_id, seed, **kwargs):
spec = registry.get(env_id) # spec = registry.get(env_id) # TODO: This doesn't work with gym ==0.21.0
spec = registry.spec(env_id)
# This access is required to allow for nested dict updates # This access is required to allow for nested dict updates
all_kwargs = deepcopy(spec._kwargs) all_kwargs = deepcopy(spec._kwargs)
nested_update(all_kwargs, **kwargs) nested_update(all_kwargs, kwargs)
return _make(env_id, seed, **all_kwargs) return _make(env_id, seed, **all_kwargs)
@ -224,7 +225,7 @@ def make_bb_env_helper(**kwargs):
seed = kwargs.pop("seed", None) seed = kwargs.pop("seed", None)
wrappers = kwargs.pop("wrappers") wrappers = kwargs.pop("wrappers")
traj_gen_kwargs = kwargs.pop("traj_gen_kwargs", {}) traj_gen_kwargs = kwargs.pop("trajectory_generator_kwargs", {})
black_box_kwargs = kwargs.pop('black_box_kwargs', {}) black_box_kwargs = kwargs.pop('black_box_kwargs', {})
contr_kwargs = kwargs.pop("controller_kwargs", {}) contr_kwargs = kwargs.pop("controller_kwargs", {})
phase_kwargs = kwargs.pop("phase_generator_kwargs", {}) phase_kwargs = kwargs.pop("phase_generator_kwargs", {})