This commit is contained in:
Maximilian Huettenrauch 2021-01-14 17:10:03 +01:00
parent 104281fe16
commit b7400c477d
5 changed files with 47 additions and 6 deletions

View File

@ -1,7 +1,5 @@
import gym
import numpy as np
import matplotlib
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
from matplotlib import patches
@ -112,7 +110,7 @@ class HoleReacher(gym.Env):
if self._is_collided:
reward -= self.collision_penalty
info = {}
info = {"is_collided": self._is_collided}
self._steps += 1
@ -286,6 +284,10 @@ class HoleReacher(gym.Env):
plt.pause(0.01)
def close(self):
if self.fig is not None:
plt.close(self.fig)
if __name__ == '__main__':
nl = 5
@ -306,3 +308,5 @@ if __name__ == '__main__':
if done:
break
env.close()

View File

@ -0,0 +1,34 @@
from alr_envs.classic_control.hole_reacher import HoleReacher
from alr_envs.utils.dmp_env_wrapper import DmpEnvWrapperVel
def make_env(rank, seed=0):
"""
Utility function for multiprocessed env.
:param env_id: (str) the environment ID
:param num_env: (int) the number of environments you wish to have in subprocesses
:param seed: (int) the initial seed for RNG
:param rank: (int) index of the subprocess
:returns a function that generates an environment
"""
def _init():
env = HoleReacher(num_links=5,
allow_self_collision=False,
allow_wall_collision=False,
hole_width=0.15,
hole_depth=1,
hole_x=1,
collision_penalty=100000)
env = DmpEnvWrapperVel(env,
num_dof=5,
num_basis=5,
duration=2,
dt=env._dt,
learn_goal=True)
env.seed(seed + rank)
return env
return _init

View File

@ -96,7 +96,7 @@ class DmpAsyncVectorEnv(gym.vector.AsyncVectorEnv):
# return (deepcopy(self.observations) if self.copy else self.observations,
# np.array(rewards), np.array(dones, dtype=np.bool_), infos)
return np.array(rewards)
return np.array(rewards), infos
def rollout(self, actions):
self.rollout_async(actions)
@ -134,6 +134,7 @@ def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
env.seed(data)
pipe.send((None, True))
elif command == 'close':
env.close()
pipe.send((None, True))
break
elif command == 'idle':

View File

@ -113,18 +113,19 @@ class DmpEnvWrapperVel(DmpEnvWrapperBase):
trajectory, velocities = self.dmp.reference_trajectory(self.t)
rews = []
infos = []
self.env.reset()
for t, vel in enumerate(velocities):
obs, rew, done, info = self.env.step(vel)
rews.append(rew)
infos.append(info)
if render:
self.env.render(mode="human")
if done:
break
reward = np.sum(rews)
info = {}
return obs, reward, done, info

View File

@ -2,5 +2,6 @@ from setuptools import setup
setup(name='alr_envs',
version='0.0.1',
install_requires=['gym', 'PyQt5', 'matplotlib'] # And any other dependencies foo needs
install_requires=['gym', 'PyQt5', 'matplotlib',
'mp_lib @ git+https://git@github.com/maxhuettenrauch/mp_lib@master#egg=mp_lib',], # And any other dependencies foo needs
)