174 lines
6.9 KiB
Python
174 lines
6.9 KiB
Python
import gym
|
|
from gym.error import (AlreadyPendingCallError, NoAsyncCallError)
|
|
from gym.vector.utils import concatenate, create_empty_array
|
|
from gym.vector.async_vector_env import AsyncState
|
|
import numpy as np
|
|
import multiprocessing as mp
|
|
import sys
|
|
|
|
|
|
def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
|
|
assert shared_memory is None
|
|
env = env_fn()
|
|
parent_pipe.close()
|
|
try:
|
|
while True:
|
|
command, data = pipe.recv()
|
|
if command == 'reset':
|
|
observation = env.reset()
|
|
pipe.send((observation, True))
|
|
elif command == 'step':
|
|
observation, reward, done, info = env.step(data)
|
|
if done:
|
|
observation = env.reset()
|
|
pipe.send(((observation, reward, done, info), True))
|
|
elif command == 'rollout':
|
|
rewards = []
|
|
infos = []
|
|
for p, c in zip(*data):
|
|
reward, info = env.rollout(p, c)
|
|
rewards.append(reward)
|
|
infos.append(info)
|
|
pipe.send(((rewards, infos), (True, ) * len(rewards)))
|
|
elif command == 'seed':
|
|
env.seed(data)
|
|
pipe.send((None, True))
|
|
elif command == 'close':
|
|
env.close()
|
|
pipe.send((None, True))
|
|
break
|
|
elif command == 'idle':
|
|
pipe.send((None, True))
|
|
elif command == '_check_observation_space':
|
|
pipe.send((data == env.observation_space, True))
|
|
else:
|
|
raise RuntimeError('Received unknown command `{0}`. Must '
|
|
'be one of {`reset`, `step`, `seed`, `close`, '
|
|
'`_check_observation_space`}.'.format(command))
|
|
except (KeyboardInterrupt, Exception):
|
|
error_queue.put((index,) + sys.exc_info()[:2])
|
|
pipe.send((None, False))
|
|
finally:
|
|
env.close()
|
|
|
|
|
|
class DmpAsyncVectorEnv(gym.vector.AsyncVectorEnv):
|
|
def __init__(self, env_fns, n_samples, observation_space=None, action_space=None,
|
|
shared_memory=False, copy=True, context="spawn", daemon=True, worker=_worker):
|
|
super(DmpAsyncVectorEnv, self).__init__(env_fns,
|
|
observation_space=observation_space,
|
|
action_space=action_space,
|
|
shared_memory=shared_memory,
|
|
copy=copy,
|
|
context=context,
|
|
daemon=daemon,
|
|
worker=worker)
|
|
|
|
# we need to overwrite the number of samples as we may sample more than num_envs
|
|
self.observations = create_empty_array(self.single_observation_space,
|
|
n=n_samples,
|
|
fn=np.zeros)
|
|
|
|
def __call__(self, params, contexts=None):
|
|
return self.rollout(params, contexts)
|
|
|
|
def rollout_async(self, params, contexts):
|
|
"""
|
|
Parameters
|
|
----------
|
|
params : iterable of samples from `action_space`
|
|
List of actions.
|
|
"""
|
|
self._assert_is_running()
|
|
if self._state != AsyncState.DEFAULT:
|
|
raise AlreadyPendingCallError('Calling `rollout_async` while waiting '
|
|
'for a pending call to `{0}` to complete.'.format(
|
|
self._state.value), self._state.value)
|
|
|
|
params = np.atleast_2d(params)
|
|
split_params = np.array_split(params, np.minimum(len(params), self.num_envs))
|
|
if contexts is None:
|
|
split_contexts = np.array_split([None, ] * len(params), np.minimum(len(params), self.num_envs))
|
|
else:
|
|
split_contexts = np.array_split(contexts, np.minimum(len(contexts), self.num_envs))
|
|
|
|
assert np.all([len(p) == len(c) for p, c in zip(split_params, split_contexts)])
|
|
for pipe, param, context in zip(self.parent_pipes, split_params, split_contexts):
|
|
pipe.send(('rollout', (param, context)))
|
|
for pipe in self.parent_pipes[len(split_params):]:
|
|
pipe.send(('idle', None))
|
|
self._state = AsyncState.WAITING_ROLLOUT
|
|
|
|
def rollout_wait(self, timeout=None):
|
|
"""
|
|
Parameters
|
|
----------
|
|
timeout : int or float, optional
|
|
Number of seconds before the call to `step_wait` times out. If
|
|
`None`, the call to `step_wait` never times out.
|
|
|
|
Returns
|
|
-------
|
|
observations : sample from `observation_space`
|
|
A batch of observations from the vectorized environment.
|
|
|
|
rewards : `np.ndarray` instance (dtype `np.float_`)
|
|
A vector of rewards from the vectorized environment.
|
|
|
|
dones : `np.ndarray` instance (dtype `np.bool_`)
|
|
A vector whose entries indicate whether the episode has ended.
|
|
|
|
infos : list of dict
|
|
A list of auxiliary diagnostic information.
|
|
"""
|
|
self._assert_is_running()
|
|
if self._state != AsyncState.WAITING_ROLLOUT:
|
|
raise NoAsyncCallError('Calling `rollout_wait` without any prior call '
|
|
'to `rollout_async`.', AsyncState.WAITING_ROLLOUT.value)
|
|
|
|
if not self._poll(timeout):
|
|
self._state = AsyncState.DEFAULT
|
|
raise mp.TimeoutError('The call to `rollout_wait` has timed out after '
|
|
'{0} second{1}.'.format(timeout, 's' if timeout > 1 else ''))
|
|
|
|
results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
|
|
results = [r for r in results if r is not None]
|
|
self._raise_if_errors(successes)
|
|
self._state = AsyncState.DEFAULT
|
|
|
|
rewards, infos = [_flatten_list(r) for r in zip(*results)]
|
|
|
|
# for now, we ignore the observations and only return the rewards
|
|
|
|
# if not self.shared_memory:
|
|
# self.observations = concatenate(observations_list, self.observations,
|
|
# self.single_observation_space)
|
|
|
|
# return (deepcopy(self.observations) if self.copy else self.observations,
|
|
# np.array(rewards), np.array(dones, dtype=np.bool_), infos)
|
|
|
|
return np.array(rewards), infos
|
|
|
|
def rollout(self, actions, contexts):
|
|
self.rollout_async(actions, contexts)
|
|
return self.rollout_wait()
|
|
|
|
|
|
def _flatten_obs(obs):
|
|
assert isinstance(obs, (list, tuple))
|
|
assert len(obs) > 0
|
|
|
|
if isinstance(obs[0], dict):
|
|
keys = obs[0].keys()
|
|
return {k: np.stack([o[k] for o in obs]) for k in keys}
|
|
else:
|
|
return np.stack(obs)
|
|
|
|
|
|
def _flatten_list(l):
|
|
assert isinstance(l, (list, tuple))
|
|
assert len(l) > 0
|
|
assert all([len(l_) > 0 for l_ in l])
|
|
|
|
return [l__ for l_ in l for l__ in l_]
|