BugFix: DId not pass kwargs down in new TT envs
This commit is contained in:
parent
db8221ebb2
commit
3bc0a23ec2
@ -8,7 +8,7 @@ from .hopper_throw.hopper_throw_in_basket import HopperThrowInBasketEnv
|
|||||||
from .reacher.reacher import ReacherEnv
|
from .reacher.reacher import ReacherEnv
|
||||||
from .walker_2d_jump.walker_2d_jump import Walker2dJumpEnv
|
from .walker_2d_jump.walker_2d_jump import Walker2dJumpEnv
|
||||||
from .box_pushing.box_pushing_env import BoxPushingDense, BoxPushingTemporalSparse, BoxPushingTemporalSpatialSparse
|
from .box_pushing.box_pushing_env import BoxPushingDense, BoxPushingTemporalSparse, BoxPushingTemporalSpatialSparse
|
||||||
from .table_tennis.table_tennis_env import TableTennisEnv, TableTennisWind, TableTennisGoalSwitching
|
from .table_tennis.table_tennis_env import TableTennisEnv, TableTennisWind, TableTennisGoalSwitching, TableTennisMarkovian, TableTennisRandomInit
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from .air_hockey.air_hockey_env_wrapper import AirHockeyEnv
|
from .air_hockey.air_hockey_env_wrapper import AirHockeyEnv
|
||||||
|
@ -41,6 +41,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
enable_artificial_wind: bool = False,
|
enable_artificial_wind: bool = False,
|
||||||
random_pos_scale: float = 0.0,
|
random_pos_scale: float = 0.0,
|
||||||
random_vel_scale: float = 0.0,
|
random_vel_scale: float = 0.0,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
utils.EzPickle.__init__(**locals())
|
utils.EzPickle.__init__(**locals())
|
||||||
self._steps = 0
|
self._steps = 0
|
||||||
@ -490,8 +491,9 @@ class TableTennisGoalSwitching(TableTennisEnv):
|
|||||||
class TableTennisRandomInit(TableTennisEnv):
|
class TableTennisRandomInit(TableTennisEnv):
|
||||||
def __init__(self, ctxt_dim: int = 4, frame_skip: int = 4,
|
def __init__(self, ctxt_dim: int = 4, frame_skip: int = 4,
|
||||||
random_pos_scale: float = 1.0,
|
random_pos_scale: float = 1.0,
|
||||||
random_vel_scale: float = 0.0):
|
random_vel_scale: float = 0.0,
|
||||||
|
**kwargs):
|
||||||
super().__init__(ctxt_dim=ctxt_dim, frame_skip=frame_skip,
|
super().__init__(ctxt_dim=ctxt_dim, frame_skip=frame_skip,
|
||||||
random_pos_scale=random_pos_scale,
|
random_pos_scale=random_pos_scale,
|
||||||
random_vel_scale=random_vel_scale)
|
random_vel_scale=random_vel_scale,
|
||||||
|
**kwargs)
|
Loading…
Reference in New Issue
Block a user