This commit is contained in:
Fabian 2022-07-07 10:47:04 +02:00
parent 4a3134d7be
commit fc00cf8a87
24 changed files with 235 additions and 302 deletions

View File

@ -157,60 +157,36 @@ register(
id='ALRAntJump-v0', id='ALRAntJump-v0',
entry_point='alr_envs.alr.mujoco:AntJumpEnv', entry_point='alr_envs.alr.mujoco:AntJumpEnv',
max_episode_steps=MAX_EPISODE_STEPS_ANTJUMP, max_episode_steps=MAX_EPISODE_STEPS_ANTJUMP,
kwargs={
"max_episode_steps": MAX_EPISODE_STEPS_ANTJUMP,
"context": True
}
) )
register( register(
id='ALRHalfCheetahJump-v0', id='ALRHalfCheetahJump-v0',
entry_point='alr_envs.alr.mujoco:ALRHalfCheetahJumpEnv', entry_point='alr_envs.alr.mujoco:ALRHalfCheetahJumpEnv',
max_episode_steps=MAX_EPISODE_STEPS_HALFCHEETAHJUMP, max_episode_steps=MAX_EPISODE_STEPS_HALFCHEETAHJUMP,
kwargs={
"max_episode_steps": MAX_EPISODE_STEPS_HALFCHEETAHJUMP,
"context": True
}
) )
register( register(
id='HopperJumpOnBox-v0', id='HopperJumpOnBox-v0',
entry_point='alr_envs.alr.mujoco:ALRHopperJumpOnBoxEnv', entry_point='alr_envs.alr.mujoco:HopperJumpOnBoxEnv',
max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMPONBOX, max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMPONBOX,
kwargs={
"max_episode_steps": MAX_EPISODE_STEPS_HOPPERJUMPONBOX,
"context": True
}
) )
register( register(
id='ALRHopperThrow-v0', id='ALRHopperThrow-v0',
entry_point='alr_envs.alr.mujoco:ALRHopperThrowEnv', entry_point='alr_envs.alr.mujoco:ALRHopperThrowEnv',
max_episode_steps=MAX_EPISODE_STEPS_HOPPERTHROW, max_episode_steps=MAX_EPISODE_STEPS_HOPPERTHROW,
kwargs={
"max_episode_steps": MAX_EPISODE_STEPS_HOPPERTHROW,
"context": True
}
) )
register( register(
id='ALRHopperThrowInBasket-v0', id='ALRHopperThrowInBasket-v0',
entry_point='alr_envs.alr.mujoco:ALRHopperThrowInBasketEnv', entry_point='alr_envs.alr.mujoco:ALRHopperThrowInBasketEnv',
max_episode_steps=MAX_EPISODE_STEPS_HOPPERTHROWINBASKET, max_episode_steps=MAX_EPISODE_STEPS_HOPPERTHROWINBASKET,
kwargs={
"max_episode_steps": MAX_EPISODE_STEPS_HOPPERTHROWINBASKET,
"context": True
}
) )
register( register(
id='ALRWalker2DJump-v0', id='ALRWalker2DJump-v0',
entry_point='alr_envs.alr.mujoco:ALRWalker2dJumpEnv', entry_point='alr_envs.alr.mujoco:ALRWalker2dJumpEnv',
max_episode_steps=MAX_EPISODE_STEPS_WALKERJUMP, max_episode_steps=MAX_EPISODE_STEPS_WALKERJUMP,
kwargs={
"max_episode_steps": MAX_EPISODE_STEPS_WALKERJUMP,
"context": True
}
) )
register( register(
@ -403,46 +379,48 @@ for _v in _versions:
## Table Tennis needs to be fixed according to Zhou's implementation ## Table Tennis needs to be fixed according to Zhou's implementation
######################################################################################################################## # TODO: Add later when finished
# ########################################################################################################################
## AntJump #
_versions = ['ALRAntJump-v0'] # ## AntJump
for _v in _versions: # _versions = ['ALRAntJump-v0']
_name = _v.split("-") # for _v in _versions:
_env_id = f'{_name[0]}ProMP-{_name[1]}' # _name = _v.split("-")
kwargs_dict_ant_jump_promp = deepcopy(DEFAULT_BB_DICT_ProMP) # _env_id = f'{_name[0]}ProMP-{_name[1]}'
kwargs_dict_ant_jump_promp['wrappers'].append(mujoco.ant_jump.MPWrapper) # kwargs_dict_ant_jump_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
kwargs_dict_ant_jump_promp['name'] = _v # kwargs_dict_ant_jump_promp['wrappers'].append(mujoco.ant_jump.MPWrapper)
register( # kwargs_dict_ant_jump_promp['name'] = _v
id=_env_id, # register(
entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper', # id=_env_id,
kwargs=kwargs_dict_ant_jump_promp # entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
) # kwargs=kwargs_dict_ant_jump_promp
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) # )
# ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
######################################################################################################################## #
# ########################################################################################################################
## HalfCheetahJump #
_versions = ['ALRHalfCheetahJump-v0'] # ## HalfCheetahJump
for _v in _versions: # _versions = ['ALRHalfCheetahJump-v0']
_name = _v.split("-") # for _v in _versions:
_env_id = f'{_name[0]}ProMP-{_name[1]}' # _name = _v.split("-")
kwargs_dict_halfcheetah_jump_promp = deepcopy(DEFAULT_BB_DICT_ProMP) # _env_id = f'{_name[0]}ProMP-{_name[1]}'
kwargs_dict_halfcheetah_jump_promp['wrappers'].append(mujoco.half_cheetah_jump.MPWrapper) # kwargs_dict_halfcheetah_jump_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
kwargs_dict_halfcheetah_jump_promp['name'] = _v # kwargs_dict_halfcheetah_jump_promp['wrappers'].append(mujoco.half_cheetah_jump.MPWrapper)
register( # kwargs_dict_halfcheetah_jump_promp['name'] = _v
id=_env_id, # register(
entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper', # id=_env_id,
kwargs=kwargs_dict_halfcheetah_jump_promp # entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
) # kwargs=kwargs_dict_halfcheetah_jump_promp
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) # )
# ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
######################################################################################################################## #
# ########################################################################################################################
## HopperJump ## HopperJump
_versions = ['HopperJump-v0', 'HopperJumpSparse-v0', 'ALRHopperJumpOnBox-v0', 'ALRHopperThrow-v0', _versions = ['HopperJump-v0', 'HopperJumpSparse-v0',
'ALRHopperThrowInBasket-v0'] # 'ALRHopperJumpOnBox-v0', 'ALRHopperThrow-v0', 'ALRHopperThrowInBasket-v0'
]
# TODO: Check if all environments work with the same MPWrapper # TODO: Check if all environments work with the same MPWrapper
for _v in _versions: for _v in _versions:
_name = _v.split("-") _name = _v.split("-")
@ -457,23 +435,23 @@ for _v in _versions:
) )
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
######################################################################################################################## # ########################################################################################################################
#
#
## Walker2DJump # ## Walker2DJump
_versions = ['ALRWalker2DJump-v0'] # _versions = ['ALRWalker2DJump-v0']
for _v in _versions: # for _v in _versions:
_name = _v.split("-") # _name = _v.split("-")
_env_id = f'{_name[0]}ProMP-{_name[1]}' # _env_id = f'{_name[0]}ProMP-{_name[1]}'
kwargs_dict_walker2d_jump_promp = deepcopy(DEFAULT_BB_DICT_ProMP) # kwargs_dict_walker2d_jump_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
kwargs_dict_walker2d_jump_promp['wrappers'].append(mujoco.walker_2d_jump.MPWrapper) # kwargs_dict_walker2d_jump_promp['wrappers'].append(mujoco.walker_2d_jump.MPWrapper)
kwargs_dict_walker2d_jump_promp['name'] = _v # kwargs_dict_walker2d_jump_promp['name'] = _v
register( # register(
id=_env_id, # id=_env_id,
entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper', # entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
kwargs=kwargs_dict_walker2d_jump_promp # kwargs=kwargs_dict_walker2d_jump_promp
) # )
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) # ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
### Depricated, we will not provide non random starts anymore ### Depricated, we will not provide non random starts anymore
""" """
@ -639,7 +617,7 @@ for i in _vs:
register( register(
id='ALRHopperJumpOnBox-v0', id='ALRHopperJumpOnBox-v0',
entry_point='alr_envs.alr.mujoco:ALRHopperJumpOnBoxEnv', entry_point='alr_envs.alr.mujoco:HopperJumpOnBoxEnv',
max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMPONBOX, max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMPONBOX,
kwargs={ kwargs={
"max_episode_steps": MAX_EPISODE_STEPS_HOPPERJUMPONBOX, "max_episode_steps": MAX_EPISODE_STEPS_HOPPERJUMPONBOX,

View File

@ -1,8 +1,9 @@
from .beerpong.beerpong import BeerPongEnv, BeerPongEnvFixedReleaseStep, BeerPongEnvStepBasedEpisodicReward from .beerpong.beerpong import BeerPongEnv, BeerPongEnvFixedReleaseStep, BeerPongEnvStepBasedEpisodicReward
from .ant_jump.ant_jump import AntJumpEnv from .ant_jump.ant_jump import AntJumpEnv
from .half_cheetah_jump.half_cheetah_jump import ALRHalfCheetahJumpEnv from .half_cheetah_jump.half_cheetah_jump import ALRHalfCheetahJumpEnv
from .hopper_jump.hopper_jump_on_box import ALRHopperJumpOnBoxEnv from .hopper_jump.hopper_jump_on_box import HopperJumpOnBoxEnv
from .hopper_throw.hopper_throw import ALRHopperThrowEnv from .hopper_throw.hopper_throw import ALRHopperThrowEnv
from .hopper_throw.hopper_throw_in_basket import ALRHopperThrowInBasketEnv from .hopper_throw.hopper_throw_in_basket import ALRHopperThrowInBasketEnv
from .reacher.reacher import ReacherEnv from .reacher.reacher import ReacherEnv
from .walker_2d_jump.walker_2d_jump import ALRWalker2dJumpEnv from .walker_2d_jump.walker_2d_jump import ALRWalker2dJumpEnv
from .hopper_jump.hopper_jump import HopperJumpEnv

View File

@ -7,7 +7,8 @@ from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper
class MPWrapper(RawInterfaceWrapper): class MPWrapper(RawInterfaceWrapper):
def get_context_mask(self): @property
def context_mask(self) -> np.ndarray:
return np.hstack([ return np.hstack([
[False] * 7, # cos [False] * 7, # cos
[False] * 7, # sin [False] * 7, # sin
@ -15,16 +16,16 @@ class MPWrapper(RawInterfaceWrapper):
[False] * 3, # cup_goal_diff_final [False] * 3, # cup_goal_diff_final
[False] * 3, # cup_goal_diff_top [False] * 3, # cup_goal_diff_top
[True] * 2, # xy position of cup [True] * 2, # xy position of cup
[False] # env steps # [False] # env steps
]) ])
@property @property
def current_pos(self) -> Union[float, int, np.ndarray, Tuple]: def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
return self.env.sim.data.qpos[0:7].copy() return self.env.data.qpos[0:7].copy()
@property @property
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]: def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
return self.env.sim.data.qvel[0:7].copy() return self.env.data.qvel[0:7].copy()
# TODO: Fix this # TODO: Fix this
def _episode_callback(self, action: np.ndarray, mp) -> Tuple[np.ndarray, Union[np.ndarray, None]]: def _episode_callback(self, action: np.ndarray, mp) -> Tuple[np.ndarray, Union[np.ndarray, None]]:

View File

@ -69,7 +69,7 @@ class ALRHalfCheetahJumpEnv(HalfCheetahEnv):
options: Optional[dict] = None, ) -> Union[ObsType, Tuple[ObsType, dict]]: options: Optional[dict] = None, ) -> Union[ObsType, Tuple[ObsType, dict]]:
self.max_height = 0 self.max_height = 0
self.current_step = 0 self.current_step = 0
self.goal = np.random.uniform(1.1, 1.6, 1) # 1.1 1.6 self.goal = self.np_random.uniform(1.1, 1.6, 1) # 1.1 1.6
return super().reset() return super().reset()
# overwrite reset_model to make it deterministic # overwrite reset_model to make it deterministic

View File

@ -1,2 +1 @@
from .mp_wrapper import MPWrapper from .mp_wrapper import MPWrapper

View File

@ -1,10 +1,9 @@
import copy import copy
from typing import Optional
from gym.envs.mujoco.hopper_v3 import HopperEnv
import numpy as np
import os import os
import numpy as np
from gym.envs.mujoco.hopper_v3 import HopperEnv
MAX_EPISODE_STEPS_HOPPERJUMP = 250 MAX_EPISODE_STEPS_HOPPERJUMP = 250
@ -23,10 +22,10 @@ class HopperJumpEnv(HopperEnv):
xml_file='hopper_jump.xml', xml_file='hopper_jump.xml',
forward_reward_weight=1.0, forward_reward_weight=1.0,
ctrl_cost_weight=1e-3, ctrl_cost_weight=1e-3,
healthy_reward=2.0, # 1 step healthy_reward=2.0,
contact_weight=2.0, # 0 step contact_weight=2.0,
height_weight=10.0, # 3 step height_weight=10.0,
dist_weight=3.0, # 3 step dist_weight=3.0,
terminate_when_unhealthy=False, terminate_when_unhealthy=False,
healthy_state_range=(-100.0, 100.0), healthy_state_range=(-100.0, 100.0),
healthy_z_range=(0.5, float('inf')), healthy_z_range=(0.5, float('inf')),
@ -42,7 +41,7 @@ class HopperJumpEnv(HopperEnv):
self._contact_weight = contact_weight self._contact_weight = contact_weight
self.max_height = 0 self.max_height = 0
self.goal = 0 self.goal = np.zeros(3, )
self._steps = 0 self._steps = 0
self.contact_with_floor = False self.contact_with_floor = False
@ -58,6 +57,10 @@ class HopperJumpEnv(HopperEnv):
# increase initial height # increase initial height
self.init_qpos[1] = 1.5 self.init_qpos[1] = 1.5
@property
def exclude_current_positions_from_observation(self):
return self._exclude_current_positions_from_observation
def step(self, action): def step(self, action):
self._steps += 1 self._steps += 1
@ -80,7 +83,7 @@ class HopperJumpEnv(HopperEnv):
costs = ctrl_cost costs = ctrl_cost
done = False done = False
goal_dist = np.linalg.norm(site_pos_after - np.array([self.goal, 0, 0])) goal_dist = np.linalg.norm(site_pos_after - self.goal)
if self.contact_dist is None and self.contact_with_floor: if self.contact_dist is None and self.contact_with_floor:
self.contact_dist = goal_dist self.contact_dist = goal_dist
@ -99,7 +102,7 @@ class HopperJumpEnv(HopperEnv):
height=height_after, height=height_after,
x_pos=site_pos_after, x_pos=site_pos_after,
max_height=self.max_height, max_height=self.max_height,
goal=self.goal, goal=self.goal[:1],
goal_dist=goal_dist, goal_dist=goal_dist,
height_rew=self.max_height, height_rew=self.max_height,
healthy_reward=self.healthy_reward * 2, healthy_reward=self.healthy_reward * 2,
@ -109,14 +112,15 @@ class HopperJumpEnv(HopperEnv):
return observation, reward, done, info return observation, reward, done, info
def _get_obs(self): def _get_obs(self):
goal_dist = self.data.get_site_xpos('foot_site') - np.array([self.goal, 0, 0]) goal_dist = self.data.get_site_xpos('foot_site') - self.goal
return np.concatenate((super(HopperJumpEnv, self)._get_obs(), goal_dist.copy(), self.goal.copy())) return np.concatenate((super(HopperJumpEnv, self)._get_obs(), goal_dist.copy(), self.goal[:1]))
def reset_model(self): def reset_model(self):
super(HopperJumpEnv, self).reset_model() super(HopperJumpEnv, self).reset_model()
self.goal = self.np_random.uniform(0.3, 1.35, 1)[0] # self.goal = self.np_random.uniform(0.3, 1.35, 1)[0]
self.sim.model.body_pos[self.sim.model.body_name2id('goal_site_body')] = np.array([self.goal, 0, 0]) self.goal = np.concatenate([self.np_random.uniform(0.3, 1.35, 1), np.zeros(2, )])
self.sim.model.body_pos[self.sim.model.body_name2id('goal_site_body')] = self.goal
self.max_height = 0 self.max_height = 0
self._steps = 0 self._steps = 0

View File

@ -6,7 +6,7 @@ import os
MAX_EPISODE_STEPS_HOPPERJUMPONBOX = 250 MAX_EPISODE_STEPS_HOPPERJUMPONBOX = 250
class ALRHopperJumpOnBoxEnv(HopperEnv): class HopperJumpOnBoxEnv(HopperEnv):
""" """
Initialization changes to normal Hopper: Initialization changes to normal Hopper:
- healthy_reward: 1.0 -> 0.01 -> 0.001 - healthy_reward: 1.0 -> 0.01 -> 0.001
@ -153,7 +153,7 @@ class ALRHopperJumpOnBoxEnv(HopperEnv):
if __name__ == '__main__': if __name__ == '__main__':
render_mode = "human" # "human" or "partial" or "final" render_mode = "human" # "human" or "partial" or "final"
env = ALRHopperJumpOnBoxEnv() env = HopperJumpOnBoxEnv()
obs = env.reset() obs = env.reset()
for i in range(2000): for i in range(2000):

View File

@ -14,7 +14,8 @@ class MPWrapper(RawInterfaceWrapper):
[False] * (2 + int(not self.exclude_current_positions_from_observation)), # position [False] * (2 + int(not self.exclude_current_positions_from_observation)), # position
[True] * 3, # set to true if randomize initial pos [True] * 3, # set to true if randomize initial pos
[False] * 6, # velocity [False] * 6, # velocity
[True] [True] * 3, # goal distance
[True] # goal
]) ])
@property @property

View File

@ -67,7 +67,9 @@ class BlackBoxWrapper(gym.ObservationWrapper):
def observation(self, observation): def observation(self, observation):
# return context space if we are # return context space if we are
return observation[self.env.context_mask] if self.return_context_observation else observation obs = observation[self.env.context_mask] if self.return_context_observation else observation
# cast dtype because metaworld returns incorrect that throws gym error
return obs.astype(self.observation_space.dtype)
def get_trajectory(self, action: np.ndarray) -> Tuple: def get_trajectory(self, action: np.ndarray) -> Tuple:
clipped_params = np.clip(action, self.traj_gen_action_space.low, self.traj_gen_action_space.high) clipped_params = np.clip(action, self.traj_gen_action_space.low, self.traj_gen_action_space.high)
@ -147,7 +149,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
infos[k] = elems infos[k] = elems
if self.render_kwargs: if self.render_kwargs:
self.render(**self.render_kwargs) self.env.render(**self.render_kwargs)
if done or self.replanning_schedule(self.current_pos, self.current_vel, obs, c_action, if done or self.replanning_schedule(self.current_pos, self.current_vel, obs, c_action,
t + 1 + self.current_traj_steps): t + 1 + self.current_traj_steps):
@ -170,13 +172,13 @@ class BlackBoxWrapper(gym.ObservationWrapper):
def render(self, **kwargs): def render(self, **kwargs):
"""Only set render options here, such that they can be used during the rollout. """Only set render options here, such that they can be used during the rollout.
This only needs to be called once""" This only needs to be called once"""
self.render_kwargs = kwargs or self.render_kwargs self.render_kwargs = kwargs
# self.env.render(mode=self.render_mode, **self.render_kwargs) # self.env.render(mode=self.render_mode, **self.render_kwargs)
self.env.render(**self.render_kwargs) # self.env.render(**self.render_kwargs)
def reset(self, *, seed: Optional[int] = None, return_info: bool = False, options: Optional[dict] = None): def reset(self, *, seed: Optional[int] = None, return_info: bool = False, options: Optional[dict] = None):
self.current_traj_steps = 0 self.current_traj_steps = 0
return super(BlackBoxWrapper, self).reset(seed=seed, return_info=return_info, options=options) return super(BlackBoxWrapper, self).reset()
def plot_trajs(self, des_trajs, des_vels): def plot_trajs(self, des_trajs, des_vels):
import matplotlib.pyplot as plt import matplotlib.pyplot as plt

View File

@ -10,13 +10,12 @@ class MetaWorldController(BaseController):
Unlike the other Controllers, this is a special tracking_controller for MetaWorld environments. Unlike the other Controllers, this is a special tracking_controller for MetaWorld environments.
They use a position delta for the xyz coordinates and a raw position for the gripper opening. They use a position delta for the xyz coordinates and a raw position for the gripper opening.
:param env: A position environment
""" """
def get_action(self, des_pos, des_vel, c_pos, c_vel): def get_action(self, des_pos, des_vel, c_pos, c_vel):
gripper_pos = des_pos[-1] gripper_pos = des_pos[-1]
cur_pos = env.current_pos[:-1] cur_pos = c_pos[:-1]
xyz_pos = des_pos[:-1] xyz_pos = des_pos[:-1]
assert xyz_pos.shape == cur_pos.shape, \ assert xyz_pos.shape == cur_pos.shape, \

View File

@ -63,16 +63,16 @@ def example_custom_mp(env_name="Reacher5dProMP-v0", seed=1, iterations=1, render
# mp_dict.update({'black_box_kwargs': {'learn_sub_trajectories': True}}) # mp_dict.update({'black_box_kwargs': {'learn_sub_trajectories': True}})
# mp_dict.update({'black_box_kwargs': {'do_replanning': lambda pos, vel, t: lambda t: t % 100}}) # mp_dict.update({'black_box_kwargs': {'do_replanning': lambda pos, vel, t: lambda t: t % 100}})
rewards = 0
obs = env.reset()
# This time rendering every trajectory # This time rendering every trajectory
if render: if render:
env.render(mode="human") env.render(mode="human")
rewards = 0
obs = env.reset()
# number of samples/full trajectories (multiple environment steps) # number of samples/full trajectories (multiple environment steps)
for i in range(iterations): for i in range(iterations):
ac = env.action_space.sample() ac = env.action_space.sample() * 1000
obs, reward, done, info = env.step(ac) obs, reward, done, info = env.step(ac)
rewards += reward rewards += reward
@ -139,7 +139,7 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True):
if __name__ == '__main__': if __name__ == '__main__':
render = False render = True
# # DMP # # DMP
# example_mp("alr_envs:HoleReacherDMP-v1", seed=10, iterations=1, render=render) # example_mp("alr_envs:HoleReacherDMP-v1", seed=10, iterations=1, render=render)
# #
@ -150,7 +150,7 @@ if __name__ == '__main__':
# example_mp("alr_envs:HoleReacherDetPMP-v1", seed=10, iterations=1, render=render) # example_mp("alr_envs:HoleReacherDetPMP-v1", seed=10, iterations=1, render=render)
# Altered basis functions # Altered basis functions
example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=10, render=render) example_custom_mp("HopperJumpSparseProMP-v0", seed=10, iterations=10, render=render)
# Custom MP # Custom MP
# example_fully_custom_mp(seed=10, iterations=1, render=render) # example_fully_custom_mp(seed=10, iterations=1, render=render)

View File

@ -1,3 +1,5 @@
from copy import deepcopy
from gym import register from gym import register
from . import goal_object_change_mp_wrapper, goal_change_mp_wrapper, goal_endeffector_change_mp_wrapper, \ from . import goal_object_change_mp_wrapper, goal_change_mp_wrapper, goal_endeffector_change_mp_wrapper, \
@ -7,27 +9,39 @@ ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []}
# MetaWorld # MetaWorld
DEFAULT_BB_DICT_ProMP = {
"name": 'EnvName',
"wrappers": [],
"trajectory_generator_kwargs": {
'trajectory_generator_type': 'promp'
},
"phase_generator_kwargs": {
'phase_generator_type': 'linear'
},
"controller_kwargs": {
'controller_type': 'metaworld',
},
"basis_generator_kwargs": {
'basis_generator_type': 'zero_rbf',
'num_basis': 5,
'num_basis_zero_start': 1
}
}
_goal_change_envs = ["assembly-v2", "pick-out-of-hole-v2", "plate-slide-v2", "plate-slide-back-v2", _goal_change_envs = ["assembly-v2", "pick-out-of-hole-v2", "plate-slide-v2", "plate-slide-back-v2",
"plate-slide-side-v2", "plate-slide-back-side-v2"] "plate-slide-side-v2", "plate-slide-back-side-v2"]
for _task in _goal_change_envs: for _task in _goal_change_envs:
task_id_split = _task.split("-") task_id_split = _task.split("-")
name = "".join([s.capitalize() for s in task_id_split[:-1]]) name = "".join([s.capitalize() for s in task_id_split[:-1]])
_env_id = f'{name}ProMP-{task_id_split[-1]}' _env_id = f'{name}ProMP-{task_id_split[-1]}'
kwargs_dict_goal_change_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
kwargs_dict_goal_change_promp['wrappers'].append(goal_change_mp_wrapper.MPWrapper)
kwargs_dict_goal_change_promp['name'] = _task
register( register(
id=_env_id, id=_env_id,
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper', entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
kwargs={ kwargs=kwargs_dict_goal_change_promp
"name": _task,
"wrappers": [goal_change_mp_wrapper.MPWrapper],
"traj_gen_kwargs": {
"num_dof": 4,
"num_basis": 5,
"duration": 6.25,
"post_traj_time": 0,
"zero_start": True,
"policy_type": "metaworld",
}
}
) )
ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
@ -36,21 +50,13 @@ for _task in _object_change_envs:
task_id_split = _task.split("-") task_id_split = _task.split("-")
name = "".join([s.capitalize() for s in task_id_split[:-1]]) name = "".join([s.capitalize() for s in task_id_split[:-1]])
_env_id = f'{name}ProMP-{task_id_split[-1]}' _env_id = f'{name}ProMP-{task_id_split[-1]}'
kwargs_dict_object_change_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
kwargs_dict_object_change_promp['wrappers'].append(object_change_mp_wrapper.MPWrapper)
kwargs_dict_object_change_promp['name'] = _task
register( register(
id=_env_id, id=_env_id,
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper', entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
kwargs={ kwargs=kwargs_dict_object_change_promp
"name": _task,
"wrappers": [object_change_mp_wrapper.MPWrapper],
"traj_gen_kwargs": {
"num_dof": 4,
"num_basis": 5,
"duration": 6.25,
"post_traj_time": 0,
"zero_start": True,
"policy_type": "metaworld",
}
}
) )
ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
@ -69,21 +75,14 @@ for _task in _goal_and_object_change_envs:
task_id_split = _task.split("-") task_id_split = _task.split("-")
name = "".join([s.capitalize() for s in task_id_split[:-1]]) name = "".join([s.capitalize() for s in task_id_split[:-1]])
_env_id = f'{name}ProMP-{task_id_split[-1]}' _env_id = f'{name}ProMP-{task_id_split[-1]}'
kwargs_dict_goal_and_object_change_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
kwargs_dict_goal_and_object_change_promp['wrappers'].append(goal_object_change_mp_wrapper.MPWrapper)
kwargs_dict_goal_and_object_change_promp['name'] = _task
register( register(
id=_env_id, id=_env_id,
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper', entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
kwargs={ kwargs=kwargs_dict_goal_and_object_change_promp
"name": _task,
"wrappers": [goal_object_change_mp_wrapper.MPWrapper],
"traj_gen_kwargs": {
"num_dof": 4,
"num_basis": 5,
"duration": 6.25,
"post_traj_time": 0,
"zero_start": True,
"policy_type": "metaworld",
}
}
) )
ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
@ -92,20 +91,13 @@ for _task in _goal_and_endeffector_change_envs:
task_id_split = _task.split("-") task_id_split = _task.split("-")
name = "".join([s.capitalize() for s in task_id_split[:-1]]) name = "".join([s.capitalize() for s in task_id_split[:-1]])
_env_id = f'{name}ProMP-{task_id_split[-1]}' _env_id = f'{name}ProMP-{task_id_split[-1]}'
kwargs_dict_goal_and_endeffector_change_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
kwargs_dict_goal_and_endeffector_change_promp['wrappers'].append(goal_endeffector_change_mp_wrapper.MPWrapper)
kwargs_dict_goal_and_endeffector_change_promp['name'] = _task
register( register(
id=_env_id, id=_env_id,
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper', entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
kwargs={ kwargs=kwargs_dict_goal_and_endeffector_change_promp
"name": _task,
"wrappers": [goal_endeffector_change_mp_wrapper.MPWrapper],
"traj_gen_kwargs": {
"num_dof": 4,
"num_basis": 5,
"duration": 6.25,
"post_traj_time": 0,
"zero_start": True,
"policy_type": "metaworld",
}
}
) )
ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)

View File

@ -0,0 +1,21 @@
from abc import ABC
from typing import Tuple, Union
import numpy as np
from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper
class BaseMetaworldMPWrapper(RawInterfaceWrapper, ABC):
@property
def current_pos(self) -> Union[float, int, np.ndarray]:
r_close = self.env.data.get_joint_qpos("r_close")
# TODO check if this is correct
# return np.hstack([self.env.data.get_body_xpos('hand').flatten() / self.env.action_scale, r_close])
return np.hstack([self.env.data.mocap_pos.flatten() / self.env.action_scale, r_close])
@property
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
# TODO check if this is correct
return np.zeros(4, )
# raise NotImplementedError("Velocity cannot be retrieved.")

View File

@ -1,11 +1,9 @@
from typing import Tuple, Union
import numpy as np import numpy as np
from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper from alr_envs.meta.base_metaworld_mp_wrapper import BaseMetaworldMPWrapper
class MPWrapper(RawInterfaceWrapper): class MPWrapper(BaseMetaworldMPWrapper):
""" """
This Wrapper is for environments where merely the goal changes in the beginning This Wrapper is for environments where merely the goal changes in the beginning
and no secondary objects or end effectors are altered at the start of an episode. and no secondary objects or end effectors are altered at the start of an episode.
@ -49,20 +47,3 @@ class MPWrapper(RawInterfaceWrapper):
# Goal # Goal
[True] * 3, # goal position [True] * 3, # goal position
]) ])
@property
def current_pos(self) -> Union[float, int, np.ndarray]:
r_close = self.env.data.get_joint_qpos("r_close")
return np.hstack([self.env.data.mocap_pos.flatten() / self.env.action_scale, r_close])
@property
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
raise NotImplementedError("Velocity cannot be retrieved.")
@property
def goal_pos(self) -> Union[float, int, np.ndarray, Tuple]:
raise ValueError("Goal position is not available and has to be learnt based on the environment.")
@property
def dt(self) -> Union[float, int]:
return self.env.dt

View File

@ -1,11 +1,9 @@
from typing import Tuple, Union
import numpy as np import numpy as np
from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper from alr_envs.meta.base_metaworld_mp_wrapper import BaseMetaworldMPWrapper
class MPWrapper(RawInterfaceWrapper): class MPWrapper(BaseMetaworldMPWrapper):
""" """
This Wrapper is for environments where merely the goal changes in the beginning This Wrapper is for environments where merely the goal changes in the beginning
and no secondary objects or end effectors are altered at the start of an episode. and no secondary objects or end effectors are altered at the start of an episode.
@ -49,20 +47,3 @@ class MPWrapper(RawInterfaceWrapper):
# Goal # Goal
[True] * 3, # goal position [True] * 3, # goal position
]) ])
@property
def current_pos(self) -> Union[float, int, np.ndarray]:
r_close = self.env.data.get_joint_qpos("r_close")
return np.hstack([self.env.data.mocap_pos.flatten() / self.env.action_scale, r_close])
@property
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
raise NotImplementedError("Velocity cannot be retrieved.")
@property
def goal_pos(self) -> Union[float, int, np.ndarray, Tuple]:
raise ValueError("Goal position is not available and has to be learnt based on the environment.")
@property
def dt(self) -> Union[float, int]:
return self.env.dt

View File

@ -1,11 +1,9 @@
from typing import Tuple, Union
import numpy as np import numpy as np
from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper from alr_envs.meta.base_metaworld_mp_wrapper import BaseMetaworldMPWrapper
class MPWrapper(RawInterfaceWrapper): class MPWrapper(BaseMetaworldMPWrapper):
""" """
This Wrapper is for environments where merely the goal changes in the beginning This Wrapper is for environments where merely the goal changes in the beginning
and no secondary objects or end effectors are altered at the start of an episode. and no secondary objects or end effectors are altered at the start of an episode.
@ -49,20 +47,3 @@ class MPWrapper(RawInterfaceWrapper):
# Goal # Goal
[True] * 3, # goal position [True] * 3, # goal position
]) ])
@property
def current_pos(self) -> Union[float, int, np.ndarray]:
r_close = self.env.data.get_joint_qpos("r_close")
return np.hstack([self.env.data.mocap_pos.flatten() / self.env.action_scale, r_close])
@property
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
raise NotImplementedError("Velocity cannot be retrieved.")
@property
def goal_pos(self) -> Union[float, int, np.ndarray, Tuple]:
raise ValueError("Goal position is not available and has to be learnt based on the environment.")
@property
def dt(self) -> Union[float, int]:
return self.env.dt

View File

@ -1,11 +1,9 @@
from typing import Tuple, Union
import numpy as np import numpy as np
from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper from alr_envs.meta.base_metaworld_mp_wrapper import BaseMetaworldMPWrapper
class MPWrapper(RawInterfaceWrapper): class MPWrapper(BaseMetaworldMPWrapper):
""" """
This Wrapper is for environments where merely the goal changes in the beginning This Wrapper is for environments where merely the goal changes in the beginning
and no secondary objects or end effectors are altered at the start of an episode. and no secondary objects or end effectors are altered at the start of an episode.
@ -49,20 +47,3 @@ class MPWrapper(RawInterfaceWrapper):
# Goal # Goal
[True] * 3, # goal position [True] * 3, # goal position
]) ])
@property
def current_pos(self) -> Union[float, int, np.ndarray]:
r_close = self.env.data.get_joint_qpos("r_close")
return np.hstack([self.env.data.mocap_pos.flatten() / self.env.action_scale, r_close])
@property
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
raise NotImplementedError("Velocity cannot be retrieved.")
@property
def goal_pos(self) -> Union[float, int, np.ndarray, Tuple]:
raise ValueError("Goal position is not available and has to be learnt based on the environment.")
@property
def dt(self) -> Union[float, int]:
return self.env.dt

View File

@ -20,7 +20,7 @@ def make_dmc(
environment_kwargs: dict = {}, environment_kwargs: dict = {},
time_limit: Union[None, float] = None, time_limit: Union[None, float] = None,
channels_first: bool = True channels_first: bool = True
): ):
# Adopted from: https://github.com/denisyarats/dmc2gym/blob/master/dmc2gym/__init__.py # Adopted from: https://github.com/denisyarats/dmc2gym/blob/master/dmc2gym/__init__.py
# License: MIT # License: MIT
# Copyright (c) 2020 Denis Yarats # Copyright (c) 2020 Denis Yarats
@ -32,12 +32,10 @@ def make_dmc(
env_id = f'dmc_{domain_name}_{task_name}_{seed}-v1' env_id = f'dmc_{domain_name}_{task_name}_{seed}-v1'
if from_pixels: if from_pixels:
assert not visualize_reward, 'cannot use visualize reward when learning from pixels' assert not visualize_reward, 'Cannot use visualize reward when learning from pixels.'
# shorten episode length
if episode_length is None:
# Default lengths for benchmarking suite is 1000 and for manipulation tasks 250 # Default lengths for benchmarking suite is 1000 and for manipulation tasks 250
episode_length = 250 if domain_name == "manipulation" else 1000 episode_length = episode_length or (250 if domain_name == "manipulation" else 1000)
max_episode_steps = (episode_length + frame_skip - 1) // frame_skip max_episode_steps = (episode_length + frame_skip - 1) // frame_skip
if env_id not in gym.envs.registry.env_specs: if env_id not in gym.envs.registry.env_specs:

View File

@ -8,7 +8,7 @@ from gym.envs.registration import EnvSpec, registry
from gym.wrappers import TimeAwareObservation from gym.wrappers import TimeAwareObservation
from alr_envs.black_box.black_box_wrapper import BlackBoxWrapper from alr_envs.black_box.black_box_wrapper import BlackBoxWrapper
from alr_envs.black_box.controller.controller_factory import get_controller from alr_envs.black_box.factory.controller_factory import get_controller
from alr_envs.black_box.factory.basis_generator_factory import get_basis_generator from alr_envs.black_box.factory.basis_generator_factory import get_basis_generator
from alr_envs.black_box.factory.phase_generator_factory import get_phase_generator from alr_envs.black_box.factory.phase_generator_factory import get_phase_generator
from alr_envs.black_box.factory.trajectory_generator_factory import get_trajectory_generator from alr_envs.black_box.factory.trajectory_generator_factory import get_trajectory_generator
@ -43,11 +43,7 @@ def make_rank(env_id: str, seed: int, rank: int = 0, return_callable=True, **kwa
def make(env_id, seed, **kwargs): def make(env_id, seed, **kwargs):
# This access is required to allow for nested dict updates return _make(env_id, seed, **kwargs)
spec = registry.get(env_id)
all_kwargs = deepcopy(spec.kwargs)
nested_update(all_kwargs, kwargs)
return _make(env_id, seed, **all_kwargs)
def _make(env_id: str, seed, **kwargs): def _make(env_id: str, seed, **kwargs):
@ -62,12 +58,25 @@ def _make(env_id: str, seed, **kwargs):
Returns: Gym environment Returns: Gym environment
""" """
if any(deprec in env_id for deprec in ["DetPMP", "detpmp"]):
warnings.warn("DetPMP is deprecated and converted to ProMP") # 'dmc:domain-task'
env_id = env_id.replace("DetPMP", "ProMP") # 'gym:name-vX'
env_id = env_id.replace("detpmp", "promp") # 'meta:name-vX'
# 'meta:bb:name-vX'
# 'hand:name-vX'
# 'name-vX'
# 'bb:name-vX'
#
# env_id.split(':')
# if 'dmc' :
try: try:
# This access is required to allow for nested dict updates for BB envs
spec = registry.get(env_id)
all_kwargs = deepcopy(spec.kwargs)
nested_update(all_kwargs, kwargs)
kwargs = all_kwargs
# Add seed to kwargs in case it is a predefined gym+dmc hybrid environment. # Add seed to kwargs in case it is a predefined gym+dmc hybrid environment.
if env_id.startswith("dmc"): if env_id.startswith("dmc"):
kwargs.update({"seed": seed}) kwargs.update({"seed": seed})
@ -77,22 +86,25 @@ def _make(env_id: str, seed, **kwargs):
env.seed(seed) env.seed(seed)
env.action_space.seed(seed) env.action_space.seed(seed)
env.observation_space.seed(seed) env.observation_space.seed(seed)
except gym.error.Error: except (gym.error.Error, AttributeError):
# MetaWorld env # MetaWorld env
import metaworld import metaworld
if env_id in metaworld.ML1.ENV_NAMES: if env_id in metaworld.ML1.ENV_NAMES:
env = metaworld.envs.ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[env_id + "-goal-observable"](seed=seed, **kwargs) env = metaworld.envs.ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[env_id + "-goal-observable"](seed=seed, **kwargs)
# setting this avoids generating the same initialization after each reset # setting this avoids generating the same initialization after each reset
env._freeze_rand_vec = False env._freeze_rand_vec = False
env.seeded_rand_vec = True
# Manually set spec, as metaworld environments are not registered via gym # Manually set spec, as metaworld environments are not registered via gym
env.unwrapped.spec = EnvSpec(env_id) env.unwrapped.spec = EnvSpec(env_id)
# Set Timelimit based on the maximum allowed path length of the environment # Set Timelimit based on the maximum allowed path length of the environment
env = gym.wrappers.TimeLimit(env, max_episode_steps=env.max_path_length) env = gym.wrappers.TimeLimit(env, max_episode_steps=env.max_path_length)
env.seed(seed) # env.seed(seed)
env.action_space.seed(seed) # env.action_space.seed(seed)
env.observation_space.seed(seed) # env.observation_space.seed(seed)
env.goal_space.seed(seed) # env.goal_space.seed(seed)
else: else:
# DMC # DMC

View File

@ -1,10 +1,10 @@
import itertools import itertools
from setuptools import setup from setuptools import setup, find_packages
# Environment-specific dependencies for dmc and metaworld # Environment-specific dependencies for dmc and metaworld
extras = { extras = {
"dmc": ["dm_control"], "dmc": ["dm_control==1.0.1"],
"meta": ["metaworld @ git+https://github.com/rlworkgroup/metaworld.git@master#egg=metaworld"], "meta": ["metaworld @ git+https://github.com/rlworkgroup/metaworld.git@master#egg=metaworld"],
"mujoco": ["mujoco==2.2.0", "imageio>=2.14.1"], "mujoco": ["mujoco==2.2.0", "imageio>=2.14.1"],
} }
@ -16,12 +16,28 @@ extras["all"] = list(set(itertools.chain.from_iterable(map(lambda group: extras[
setup( setup(
author='Fabian Otto, Onur Celik, Marcel Sandermann, Maximilian Huettenrauch', author='Fabian Otto, Onur Celik, Marcel Sandermann, Maximilian Huettenrauch',
name='simple_gym', name='simple_gym',
version='0.0.1', version='0.1',
packages=['alr_envs', 'alr_envs.alr', 'alr_envs.open_ai', 'alr_envs.dmc', 'alr_envs.meta', 'alr_envs.utils'], classifiers=[
# Python 3.6 is minimally supported (only with basic gym environments and API)
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
],
extras_require=extras,
install_requires=[ install_requires=[
'gym', 'gym>=0.24.0',
"mujoco_py<2.2,>=2.1", "mujoco_py<2.2,>=2.1",
], ],
packages=[package for package in find_packages() if package.startswith("alr_envs")],
# packages=['alr_envs', 'alr_envs.alr', 'alr_envs.open_ai', 'alr_envs.dmc', 'alr_envs.meta', 'alr_envs.utils'],
package_data={
"alr_envs": [
"alr/mujoco/*/assets/*.xml",
]
},
python_requires=">=3.6",
url='https://github.com/ALRhub/alr_envs/', url='https://github.com/ALRhub/alr_envs/',
# license='AGPL-3.0 license', # license='AGPL-3.0 license',
author_email='', author_email='',

View File

@ -34,12 +34,7 @@ class TestMPEnvironments(unittest.TestCase):
obs = env.reset() obs = env.reset()
self._verify_observations(obs, env.observation_space, "reset()") self._verify_observations(obs, env.observation_space, "reset()")
length = env.spec.max_episode_steps iterations = iterations or (env.spec.max_episode_steps or 1)
if iterations is None:
if length is None:
iterations = 1
else:
iterations = length
# number of samples(multiple environment steps) # number of samples(multiple environment steps)
for i in range(iterations): for i in range(iterations):
@ -76,7 +71,7 @@ class TestMPEnvironments(unittest.TestCase):
traj2 = self._run_env(env_id, seed=seed) traj2 = self._run_env(env_id, seed=seed)
for i, time_step in enumerate(zip(*traj1, *traj2)): for i, time_step in enumerate(zip(*traj1, *traj2)):
obs1, rwd1, done1, obs2, rwd2, done2 = time_step obs1, rwd1, done1, obs2, rwd2, done2 = time_step
self.assertTrue(np.array_equal(obs1, obs2), f"Observations [{i}] {obs1} and {obs2} do not match.") self.assertTrue(np.allclose(obs1, obs2), f"Observations [{i}] {obs1} and {obs2} do not match.")
self.assertEqual(rwd1, rwd2, f"Rewards [{i}] {rwd1} and {rwd2} do not match.") self.assertEqual(rwd1, rwd2, f"Rewards [{i}] {rwd1} and {rwd2} do not match.")
self.assertEqual(done1, done2, f"Dones [{i}] {done1} and {done2} do not match.") self.assertEqual(done1, done2, f"Dones [{i}] {done1} and {done2} do not match.")

View File

@ -36,12 +36,7 @@ class TestStepDMCEnvironments(unittest.TestCase):
obs = env.reset() obs = env.reset()
self._verify_observations(obs, env.observation_space, "reset()") self._verify_observations(obs, env.observation_space, "reset()")
length = env.spec.max_episode_steps iterations = iterations or (env.spec.max_episode_steps or 1)
if iterations is None:
if length is None:
iterations = 1
else:
iterations = length
# number of samples(multiple environment steps) # number of samples(multiple environment steps)
for i in range(iterations): for i in range(iterations):

View File

@ -35,12 +35,7 @@ class TestStepMetaWorlEnvironments(unittest.TestCase):
obs = env.reset() obs = env.reset()
self._verify_observations(obs, env.observation_space, "reset()") self._verify_observations(obs, env.observation_space, "reset()")
length = env.max_path_length iterations = iterations or (env.spec.max_episode_steps or 1)
if iterations is None:
if length is None:
iterations = 1
else:
iterations = length
# number of samples(multiple environment steps) # number of samples(multiple environment steps)
for i in range(iterations): for i in range(iterations):