2021-05-12 09:52:25 +02:00
|
|
|
from abc import abstractmethod
|
|
|
|
from typing import Union
|
|
|
|
|
|
|
|
import gym
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
2021-05-21 15:44:49 +02:00
|
|
|
class AlrEnv(gym.Env):
|
2021-05-12 09:52:25 +02:00
|
|
|
|
|
|
|
@property
|
|
|
|
@abstractmethod
|
2021-05-12 17:48:57 +02:00
|
|
|
def active_obs(self):
|
2021-05-18 10:39:30 +02:00
|
|
|
"""Returns boolean mask for each observation entry
|
|
|
|
whether the observation is returned for the contextual case or not.
|
2021-05-12 09:52:25 +02:00
|
|
|
This effectively allows to filter unwanted or unnecessary observations from the full step-based case.
|
|
|
|
"""
|
2021-05-18 10:39:30 +02:00
|
|
|
return np.ones(self.observation_space.shape, dtype=bool)
|
2021-05-12 09:52:25 +02:00
|
|
|
|
|
|
|
@property
|
|
|
|
@abstractmethod
|
|
|
|
def start_pos(self) -> Union[float, int, np.ndarray]:
|
|
|
|
"""
|
2021-05-18 10:39:30 +02:00
|
|
|
Returns the starting position of the joints
|
2021-05-12 09:52:25 +02:00
|
|
|
"""
|
|
|
|
raise NotImplementedError()
|
|
|
|
|
|
|
|
@property
|
|
|
|
def goal_pos(self) -> Union[float, int, np.ndarray]:
|
|
|
|
"""
|
|
|
|
Returns the current final position of the joints for the MP.
|
|
|
|
By default this returns the starting position.
|
|
|
|
"""
|
|
|
|
return self.start_pos
|