From 77976aef21e640d889cbcdb80c65c2d93d1920b8 Mon Sep 17 00:00:00 2001 From: Onur Date: Tue, 12 Jul 2022 11:37:52 +0200 Subject: [PATCH] fix time aware observation wrapper issue and add try catch for rendering dmc --- alr_envs/alr/mujoco/reacher/mp_wrapper.py | 16 ++++++++-------- alr_envs/black_box/black_box_wrapper.py | 5 +---- alr_envs/dmc/dmc_wrapper.py | 14 ++++++++------ alr_envs/utils/make_env_helpers.py | 11 +++++------ .../time_aware_observation.py | 0 5 files changed, 22 insertions(+), 24 deletions(-) rename alr_envs/{black_box => utils}/time_aware_observation.py (100%) diff --git a/alr_envs/alr/mujoco/reacher/mp_wrapper.py b/alr_envs/alr/mujoco/reacher/mp_wrapper.py index 37422b1..2ef7c78 100644 --- a/alr_envs/alr/mujoco/reacher/mp_wrapper.py +++ b/alr_envs/alr/mujoco/reacher/mp_wrapper.py @@ -1,5 +1,6 @@ from typing import Union, Tuple +import gym import numpy as np from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper @@ -9,14 +10,13 @@ class MPWrapper(RawInterfaceWrapper): @property def context_mask(self): - return np.concatenate([ - [False] * self.env.n_links, # cos - [False] * self.env.n_links, # sin - [True] * 2, # goal position - [False] * self.env.n_links, # angular velocity - [False] * 3, # goal distance - # [False], # step - ]) + return np.concatenate([[False] * self.env.n_links, # cos + [False] * self.env.n_links, # sin + [True] * 2, # goal position + [False] * self.env.n_links, # angular velocity + [False] * 3, # goal distance + # [False], # step + ]) @property def current_pos(self) -> Union[float, int, np.ndarray, Tuple]: diff --git a/alr_envs/black_box/black_box_wrapper.py b/alr_envs/black_box/black_box_wrapper.py index 8635ca7..0ac2d58 100644 --- a/alr_envs/black_box/black_box_wrapper.py +++ b/alr_envs/black_box/black_box_wrapper.py @@ -70,10 +70,7 @@ class BlackBoxWrapper(gym.ObservationWrapper): def observation(self, observation): # return context space if we are - mask = self.env.context_mask - # if self.is_time_aware: - # mask = np.append(mask, False) - obs = observation[mask] if self.return_context_observation else observation + obs = observation[self.env.context_mask] if self.return_context_observation else observation # cast dtype because metaworld returns incorrect that throws gym error return obs.astype(self.observation_space.dtype) diff --git a/alr_envs/dmc/dmc_wrapper.py b/alr_envs/dmc/dmc_wrapper.py index 43c45f5..59f7793 100644 --- a/alr_envs/dmc/dmc_wrapper.py +++ b/alr_envs/dmc/dmc_wrapper.py @@ -2,7 +2,6 @@ # License: MIT # Copyright (c) 2020 Denis Yarats import collections -import cv2 from collections.abc import MutableMapping from typing import Any, Dict, Tuple, Optional, Union, Callable @@ -138,11 +137,14 @@ class DMCWrapper(gym.Env): img = np.dstack([img.astype(np.uint8)] * 3) if mode == 'human': - if self._window is None: - self._window = cv2.namedWindow(self.id, cv2.WINDOW_AUTOSIZE) - - cv2.imshow(self.id, img[..., ::-1]) # Image in BGR - cv2.waitKey(1) + try: + import cv2 + if self._window is None: + self._window = cv2.namedWindow(self.id, cv2.WINDOW_AUTOSIZE) + cv2.imshow(self.id, img[..., ::-1]) # Image in BGR + cv2.waitKey(1) + except ImportError: + raise gym.error.DependencyNotInstalled("Rendering requires opencv. Run `pip install opencv-python`") # PYGAME seems to destroy some global rendering configs from the physics render # except ImportError: # import pygame diff --git a/alr_envs/utils/make_env_helpers.py b/alr_envs/utils/make_env_helpers.py index 48d1e1f..8cdf8c3 100644 --- a/alr_envs/utils/make_env_helpers.py +++ b/alr_envs/utils/make_env_helpers.py @@ -28,7 +28,7 @@ from alr_envs.black_box.factory.controller_factory import get_controller from alr_envs.black_box.factory.phase_generator_factory import get_phase_generator from alr_envs.black_box.factory.trajectory_generator_factory import get_trajectory_generator from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper -from alr_envs.black_box.time_aware_observation import TimeAwareObservation +from alr_envs.utils.time_aware_observation import TimeAwareObservation from alr_envs.utils.utils import nested_update @@ -148,11 +148,10 @@ def make_bb( if learn_sub_trajs and do_replanning: raise ValueError('Cannot used sub-trajectory learning and replanning together.') - if learn_sub_trajs or do_replanning: - # add time_step observation when replanning - if not any(issubclass(w, TimeAwareObservation) for w in kwargs['wrappers']): - # Add as first wrapper in order to alter observation - kwargs['wrappers'].insert(0, TimeAwareObservation) + # add time_step observation when replanning + if (learn_sub_trajs or do_replanning) and not any(issubclass(w, TimeAwareObservation) for w in kwargs['wrappers']): + # Add as first wrapper in order to alter observation + kwargs['wrappers'].insert(0, TimeAwareObservation) env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs) diff --git a/alr_envs/black_box/time_aware_observation.py b/alr_envs/utils/time_aware_observation.py similarity index 100% rename from alr_envs/black_box/time_aware_observation.py rename to alr_envs/utils/time_aware_observation.py