added rendering to DMC envs and updated examples
This commit is contained in:
parent
7c04b25eec
commit
eae149f838
@ -6,19 +6,24 @@ def example_dmc(env_name="fish-swim", seed=1, iterations=1000):
|
|||||||
env = make_env(env_name, seed)
|
env = make_env(env_name, seed)
|
||||||
rewards = 0
|
rewards = 0
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
print(obs)
|
print("observation shape:", env.observation_space.shape)
|
||||||
|
print("action shape:", env.action_space.shape)
|
||||||
|
|
||||||
# number of samples(multiple environment steps)
|
# number of samples(multiple environment steps)
|
||||||
for i in range(10):
|
for i in range(iterations):
|
||||||
ac = env.action_space.sample()
|
ac = env.action_space.sample()
|
||||||
obs, reward, done, info = env.step(ac)
|
obs, reward, done, info = env.step(ac)
|
||||||
rewards += reward
|
rewards += reward
|
||||||
|
|
||||||
|
env.render("human")
|
||||||
|
|
||||||
if done:
|
if done:
|
||||||
print(rewards)
|
print(env_name, rewards)
|
||||||
rewards = 0
|
rewards = 0
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
|
|
||||||
|
env.close()
|
||||||
|
|
||||||
|
|
||||||
def example_custom_dmc_and_mp(seed=1):
|
def example_custom_dmc_and_mp(seed=1):
|
||||||
"""
|
"""
|
||||||
@ -50,12 +55,13 @@ def example_custom_dmc_and_mp(seed=1):
|
|||||||
"weights_scale": 50,
|
"weights_scale": 50,
|
||||||
"goal_scale": 0.1
|
"goal_scale": 0.1
|
||||||
}
|
}
|
||||||
env = make_dmp_env(base_env, wrappers=wrappers, seed=seed, **mp_kwargs)
|
env = make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs)
|
||||||
# OR for a deterministic ProMP:
|
# OR for a deterministic ProMP:
|
||||||
# env = make_detpmp_env(base_env, wrappers=wrappers, seed=seed, **mp_args)
|
# env = make_detpmp_env(base_env, wrappers=wrappers, seed=seed, **mp_args)
|
||||||
|
|
||||||
rewards = 0
|
rewards = 0
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
|
env.render("human")
|
||||||
|
|
||||||
# number of samples/full trajectories (multiple environment steps)
|
# number of samples/full trajectories (multiple environment steps)
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
@ -64,17 +70,26 @@ def example_custom_dmc_and_mp(seed=1):
|
|||||||
rewards += reward
|
rewards += reward
|
||||||
|
|
||||||
if done:
|
if done:
|
||||||
print(rewards)
|
print(base_env, rewards)
|
||||||
rewards = 0
|
rewards = 0
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
|
|
||||||
|
env.close()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# Disclaimer: DMC environments require the seed to be specified in the beginning.
|
# Disclaimer: DMC environments require the seed to be specified in the beginning.
|
||||||
# Adjusting it afterwards with env.seed() is not recommended as it does not affect the underlying physics.
|
# Adjusting it afterwards with env.seed() is not recommended as it does not affect the underlying physics.
|
||||||
|
|
||||||
# Standard DMC task
|
# For rendering DMC
|
||||||
example_dmc("fish_swim", seed=10, iterations=1000)
|
# export MUJOCO_GL="osmesa"
|
||||||
|
|
||||||
|
# Standard DMC Suite tasks
|
||||||
|
example_dmc("fish-swim", seed=10, iterations=100)
|
||||||
|
|
||||||
|
# Manipulation tasks
|
||||||
|
# The vision versions are currently not integrated
|
||||||
|
example_dmc("manipulation-reach_site_features", seed=10, iterations=100)
|
||||||
|
|
||||||
# Gym + DMC hybrid task provided in the MP framework
|
# Gym + DMC hybrid task provided in the MP framework
|
||||||
example_dmc("dmc_ball_in_cup_dmp-v0", seed=10, iterations=10)
|
example_dmc("dmc_ball_in_cup_dmp-v0", seed=10, iterations=10)
|
||||||
|
@ -8,7 +8,7 @@ from alr_envs.utils.make_env_helpers import make_env
|
|||||||
from alr_envs.utils.mp_env_async_sampler import AlrContextualMpEnvSampler, AlrMpEnvSampler, DummyDist
|
from alr_envs.utils.mp_env_async_sampler import AlrContextualMpEnvSampler, AlrMpEnvSampler, DummyDist
|
||||||
|
|
||||||
|
|
||||||
def example_general(env_id='alr_envs:ALRReacher-v0', seed=1):
|
def example_general(env_id: str, seed=1, iterations=1000):
|
||||||
"""
|
"""
|
||||||
Example for running any env in the step based setting.
|
Example for running any env in the step based setting.
|
||||||
This also includes DMC environments when leveraging our custom make_env function.
|
This also includes DMC environments when leveraging our custom make_env function.
|
||||||
@ -17,16 +17,16 @@ def example_general(env_id='alr_envs:ALRReacher-v0', seed=1):
|
|||||||
env = make_env(env_id, seed)
|
env = make_env(env_id, seed)
|
||||||
rewards = 0
|
rewards = 0
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
print("Observation shape: ", obs.shape)
|
print("Observation shape: ", env.observation_space.shape)
|
||||||
print("Action shape: ", env.action_space.shape)
|
print("Action shape: ", env.action_space.shape)
|
||||||
|
|
||||||
# number of environment steps
|
# number of environment steps
|
||||||
for i in range(10000):
|
for i in range(iterations):
|
||||||
obs, reward, done, info = env.step(env.action_space.sample())
|
obs, reward, done, info = env.step(env.action_space.sample())
|
||||||
rewards += reward
|
rewards += reward
|
||||||
|
|
||||||
# if i % 1 == 0:
|
if i % 1 == 0:
|
||||||
# env.render()
|
env.render()
|
||||||
|
|
||||||
if done:
|
if done:
|
||||||
print(rewards)
|
print(rewards)
|
||||||
@ -65,10 +65,5 @@ def example_async(env_id="alr_envs:HoleReacherDMP-v0", n_cpu=4, seed=int('533D',
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# DMC
|
# Mujoco task from framework
|
||||||
# example_general("fish-swim")
|
example_general("alr_envs:ALRReacher-v0")
|
||||||
|
|
||||||
# custom mujoco env
|
|
||||||
# example_general("alr_envs:ALRReacher-v0")
|
|
||||||
|
|
||||||
example_general("ball_in_cup-catch")
|
|
||||||
|
@ -83,12 +83,17 @@ def example_custom_mp(seed=1):
|
|||||||
"weights_scale": 50,
|
"weights_scale": 50,
|
||||||
"goal_scale": 0.1
|
"goal_scale": 0.1
|
||||||
}
|
}
|
||||||
env = make_dmp_env(base_env, wrappers=wrappers, seed=seed, **mp_kwargs)
|
env = make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs)
|
||||||
# OR for a deterministic ProMP:
|
# OR for a deterministic ProMP:
|
||||||
# env = make_detpmp_env(base_env, wrappers=wrappers, seed=seed)
|
# env = make_detpmp_env(base_env, wrappers=wrappers, seed=seed)
|
||||||
|
|
||||||
rewards = 0
|
rewards = 0
|
||||||
# env.render(mode=None)
|
# render full DMP trajectory
|
||||||
|
# It is only required to call render() once in the beginning, which renders every consecutive trajectory.
|
||||||
|
# Resetting to no rendering, can be achieved by render(mode=None).
|
||||||
|
# It is also possible to change them mode multiple times when
|
||||||
|
# e.g. only every nth trajectory should be displayed.
|
||||||
|
env.render(mode="human")
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
|
|
||||||
# number of samples/full trajectories (multiple environment steps)
|
# number of samples/full trajectories (multiple environment steps)
|
||||||
@ -97,12 +102,6 @@ def example_custom_mp(seed=1):
|
|||||||
obs, reward, done, info = env.step(ac)
|
obs, reward, done, info = env.step(ac)
|
||||||
rewards += reward
|
rewards += reward
|
||||||
|
|
||||||
if i % 1 == 0:
|
|
||||||
# render full DMP trajectory
|
|
||||||
# render can only be called once in the beginning as well. That would render every trajectory
|
|
||||||
# Calling it after every trajectory allows to modify the mode. mode=None, disables rendering.
|
|
||||||
env.render(mode="human")
|
|
||||||
|
|
||||||
if done:
|
if done:
|
||||||
print(rewards)
|
print(rewards)
|
||||||
rewards = 0
|
rewards = 0
|
||||||
|
@ -26,7 +26,7 @@ def make_contextual_env(rank, seed=0):
|
|||||||
return _init
|
return _init
|
||||||
|
|
||||||
|
|
||||||
def make_env(rank, seed=0):
|
def _make_env(rank, seed=0):
|
||||||
"""
|
"""
|
||||||
Utility function for multiprocessed env.
|
Utility function for multiprocessed env.
|
||||||
|
|
||||||
|
@ -26,7 +26,7 @@ def make_contextual_env(rank, seed=0):
|
|||||||
return _init
|
return _init
|
||||||
|
|
||||||
|
|
||||||
def make_env(rank, seed=0):
|
def _make_env(rank, seed=0):
|
||||||
"""
|
"""
|
||||||
Utility function for multiprocessed env.
|
Utility function for multiprocessed env.
|
||||||
|
|
||||||
|
@ -1,11 +1,12 @@
|
|||||||
# Adopted from: https://github.com/denisyarats/dmc2gym/blob/master/dmc2gym/wrappers.py
|
# Adopted from: https://github.com/denisyarats/dmc2gym/blob/master/dmc2gym/wrappers.py
|
||||||
# License: MIT
|
# License: MIT
|
||||||
# Copyright (c) 2020 Denis Yarats
|
# Copyright (c) 2020 Denis Yarats
|
||||||
import matplotlib.pyplot as plt
|
from typing import Any, Dict, Tuple
|
||||||
from gym import core, spaces
|
|
||||||
from dm_control import suite, manipulation
|
|
||||||
from dm_env import specs
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from dm_control import manipulation, suite
|
||||||
|
from dm_env import specs
|
||||||
|
from gym import core, spaces
|
||||||
|
|
||||||
|
|
||||||
def _spec_to_box(spec):
|
def _spec_to_box(spec):
|
||||||
@ -43,8 +44,8 @@ class DMCWrapper(core.Env):
|
|||||||
self,
|
self,
|
||||||
domain_name,
|
domain_name,
|
||||||
task_name,
|
task_name,
|
||||||
task_kwargs=None,
|
task_kwargs={},
|
||||||
visualize_reward={},
|
visualize_reward=True,
|
||||||
from_pixels=False,
|
from_pixels=False,
|
||||||
height=84,
|
height=84,
|
||||||
width=84,
|
width=84,
|
||||||
@ -65,49 +66,23 @@ class DMCWrapper(core.Env):
|
|||||||
if domain_name == "manipulation":
|
if domain_name == "manipulation":
|
||||||
assert not from_pixels, \
|
assert not from_pixels, \
|
||||||
"TODO: Vision interface for manipulation is different to suite and needs to be implemented"
|
"TODO: Vision interface for manipulation is different to suite and needs to be implemented"
|
||||||
self._env = manipulation.load(
|
self._env = manipulation.load(environment_name=task_name, seed=task_kwargs['random'])
|
||||||
environment_name=task_name,
|
|
||||||
seed=task_kwargs['random']
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
self._env = suite.load(
|
self._env = suite.load(domain_name=domain_name, task_name=task_name, task_kwargs=task_kwargs,
|
||||||
domain_name=domain_name,
|
visualize_reward=visualize_reward, environment_kwargs=environment_kwargs)
|
||||||
task_name=task_name,
|
|
||||||
task_kwargs=task_kwargs,
|
|
||||||
visualize_reward=visualize_reward,
|
|
||||||
environment_kwargs=environment_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
# true and normalized action spaces
|
# action and observation space
|
||||||
self._true_action_space = _spec_to_box([self._env.action_spec()])
|
self._action_space = _spec_to_box([self._env.action_spec()])
|
||||||
self._norm_action_space = spaces.Box(
|
self._observation_space = _spec_to_box(self._env.observation_spec().values())
|
||||||
low=-1.0,
|
|
||||||
high=1.0,
|
|
||||||
shape=self._true_action_space.shape,
|
|
||||||
dtype=np.float32
|
|
||||||
)
|
|
||||||
|
|
||||||
# create observation space
|
self._last_observation = None
|
||||||
if from_pixels:
|
self.viewer = None
|
||||||
shape = [3, height, width] if channels_first else [height, width, 3]
|
|
||||||
self._observation_space = spaces.Box(
|
|
||||||
low=0, high=255, shape=shape, dtype=np.uint8
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self._observation_space = _spec_to_box(
|
|
||||||
self._env.observation_spec().values()
|
|
||||||
)
|
|
||||||
|
|
||||||
self._state_space = _spec_to_box(
|
|
||||||
self._env.observation_spec().values()
|
|
||||||
)
|
|
||||||
|
|
||||||
self.current_state = None
|
|
||||||
|
|
||||||
# set seed
|
# set seed
|
||||||
self.seed(seed=task_kwargs.get('random', 1))
|
self.seed(seed=task_kwargs.get('random', 1))
|
||||||
|
|
||||||
def __getattr__(self, name):
|
def __getattr__(self, name):
|
||||||
|
"""Delegate attribute access to underlying environment."""
|
||||||
return getattr(self._env, name)
|
return getattr(self._env, name)
|
||||||
|
|
||||||
def _get_obs(self, time_step):
|
def _get_obs(self, time_step):
|
||||||
@ -124,59 +99,72 @@ class DMCWrapper(core.Env):
|
|||||||
obs = _flatten_obs(time_step.observation)
|
obs = _flatten_obs(time_step.observation)
|
||||||
return obs
|
return obs
|
||||||
|
|
||||||
def _convert_action(self, action):
|
|
||||||
action = action.astype(float)
|
|
||||||
true_delta = self._true_action_space.high - self._true_action_space.low
|
|
||||||
norm_delta = self._norm_action_space.high - self._norm_action_space.low
|
|
||||||
action = (action - self._norm_action_space.low) / norm_delta
|
|
||||||
action = action * true_delta + self._true_action_space.low
|
|
||||||
action = action.astype(np.float32)
|
|
||||||
return action
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def observation_space(self):
|
def observation_space(self):
|
||||||
return self._observation_space
|
return self._observation_space
|
||||||
|
|
||||||
@property
|
|
||||||
def state_space(self):
|
|
||||||
return self._state_space
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def action_space(self):
|
def action_space(self):
|
||||||
return self._norm_action_space
|
return self._action_space
|
||||||
|
|
||||||
def seed(self, seed):
|
def seed(self, seed=None):
|
||||||
self._true_action_space.seed(seed)
|
self._action_space.seed(seed)
|
||||||
self._norm_action_space.seed(seed)
|
|
||||||
self._observation_space.seed(seed)
|
self._observation_space.seed(seed)
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action) -> Tuple[np.ndarray, float, bool, Dict[str, Any]]:
|
||||||
assert self._norm_action_space.contains(action)
|
assert self._action_space.contains(action)
|
||||||
action = self._convert_action(action)
|
|
||||||
assert self._true_action_space.contains(action)
|
|
||||||
reward = 0
|
reward = 0
|
||||||
extra = {'internal_state': self._env.physics.get_state().copy()}
|
extra = {'internal_state': self._env.physics.get_state().copy()}
|
||||||
|
|
||||||
for _ in range(self._frame_skip):
|
for _ in range(self._frame_skip):
|
||||||
time_step = self._env.step(action)
|
time_step = self._env.step(action)
|
||||||
reward += time_step.reward or 0
|
reward += time_step.reward or 0.
|
||||||
done = time_step.last()
|
done = time_step.last()
|
||||||
if done:
|
if done:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
self._last_observation = _flatten_obs(time_step.observation)
|
||||||
obs = self._get_obs(time_step)
|
obs = self._get_obs(time_step)
|
||||||
self.current_state = _flatten_obs(time_step.observation)
|
|
||||||
extra['discount'] = time_step.discount
|
extra['discount'] = time_step.discount
|
||||||
return obs, reward, done, extra
|
return obs, reward, done, extra
|
||||||
|
|
||||||
def reset(self):
|
def reset(self) -> np.ndarray:
|
||||||
time_step = self._env.reset()
|
time_step = self._env.reset()
|
||||||
self.current_state = _flatten_obs(time_step.observation)
|
self._last_observation = _flatten_obs(time_step.observation)
|
||||||
obs = self._get_obs(time_step)
|
obs = self._get_obs(time_step)
|
||||||
return obs
|
return obs
|
||||||
|
|
||||||
def render(self, mode='rgb_array', height=None, width=None, camera_id=0):
|
def render(self, mode='rgb_array', height=None, width=None, camera_id=0):
|
||||||
assert mode == 'rgb_array', 'only support rgb_array mode, given %s' % mode
|
if self._last_observation is None:
|
||||||
height = height or self._height
|
raise ValueError('Environment not ready to render. Call reset() first.')
|
||||||
width = width or self._width
|
|
||||||
camera_id = camera_id or self._camera_id
|
# assert mode == 'rgb_array', 'only support rgb_array mode, given %s' % mode
|
||||||
return self._env.physics.render(height=height, width=width, camera_id=camera_id)
|
if mode == "rgb_array":
|
||||||
|
height = height or self._height
|
||||||
|
width = width or self._width
|
||||||
|
camera_id = camera_id or self._camera_id
|
||||||
|
return self._env.physics.render(height=height, width=width, camera_id=camera_id)
|
||||||
|
|
||||||
|
elif mode == 'human':
|
||||||
|
if self.viewer is None:
|
||||||
|
# pylint: disable=import-outside-toplevel
|
||||||
|
# pylint: disable=g-import-not-at-top
|
||||||
|
from gym.envs.classic_control import rendering
|
||||||
|
self.viewer = rendering.SimpleImageViewer()
|
||||||
|
# Render max available buffer size. Larger is only possible by altering the XML.
|
||||||
|
img = self._env.physics.render(height=self._env.physics.model.vis.global_.offheight,
|
||||||
|
width=self._env.physics.model.vis.global_.offwidth)
|
||||||
|
self.viewer.imshow(img)
|
||||||
|
return self.viewer.isopen
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
super().close()
|
||||||
|
if self.viewer is not None and self.viewer.isopen:
|
||||||
|
self.viewer.close()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def reward_range(self) -> Tuple[float, float]:
|
||||||
|
reward_spec = self._env.reward_spec()
|
||||||
|
if isinstance(reward_spec, specs.BoundedArray):
|
||||||
|
return reward_spec.minimum, reward_spec.maximum
|
||||||
|
return -float('inf'), float('inf')
|
||||||
|
@ -82,7 +82,7 @@ def _make_wrapped_env(env_id: str, wrappers: Iterable[Type[gym.Wrapper]], seed=1
|
|||||||
return _env
|
return _env
|
||||||
|
|
||||||
|
|
||||||
def make_dmp_env(env_id: str, wrappers: Iterable, seed=1, **mp_kwargs):
|
def make_dmp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwargs):
|
||||||
"""
|
"""
|
||||||
This can also be used standalone for manually building a custom DMP environment.
|
This can also be used standalone for manually building a custom DMP environment.
|
||||||
Args:
|
Args:
|
||||||
@ -95,11 +95,11 @@ def make_dmp_env(env_id: str, wrappers: Iterable, seed=1, **mp_kwargs):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed)
|
_env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs)
|
||||||
return DmpWrapper(_env, **mp_kwargs)
|
return DmpWrapper(_env, **mp_kwargs)
|
||||||
|
|
||||||
|
|
||||||
def make_detpmp_env(env_id: str, wrappers: Iterable, seed=1, **mp_kwargs):
|
def make_detpmp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwargs):
|
||||||
"""
|
"""
|
||||||
This can also be used standalone for manually building a custom Det ProMP environment.
|
This can also be used standalone for manually building a custom Det ProMP environment.
|
||||||
Args:
|
Args:
|
||||||
@ -111,7 +111,7 @@ def make_detpmp_env(env_id: str, wrappers: Iterable, seed=1, **mp_kwargs):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed)
|
_env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs)
|
||||||
return DetPMPWrapper(_env, **mp_kwargs)
|
return DetPMPWrapper(_env, **mp_kwargs)
|
||||||
|
|
||||||
|
|
||||||
@ -129,9 +129,9 @@ def make_dmp_env_helper(**kwargs):
|
|||||||
Returns: DMP wrapped gym env
|
Returns: DMP wrapped gym env
|
||||||
|
|
||||||
"""
|
"""
|
||||||
seed = kwargs.get("seed", None)
|
seed = kwargs.pop("seed", None)
|
||||||
return make_dmp_env(env_id=kwargs.pop("name"), wrappers=kwargs.pop("wrappers"), seed=seed,
|
return make_dmp_env(env_id=kwargs.pop("name"), wrappers=kwargs.pop("wrappers"), seed=seed,
|
||||||
**kwargs.get("mp_kwargs"))
|
mp_kwargs=kwargs.pop("mp_kwargs"), **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def make_detpmp_env_helper(**kwargs):
|
def make_detpmp_env_helper(**kwargs):
|
||||||
@ -149,12 +149,13 @@ def make_detpmp_env_helper(**kwargs):
|
|||||||
Returns: DMP wrapped gym env
|
Returns: DMP wrapped gym env
|
||||||
|
|
||||||
"""
|
"""
|
||||||
seed = kwargs.get("seed", None)
|
seed = kwargs.pop("seed", None)
|
||||||
return make_detpmp_env(env_id=kwargs.pop("name"), wrappers=kwargs.pop("wrappers"), seed=seed,
|
return make_detpmp_env(env_id=kwargs.pop("name"), wrappers=kwargs.pop("wrappers"), seed=seed,
|
||||||
**kwargs.get("mp_kwargs"))
|
mp_kwargs=kwargs.pop("mp_kwargs"), **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def make_contextual_env(env_id, context, seed, rank):
|
def make_contextual_env(env_id, context, seed, rank):
|
||||||
env = gym.make(env_id, context=context)
|
env = make_env(env_id, seed + rank, context=context)
|
||||||
env.seed(seed + rank)
|
# env = gym.make(env_id, context=context)
|
||||||
|
# env.seed(seed + rank)
|
||||||
return lambda: env
|
return lambda: env
|
||||||
|
@ -3,11 +3,7 @@ from gym.vector.async_vector_env import AsyncVectorEnv
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from _collections import defaultdict
|
from _collections import defaultdict
|
||||||
|
|
||||||
|
from alr_envs.utils.make_env_helpers import make_env_rank
|
||||||
def make_env(env_id, rank, seed=0, **env_kwargs):
|
|
||||||
env = gym.make(env_id, **env_kwargs)
|
|
||||||
env.seed(seed + rank)
|
|
||||||
return lambda: env
|
|
||||||
|
|
||||||
|
|
||||||
def split_array(ary, size):
|
def split_array(ary, size):
|
||||||
@ -55,9 +51,10 @@ class AlrMpEnvSampler:
|
|||||||
An asynchronous sampler for non contextual MPWrapper environments. A sampler object can be called with a set of
|
An asynchronous sampler for non contextual MPWrapper environments. A sampler object can be called with a set of
|
||||||
parameters and returns the corresponding final obs, rewards, dones and info dicts.
|
parameters and returns the corresponding final obs, rewards, dones and info dicts.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, env_id, num_envs, seed=0, **env_kwargs):
|
def __init__(self, env_id, num_envs, seed=0, **env_kwargs):
|
||||||
self.num_envs = num_envs
|
self.num_envs = num_envs
|
||||||
self.env = AsyncVectorEnv([make_env(env_id, seed, i, **env_kwargs) for i in range(num_envs)])
|
self.env = AsyncVectorEnv([make_env_rank(env_id, seed, i, **env_kwargs) for i in range(num_envs)])
|
||||||
|
|
||||||
def __call__(self, params):
|
def __call__(self, params):
|
||||||
params = np.atleast_2d(params)
|
params = np.atleast_2d(params)
|
||||||
@ -74,8 +71,8 @@ class AlrMpEnvSampler:
|
|||||||
vals['info'].append(info)
|
vals['info'].append(info)
|
||||||
|
|
||||||
# do not return values above threshold
|
# do not return values above threshold
|
||||||
return np.vstack(vals['obs'])[:n_samples], np.hstack(vals['reward'])[:n_samples],\
|
return np.vstack(vals['obs'])[:n_samples], np.hstack(vals['reward'])[:n_samples], \
|
||||||
_flatten_list(vals['done'])[:n_samples], _flatten_list(vals['info'])[:n_samples]
|
_flatten_list(vals['done'])[:n_samples], _flatten_list(vals['info'])[:n_samples]
|
||||||
|
|
||||||
|
|
||||||
class AlrContextualMpEnvSampler:
|
class AlrContextualMpEnvSampler:
|
||||||
@ -83,12 +80,12 @@ class AlrContextualMpEnvSampler:
|
|||||||
An asynchronous sampler for contextual MPWrapper environments. A sampler object can be called with a set of
|
An asynchronous sampler for contextual MPWrapper environments. A sampler object can be called with a set of
|
||||||
parameters and returns the corresponding final obs, rewards, dones and info dicts.
|
parameters and returns the corresponding final obs, rewards, dones and info dicts.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, env_id, num_envs, seed=0, **env_kwargs):
|
def __init__(self, env_id, num_envs, seed=0, **env_kwargs):
|
||||||
self.num_envs = num_envs
|
self.num_envs = num_envs
|
||||||
self.env = AsyncVectorEnv([make_env(env_id, seed, i, **env_kwargs) for i in range(num_envs)])
|
self.env = AsyncVectorEnv([make_env(env_id, seed, i, **env_kwargs) for i in range(num_envs)])
|
||||||
|
|
||||||
def __call__(self, dist, n_samples):
|
def __call__(self, dist, n_samples):
|
||||||
|
|
||||||
repeat = int(np.ceil(n_samples / self.env.num_envs))
|
repeat = int(np.ceil(n_samples / self.env.num_envs))
|
||||||
vals = defaultdict(list)
|
vals = defaultdict(list)
|
||||||
|
|
||||||
@ -106,8 +103,8 @@ class AlrContextualMpEnvSampler:
|
|||||||
|
|
||||||
# do not return values above threshold
|
# do not return values above threshold
|
||||||
return np.vstack(vals['new_samples'])[:n_samples], \
|
return np.vstack(vals['new_samples'])[:n_samples], \
|
||||||
np.vstack(vals['obs'])[:n_samples], np.hstack(vals['reward'])[:n_samples], \
|
np.vstack(vals['obs'])[:n_samples], np.hstack(vals['reward'])[:n_samples], \
|
||||||
_flatten_list(vals['done'])[:n_samples], _flatten_list(vals['info'])[:n_samples]
|
_flatten_list(vals['done'])[:n_samples], _flatten_list(vals['info'])[:n_samples]
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
Reference in New Issue
Block a user