Fixed TableTennisWind and TableTennisGoalSwitching not correctly passing args to parent class

This commit is contained in:
Dominik Moritz Roth 2023-10-29 12:53:06 +01:00
parent d7ea1f80a0
commit dc773c6c10

View File

@ -276,11 +276,11 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
class TableTennisWind(TableTennisEnv): class TableTennisWind(TableTennisEnv):
def __init__(self, ctxt_dim: int = 4, frame_skip: int = 4): def __init__(self, ctxt_dim: int = 4, frame_skip: int = 4, **kwargs):
self.observation_space = spaces.Box( self.observation_space = spaces.Box(
low=-np.inf, high=np.inf, shape=(22,), dtype=np.float64 low=-np.inf, high=np.inf, shape=(22,), dtype=np.float64
) )
super().__init__(ctxt_dim=ctxt_dim, frame_skip=frame_skip, enable_artificial_wind=True) super().__init__(ctxt_dim=ctxt_dim, frame_skip=frame_skip, enable_artificial_wind=True, **kwargs)
def _get_obs(self): def _get_obs(self):
obs = np.concatenate([ obs = np.concatenate([
@ -298,5 +298,5 @@ class TableTennisWind(TableTennisEnv):
class TableTennisGoalSwitching(TableTennisEnv): class TableTennisGoalSwitching(TableTennisEnv):
def __init__(self, frame_skip: int = 4, goal_switching_step: int = 99): def __init__(self, frame_skip: int = 4, goal_switching_step: int = 99, **kwargs):
super().__init__(frame_skip=frame_skip, goal_switching_step=goal_switching_step) super().__init__(frame_skip=frame_skip, goal_switching_step=goal_switching_step, **kwargs)