Porting Metaworld Bug Mitigations
This commit is contained in:
		
							parent
							
								
									20510d8f68
								
							
						
					
					
						commit
						9ce040d110
					
				| @ -1,3 +1,4 @@ | ||||
| import random | ||||
| from typing import Iterable, Type, Union, Optional | ||||
| 
 | ||||
| import numpy as np | ||||
| @ -13,12 +14,10 @@ from fancy_gym.utils.env_compatibility import EnvCompatibility | ||||
| try: | ||||
|     import metaworld | ||||
| except Exception: | ||||
|     # catch Exception as Import error does not catch missing mujoco-py | ||||
|     # TODO: Print info? | ||||
|     pass | ||||
|     print('[FANCY GYM] Metaworld not avaible') | ||||
| 
 | ||||
| 
 | ||||
| class MujocoMapSpacesWrapper(gym.Wrapper, gym.utils.RecordConstructorArgs): | ||||
| class FixMetaworldHasIncorrectObsSpaceWrapper(gym.Wrapper, gym.utils.RecordConstructorArgs): | ||||
|     def __init__(self, env: gym.Env): | ||||
|         gym.utils.RecordConstructorArgs.__init__(self) | ||||
|         gym.Wrapper.__init__(self, env) | ||||
| @ -29,11 +28,11 @@ class MujocoMapSpacesWrapper(gym.Wrapper, gym.utils.RecordConstructorArgs): | ||||
|         Obs_Space_Class = getattr(gym.spaces, str(eos.__class__).split("'")[1].split('.')[-1]) | ||||
|         Act_Space_Class = getattr(gym.spaces, str(eas.__class__).split("'")[1].split('.')[-1]) | ||||
| 
 | ||||
|         self.observation_space = Obs_Space_Class(low=eos.low, high=eos.high, dtype=eos.dtype) | ||||
|         self.observation_space = Obs_Space_Class(low=eos.low-np.inf, high=eos.high+np.inf, dtype=eos.dtype) | ||||
|         self.action_space = Act_Space_Class(low=eas.low, high=eas.high, dtype=eas.dtype) | ||||
| 
 | ||||
| 
 | ||||
| class MitigateMetaworldBug(gym.Wrapper, gym.utils.RecordConstructorArgs): | ||||
| class FixMetaworldIncorrectResetPathLengthWrapper(gym.Wrapper, gym.utils.RecordConstructorArgs): | ||||
|     def __init__(self, env: gym.Env): | ||||
|         gym.utils.RecordConstructorArgs.__init__(self) | ||||
|         gym.Wrapper.__init__(self, env) | ||||
| @ -50,13 +49,12 @@ class MitigateMetaworldBug(gym.Wrapper, gym.utils.RecordConstructorArgs): | ||||
|         return ret | ||||
| 
 | ||||
| 
 | ||||
| class MetaworldResetFix(gym.Wrapper, gym.utils.RecordConstructorArgs): | ||||
| class FixMetaworldIgnoresSeedOnResetWrapper(gym.Wrapper, gym.utils.RecordConstructorArgs): | ||||
|     def __init__(self, env: gym.Env): | ||||
|         gym.utils.RecordConstructorArgs.__init__(self) | ||||
|         gym.Wrapper.__init__(self, env) | ||||
| 
 | ||||
|     def reset(self, **kwargs): | ||||
|         self.env.reset(**kwargs) | ||||
|         if 'seed' in kwargs: | ||||
|             self.env.seed(kwargs['seed']) | ||||
|         return self.env.reset(**kwargs) | ||||
| @ -66,32 +64,19 @@ def make_metaworld(underlying_id: str, seed: int = 1, render_mode: Optional[str] | ||||
|     if underlying_id not in metaworld.ML1.ENV_NAMES: | ||||
|         raise ValueError(f'Specified environment "{underlying_id}" not present in metaworld ML1.') | ||||
| 
 | ||||
|     _env = metaworld.envs.ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[underlying_id + "-goal-observable"](seed=seed, **kwargs) | ||||
|     env = metaworld.envs.ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[underlying_id + "-goal-observable"](seed=seed, **kwargs) | ||||
| 
 | ||||
|     # setting this avoids generating the same initialization after each reset | ||||
|     _env._freeze_rand_vec = False | ||||
|     env._freeze_rand_vec = False | ||||
|     # New argument to use global seeding | ||||
|     _env.seeded_rand_vec = True | ||||
|     env.seeded_rand_vec = True | ||||
| 
 | ||||
|     max_episode_steps = _env.max_path_length | ||||
| 
 | ||||
|     # TODO remove this as soon as there is support for the new API | ||||
|     _env = EnvCompatibility(_env, render_mode) | ||||
|     env = _env | ||||
| 
 | ||||
|     # gym_id = '_metaworld_compat_' + uuid.uuid4().hex + '-v0' | ||||
|     # gym_register( | ||||
|     #    id=gym_id, | ||||
|     #    entry_point=lambda: _env, | ||||
|     #    max_episode_steps=max_episode_steps, | ||||
|     # ) | ||||
| 
 | ||||
|     # TODO enable checker when the incorrect dtype of obs and observation space are fixed by metaworld | ||||
|     # env = gym.make(gym_id, disable_env_checker=True) | ||||
|     env = MujocoMapSpacesWrapper(env) | ||||
|     # TODO remove, when this has been fixed upstream | ||||
|     env = MitigateMetaworldBug(env) | ||||
|     env = MetaworldResetFix(env) | ||||
|     env = FixMetaworldHasIncorrectObsSpaceWrapper(env) | ||||
|     # TODO remove, when this has been fixed upstream | ||||
|     # env = FixMetaworldIncorrectResetPathLengthWrapper(env) | ||||
|     # TODO remove, when this has been fixed upstream | ||||
|     env = FixMetaworldIgnoresSeedOnResetWrapper(env) | ||||
|     return env | ||||
| 
 | ||||
| 
 | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user