fix time aware observation wrapper issue and add try catch for rendering dmc

This commit is contained in:
Onur 2022-07-12 11:37:52 +02:00
parent da49d1b7f7
commit 77976aef21
5 changed files with 22 additions and 24 deletions

View File

@ -1,5 +1,6 @@
from typing import Union, Tuple
import gym
import numpy as np
from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper
@ -9,8 +10,7 @@ class MPWrapper(RawInterfaceWrapper):
@property
def context_mask(self):
return np.concatenate([
[False] * self.env.n_links, # cos
return np.concatenate([[False] * self.env.n_links, # cos
[False] * self.env.n_links, # sin
[True] * 2, # goal position
[False] * self.env.n_links, # angular velocity

View File

@ -70,10 +70,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
def observation(self, observation):
# return context space if we are
mask = self.env.context_mask
# if self.is_time_aware:
# mask = np.append(mask, False)
obs = observation[mask] if self.return_context_observation else observation
obs = observation[self.env.context_mask] if self.return_context_observation else observation
# cast dtype because metaworld returns incorrect that throws gym error
return obs.astype(self.observation_space.dtype)

View File

@ -2,7 +2,6 @@
# License: MIT
# Copyright (c) 2020 Denis Yarats
import collections
import cv2
from collections.abc import MutableMapping
from typing import Any, Dict, Tuple, Optional, Union, Callable
@ -138,11 +137,14 @@ class DMCWrapper(gym.Env):
img = np.dstack([img.astype(np.uint8)] * 3)
if mode == 'human':
try:
import cv2
if self._window is None:
self._window = cv2.namedWindow(self.id, cv2.WINDOW_AUTOSIZE)
cv2.imshow(self.id, img[..., ::-1]) # Image in BGR
cv2.waitKey(1)
except ImportError:
raise gym.error.DependencyNotInstalled("Rendering requires opencv. Run `pip install opencv-python`")
# PYGAME seems to destroy some global rendering configs from the physics render
# except ImportError:
# import pygame

View File

@ -28,7 +28,7 @@ from alr_envs.black_box.factory.controller_factory import get_controller
from alr_envs.black_box.factory.phase_generator_factory import get_phase_generator
from alr_envs.black_box.factory.trajectory_generator_factory import get_trajectory_generator
from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper
from alr_envs.black_box.time_aware_observation import TimeAwareObservation
from alr_envs.utils.time_aware_observation import TimeAwareObservation
from alr_envs.utils.utils import nested_update
@ -148,9 +148,8 @@ def make_bb(
if learn_sub_trajs and do_replanning:
raise ValueError('Cannot used sub-trajectory learning and replanning together.')
if learn_sub_trajs or do_replanning:
# add time_step observation when replanning
if not any(issubclass(w, TimeAwareObservation) for w in kwargs['wrappers']):
if (learn_sub_trajs or do_replanning) and not any(issubclass(w, TimeAwareObservation) for w in kwargs['wrappers']):
# Add as first wrapper in order to alter observation
kwargs['wrappers'].insert(0, TimeAwareObservation)