added asyc DMP example
This commit is contained in:
parent
0097fe4f99
commit
dee2fad263
@ -1,4 +1,5 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -49,24 +50,26 @@ class MPWrapper(gym.Wrapper, ABC):
|
|||||||
# self._velocity = velocity
|
# self._velocity = velocity
|
||||||
|
|
||||||
rewards = 0
|
rewards = 0
|
||||||
infos = []
|
# infos = defaultdict(list)
|
||||||
|
|
||||||
# TODO: @Max Why do we need this configure, states should be part of the model
|
# TODO: @Max Why do we need this configure, states should be part of the model
|
||||||
# self.env.configure(context)
|
# self.env.configure(context)
|
||||||
obs = self.env.reset()
|
obs = self.env.reset()
|
||||||
|
info = {}
|
||||||
|
|
||||||
for t, pos_vel in enumerate(zip(trajectory, velocity)):
|
for t, pos_vel in enumerate(zip(trajectory, velocity)):
|
||||||
ac = self.policy.get_action(pos_vel[0], pos_vel[1])
|
ac = self.policy.get_action(pos_vel[0], pos_vel[1])
|
||||||
obs, rew, done, info = self.env.step(ac)
|
obs, rew, done, info = self.env.step(ac)
|
||||||
rewards += rew
|
rewards += rew
|
||||||
infos.append(info)
|
# TODO return all dicts?
|
||||||
|
# [infos[k].append(v) for k, v in info.items()]
|
||||||
if self.render_mode:
|
if self.render_mode:
|
||||||
self.env.render(mode=self.render_mode, **self.render_kwargs)
|
self.env.render(mode=self.render_mode, **self.render_kwargs)
|
||||||
if done:
|
if done:
|
||||||
break
|
break
|
||||||
|
|
||||||
done = True
|
done = True
|
||||||
return obs, rewards, done, infos
|
return obs, rewards, done, info
|
||||||
|
|
||||||
def render(self, mode='human', **kwargs):
|
def render(self, mode='human', **kwargs):
|
||||||
"""Only set render options here, such that they can be used during the rollout.
|
"""Only set render options here, such that they can be used during the rollout.
|
||||||
|
42
example.py
42
example.py
@ -1,4 +1,7 @@
|
|||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def example_mujoco():
|
def example_mujoco():
|
||||||
@ -36,7 +39,7 @@ def example_dmp():
|
|||||||
# render full DMP trajectory
|
# render full DMP trajectory
|
||||||
# render can only be called once in the beginning as well. That would render every trajectory
|
# render can only be called once in the beginning as well. That would render every trajectory
|
||||||
# Calling it after every trajectory allows to modify the mode. mode=None, disables rendering.
|
# Calling it after every trajectory allows to modify the mode. mode=None, disables rendering.
|
||||||
env.render(mode="partial")
|
env.render(mode="human")
|
||||||
|
|
||||||
if done:
|
if done:
|
||||||
print(rewards)
|
print(rewards)
|
||||||
@ -44,5 +47,40 @@ def example_dmp():
|
|||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
|
|
||||||
|
|
||||||
|
def example_async(n_cpu=4, seed=int('533D', 16)):
|
||||||
|
def make_env(env_id, seed, rank):
|
||||||
|
env = gym.make(env_id)
|
||||||
|
env.seed(seed + rank)
|
||||||
|
return lambda: env
|
||||||
|
|
||||||
|
def sample(env: gym.vector.VectorEnv, n_samples=100):
|
||||||
|
# for plotting
|
||||||
|
rewards = np.zeros(n_cpu)
|
||||||
|
|
||||||
|
# this would generate more samples than requested if n_samples % num_envs != 0
|
||||||
|
repeat = int(np.ceil(n_samples / env.num_envs))
|
||||||
|
vals = defaultdict(list)
|
||||||
|
for i in range(repeat):
|
||||||
|
obs, reward, done, info = envs.step(envs.action_space.sample())
|
||||||
|
vals['obs'].append(obs)
|
||||||
|
vals['reward'].append(reward)
|
||||||
|
vals['done'].append(done)
|
||||||
|
vals['info'].append(info)
|
||||||
|
rewards += reward
|
||||||
|
if np.any(done):
|
||||||
|
print(rewards[done])
|
||||||
|
rewards[done] = 0
|
||||||
|
|
||||||
|
# do not return values above threshold
|
||||||
|
return (*map(lambda v: np.stack(v)[:n_samples], vals.values()),)
|
||||||
|
|
||||||
|
envs = gym.vector.AsyncVectorEnv([make_env("alr_envs:HoleReacherDMP-v0", seed, i) for i in range(n_cpu)])
|
||||||
|
|
||||||
|
obs = envs.reset()
|
||||||
|
print(sample(envs, 16))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
example_dmp()
|
# example_mujoco()
|
||||||
|
# example_dmp()
|
||||||
|
example_async()
|
||||||
|
Loading…
Reference in New Issue
Block a user