diff --git a/fancy_gym/meta/metaworld_adapter.py b/fancy_gym/meta/metaworld_adapter.py index 20519d5..8685cad 100644 --- a/fancy_gym/meta/metaworld_adapter.py +++ b/fancy_gym/meta/metaworld_adapter.py @@ -51,12 +51,29 @@ class FixMetaworldIgnoresSeedOnResetWrapper(gym.Wrapper, gym.utils.RecordConstru gym.Wrapper.__init__(self, env) def reset(self, **kwargs): - print('[!] You just called .reset on a Metaworld env and supplied a seed. Metaworld curretly does not correctly implement seeding. Do not rely on deterministic behavior.') if 'seed' in kwargs: + print('[!] You just called .reset on a Metaworld env and supplied a seed. Metaworld curretly does not correctly implement seeding. Do not rely on deterministic behavior.') self.env.seed(kwargs['seed']) return self.env.reset(**kwargs) +class FixMetaworldRenderOnStep(gym.Wrapper, gym.utils.RecordConstructorArgs): + def __init__(self, env: gym.Env): + gym.utils.RecordConstructorArgs.__init__(self) + gym.Wrapper.__init__(self, env) + self.render_active = False + + def render(self, *args, **kwargs): + self.render_active = True + return self.env.render(*args, **kwargs) + + def step(self, *args, **kwargs): + ret = self.env.step(*args, **kwargs) + if self.render_active: + self.env.render() + return ret + + def make_metaworld(underlying_id: str, seed: int = 1, render_mode: Optional[str] = None, **kwargs): if underlying_id not in metaworld.ML1.ENV_NAMES: raise ValueError(f'Specified environment "{underlying_id}" not present in metaworld ML1.') @@ -68,11 +85,9 @@ def make_metaworld(underlying_id: str, seed: int = 1, render_mode: Optional[str] # New argument to use global seeding env.seeded_rand_vec = True - # TODO remove, when this has been fixed upstream env = FixMetaworldHasIncorrectObsSpaceWrapper(env) - # TODO remove, when this has been fixed upstream # env = FixMetaworldIncorrectResetPathLengthWrapper(env) - # TODO remove, when this has been fixed upstream + env = FixMetaworldRenderOnStep(env) env = FixMetaworldIgnoresSeedOnResetWrapper(env) return env