minor bug fixes
This commit is contained in:
parent
4aacd71ed3
commit
7957632eb0
@ -67,26 +67,27 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
||||
|
||||
def observation(self, observation):
|
||||
# return context space if we are
|
||||
obs = observation[self.env.context_mask] if self.return_context_observation else observation
|
||||
if self.return_context_observation:
|
||||
observation = observation[self.env.context_mask]
|
||||
# cast dtype because metaworld returns incorrect that throws gym error
|
||||
return obs.astype(self.observation_space.dtype)
|
||||
return observation.astype(self.observation_space.dtype)
|
||||
|
||||
def get_trajectory(self, action: np.ndarray) -> Tuple:
|
||||
clipped_params = np.clip(action, self.traj_gen_action_space.low, self.traj_gen_action_space.high)
|
||||
self.traj_gen.set_params(clipped_params)
|
||||
# TODO: is this correct for replanning? Do we need to adjust anything here?
|
||||
bc_time = np.array(0 if not self.do_replanning else self.current_traj_steps * self.dt)
|
||||
# TODO we could think about initializing with the previous desired value in order to have a smooth transition
|
||||
# at least from the planning point of view.
|
||||
self.traj_gen.set_boundary_conditions(bc_time, self.current_pos, self.current_vel)
|
||||
# TODO: remove the - self.dt after Bruces fix.
|
||||
self.traj_gen.set_duration(None if self.learn_sub_trajectories else self.duration - self.dt, self.dt)
|
||||
duration = None if self.learn_sub_trajectories else self.duration
|
||||
self.traj_gen.set_duration(duration, self.dt)
|
||||
# traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True)
|
||||
trajectory = get_numpy(self.traj_gen.get_traj_pos())
|
||||
velocity = get_numpy(self.traj_gen.get_traj_vel())
|
||||
|
||||
if self.do_replanning:
|
||||
# Remove first part of trajectory as this is already over
|
||||
trajectory = trajectory[self.current_traj_steps:]
|
||||
velocity = velocity[self.current_traj_steps:]
|
||||
# Remove first element of trajectory as this is the current position and velocity
|
||||
trajectory = trajectory[1:]
|
||||
velocity = velocity[1:]
|
||||
|
||||
return trajectory, velocity
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from collections.abc import MutableMapping
|
||||
@ -310,7 +311,11 @@ def make_gym(env_id, seed, **kwargs):
|
||||
"""
|
||||
# Getting the existing keywords to allow for nested dict updates for BB envs
|
||||
# gym only allows for non nested updates.
|
||||
try:
|
||||
all_kwargs = deepcopy(registry.get(env_id).kwargs)
|
||||
except AttributeError as e:
|
||||
logging.error(f'The gym environment with id {env_id} could not been found.')
|
||||
raise e
|
||||
nested_update(all_kwargs, kwargs)
|
||||
kwargs = all_kwargs
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user