minor bug fixes
This commit is contained in:
parent
4aacd71ed3
commit
7957632eb0
@ -67,26 +67,27 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
|
|
||||||
def observation(self, observation):
|
def observation(self, observation):
|
||||||
# return context space if we are
|
# return context space if we are
|
||||||
obs = observation[self.env.context_mask] if self.return_context_observation else observation
|
if self.return_context_observation:
|
||||||
|
observation = observation[self.env.context_mask]
|
||||||
# cast dtype because metaworld returns incorrect that throws gym error
|
# cast dtype because metaworld returns incorrect that throws gym error
|
||||||
return obs.astype(self.observation_space.dtype)
|
return observation.astype(self.observation_space.dtype)
|
||||||
|
|
||||||
def get_trajectory(self, action: np.ndarray) -> Tuple:
|
def get_trajectory(self, action: np.ndarray) -> Tuple:
|
||||||
clipped_params = np.clip(action, self.traj_gen_action_space.low, self.traj_gen_action_space.high)
|
clipped_params = np.clip(action, self.traj_gen_action_space.low, self.traj_gen_action_space.high)
|
||||||
self.traj_gen.set_params(clipped_params)
|
self.traj_gen.set_params(clipped_params)
|
||||||
# TODO: is this correct for replanning? Do we need to adjust anything here?
|
|
||||||
bc_time = np.array(0 if not self.do_replanning else self.current_traj_steps * self.dt)
|
bc_time = np.array(0 if not self.do_replanning else self.current_traj_steps * self.dt)
|
||||||
|
# TODO we could think about initializing with the previous desired value in order to have a smooth transition
|
||||||
|
# at least from the planning point of view.
|
||||||
self.traj_gen.set_boundary_conditions(bc_time, self.current_pos, self.current_vel)
|
self.traj_gen.set_boundary_conditions(bc_time, self.current_pos, self.current_vel)
|
||||||
# TODO: remove the - self.dt after Bruces fix.
|
duration = None if self.learn_sub_trajectories else self.duration
|
||||||
self.traj_gen.set_duration(None if self.learn_sub_trajectories else self.duration - self.dt, self.dt)
|
self.traj_gen.set_duration(duration, self.dt)
|
||||||
# traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True)
|
# traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True)
|
||||||
trajectory = get_numpy(self.traj_gen.get_traj_pos())
|
trajectory = get_numpy(self.traj_gen.get_traj_pos())
|
||||||
velocity = get_numpy(self.traj_gen.get_traj_vel())
|
velocity = get_numpy(self.traj_gen.get_traj_vel())
|
||||||
|
|
||||||
if self.do_replanning:
|
# Remove first element of trajectory as this is the current position and velocity
|
||||||
# Remove first part of trajectory as this is already over
|
trajectory = trajectory[1:]
|
||||||
trajectory = trajectory[self.current_traj_steps:]
|
velocity = velocity[1:]
|
||||||
velocity = velocity[self.current_traj_steps:]
|
|
||||||
|
|
||||||
return trajectory, velocity
|
return trajectory, velocity
|
||||||
|
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import logging
|
||||||
import re
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import MutableMapping
|
from collections.abc import MutableMapping
|
||||||
@ -310,7 +311,11 @@ def make_gym(env_id, seed, **kwargs):
|
|||||||
"""
|
"""
|
||||||
# Getting the existing keywords to allow for nested dict updates for BB envs
|
# Getting the existing keywords to allow for nested dict updates for BB envs
|
||||||
# gym only allows for non nested updates.
|
# gym only allows for non nested updates.
|
||||||
|
try:
|
||||||
all_kwargs = deepcopy(registry.get(env_id).kwargs)
|
all_kwargs = deepcopy(registry.get(env_id).kwargs)
|
||||||
|
except AttributeError as e:
|
||||||
|
logging.error(f'The gym environment with id {env_id} could not been found.')
|
||||||
|
raise e
|
||||||
nested_update(all_kwargs, kwargs)
|
nested_update(all_kwargs, kwargs)
|
||||||
kwargs = all_kwargs
|
kwargs = all_kwargs
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user