Fix: SimpleReacher and ViaPointReacher did not seed correctly

This commit is contained in:
Dominik Moritz Roth 2023-08-28 18:38:33 +02:00
parent 820e781a0c
commit 155807207f
2 changed files with 11 additions and 2 deletions

View File

@ -45,9 +45,13 @@ class SimpleReacherEnv(BaseReacherTorqueEnv):
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None) \ def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None) \
-> Tuple[ObsType, Dict[str, Any]]: -> Tuple[ObsType, Dict[str, Any]]:
ret = super().reset(seed=seed, options=options) # Reset twice to ensure we return obs after generating goal and generating goal after executing seeded reset.
# (Env will not behave deterministic otherwise)
# Yes, there is probably a more elegant solution to this problem...
self._generate_goal() self._generate_goal()
return ret super().reset(seed=seed, options=options)
self._generate_goal()
return super().reset(seed=seed, options=options)
def _get_reward(self, action: np.ndarray): def _get_reward(self, action: np.ndarray):
diff = self.end_effector - self._goal diff = self.end_effector - self._goal

View File

@ -44,6 +44,11 @@ class ViaPointReacherEnv(BaseReacherDirectEnv):
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None) \ def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None) \
-> Tuple[ObsType, Dict[str, Any]]: -> Tuple[ObsType, Dict[str, Any]]:
# Reset twice to ensure we return obs after generating goal and generating goal after executing seeded reset.
# (Env will not behave deterministic otherwise)
# Yes, there is probably a more elegant solution to this problem...
self._generate_goal()
super().reset(seed=seed, options=options)
self._generate_goal() self._generate_goal()
return super().reset(seed=seed, options=options) return super().reset(seed=seed, options=options)