Fix bugs to create mp environments. Still conflicts with mp_pytorch_lib
This commit is contained in:
parent
d4e3b957a9
commit
2161cfd3a6
@ -428,7 +428,7 @@ for _v in _versions:
|
||||
kwargs_dict_bp_promp['name'] = f"alr_envs:{_v}"
|
||||
register(
|
||||
id=_env_id,
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_mp_env_helper',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
|
||||
kwargs=kwargs_dict_bp_promp
|
||||
)
|
||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||
|
@ -1 +1 @@
|
||||
from new_mp_wrapper import MPWrapper
|
||||
from .mp_wrapper import MPWrapper
|
||||
|
@ -25,8 +25,9 @@ class MPWrapper(RawInterfaceWrapper):
|
||||
[False] # env steps
|
||||
])
|
||||
|
||||
def _episode_callback(self, action: np.ndarray) -> Tuple[np.ndarray, Union[np.ndarray, None]]:
|
||||
if self.mp.learn_tau:
|
||||
# TODO: Fix this
|
||||
def _episode_callback(self, action: np.ndarray, mp) -> Tuple[np.ndarray, Union[np.ndarray, None]]:
|
||||
if mp.learn_tau:
|
||||
self.env.env.release_step = action[0] / self.env.dt # Tau value
|
||||
return action, None
|
||||
else:
|
||||
|
@ -1,4 +1,3 @@
|
||||
from abc import ABC
|
||||
from typing import Tuple, Union
|
||||
|
||||
import gym
|
||||
@ -80,7 +79,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
||||
bc_time=np.zeros((1,)) if not self.replanning_schedule else self.current_traj_steps * self.dt,
|
||||
bc_pos=self.current_pos, bc_vel=self.current_vel)
|
||||
# TODO: is this correct for replanning? Do we need to adjust anything here?
|
||||
self.traj_gen.set_duration(None if self.learn_sub_trajectories else self.duration, np.array([self.dt]))
|
||||
self.traj_gen.set_duration(None if self.learn_sub_trajectories else np.array([self.duration]), np.array([self.dt]))
|
||||
traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True)
|
||||
trajectory_tensor, velocity_tensor = traj_dict['pos'], traj_dict['vel']
|
||||
|
||||
@ -109,7 +108,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
||||
""" This function generates a trajectory based on a MP and then does the usual loop over reset and step"""
|
||||
|
||||
# agent to learn when to release the ball
|
||||
mp_params, env_spec_params = self._episode_callback(action)
|
||||
mp_params, env_spec_params = self.env._episode_callback(action, self.traj_gen)
|
||||
trajectory, velocity = self.get_trajectory(mp_params)
|
||||
|
||||
trajectory_length = len(trajectory)
|
||||
|
@ -1,8 +1,9 @@
|
||||
from typing import Union, Tuple
|
||||
from mp_pytorch.mp.mp_interfaces import MPInterface
|
||||
from abc import abstractmethod
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
from abc import abstractmethod
|
||||
|
||||
|
||||
class RawInterfaceWrapper(gym.Wrapper):
|
||||
@ -56,7 +57,7 @@ class RawInterfaceWrapper(gym.Wrapper):
|
||||
# return bool(self.replanning_model(s))
|
||||
return False
|
||||
|
||||
def _episode_callback(self, action: np.ndarray) -> Tuple[np.ndarray, Union[np.ndarray, None]]:
|
||||
def _episode_callback(self, action: np.ndarray, traj_gen: MPInterface) -> Tuple[np.ndarray, Union[np.ndarray, None]]:
|
||||
"""
|
||||
Used to extract the parameters for the motion primitive and other parameters from an action array which might
|
||||
include other actions like ball releasing time for the beer pong environment.
|
||||
|
@ -43,10 +43,11 @@ def make_rank(env_id: str, seed: int, rank: int = 0, return_callable=True, **kwa
|
||||
|
||||
|
||||
def make(env_id, seed, **kwargs):
|
||||
spec = registry.get(env_id)
|
||||
# spec = registry.get(env_id) # TODO: This doesn't work with gym ==0.21.0
|
||||
spec = registry.spec(env_id)
|
||||
# This access is required to allow for nested dict updates
|
||||
all_kwargs = deepcopy(spec._kwargs)
|
||||
nested_update(all_kwargs, **kwargs)
|
||||
nested_update(all_kwargs, kwargs)
|
||||
return _make(env_id, seed, **all_kwargs)
|
||||
|
||||
|
||||
@ -224,7 +225,7 @@ def make_bb_env_helper(**kwargs):
|
||||
seed = kwargs.pop("seed", None)
|
||||
wrappers = kwargs.pop("wrappers")
|
||||
|
||||
traj_gen_kwargs = kwargs.pop("traj_gen_kwargs", {})
|
||||
traj_gen_kwargs = kwargs.pop("trajectory_generator_kwargs", {})
|
||||
black_box_kwargs = kwargs.pop('black_box_kwargs', {})
|
||||
contr_kwargs = kwargs.pop("controller_kwargs", {})
|
||||
phase_kwargs = kwargs.pop("phase_generator_kwargs", {})
|
||||
|
Loading…
Reference in New Issue
Block a user