This commit is contained in:
Maximilian Huettenrauch 2021-01-14 17:10:03 +01:00
parent 104281fe16
commit b7400c477d
5 changed files with 47 additions and 6 deletions

View File

@ -1,7 +1,5 @@
import gym import gym
import numpy as np import numpy as np
import matplotlib
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from matplotlib import patches from matplotlib import patches
@ -112,7 +110,7 @@ class HoleReacher(gym.Env):
if self._is_collided: if self._is_collided:
reward -= self.collision_penalty reward -= self.collision_penalty
info = {} info = {"is_collided": self._is_collided}
self._steps += 1 self._steps += 1
@ -286,6 +284,10 @@ class HoleReacher(gym.Env):
plt.pause(0.01) plt.pause(0.01)
def close(self):
if self.fig is not None:
plt.close(self.fig)
if __name__ == '__main__': if __name__ == '__main__':
nl = 5 nl = 5
@ -306,3 +308,5 @@ if __name__ == '__main__':
if done: if done:
break break
env.close()

View File

@ -0,0 +1,34 @@
from alr_envs.classic_control.hole_reacher import HoleReacher
from alr_envs.utils.dmp_env_wrapper import DmpEnvWrapperVel
def make_env(rank, seed=0):
"""
Utility function for multiprocessed env.
:param env_id: (str) the environment ID
:param num_env: (int) the number of environments you wish to have in subprocesses
:param seed: (int) the initial seed for RNG
:param rank: (int) index of the subprocess
:returns a function that generates an environment
"""
def _init():
env = HoleReacher(num_links=5,
allow_self_collision=False,
allow_wall_collision=False,
hole_width=0.15,
hole_depth=1,
hole_x=1,
collision_penalty=100000)
env = DmpEnvWrapperVel(env,
num_dof=5,
num_basis=5,
duration=2,
dt=env._dt,
learn_goal=True)
env.seed(seed + rank)
return env
return _init

View File

@ -96,7 +96,7 @@ class DmpAsyncVectorEnv(gym.vector.AsyncVectorEnv):
# return (deepcopy(self.observations) if self.copy else self.observations, # return (deepcopy(self.observations) if self.copy else self.observations,
# np.array(rewards), np.array(dones, dtype=np.bool_), infos) # np.array(rewards), np.array(dones, dtype=np.bool_), infos)
return np.array(rewards) return np.array(rewards), infos
def rollout(self, actions): def rollout(self, actions):
self.rollout_async(actions) self.rollout_async(actions)
@ -134,6 +134,7 @@ def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
env.seed(data) env.seed(data)
pipe.send((None, True)) pipe.send((None, True))
elif command == 'close': elif command == 'close':
env.close()
pipe.send((None, True)) pipe.send((None, True))
break break
elif command == 'idle': elif command == 'idle':

View File

@ -113,18 +113,19 @@ class DmpEnvWrapperVel(DmpEnvWrapperBase):
trajectory, velocities = self.dmp.reference_trajectory(self.t) trajectory, velocities = self.dmp.reference_trajectory(self.t)
rews = [] rews = []
infos = []
self.env.reset() self.env.reset()
for t, vel in enumerate(velocities): for t, vel in enumerate(velocities):
obs, rew, done, info = self.env.step(vel) obs, rew, done, info = self.env.step(vel)
rews.append(rew) rews.append(rew)
infos.append(info)
if render: if render:
self.env.render(mode="human") self.env.render(mode="human")
if done: if done:
break break
reward = np.sum(rews) reward = np.sum(rews)
info = {}
return obs, reward, done, info return obs, reward, done, info

View File

@ -2,5 +2,6 @@ from setuptools import setup
setup(name='alr_envs', setup(name='alr_envs',
version='0.0.1', version='0.0.1',
install_requires=['gym', 'PyQt5', 'matplotlib'] # And any other dependencies foo needs install_requires=['gym', 'PyQt5', 'matplotlib',
'mp_lib @ git+https://git@github.com/maxhuettenrauch/mp_lib@master#egg=mp_lib',], # And any other dependencies foo needs
) )