call on superclass for obs wrapper

This commit is contained in:
Fabian 2022-06-30 14:20:52 +02:00
parent 3273f455c5
commit c3a8352c63
8 changed files with 18 additions and 123 deletions

View File

@ -1,10 +1,11 @@
from typing import Iterable, Union
from abc import ABC, abstractmethod
from typing import Union
import gym
import matplotlib.pyplot as plt
import numpy as np
from gym import spaces
from gym.utils import seeding
from alr_envs.alr.classic_control.utils import intersect

View File

@ -1,27 +0,0 @@
from typing import Tuple, Union
import numpy as np
from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper
class MPWrapper(RawInterfaceWrapper):
def get_context_mask(self):
return np.hstack([
[self.env.random_start] * self.env.n_links, # cos
[self.env.random_start] * self.env.n_links, # sin
[self.env.random_start] * self.env.n_links, # velocity
[self.env.initial_width is None], # hole width
# [self.env.hole_depth is None], # hole depth
[True] * 2, # x-y coordinates of target distance
[False] # env steps
])
@property
def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
return self.env.current_pos
@property
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
return self.env.current_vel

View File

@ -6,8 +6,8 @@ from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper
class MPWrapper(RawInterfaceWrapper):
@property
def context_mask(self) -> np.ndarray:
def context_mask(self):
return np.hstack([
[self.env.random_start] * self.env.n_links, # cos
[self.env.random_start] * self.env.n_links, # sin
@ -25,10 +25,6 @@ class MPWrapper(RawInterfaceWrapper):
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
return self.env.current_vel
@property
def goal_pos(self) -> Union[float, int, np.ndarray, Tuple]:
raise ValueError("Goal position is not available and has to be learnt based on the environment.")
@property
def dt(self) -> Union[float, int]:
return self.env.dt

View File

@ -1,30 +0,0 @@
from typing import Tuple, Union
import numpy as np
from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper
class MPWrapper(RawInterfaceWrapper):
def context_mask(self):
return np.hstack([
[self.env.random_start] * self.env.n_links, # cos
[self.env.random_start] * self.env.n_links, # sin
[self.env.random_start] * self.env.n_links, # velocity
[self.env.initial_via_target is None] * 2, # x-y coordinates of via point distance
[True] * 2, # x-y coordinates of target distance
[False] # env steps
])
@property
def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
return self.env.current_pos
@property
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
return self.env.current_vel
@property
def dt(self) -> Union[float, int]:
return self.env.dt

View File

@ -1,2 +1,2 @@
from .mp_wrapper import MPWrapper
from .new_mp_wrapper import MPWrapper as NewMPWrapper
from .mp_wrapper import MPWrapper as NewMPWrapper

View File

@ -1,4 +1,4 @@
from typing import Union
from typing import Union, Tuple
import numpy as np
@ -8,37 +8,21 @@ from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper
class MPWrapper(RawInterfaceWrapper):
@property
def context_mask(self) -> np.ndarray:
def context_mask(self):
return np.concatenate([
[False] * self.n_links, # cos
[False] * self.n_links, # sin
[False] * self.env.n_links, # cos
[False] * self.env.n_links, # sin
[True] * 2, # goal position
[False] * self.n_links, # angular velocity
[False] * self.env.n_links, # angular velocity
[False] * 3, # goal distance
# self.get_body_com("target"), # only return target to make problem harder
[False], # step
])
# @property
# def active_obs(self):
# return np.concatenate([
# [True] * self.n_links, # cos, True
# [True] * self.n_links, # sin, True
# [True] * 2, # goal position
# [True] * self.n_links, # angular velocity, True
# [True] * 3, # goal distance
# # self.get_body_com("target"), # only return target to make problem harder
# [False], # step
# ])
@property
def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
return self.env.sim.data.qpos.flat[:self.env.n_links]
@property
def current_vel(self) -> Union[float, int, np.ndarray]:
return self.sim.data.qvel.flat[:self.n_links]
@property
def current_pos(self) -> Union[float, int, np.ndarray]:
return self.sim.data.qpos.flat[:self.n_links]
@property
def dt(self) -> Union[float, int]:
return self.env.dt
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
return self.env.sim.data.qvel.flat[:self.env.n_links]

View File

@ -1,28 +0,0 @@
from alr_envs.mp.black_box_wrapper import BlackBoxWrapper
from typing import Union, Tuple
import numpy as np
from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper
class MPWrapper(RawInterfaceWrapper):
@property
def context_mask(self):
return np.concatenate([
[False] * self.env.n_links, # cos
[False] * self.env.n_links, # sin
[True] * 2, # goal position
[False] * self.env.n_links, # angular velocity
[False] * 3, # goal distance
# self.get_body_com("target"), # only return target to make problem harder
[False], # step
])
@property
def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
return self.env.sim.data.qpos.flat[:self.env.n_links]
@property
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
return self.env.sim.data.qvel.flat[:self.env.n_links]

View File

@ -11,7 +11,7 @@ from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper
from alr_envs.utils.utils import get_numpy
class BlackBoxWrapper(gym.ObservationWrapper, ABC):
class BlackBoxWrapper(gym.ObservationWrapper):
def __init__(self,
env: RawInterfaceWrapper,
@ -34,9 +34,8 @@ class BlackBoxWrapper(gym.ObservationWrapper, ABC):
reward_aggregation: function that takes the np.ndarray of step rewards as input and returns the trajectory
reward, default summation over all values.
"""
super().__init__()
super().__init__(env)
self.env = env
self.duration = duration
self.learn_sub_trajectories = learn_sub_trajectories
self.replanning_schedule = replanning_schedule