call on superclass for obs wrapper

This commit is contained in:
Fabian 2022-06-30 14:20:52 +02:00
parent 3273f455c5
commit c3a8352c63
8 changed files with 18 additions and 123 deletions

View File

@ -1,10 +1,11 @@
from typing import Iterable, Union
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Union
import gym import gym
import matplotlib.pyplot as plt
import numpy as np import numpy as np
from gym import spaces from gym import spaces
from gym.utils import seeding from gym.utils import seeding
from alr_envs.alr.classic_control.utils import intersect from alr_envs.alr.classic_control.utils import intersect

View File

@ -1,27 +0,0 @@
from typing import Tuple, Union
import numpy as np
from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper
class MPWrapper(RawInterfaceWrapper):
def get_context_mask(self):
return np.hstack([
[self.env.random_start] * self.env.n_links, # cos
[self.env.random_start] * self.env.n_links, # sin
[self.env.random_start] * self.env.n_links, # velocity
[self.env.initial_width is None], # hole width
# [self.env.hole_depth is None], # hole depth
[True] * 2, # x-y coordinates of target distance
[False] # env steps
])
@property
def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
return self.env.current_pos
@property
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
return self.env.current_vel

View File

@ -6,8 +6,8 @@ from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper
class MPWrapper(RawInterfaceWrapper): class MPWrapper(RawInterfaceWrapper):
@property
def context_mask(self) -> np.ndarray: def context_mask(self):
return np.hstack([ return np.hstack([
[self.env.random_start] * self.env.n_links, # cos [self.env.random_start] * self.env.n_links, # cos
[self.env.random_start] * self.env.n_links, # sin [self.env.random_start] * self.env.n_links, # sin
@ -25,10 +25,6 @@ class MPWrapper(RawInterfaceWrapper):
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]: def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
return self.env.current_vel return self.env.current_vel
@property
def goal_pos(self) -> Union[float, int, np.ndarray, Tuple]:
raise ValueError("Goal position is not available and has to be learnt based on the environment.")
@property @property
def dt(self) -> Union[float, int]: def dt(self) -> Union[float, int]:
return self.env.dt return self.env.dt

View File

@ -1,30 +0,0 @@
from typing import Tuple, Union
import numpy as np
from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper
class MPWrapper(RawInterfaceWrapper):
def context_mask(self):
return np.hstack([
[self.env.random_start] * self.env.n_links, # cos
[self.env.random_start] * self.env.n_links, # sin
[self.env.random_start] * self.env.n_links, # velocity
[self.env.initial_via_target is None] * 2, # x-y coordinates of via point distance
[True] * 2, # x-y coordinates of target distance
[False] # env steps
])
@property
def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
return self.env.current_pos
@property
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
return self.env.current_vel
@property
def dt(self) -> Union[float, int]:
return self.env.dt

View File

@ -1,2 +1,2 @@
from .mp_wrapper import MPWrapper from .mp_wrapper import MPWrapper
from .new_mp_wrapper import MPWrapper as NewMPWrapper from .mp_wrapper import MPWrapper as NewMPWrapper

View File

@ -1,4 +1,4 @@
from typing import Union from typing import Union, Tuple
import numpy as np import numpy as np
@ -8,37 +8,21 @@ from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper
class MPWrapper(RawInterfaceWrapper): class MPWrapper(RawInterfaceWrapper):
@property @property
def context_mask(self) -> np.ndarray: def context_mask(self):
return np.concatenate([ return np.concatenate([
[False] * self.n_links, # cos [False] * self.env.n_links, # cos
[False] * self.n_links, # sin [False] * self.env.n_links, # sin
[True] * 2, # goal position [True] * 2, # goal position
[False] * self.n_links, # angular velocity [False] * self.env.n_links, # angular velocity
[False] * 3, # goal distance [False] * 3, # goal distance
# self.get_body_com("target"), # only return target to make problem harder # self.get_body_com("target"), # only return target to make problem harder
[False], # step [False], # step
]) ])
# @property @property
# def active_obs(self): def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
# return np.concatenate([ return self.env.sim.data.qpos.flat[:self.env.n_links]
# [True] * self.n_links, # cos, True
# [True] * self.n_links, # sin, True
# [True] * 2, # goal position
# [True] * self.n_links, # angular velocity, True
# [True] * 3, # goal distance
# # self.get_body_com("target"), # only return target to make problem harder
# [False], # step
# ])
@property @property
def current_vel(self) -> Union[float, int, np.ndarray]: def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
return self.sim.data.qvel.flat[:self.n_links] return self.env.sim.data.qvel.flat[:self.env.n_links]
@property
def current_pos(self) -> Union[float, int, np.ndarray]:
return self.sim.data.qpos.flat[:self.n_links]
@property
def dt(self) -> Union[float, int]:
return self.env.dt

View File

@ -1,28 +0,0 @@
from alr_envs.mp.black_box_wrapper import BlackBoxWrapper
from typing import Union, Tuple
import numpy as np
from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper
class MPWrapper(RawInterfaceWrapper):
@property
def context_mask(self):
return np.concatenate([
[False] * self.env.n_links, # cos
[False] * self.env.n_links, # sin
[True] * 2, # goal position
[False] * self.env.n_links, # angular velocity
[False] * 3, # goal distance
# self.get_body_com("target"), # only return target to make problem harder
[False], # step
])
@property
def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
return self.env.sim.data.qpos.flat[:self.env.n_links]
@property
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
return self.env.sim.data.qvel.flat[:self.env.n_links]

View File

@ -11,7 +11,7 @@ from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper
from alr_envs.utils.utils import get_numpy from alr_envs.utils.utils import get_numpy
class BlackBoxWrapper(gym.ObservationWrapper, ABC): class BlackBoxWrapper(gym.ObservationWrapper):
def __init__(self, def __init__(self,
env: RawInterfaceWrapper, env: RawInterfaceWrapper,
@ -34,9 +34,8 @@ class BlackBoxWrapper(gym.ObservationWrapper, ABC):
reward_aggregation: function that takes the np.ndarray of step rewards as input and returns the trajectory reward_aggregation: function that takes the np.ndarray of step rewards as input and returns the trajectory
reward, default summation over all values. reward, default summation over all values.
""" """
super().__init__() super().__init__(env)
self.env = env
self.duration = duration self.duration = duration
self.learn_sub_trajectories = learn_sub_trajectories self.learn_sub_trajectories = learn_sub_trajectories
self.replanning_schedule = replanning_schedule self.replanning_schedule = replanning_schedule