2021-06-28 17:25:53 +02:00
|
|
|
import re
|
2021-07-02 13:09:56 +02:00
|
|
|
from typing import Union
|
2021-06-28 17:25:53 +02:00
|
|
|
|
|
|
|
import gym
|
|
|
|
from gym.envs.registration import register
|
|
|
|
|
|
|
|
|
|
|
|
def make(
|
2021-07-02 13:09:56 +02:00
|
|
|
id: str,
|
|
|
|
seed: int = 1,
|
|
|
|
visualize_reward: bool = True,
|
|
|
|
from_pixels: bool = False,
|
|
|
|
height: int = 84,
|
|
|
|
width: int = 84,
|
|
|
|
camera_id: int = 0,
|
|
|
|
frame_skip: int = 1,
|
|
|
|
episode_length: Union[None, int] = None,
|
|
|
|
environment_kwargs: dict = {},
|
|
|
|
time_limit: Union[None, float] = None,
|
|
|
|
channels_first: bool = True
|
2021-06-28 17:25:53 +02:00
|
|
|
):
|
|
|
|
# Adopted from: https://github.com/denisyarats/dmc2gym/blob/master/dmc2gym/__init__.py
|
|
|
|
# License: MIT
|
|
|
|
# Copyright (c) 2020 Denis Yarats
|
|
|
|
|
|
|
|
assert re.match(r"\w+-\w+", id), "env_id does not have the following structure: 'domain_name-task_name'"
|
|
|
|
domain_name, task_name = id.split("-")
|
|
|
|
|
|
|
|
env_id = f'dmc_{domain_name}_{task_name}_{seed}-v1'
|
|
|
|
|
|
|
|
if from_pixels:
|
|
|
|
assert not visualize_reward, 'cannot use visualize reward when learning from pixels'
|
|
|
|
|
|
|
|
# shorten episode length
|
2021-07-02 13:09:56 +02:00
|
|
|
if episode_length is None:
|
|
|
|
# Default lengths for benchmarking suite is 1000 and for manipulation tasks 250
|
|
|
|
episode_length = 250 if domain_name == "manipulation" else 1000
|
|
|
|
|
2021-06-28 17:25:53 +02:00
|
|
|
max_episode_steps = (episode_length + frame_skip - 1) // frame_skip
|
|
|
|
|
|
|
|
if env_id not in gym.envs.registry.env_specs:
|
2021-07-02 13:09:56 +02:00
|
|
|
task_kwargs = {'random': seed}
|
2021-06-29 16:17:18 +02:00
|
|
|
# if seed is not None:
|
2021-07-02 13:09:56 +02:00
|
|
|
# task_kwargs['random'] = seed
|
2021-06-28 17:25:53 +02:00
|
|
|
if time_limit is not None:
|
|
|
|
task_kwargs['time_limit'] = time_limit
|
|
|
|
register(
|
|
|
|
id=env_id,
|
|
|
|
entry_point='alr_envs.utils.dmc2gym_wrapper:DMCWrapper',
|
|
|
|
kwargs=dict(
|
|
|
|
domain_name=domain_name,
|
|
|
|
task_name=task_name,
|
|
|
|
task_kwargs=task_kwargs,
|
|
|
|
environment_kwargs=environment_kwargs,
|
|
|
|
visualize_reward=visualize_reward,
|
|
|
|
from_pixels=from_pixels,
|
|
|
|
height=height,
|
|
|
|
width=width,
|
|
|
|
camera_id=camera_id,
|
|
|
|
frame_skip=frame_skip,
|
|
|
|
channels_first=channels_first,
|
|
|
|
),
|
|
|
|
max_episode_steps=max_episode_steps,
|
|
|
|
)
|
|
|
|
return gym.make(env_id)
|