From b7400c477d383108887fb4821f40928b43897a9a Mon Sep 17 00:00:00 2001 From: Maximilian Huettenrauch Date: Thu, 14 Jan 2021 17:10:03 +0100 Subject: [PATCH] updates --- alr_envs/classic_control/hole_reacher.py | 10 ++++--- alr_envs/classic_control/utils.py | 34 ++++++++++++++++++++++++ alr_envs/utils/dmp_async_vec_env.py | 3 ++- alr_envs/utils/dmp_env_wrapper.py | 3 ++- setup.py | 3 ++- 5 files changed, 47 insertions(+), 6 deletions(-) create mode 100644 alr_envs/classic_control/utils.py diff --git a/alr_envs/classic_control/hole_reacher.py b/alr_envs/classic_control/hole_reacher.py index 441e0cf..f19915b 100644 --- a/alr_envs/classic_control/hole_reacher.py +++ b/alr_envs/classic_control/hole_reacher.py @@ -1,7 +1,5 @@ import gym import numpy as np -import matplotlib -matplotlib.use('TkAgg') import matplotlib.pyplot as plt from matplotlib import patches @@ -112,7 +110,7 @@ class HoleReacher(gym.Env): if self._is_collided: reward -= self.collision_penalty - info = {} + info = {"is_collided": self._is_collided} self._steps += 1 @@ -286,6 +284,10 @@ class HoleReacher(gym.Env): plt.pause(0.01) + def close(self): + if self.fig is not None: + plt.close(self.fig) + if __name__ == '__main__': nl = 5 @@ -306,3 +308,5 @@ if __name__ == '__main__': if done: break + + env.close() diff --git a/alr_envs/classic_control/utils.py b/alr_envs/classic_control/utils.py new file mode 100644 index 0000000..61156f1 --- /dev/null +++ b/alr_envs/classic_control/utils.py @@ -0,0 +1,34 @@ +from alr_envs.classic_control.hole_reacher import HoleReacher +from alr_envs.utils.dmp_env_wrapper import DmpEnvWrapperVel + + +def make_env(rank, seed=0): + """ + Utility function for multiprocessed env. + + :param env_id: (str) the environment ID + :param num_env: (int) the number of environments you wish to have in subprocesses + :param seed: (int) the initial seed for RNG + :param rank: (int) index of the subprocess + :returns a function that generates an environment + """ + + def _init(): + env = HoleReacher(num_links=5, + allow_self_collision=False, + allow_wall_collision=False, + hole_width=0.15, + hole_depth=1, + hole_x=1, + collision_penalty=100000) + + env = DmpEnvWrapperVel(env, + num_dof=5, + num_basis=5, + duration=2, + dt=env._dt, + learn_goal=True) + env.seed(seed + rank) + return env + + return _init diff --git a/alr_envs/utils/dmp_async_vec_env.py b/alr_envs/utils/dmp_async_vec_env.py index f576277..56f4ca7 100644 --- a/alr_envs/utils/dmp_async_vec_env.py +++ b/alr_envs/utils/dmp_async_vec_env.py @@ -96,7 +96,7 @@ class DmpAsyncVectorEnv(gym.vector.AsyncVectorEnv): # return (deepcopy(self.observations) if self.copy else self.observations, # np.array(rewards), np.array(dones, dtype=np.bool_), infos) - return np.array(rewards) + return np.array(rewards), infos def rollout(self, actions): self.rollout_async(actions) @@ -134,6 +134,7 @@ def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue): env.seed(data) pipe.send((None, True)) elif command == 'close': + env.close() pipe.send((None, True)) break elif command == 'idle': diff --git a/alr_envs/utils/dmp_env_wrapper.py b/alr_envs/utils/dmp_env_wrapper.py index cf73ae5..9b1bdf1 100644 --- a/alr_envs/utils/dmp_env_wrapper.py +++ b/alr_envs/utils/dmp_env_wrapper.py @@ -113,18 +113,19 @@ class DmpEnvWrapperVel(DmpEnvWrapperBase): trajectory, velocities = self.dmp.reference_trajectory(self.t) rews = [] + infos = [] self.env.reset() for t, vel in enumerate(velocities): obs, rew, done, info = self.env.step(vel) rews.append(rew) + infos.append(info) if render: self.env.render(mode="human") if done: break reward = np.sum(rews) - info = {} return obs, reward, done, info diff --git a/setup.py b/setup.py index 6f30786..88324de 100644 --- a/setup.py +++ b/setup.py @@ -2,5 +2,6 @@ from setuptools import setup setup(name='alr_envs', version='0.0.1', - install_requires=['gym', 'PyQt5', 'matplotlib'] # And any other dependencies foo needs + install_requires=['gym', 'PyQt5', 'matplotlib', + 'mp_lib @ git+https://git@github.com/maxhuettenrauch/mp_lib@master#egg=mp_lib',], # And any other dependencies foo needs )