updates
This commit is contained in:
parent
104281fe16
commit
b7400c477d
@ -1,7 +1,5 @@
|
||||
import gym
|
||||
import numpy as np
|
||||
import matplotlib
|
||||
matplotlib.use('TkAgg')
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib import patches
|
||||
|
||||
@ -112,7 +110,7 @@ class HoleReacher(gym.Env):
|
||||
if self._is_collided:
|
||||
reward -= self.collision_penalty
|
||||
|
||||
info = {}
|
||||
info = {"is_collided": self._is_collided}
|
||||
|
||||
self._steps += 1
|
||||
|
||||
@ -286,6 +284,10 @@ class HoleReacher(gym.Env):
|
||||
|
||||
plt.pause(0.01)
|
||||
|
||||
def close(self):
|
||||
if self.fig is not None:
|
||||
plt.close(self.fig)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
nl = 5
|
||||
@ -306,3 +308,5 @@ if __name__ == '__main__':
|
||||
|
||||
if done:
|
||||
break
|
||||
|
||||
env.close()
|
||||
|
34
alr_envs/classic_control/utils.py
Normal file
34
alr_envs/classic_control/utils.py
Normal file
@ -0,0 +1,34 @@
|
||||
from alr_envs.classic_control.hole_reacher import HoleReacher
|
||||
from alr_envs.utils.dmp_env_wrapper import DmpEnvWrapperVel
|
||||
|
||||
|
||||
def make_env(rank, seed=0):
|
||||
"""
|
||||
Utility function for multiprocessed env.
|
||||
|
||||
:param env_id: (str) the environment ID
|
||||
:param num_env: (int) the number of environments you wish to have in subprocesses
|
||||
:param seed: (int) the initial seed for RNG
|
||||
:param rank: (int) index of the subprocess
|
||||
:returns a function that generates an environment
|
||||
"""
|
||||
|
||||
def _init():
|
||||
env = HoleReacher(num_links=5,
|
||||
allow_self_collision=False,
|
||||
allow_wall_collision=False,
|
||||
hole_width=0.15,
|
||||
hole_depth=1,
|
||||
hole_x=1,
|
||||
collision_penalty=100000)
|
||||
|
||||
env = DmpEnvWrapperVel(env,
|
||||
num_dof=5,
|
||||
num_basis=5,
|
||||
duration=2,
|
||||
dt=env._dt,
|
||||
learn_goal=True)
|
||||
env.seed(seed + rank)
|
||||
return env
|
||||
|
||||
return _init
|
@ -96,7 +96,7 @@ class DmpAsyncVectorEnv(gym.vector.AsyncVectorEnv):
|
||||
# return (deepcopy(self.observations) if self.copy else self.observations,
|
||||
# np.array(rewards), np.array(dones, dtype=np.bool_), infos)
|
||||
|
||||
return np.array(rewards)
|
||||
return np.array(rewards), infos
|
||||
|
||||
def rollout(self, actions):
|
||||
self.rollout_async(actions)
|
||||
@ -134,6 +134,7 @@ def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
|
||||
env.seed(data)
|
||||
pipe.send((None, True))
|
||||
elif command == 'close':
|
||||
env.close()
|
||||
pipe.send((None, True))
|
||||
break
|
||||
elif command == 'idle':
|
||||
|
@ -113,18 +113,19 @@ class DmpEnvWrapperVel(DmpEnvWrapperBase):
|
||||
trajectory, velocities = self.dmp.reference_trajectory(self.t)
|
||||
|
||||
rews = []
|
||||
infos = []
|
||||
|
||||
self.env.reset()
|
||||
|
||||
for t, vel in enumerate(velocities):
|
||||
obs, rew, done, info = self.env.step(vel)
|
||||
rews.append(rew)
|
||||
infos.append(info)
|
||||
if render:
|
||||
self.env.render(mode="human")
|
||||
if done:
|
||||
break
|
||||
|
||||
reward = np.sum(rews)
|
||||
info = {}
|
||||
|
||||
return obs, reward, done, info
|
||||
|
3
setup.py
3
setup.py
@ -2,5 +2,6 @@ from setuptools import setup
|
||||
|
||||
setup(name='alr_envs',
|
||||
version='0.0.1',
|
||||
install_requires=['gym', 'PyQt5', 'matplotlib'] # And any other dependencies foo needs
|
||||
install_requires=['gym', 'PyQt5', 'matplotlib',
|
||||
'mp_lib @ git+https://git@github.com/maxhuettenrauch/mp_lib@master#egg=mp_lib',], # And any other dependencies foo needs
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user