fix time aware observation wrapper issue and add try catch for rendering dmc

This commit is contained in:
Onur 2022-07-12 11:37:52 +02:00
parent da49d1b7f7
commit 77976aef21
5 changed files with 22 additions and 24 deletions

View File

@ -1,5 +1,6 @@
from typing import Union, Tuple from typing import Union, Tuple
import gym
import numpy as np import numpy as np
from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper
@ -9,8 +10,7 @@ class MPWrapper(RawInterfaceWrapper):
@property @property
def context_mask(self): def context_mask(self):
return np.concatenate([ return np.concatenate([[False] * self.env.n_links, # cos
[False] * self.env.n_links, # cos
[False] * self.env.n_links, # sin [False] * self.env.n_links, # sin
[True] * 2, # goal position [True] * 2, # goal position
[False] * self.env.n_links, # angular velocity [False] * self.env.n_links, # angular velocity

View File

@ -70,10 +70,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
def observation(self, observation): def observation(self, observation):
# return context space if we are # return context space if we are
mask = self.env.context_mask obs = observation[self.env.context_mask] if self.return_context_observation else observation
# if self.is_time_aware:
# mask = np.append(mask, False)
obs = observation[mask] if self.return_context_observation else observation
# cast dtype because metaworld returns incorrect that throws gym error # cast dtype because metaworld returns incorrect that throws gym error
return obs.astype(self.observation_space.dtype) return obs.astype(self.observation_space.dtype)

View File

@ -2,7 +2,6 @@
# License: MIT # License: MIT
# Copyright (c) 2020 Denis Yarats # Copyright (c) 2020 Denis Yarats
import collections import collections
import cv2
from collections.abc import MutableMapping from collections.abc import MutableMapping
from typing import Any, Dict, Tuple, Optional, Union, Callable from typing import Any, Dict, Tuple, Optional, Union, Callable
@ -138,11 +137,14 @@ class DMCWrapper(gym.Env):
img = np.dstack([img.astype(np.uint8)] * 3) img = np.dstack([img.astype(np.uint8)] * 3)
if mode == 'human': if mode == 'human':
try:
import cv2
if self._window is None: if self._window is None:
self._window = cv2.namedWindow(self.id, cv2.WINDOW_AUTOSIZE) self._window = cv2.namedWindow(self.id, cv2.WINDOW_AUTOSIZE)
cv2.imshow(self.id, img[..., ::-1]) # Image in BGR cv2.imshow(self.id, img[..., ::-1]) # Image in BGR
cv2.waitKey(1) cv2.waitKey(1)
except ImportError:
raise gym.error.DependencyNotInstalled("Rendering requires opencv. Run `pip install opencv-python`")
# PYGAME seems to destroy some global rendering configs from the physics render # PYGAME seems to destroy some global rendering configs from the physics render
# except ImportError: # except ImportError:
# import pygame # import pygame

View File

@ -28,7 +28,7 @@ from alr_envs.black_box.factory.controller_factory import get_controller
from alr_envs.black_box.factory.phase_generator_factory import get_phase_generator from alr_envs.black_box.factory.phase_generator_factory import get_phase_generator
from alr_envs.black_box.factory.trajectory_generator_factory import get_trajectory_generator from alr_envs.black_box.factory.trajectory_generator_factory import get_trajectory_generator
from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper
from alr_envs.black_box.time_aware_observation import TimeAwareObservation from alr_envs.utils.time_aware_observation import TimeAwareObservation
from alr_envs.utils.utils import nested_update from alr_envs.utils.utils import nested_update
@ -148,9 +148,8 @@ def make_bb(
if learn_sub_trajs and do_replanning: if learn_sub_trajs and do_replanning:
raise ValueError('Cannot used sub-trajectory learning and replanning together.') raise ValueError('Cannot used sub-trajectory learning and replanning together.')
if learn_sub_trajs or do_replanning:
# add time_step observation when replanning # add time_step observation when replanning
if not any(issubclass(w, TimeAwareObservation) for w in kwargs['wrappers']): if (learn_sub_trajs or do_replanning) and not any(issubclass(w, TimeAwareObservation) for w in kwargs['wrappers']):
# Add as first wrapper in order to alter observation # Add as first wrapper in order to alter observation
kwargs['wrappers'].insert(0, TimeAwareObservation) kwargs['wrappers'].insert(0, TimeAwareObservation)