fix time aware observation wrapper issue and add try catch for rendering dmc
This commit is contained in:
parent
da49d1b7f7
commit
77976aef21
@ -1,5 +1,6 @@
|
|||||||
from typing import Union, Tuple
|
from typing import Union, Tuple
|
||||||
|
|
||||||
|
import gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper
|
from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper
|
||||||
@ -9,14 +10,13 @@ class MPWrapper(RawInterfaceWrapper):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def context_mask(self):
|
def context_mask(self):
|
||||||
return np.concatenate([
|
return np.concatenate([[False] * self.env.n_links, # cos
|
||||||
[False] * self.env.n_links, # cos
|
[False] * self.env.n_links, # sin
|
||||||
[False] * self.env.n_links, # sin
|
[True] * 2, # goal position
|
||||||
[True] * 2, # goal position
|
[False] * self.env.n_links, # angular velocity
|
||||||
[False] * self.env.n_links, # angular velocity
|
[False] * 3, # goal distance
|
||||||
[False] * 3, # goal distance
|
# [False], # step
|
||||||
# [False], # step
|
])
|
||||||
])
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
|
def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||||
|
@ -70,10 +70,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
|
|
||||||
def observation(self, observation):
|
def observation(self, observation):
|
||||||
# return context space if we are
|
# return context space if we are
|
||||||
mask = self.env.context_mask
|
obs = observation[self.env.context_mask] if self.return_context_observation else observation
|
||||||
# if self.is_time_aware:
|
|
||||||
# mask = np.append(mask, False)
|
|
||||||
obs = observation[mask] if self.return_context_observation else observation
|
|
||||||
# cast dtype because metaworld returns incorrect that throws gym error
|
# cast dtype because metaworld returns incorrect that throws gym error
|
||||||
return obs.astype(self.observation_space.dtype)
|
return obs.astype(self.observation_space.dtype)
|
||||||
|
|
||||||
|
@ -2,7 +2,6 @@
|
|||||||
# License: MIT
|
# License: MIT
|
||||||
# Copyright (c) 2020 Denis Yarats
|
# Copyright (c) 2020 Denis Yarats
|
||||||
import collections
|
import collections
|
||||||
import cv2
|
|
||||||
from collections.abc import MutableMapping
|
from collections.abc import MutableMapping
|
||||||
from typing import Any, Dict, Tuple, Optional, Union, Callable
|
from typing import Any, Dict, Tuple, Optional, Union, Callable
|
||||||
|
|
||||||
@ -138,11 +137,14 @@ class DMCWrapper(gym.Env):
|
|||||||
img = np.dstack([img.astype(np.uint8)] * 3)
|
img = np.dstack([img.astype(np.uint8)] * 3)
|
||||||
|
|
||||||
if mode == 'human':
|
if mode == 'human':
|
||||||
if self._window is None:
|
try:
|
||||||
self._window = cv2.namedWindow(self.id, cv2.WINDOW_AUTOSIZE)
|
import cv2
|
||||||
|
if self._window is None:
|
||||||
cv2.imshow(self.id, img[..., ::-1]) # Image in BGR
|
self._window = cv2.namedWindow(self.id, cv2.WINDOW_AUTOSIZE)
|
||||||
cv2.waitKey(1)
|
cv2.imshow(self.id, img[..., ::-1]) # Image in BGR
|
||||||
|
cv2.waitKey(1)
|
||||||
|
except ImportError:
|
||||||
|
raise gym.error.DependencyNotInstalled("Rendering requires opencv. Run `pip install opencv-python`")
|
||||||
# PYGAME seems to destroy some global rendering configs from the physics render
|
# PYGAME seems to destroy some global rendering configs from the physics render
|
||||||
# except ImportError:
|
# except ImportError:
|
||||||
# import pygame
|
# import pygame
|
||||||
|
@ -28,7 +28,7 @@ from alr_envs.black_box.factory.controller_factory import get_controller
|
|||||||
from alr_envs.black_box.factory.phase_generator_factory import get_phase_generator
|
from alr_envs.black_box.factory.phase_generator_factory import get_phase_generator
|
||||||
from alr_envs.black_box.factory.trajectory_generator_factory import get_trajectory_generator
|
from alr_envs.black_box.factory.trajectory_generator_factory import get_trajectory_generator
|
||||||
from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper
|
from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper
|
||||||
from alr_envs.black_box.time_aware_observation import TimeAwareObservation
|
from alr_envs.utils.time_aware_observation import TimeAwareObservation
|
||||||
from alr_envs.utils.utils import nested_update
|
from alr_envs.utils.utils import nested_update
|
||||||
|
|
||||||
|
|
||||||
@ -148,11 +148,10 @@ def make_bb(
|
|||||||
if learn_sub_trajs and do_replanning:
|
if learn_sub_trajs and do_replanning:
|
||||||
raise ValueError('Cannot used sub-trajectory learning and replanning together.')
|
raise ValueError('Cannot used sub-trajectory learning and replanning together.')
|
||||||
|
|
||||||
if learn_sub_trajs or do_replanning:
|
# add time_step observation when replanning
|
||||||
# add time_step observation when replanning
|
if (learn_sub_trajs or do_replanning) and not any(issubclass(w, TimeAwareObservation) for w in kwargs['wrappers']):
|
||||||
if not any(issubclass(w, TimeAwareObservation) for w in kwargs['wrappers']):
|
# Add as first wrapper in order to alter observation
|
||||||
# Add as first wrapper in order to alter observation
|
kwargs['wrappers'].insert(0, TimeAwareObservation)
|
||||||
kwargs['wrappers'].insert(0, TimeAwareObservation)
|
|
||||||
|
|
||||||
env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs)
|
env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user