From dee2fad2637d666b721070c9c360ccb08a32455a Mon Sep 17 00:00:00 2001 From: ottofabian Date: Fri, 26 Mar 2021 16:37:38 +0100 Subject: [PATCH] added asyc DMP example --- alr_envs/utils/wrapper/mp_wrapper.py | 9 ++++-- example.py | 42 ++++++++++++++++++++++++++-- 2 files changed, 46 insertions(+), 5 deletions(-) diff --git a/alr_envs/utils/wrapper/mp_wrapper.py b/alr_envs/utils/wrapper/mp_wrapper.py index 6234f0c..f705643 100644 --- a/alr_envs/utils/wrapper/mp_wrapper.py +++ b/alr_envs/utils/wrapper/mp_wrapper.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from collections import defaultdict import gym import numpy as np @@ -49,24 +50,26 @@ class MPWrapper(gym.Wrapper, ABC): # self._velocity = velocity rewards = 0 - infos = [] + # infos = defaultdict(list) # TODO: @Max Why do we need this configure, states should be part of the model # self.env.configure(context) obs = self.env.reset() + info = {} for t, pos_vel in enumerate(zip(trajectory, velocity)): ac = self.policy.get_action(pos_vel[0], pos_vel[1]) obs, rew, done, info = self.env.step(ac) rewards += rew - infos.append(info) + # TODO return all dicts? + # [infos[k].append(v) for k, v in info.items()] if self.render_mode: self.env.render(mode=self.render_mode, **self.render_kwargs) if done: break done = True - return obs, rewards, done, infos + return obs, rewards, done, info def render(self, mode='human', **kwargs): """Only set render options here, such that they can be used during the rollout. diff --git a/example.py b/example.py index 0167971..0ce713f 100644 --- a/example.py +++ b/example.py @@ -1,4 +1,7 @@ +from collections import defaultdict + import gym +import numpy as np def example_mujoco(): @@ -36,7 +39,7 @@ def example_dmp(): # render full DMP trajectory # render can only be called once in the beginning as well. That would render every trajectory # Calling it after every trajectory allows to modify the mode. mode=None, disables rendering. - env.render(mode="partial") + env.render(mode="human") if done: print(rewards) @@ -44,5 +47,40 @@ def example_dmp(): obs = env.reset() +def example_async(n_cpu=4, seed=int('533D', 16)): + def make_env(env_id, seed, rank): + env = gym.make(env_id) + env.seed(seed + rank) + return lambda: env + + def sample(env: gym.vector.VectorEnv, n_samples=100): + # for plotting + rewards = np.zeros(n_cpu) + + # this would generate more samples than requested if n_samples % num_envs != 0 + repeat = int(np.ceil(n_samples / env.num_envs)) + vals = defaultdict(list) + for i in range(repeat): + obs, reward, done, info = envs.step(envs.action_space.sample()) + vals['obs'].append(obs) + vals['reward'].append(reward) + vals['done'].append(done) + vals['info'].append(info) + rewards += reward + if np.any(done): + print(rewards[done]) + rewards[done] = 0 + + # do not return values above threshold + return (*map(lambda v: np.stack(v)[:n_samples], vals.values()),) + + envs = gym.vector.AsyncVectorEnv([make_env("alr_envs:HoleReacherDMP-v0", seed, i) for i in range(n_cpu)]) + + obs = envs.reset() + print(sample(envs, 16)) + + if __name__ == '__main__': - example_dmp() + # example_mujoco() + # example_dmp() + example_async()