updated simple reacher example to new structure
This commit is contained in:
parent
f3d837349a
commit
e1dc3eeddf
@ -6,6 +6,7 @@ from gym import spaces
|
|||||||
from gym.utils import seeding
|
from gym.utils import seeding
|
||||||
|
|
||||||
from mp_env_api.envs.mp_env import MpEnv
|
from mp_env_api.envs.mp_env import MpEnv
|
||||||
|
from mp_env_api.envs.mp_env_wrapper import MPEnvWrapper
|
||||||
|
|
||||||
|
|
||||||
class SimpleReacherEnv(MpEnv):
|
class SimpleReacherEnv(MpEnv):
|
||||||
@ -19,7 +20,7 @@ class SimpleReacherEnv(MpEnv):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.link_lengths = np.ones(n_links)
|
self.link_lengths = np.ones(n_links)
|
||||||
self.n_links = n_links
|
self.n_links = n_links
|
||||||
self.dt = 0.1
|
self._dt = 0.1
|
||||||
|
|
||||||
self.random_start = random_start
|
self.random_start = random_start
|
||||||
|
|
||||||
@ -53,6 +54,10 @@ class SimpleReacherEnv(MpEnv):
|
|||||||
self._steps = 0
|
self._steps = 0
|
||||||
self.seed()
|
self.seed()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dt(self) -> Union[float, int]:
|
||||||
|
return self._dt
|
||||||
|
|
||||||
def step(self, action: np.ndarray):
|
def step(self, action: np.ndarray):
|
||||||
"""
|
"""
|
||||||
A single step with action in torque space
|
A single step with action in torque space
|
||||||
@ -172,24 +177,6 @@ class SimpleReacherEnv(MpEnv):
|
|||||||
self.fig.canvas.draw()
|
self.fig.canvas.draw()
|
||||||
self.fig.canvas.flush_events()
|
self.fig.canvas.flush_events()
|
||||||
|
|
||||||
@property
|
|
||||||
def active_obs(self):
|
|
||||||
return np.hstack([
|
|
||||||
[self.random_start] * self.n_links, # cos
|
|
||||||
[self.random_start] * self.n_links, # sin
|
|
||||||
[self.random_start] * self.n_links, # velocity
|
|
||||||
[True] * 2, # x-y coordinates of target distance
|
|
||||||
[False] # env steps
|
|
||||||
])
|
|
||||||
|
|
||||||
@property
|
|
||||||
def start_pos(self):
|
|
||||||
return self._start_pos
|
|
||||||
|
|
||||||
@property
|
|
||||||
def goal_pos(self):
|
|
||||||
raise ValueError("Goal position is not available and has to be learnt based on the environment.")
|
|
||||||
|
|
||||||
def seed(self, seed=None):
|
def seed(self, seed=None):
|
||||||
self.np_random, seed = seeding.np_random(seed)
|
self.np_random, seed = seeding.np_random(seed)
|
||||||
return [seed]
|
return [seed]
|
||||||
@ -202,24 +189,25 @@ class SimpleReacherEnv(MpEnv):
|
|||||||
return self._joints[self.n_links].T
|
return self._joints[self.n_links].T
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
class SimpleReacherMPWrapper(MPEnvWrapper):
|
||||||
nl = 5
|
@property
|
||||||
render_mode = "human" # "human" or "partial" or "final"
|
def active_obs(self):
|
||||||
env = SimpleReacherEnv(n_links=nl)
|
return np.hstack([
|
||||||
obs = env.reset()
|
[self.env.random_start] * self.env.n_links, # cos
|
||||||
print("First", obs)
|
[self.env.random_start] * self.env.n_links, # sin
|
||||||
|
[self.env.random_start] * self.env.n_links, # velocity
|
||||||
|
[True] * 2, # x-y coordinates of target distance
|
||||||
|
[False] # env steps
|
||||||
|
])
|
||||||
|
|
||||||
for i in range(2000):
|
@property
|
||||||
# objective.load_result("/tmp/cma")
|
def start_pos(self):
|
||||||
# test with random actions
|
return self._start_pos
|
||||||
ac = 2 * env.action_space.sample()
|
|
||||||
# ac = np.ones(env.action_space.shape)
|
|
||||||
obs, rew, d, info = env.step(ac)
|
|
||||||
env.render(mode=render_mode)
|
|
||||||
|
|
||||||
print(obs[env.active_obs].shape)
|
@property
|
||||||
|
def goal_pos(self):
|
||||||
|
raise ValueError("Goal position is not available and has to be learnt based on the environment.")
|
||||||
|
|
||||||
if d or i % 200 == 0:
|
@property
|
||||||
env.reset()
|
def dt(self) -> Union[float, int]:
|
||||||
|
return self.env.dt
|
||||||
env.close()
|
|
||||||
|
Loading…
Reference in New Issue
Block a user