updated simple reacher example to new structure

This commit is contained in:
ottofabian 2021-06-24 15:24:54 +02:00
parent f3d837349a
commit e1dc3eeddf

View File

@ -6,6 +6,7 @@ from gym import spaces
from gym.utils import seeding from gym.utils import seeding
from mp_env_api.envs.mp_env import MpEnv from mp_env_api.envs.mp_env import MpEnv
from mp_env_api.envs.mp_env_wrapper import MPEnvWrapper
class SimpleReacherEnv(MpEnv): class SimpleReacherEnv(MpEnv):
@ -19,7 +20,7 @@ class SimpleReacherEnv(MpEnv):
super().__init__() super().__init__()
self.link_lengths = np.ones(n_links) self.link_lengths = np.ones(n_links)
self.n_links = n_links self.n_links = n_links
self.dt = 0.1 self._dt = 0.1
self.random_start = random_start self.random_start = random_start
@ -53,6 +54,10 @@ class SimpleReacherEnv(MpEnv):
self._steps = 0 self._steps = 0
self.seed() self.seed()
@property
def dt(self) -> Union[float, int]:
return self._dt
def step(self, action: np.ndarray): def step(self, action: np.ndarray):
""" """
A single step with action in torque space A single step with action in torque space
@ -172,24 +177,6 @@ class SimpleReacherEnv(MpEnv):
self.fig.canvas.draw() self.fig.canvas.draw()
self.fig.canvas.flush_events() self.fig.canvas.flush_events()
@property
def active_obs(self):
return np.hstack([
[self.random_start] * self.n_links, # cos
[self.random_start] * self.n_links, # sin
[self.random_start] * self.n_links, # velocity
[True] * 2, # x-y coordinates of target distance
[False] # env steps
])
@property
def start_pos(self):
return self._start_pos
@property
def goal_pos(self):
raise ValueError("Goal position is not available and has to be learnt based on the environment.")
def seed(self, seed=None): def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed) self.np_random, seed = seeding.np_random(seed)
return [seed] return [seed]
@ -202,24 +189,25 @@ class SimpleReacherEnv(MpEnv):
return self._joints[self.n_links].T return self._joints[self.n_links].T
if __name__ == '__main__': class SimpleReacherMPWrapper(MPEnvWrapper):
nl = 5 @property
render_mode = "human" # "human" or "partial" or "final" def active_obs(self):
env = SimpleReacherEnv(n_links=nl) return np.hstack([
obs = env.reset() [self.env.random_start] * self.env.n_links, # cos
print("First", obs) [self.env.random_start] * self.env.n_links, # sin
[self.env.random_start] * self.env.n_links, # velocity
[True] * 2, # x-y coordinates of target distance
[False] # env steps
])
for i in range(2000): @property
# objective.load_result("/tmp/cma") def start_pos(self):
# test with random actions return self._start_pos
ac = 2 * env.action_space.sample()
# ac = np.ones(env.action_space.shape)
obs, rew, d, info = env.step(ac)
env.render(mode=render_mode)
print(obs[env.active_obs].shape) @property
def goal_pos(self):
raise ValueError("Goal position is not available and has to be learnt based on the environment.")
if d or i % 200 == 0: @property
env.reset() def dt(self) -> Union[float, int]:
return self.env.dt
env.close()