updates
This commit is contained in:
parent
104281fe16
commit
b7400c477d
@ -1,7 +1,5 @@
|
|||||||
import gym
|
import gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import matplotlib
|
|
||||||
matplotlib.use('TkAgg')
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from matplotlib import patches
|
from matplotlib import patches
|
||||||
|
|
||||||
@ -112,7 +110,7 @@ class HoleReacher(gym.Env):
|
|||||||
if self._is_collided:
|
if self._is_collided:
|
||||||
reward -= self.collision_penalty
|
reward -= self.collision_penalty
|
||||||
|
|
||||||
info = {}
|
info = {"is_collided": self._is_collided}
|
||||||
|
|
||||||
self._steps += 1
|
self._steps += 1
|
||||||
|
|
||||||
@ -286,6 +284,10 @@ class HoleReacher(gym.Env):
|
|||||||
|
|
||||||
plt.pause(0.01)
|
plt.pause(0.01)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
if self.fig is not None:
|
||||||
|
plt.close(self.fig)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
nl = 5
|
nl = 5
|
||||||
@ -306,3 +308,5 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
if done:
|
if done:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
env.close()
|
||||||
|
34
alr_envs/classic_control/utils.py
Normal file
34
alr_envs/classic_control/utils.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
from alr_envs.classic_control.hole_reacher import HoleReacher
|
||||||
|
from alr_envs.utils.dmp_env_wrapper import DmpEnvWrapperVel
|
||||||
|
|
||||||
|
|
||||||
|
def make_env(rank, seed=0):
|
||||||
|
"""
|
||||||
|
Utility function for multiprocessed env.
|
||||||
|
|
||||||
|
:param env_id: (str) the environment ID
|
||||||
|
:param num_env: (int) the number of environments you wish to have in subprocesses
|
||||||
|
:param seed: (int) the initial seed for RNG
|
||||||
|
:param rank: (int) index of the subprocess
|
||||||
|
:returns a function that generates an environment
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _init():
|
||||||
|
env = HoleReacher(num_links=5,
|
||||||
|
allow_self_collision=False,
|
||||||
|
allow_wall_collision=False,
|
||||||
|
hole_width=0.15,
|
||||||
|
hole_depth=1,
|
||||||
|
hole_x=1,
|
||||||
|
collision_penalty=100000)
|
||||||
|
|
||||||
|
env = DmpEnvWrapperVel(env,
|
||||||
|
num_dof=5,
|
||||||
|
num_basis=5,
|
||||||
|
duration=2,
|
||||||
|
dt=env._dt,
|
||||||
|
learn_goal=True)
|
||||||
|
env.seed(seed + rank)
|
||||||
|
return env
|
||||||
|
|
||||||
|
return _init
|
@ -96,7 +96,7 @@ class DmpAsyncVectorEnv(gym.vector.AsyncVectorEnv):
|
|||||||
# return (deepcopy(self.observations) if self.copy else self.observations,
|
# return (deepcopy(self.observations) if self.copy else self.observations,
|
||||||
# np.array(rewards), np.array(dones, dtype=np.bool_), infos)
|
# np.array(rewards), np.array(dones, dtype=np.bool_), infos)
|
||||||
|
|
||||||
return np.array(rewards)
|
return np.array(rewards), infos
|
||||||
|
|
||||||
def rollout(self, actions):
|
def rollout(self, actions):
|
||||||
self.rollout_async(actions)
|
self.rollout_async(actions)
|
||||||
@ -134,6 +134,7 @@ def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
|
|||||||
env.seed(data)
|
env.seed(data)
|
||||||
pipe.send((None, True))
|
pipe.send((None, True))
|
||||||
elif command == 'close':
|
elif command == 'close':
|
||||||
|
env.close()
|
||||||
pipe.send((None, True))
|
pipe.send((None, True))
|
||||||
break
|
break
|
||||||
elif command == 'idle':
|
elif command == 'idle':
|
||||||
|
@ -113,18 +113,19 @@ class DmpEnvWrapperVel(DmpEnvWrapperBase):
|
|||||||
trajectory, velocities = self.dmp.reference_trajectory(self.t)
|
trajectory, velocities = self.dmp.reference_trajectory(self.t)
|
||||||
|
|
||||||
rews = []
|
rews = []
|
||||||
|
infos = []
|
||||||
|
|
||||||
self.env.reset()
|
self.env.reset()
|
||||||
|
|
||||||
for t, vel in enumerate(velocities):
|
for t, vel in enumerate(velocities):
|
||||||
obs, rew, done, info = self.env.step(vel)
|
obs, rew, done, info = self.env.step(vel)
|
||||||
rews.append(rew)
|
rews.append(rew)
|
||||||
|
infos.append(info)
|
||||||
if render:
|
if render:
|
||||||
self.env.render(mode="human")
|
self.env.render(mode="human")
|
||||||
if done:
|
if done:
|
||||||
break
|
break
|
||||||
|
|
||||||
reward = np.sum(rews)
|
reward = np.sum(rews)
|
||||||
info = {}
|
|
||||||
|
|
||||||
return obs, reward, done, info
|
return obs, reward, done, info
|
||||||
|
3
setup.py
3
setup.py
@ -2,5 +2,6 @@ from setuptools import setup
|
|||||||
|
|
||||||
setup(name='alr_envs',
|
setup(name='alr_envs',
|
||||||
version='0.0.1',
|
version='0.0.1',
|
||||||
install_requires=['gym', 'PyQt5', 'matplotlib'] # And any other dependencies foo needs
|
install_requires=['gym', 'PyQt5', 'matplotlib',
|
||||||
|
'mp_lib @ git+https://git@github.com/maxhuettenrauch/mp_lib@master#egg=mp_lib',], # And any other dependencies foo needs
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user