minor bug fixes

This commit is contained in:
Fabian 2022-07-27 16:34:35 +02:00
parent 4aacd71ed3
commit 7957632eb0
2 changed files with 16 additions and 10 deletions

View File

@ -67,26 +67,27 @@ class BlackBoxWrapper(gym.ObservationWrapper):
def observation(self, observation): def observation(self, observation):
# return context space if we are # return context space if we are
obs = observation[self.env.context_mask] if self.return_context_observation else observation if self.return_context_observation:
observation = observation[self.env.context_mask]
# cast dtype because metaworld returns incorrect that throws gym error # cast dtype because metaworld returns incorrect that throws gym error
return obs.astype(self.observation_space.dtype) return observation.astype(self.observation_space.dtype)
def get_trajectory(self, action: np.ndarray) -> Tuple: def get_trajectory(self, action: np.ndarray) -> Tuple:
clipped_params = np.clip(action, self.traj_gen_action_space.low, self.traj_gen_action_space.high) clipped_params = np.clip(action, self.traj_gen_action_space.low, self.traj_gen_action_space.high)
self.traj_gen.set_params(clipped_params) self.traj_gen.set_params(clipped_params)
# TODO: is this correct for replanning? Do we need to adjust anything here?
bc_time = np.array(0 if not self.do_replanning else self.current_traj_steps * self.dt) bc_time = np.array(0 if not self.do_replanning else self.current_traj_steps * self.dt)
# TODO we could think about initializing with the previous desired value in order to have a smooth transition
# at least from the planning point of view.
self.traj_gen.set_boundary_conditions(bc_time, self.current_pos, self.current_vel) self.traj_gen.set_boundary_conditions(bc_time, self.current_pos, self.current_vel)
# TODO: remove the - self.dt after Bruces fix. duration = None if self.learn_sub_trajectories else self.duration
self.traj_gen.set_duration(None if self.learn_sub_trajectories else self.duration - self.dt, self.dt) self.traj_gen.set_duration(duration, self.dt)
# traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True) # traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True)
trajectory = get_numpy(self.traj_gen.get_traj_pos()) trajectory = get_numpy(self.traj_gen.get_traj_pos())
velocity = get_numpy(self.traj_gen.get_traj_vel()) velocity = get_numpy(self.traj_gen.get_traj_vel())
if self.do_replanning: # Remove first element of trajectory as this is the current position and velocity
# Remove first part of trajectory as this is already over trajectory = trajectory[1:]
trajectory = trajectory[self.current_traj_steps:] velocity = velocity[1:]
velocity = velocity[self.current_traj_steps:]
return trajectory, velocity return trajectory, velocity

View File

@ -1,3 +1,4 @@
import logging
import re import re
import uuid import uuid
from collections.abc import MutableMapping from collections.abc import MutableMapping
@ -310,7 +311,11 @@ def make_gym(env_id, seed, **kwargs):
""" """
# Getting the existing keywords to allow for nested dict updates for BB envs # Getting the existing keywords to allow for nested dict updates for BB envs
# gym only allows for non nested updates. # gym only allows for non nested updates.
try:
all_kwargs = deepcopy(registry.get(env_id).kwargs) all_kwargs = deepcopy(registry.get(env_id).kwargs)
except AttributeError as e:
logging.error(f'The gym environment with id {env_id} could not been found.')
raise e
nested_update(all_kwargs, kwargs) nested_update(all_kwargs, kwargs)
kwargs = all_kwargs kwargs = all_kwargs