Fix bugs to create mp environments. Still conflicts with mp_pytorch_lib

This commit is contained in:
Onur 2022-07-01 11:42:42 +02:00
parent d4e3b957a9
commit 2161cfd3a6
6 changed files with 14 additions and 12 deletions

View File

@ -428,7 +428,7 @@ for _v in _versions:
kwargs_dict_bp_promp['name'] = f"alr_envs:{_v}"
register(
id=_env_id,
entry_point='alr_envs.utils.make_env_helpers:make_mp_env_helper',
entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
kwargs=kwargs_dict_bp_promp
)
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)

View File

@ -1 +1 @@
from new_mp_wrapper import MPWrapper
from .mp_wrapper import MPWrapper

View File

@ -25,8 +25,9 @@ class MPWrapper(RawInterfaceWrapper):
[False] # env steps
])
def _episode_callback(self, action: np.ndarray) -> Tuple[np.ndarray, Union[np.ndarray, None]]:
if self.mp.learn_tau:
# TODO: Fix this
def _episode_callback(self, action: np.ndarray, mp) -> Tuple[np.ndarray, Union[np.ndarray, None]]:
if mp.learn_tau:
self.env.env.release_step = action[0] / self.env.dt # Tau value
return action, None
else:

View File

@ -1,4 +1,3 @@
from abc import ABC
from typing import Tuple, Union
import gym
@ -80,7 +79,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
bc_time=np.zeros((1,)) if not self.replanning_schedule else self.current_traj_steps * self.dt,
bc_pos=self.current_pos, bc_vel=self.current_vel)
# TODO: is this correct for replanning? Do we need to adjust anything here?
self.traj_gen.set_duration(None if self.learn_sub_trajectories else self.duration, np.array([self.dt]))
self.traj_gen.set_duration(None if self.learn_sub_trajectories else np.array([self.duration]), np.array([self.dt]))
traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True)
trajectory_tensor, velocity_tensor = traj_dict['pos'], traj_dict['vel']
@ -109,7 +108,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
""" This function generates a trajectory based on a MP and then does the usual loop over reset and step"""
# agent to learn when to release the ball
mp_params, env_spec_params = self._episode_callback(action)
mp_params, env_spec_params = self.env._episode_callback(action, self.traj_gen)
trajectory, velocity = self.get_trajectory(mp_params)
trajectory_length = len(trajectory)

View File

@ -1,8 +1,9 @@
from typing import Union, Tuple
from mp_pytorch.mp.mp_interfaces import MPInterface
from abc import abstractmethod
import gym
import numpy as np
from abc import abstractmethod
class RawInterfaceWrapper(gym.Wrapper):
@ -56,7 +57,7 @@ class RawInterfaceWrapper(gym.Wrapper):
# return bool(self.replanning_model(s))
return False
def _episode_callback(self, action: np.ndarray) -> Tuple[np.ndarray, Union[np.ndarray, None]]:
def _episode_callback(self, action: np.ndarray, traj_gen: MPInterface) -> Tuple[np.ndarray, Union[np.ndarray, None]]:
"""
Used to extract the parameters for the motion primitive and other parameters from an action array which might
include other actions like ball releasing time for the beer pong environment.

View File

@ -43,10 +43,11 @@ def make_rank(env_id: str, seed: int, rank: int = 0, return_callable=True, **kwa
def make(env_id, seed, **kwargs):
spec = registry.get(env_id)
# spec = registry.get(env_id) # TODO: This doesn't work with gym ==0.21.0
spec = registry.spec(env_id)
# This access is required to allow for nested dict updates
all_kwargs = deepcopy(spec._kwargs)
nested_update(all_kwargs, **kwargs)
nested_update(all_kwargs, kwargs)
return _make(env_id, seed, **all_kwargs)
@ -224,7 +225,7 @@ def make_bb_env_helper(**kwargs):
seed = kwargs.pop("seed", None)
wrappers = kwargs.pop("wrappers")
traj_gen_kwargs = kwargs.pop("traj_gen_kwargs", {})
traj_gen_kwargs = kwargs.pop("trajectory_generator_kwargs", {})
black_box_kwargs = kwargs.pop('black_box_kwargs', {})
contr_kwargs = kwargs.pop("controller_kwargs", {})
phase_kwargs = kwargs.pop("phase_generator_kwargs", {})