fix time aware observation wrapper issue and add try catch for rendering dmc
This commit is contained in:
parent
da49d1b7f7
commit
77976aef21
@ -1,5 +1,6 @@
|
||||
from typing import Union, Tuple
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
|
||||
from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper
|
||||
@ -9,14 +10,13 @@ class MPWrapper(RawInterfaceWrapper):
|
||||
|
||||
@property
|
||||
def context_mask(self):
|
||||
return np.concatenate([
|
||||
[False] * self.env.n_links, # cos
|
||||
[False] * self.env.n_links, # sin
|
||||
[True] * 2, # goal position
|
||||
[False] * self.env.n_links, # angular velocity
|
||||
[False] * 3, # goal distance
|
||||
# [False], # step
|
||||
])
|
||||
return np.concatenate([[False] * self.env.n_links, # cos
|
||||
[False] * self.env.n_links, # sin
|
||||
[True] * 2, # goal position
|
||||
[False] * self.env.n_links, # angular velocity
|
||||
[False] * 3, # goal distance
|
||||
# [False], # step
|
||||
])
|
||||
|
||||
@property
|
||||
def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||
|
@ -70,10 +70,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
||||
|
||||
def observation(self, observation):
|
||||
# return context space if we are
|
||||
mask = self.env.context_mask
|
||||
# if self.is_time_aware:
|
||||
# mask = np.append(mask, False)
|
||||
obs = observation[mask] if self.return_context_observation else observation
|
||||
obs = observation[self.env.context_mask] if self.return_context_observation else observation
|
||||
# cast dtype because metaworld returns incorrect that throws gym error
|
||||
return obs.astype(self.observation_space.dtype)
|
||||
|
||||
|
@ -2,7 +2,6 @@
|
||||
# License: MIT
|
||||
# Copyright (c) 2020 Denis Yarats
|
||||
import collections
|
||||
import cv2
|
||||
from collections.abc import MutableMapping
|
||||
from typing import Any, Dict, Tuple, Optional, Union, Callable
|
||||
|
||||
@ -138,11 +137,14 @@ class DMCWrapper(gym.Env):
|
||||
img = np.dstack([img.astype(np.uint8)] * 3)
|
||||
|
||||
if mode == 'human':
|
||||
if self._window is None:
|
||||
self._window = cv2.namedWindow(self.id, cv2.WINDOW_AUTOSIZE)
|
||||
|
||||
cv2.imshow(self.id, img[..., ::-1]) # Image in BGR
|
||||
cv2.waitKey(1)
|
||||
try:
|
||||
import cv2
|
||||
if self._window is None:
|
||||
self._window = cv2.namedWindow(self.id, cv2.WINDOW_AUTOSIZE)
|
||||
cv2.imshow(self.id, img[..., ::-1]) # Image in BGR
|
||||
cv2.waitKey(1)
|
||||
except ImportError:
|
||||
raise gym.error.DependencyNotInstalled("Rendering requires opencv. Run `pip install opencv-python`")
|
||||
# PYGAME seems to destroy some global rendering configs from the physics render
|
||||
# except ImportError:
|
||||
# import pygame
|
||||
|
@ -28,7 +28,7 @@ from alr_envs.black_box.factory.controller_factory import get_controller
|
||||
from alr_envs.black_box.factory.phase_generator_factory import get_phase_generator
|
||||
from alr_envs.black_box.factory.trajectory_generator_factory import get_trajectory_generator
|
||||
from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper
|
||||
from alr_envs.black_box.time_aware_observation import TimeAwareObservation
|
||||
from alr_envs.utils.time_aware_observation import TimeAwareObservation
|
||||
from alr_envs.utils.utils import nested_update
|
||||
|
||||
|
||||
@ -148,11 +148,10 @@ def make_bb(
|
||||
if learn_sub_trajs and do_replanning:
|
||||
raise ValueError('Cannot used sub-trajectory learning and replanning together.')
|
||||
|
||||
if learn_sub_trajs or do_replanning:
|
||||
# add time_step observation when replanning
|
||||
if not any(issubclass(w, TimeAwareObservation) for w in kwargs['wrappers']):
|
||||
# Add as first wrapper in order to alter observation
|
||||
kwargs['wrappers'].insert(0, TimeAwareObservation)
|
||||
# add time_step observation when replanning
|
||||
if (learn_sub_trajs or do_replanning) and not any(issubclass(w, TimeAwareObservation) for w in kwargs['wrappers']):
|
||||
# Add as first wrapper in order to alter observation
|
||||
kwargs['wrappers'].insert(0, TimeAwareObservation)
|
||||
|
||||
env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user