call on superclass for obs wrapper
This commit is contained in:
parent
3273f455c5
commit
c3a8352c63
@ -1,10 +1,11 @@
|
|||||||
from typing import Iterable, Union
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from gym import spaces
|
from gym import spaces
|
||||||
from gym.utils import seeding
|
from gym.utils import seeding
|
||||||
|
|
||||||
from alr_envs.alr.classic_control.utils import intersect
|
from alr_envs.alr.classic_control.utils import intersect
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,27 +0,0 @@
|
|||||||
from typing import Tuple, Union
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper
|
|
||||||
|
|
||||||
|
|
||||||
class MPWrapper(RawInterfaceWrapper):
|
|
||||||
|
|
||||||
def get_context_mask(self):
|
|
||||||
return np.hstack([
|
|
||||||
[self.env.random_start] * self.env.n_links, # cos
|
|
||||||
[self.env.random_start] * self.env.n_links, # sin
|
|
||||||
[self.env.random_start] * self.env.n_links, # velocity
|
|
||||||
[self.env.initial_width is None], # hole width
|
|
||||||
# [self.env.hole_depth is None], # hole depth
|
|
||||||
[True] * 2, # x-y coordinates of target distance
|
|
||||||
[False] # env steps
|
|
||||||
])
|
|
||||||
|
|
||||||
@property
|
|
||||||
def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
|
|
||||||
return self.env.current_pos
|
|
||||||
|
|
||||||
@property
|
|
||||||
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
|
|
||||||
return self.env.current_vel
|
|
@ -6,8 +6,8 @@ from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper
|
|||||||
|
|
||||||
|
|
||||||
class MPWrapper(RawInterfaceWrapper):
|
class MPWrapper(RawInterfaceWrapper):
|
||||||
@property
|
|
||||||
def context_mask(self) -> np.ndarray:
|
def context_mask(self):
|
||||||
return np.hstack([
|
return np.hstack([
|
||||||
[self.env.random_start] * self.env.n_links, # cos
|
[self.env.random_start] * self.env.n_links, # cos
|
||||||
[self.env.random_start] * self.env.n_links, # sin
|
[self.env.random_start] * self.env.n_links, # sin
|
||||||
@ -25,10 +25,6 @@ class MPWrapper(RawInterfaceWrapper):
|
|||||||
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
|
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||||
return self.env.current_vel
|
return self.env.current_vel
|
||||||
|
|
||||||
@property
|
|
||||||
def goal_pos(self) -> Union[float, int, np.ndarray, Tuple]:
|
|
||||||
raise ValueError("Goal position is not available and has to be learnt based on the environment.")
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dt(self) -> Union[float, int]:
|
def dt(self) -> Union[float, int]:
|
||||||
return self.env.dt
|
return self.env.dt
|
||||||
|
@ -1,30 +0,0 @@
|
|||||||
from typing import Tuple, Union
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper
|
|
||||||
|
|
||||||
|
|
||||||
class MPWrapper(RawInterfaceWrapper):
|
|
||||||
|
|
||||||
def context_mask(self):
|
|
||||||
return np.hstack([
|
|
||||||
[self.env.random_start] * self.env.n_links, # cos
|
|
||||||
[self.env.random_start] * self.env.n_links, # sin
|
|
||||||
[self.env.random_start] * self.env.n_links, # velocity
|
|
||||||
[self.env.initial_via_target is None] * 2, # x-y coordinates of via point distance
|
|
||||||
[True] * 2, # x-y coordinates of target distance
|
|
||||||
[False] # env steps
|
|
||||||
])
|
|
||||||
|
|
||||||
@property
|
|
||||||
def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
|
|
||||||
return self.env.current_pos
|
|
||||||
|
|
||||||
@property
|
|
||||||
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
|
|
||||||
return self.env.current_vel
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dt(self) -> Union[float, int]:
|
|
||||||
return self.env.dt
|
|
@ -1,2 +1,2 @@
|
|||||||
from .mp_wrapper import MPWrapper
|
from .mp_wrapper import MPWrapper
|
||||||
from .new_mp_wrapper import MPWrapper as NewMPWrapper
|
from .mp_wrapper import MPWrapper as NewMPWrapper
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Union
|
from typing import Union, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -8,37 +8,21 @@ from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper
|
|||||||
class MPWrapper(RawInterfaceWrapper):
|
class MPWrapper(RawInterfaceWrapper):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def context_mask(self) -> np.ndarray:
|
def context_mask(self):
|
||||||
return np.concatenate([
|
return np.concatenate([
|
||||||
[False] * self.n_links, # cos
|
[False] * self.env.n_links, # cos
|
||||||
[False] * self.n_links, # sin
|
[False] * self.env.n_links, # sin
|
||||||
[True] * 2, # goal position
|
[True] * 2, # goal position
|
||||||
[False] * self.n_links, # angular velocity
|
[False] * self.env.n_links, # angular velocity
|
||||||
[False] * 3, # goal distance
|
[False] * 3, # goal distance
|
||||||
# self.get_body_com("target"), # only return target to make problem harder
|
# self.get_body_com("target"), # only return target to make problem harder
|
||||||
[False], # step
|
[False], # step
|
||||||
])
|
])
|
||||||
|
|
||||||
# @property
|
@property
|
||||||
# def active_obs(self):
|
def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||||
# return np.concatenate([
|
return self.env.sim.data.qpos.flat[:self.env.n_links]
|
||||||
# [True] * self.n_links, # cos, True
|
|
||||||
# [True] * self.n_links, # sin, True
|
|
||||||
# [True] * 2, # goal position
|
|
||||||
# [True] * self.n_links, # angular velocity, True
|
|
||||||
# [True] * 3, # goal distance
|
|
||||||
# # self.get_body_com("target"), # only return target to make problem harder
|
|
||||||
# [False], # step
|
|
||||||
# ])
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def current_vel(self) -> Union[float, int, np.ndarray]:
|
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||||
return self.sim.data.qvel.flat[:self.n_links]
|
return self.env.sim.data.qvel.flat[:self.env.n_links]
|
||||||
|
|
||||||
@property
|
|
||||||
def current_pos(self) -> Union[float, int, np.ndarray]:
|
|
||||||
return self.sim.data.qpos.flat[:self.n_links]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dt(self) -> Union[float, int]:
|
|
||||||
return self.env.dt
|
|
||||||
|
@ -1,28 +0,0 @@
|
|||||||
from alr_envs.mp.black_box_wrapper import BlackBoxWrapper
|
|
||||||
from typing import Union, Tuple
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper
|
|
||||||
|
|
||||||
|
|
||||||
class MPWrapper(RawInterfaceWrapper):
|
|
||||||
|
|
||||||
@property
|
|
||||||
def context_mask(self):
|
|
||||||
return np.concatenate([
|
|
||||||
[False] * self.env.n_links, # cos
|
|
||||||
[False] * self.env.n_links, # sin
|
|
||||||
[True] * 2, # goal position
|
|
||||||
[False] * self.env.n_links, # angular velocity
|
|
||||||
[False] * 3, # goal distance
|
|
||||||
# self.get_body_com("target"), # only return target to make problem harder
|
|
||||||
[False], # step
|
|
||||||
])
|
|
||||||
|
|
||||||
@property
|
|
||||||
def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
|
|
||||||
return self.env.sim.data.qpos.flat[:self.env.n_links]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
|
|
||||||
return self.env.sim.data.qvel.flat[:self.env.n_links]
|
|
@ -11,7 +11,7 @@ from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper
|
|||||||
from alr_envs.utils.utils import get_numpy
|
from alr_envs.utils.utils import get_numpy
|
||||||
|
|
||||||
|
|
||||||
class BlackBoxWrapper(gym.ObservationWrapper, ABC):
|
class BlackBoxWrapper(gym.ObservationWrapper):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
env: RawInterfaceWrapper,
|
env: RawInterfaceWrapper,
|
||||||
@ -34,9 +34,8 @@ class BlackBoxWrapper(gym.ObservationWrapper, ABC):
|
|||||||
reward_aggregation: function that takes the np.ndarray of step rewards as input and returns the trajectory
|
reward_aggregation: function that takes the np.ndarray of step rewards as input and returns the trajectory
|
||||||
reward, default summation over all values.
|
reward, default summation over all values.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__(env)
|
||||||
|
|
||||||
self.env = env
|
|
||||||
self.duration = duration
|
self.duration = duration
|
||||||
self.learn_sub_trajectories = learn_sub_trajectories
|
self.learn_sub_trajectories = learn_sub_trajectories
|
||||||
self.replanning_schedule = replanning_schedule
|
self.replanning_schedule = replanning_schedule
|
||||||
|
Loading…
Reference in New Issue
Block a user