Mitigation: Allow seeding Metaworld on reset
This commit is contained in:
parent
6d80201a03
commit
15e1bdc218
@ -50,6 +50,18 @@ class MitigateMetaworldBug(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
class MetaworldResetFix(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||||
|
def __init__(self, env: gym.Env):
|
||||||
|
gym.utils.RecordConstructorArgs.__init__(self)
|
||||||
|
gym.Wrapper.__init__(self, env)
|
||||||
|
|
||||||
|
def reset(self, **kwargs):
|
||||||
|
ret = self.env.reset(**kwargs)
|
||||||
|
if 'seed' in kwargs:
|
||||||
|
self.env.seed(kwargs['seed'])
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def make_metaworld(underlying_id: str, seed: int = 1, render_mode: Optional[str] = None, **kwargs):
|
def make_metaworld(underlying_id: str, seed: int = 1, render_mode: Optional[str] = None, **kwargs):
|
||||||
if underlying_id not in metaworld.ML1.ENV_NAMES:
|
if underlying_id not in metaworld.ML1.ENV_NAMES:
|
||||||
raise ValueError(f'Specified environment "{underlying_id}" not present in metaworld ML1.')
|
raise ValueError(f'Specified environment "{underlying_id}" not present in metaworld ML1.')
|
||||||
@ -79,6 +91,7 @@ def make_metaworld(underlying_id: str, seed: int = 1, render_mode: Optional[str]
|
|||||||
env = MujocoMapSpacesWrapper(env)
|
env = MujocoMapSpacesWrapper(env)
|
||||||
# TODO remove, when this has been fixed upstream
|
# TODO remove, when this has been fixed upstream
|
||||||
env = MitigateMetaworldBug(env)
|
env = MitigateMetaworldBug(env)
|
||||||
|
env = MetaworldResetFix(env)
|
||||||
return env
|
return env
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user